Constructor.
753 def __init__(self, ch, which_conv=SNConv2d):
754 super(Attention, self).__init__()
755
756 self.ch = ch
757
758 self.which_conv = which_conv
759
760 self.theta = self.which_conv(
761 self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False
762 )
763
764 self.phi = self.which_conv(
765 self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False
766 )
767
768 self.g = self.which_conv(
769 self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False
770 )
771
772 self.o = self.which_conv(
773 self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False
774 )
775
776 self.gamma = P(torch.tensor(0.0), requires_grad=True)
777