Constructor.
473 ):
474 super(bn, self).__init__()
475
476 self.output_size = output_size
477
478 self.gain = P(torch.ones(output_size), requires_grad=True)
479
480 self.bias = P(torch.zeros(output_size), requires_grad=True)
481
482 self.eps = eps
483
484 self.momentum = momentum
485
486 self.cross_replica = cross_replica
487
488 self.mybn = mybn
489
490 if mybn:
491 self.bn = myBN(output_size, self.eps, self.momentum)
492
493 else:
494 self.register_buffer("stored_mean", torch.zeros(output_size))
495 self.register_buffer("stored_var", torch.ones(output_size))
496