Belle II Software development
Attention Class Reference
Inheritance diagram for Attention:

Public Member Functions

def __init__ (self, ch, which_conv=SNConv2d)
 Constructor.
 
def forward (self, x)
 forward
 

Public Attributes

 ch
 Channel multiplier.
 
 which_conv
 which_conv
 
 theta
 theta
 
 phi
 phi
 
 g
 g
 
 o
 o
 
 gamma
 Learnable gain parameter.
 

Detailed Description

Attention

Definition at line 749 of file ieagan.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  ch,
  which_conv = SNConv2d 
)

Constructor.

Definition at line 753 of file ieagan.py.

753 def __init__(self, ch, which_conv=SNConv2d):
754 super(Attention, self).__init__()
755
756 self.ch = ch
757
758 self.which_conv = which_conv
759
760 self.theta = self.which_conv(
761 self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False
762 )
763
764 self.phi = self.which_conv(
765 self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False
766 )
767
768 self.g = self.which_conv(
769 self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False
770 )
771
772 self.o = self.which_conv(
773 self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False
774 )
775
776 self.gamma = P(torch.tensor(0.0), requires_grad=True)
777

Member Function Documentation

◆ forward()

def forward (   self,
  x 
)

forward

Definition at line 779 of file ieagan.py.

779 def forward(self, x):
780 # Apply convs
781 theta = self.theta(x)
782 phi = F.max_pool2d(self.phi(x), [2, 2])
783 g = F.max_pool2d(self.g(x), [2, 2])
784 # Perform reshapes
785 theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3])
786 phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4)
787 g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4)
788 # Matmul and softmax to get attention maps
789 beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
790 # Attention map times g path
791 o = self.o(
792 torch.bmm(g, beta.transpose(1, 2)).view(
793 -1, self.ch // 2, x.shape[2], x.shape[3]
794 )
795 )
796 return self.gamma * o + x
797
798

Member Data Documentation

◆ ch

ch

Channel multiplier.

Definition at line 756 of file ieagan.py.

◆ g

g

g

Definition at line 768 of file ieagan.py.

◆ gamma

gamma

Learnable gain parameter.

Definition at line 776 of file ieagan.py.

◆ o

o

o

Definition at line 772 of file ieagan.py.

◆ phi

phi

phi

Definition at line 764 of file ieagan.py.

◆ theta

theta

theta

Definition at line 760 of file ieagan.py.

◆ which_conv

which_conv

which_conv

Definition at line 758 of file ieagan.py.


The documentation for this class was generated from the following file: