Constructor.
492 ):
493 super(bn, self).__init__()
494
495 self.output_size = output_size
496
497 self.gain = P(torch.ones(output_size), requires_grad=True)
498
499 self.bias = P(torch.zeros(output_size), requires_grad=True)
500
501 self.eps = eps
502
503 self.momentum = momentum
504
505 self.cross_replica = cross_replica
506
507 self.mybn = mybn
508
509 if mybn:
510 self.bn = myBN(output_size, self.eps, self.momentum)
511
512 else:
513
514 self.stored_mean = torch.zeros(output_size)
515 self.register_buffer("stored_mean", torch.zeros(output_size))
516
517 self.stored_var = torch.ones(output_size)
518 self.register_buffer("stored_var", torch.ones(output_size))
519
520
521 self.training: bool
522