Constructor.
536 ):
537 super(ccbn, self).__init__()
538
539 self.output_size, self.input_size = output_size, input_size
540
541 self.gain = which_linear(input_size, output_size)
542
543 self.bias = which_linear(input_size, output_size)
544
545 self.eps = eps
546
547 self.momentum = momentum
548
549 self.cross_replica = cross_replica
550
551 self.mybn = mybn
552
553 self.norm_style = norm_style
554
555 if self.mybn:
556
557 self.bn = myBN(output_size, self.eps, self.momentum)
558 elif self.norm_style in ["bn", "in"]:
559 self.register_buffer("stored_mean", torch.zeros(output_size))
560 self.register_buffer("stored_var", torch.ones(output_size))
561