Constructor.
781 def __init__(self, ch, which_conv=SNConv2d):
782 super(Attention, self).__init__()
783
784 self.ch = ch
785
786 self.which_conv = which_conv
787
788 self.theta = self.which_conv(
789 self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False
790 )
791
792 self.phi = self.which_conv(
793 self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False
794 )
795
796 self.g = self.which_conv(
797 self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False
798 )
799
800 self.o = self.which_conv(
801 self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False
802 )
803
804 self.gamma = P(torch.tensor(0.0), requires_grad=True)
805