Belle II Software development
Attention Class Reference
Inheritance diagram for Attention:

Public Member Functions

 __init__ (self, ch, which_conv=SNConv2d)
 Constructor.
 
 forward (self, x)
 forward
 

Public Attributes

 ch = ch
 Channel multiplier.
 
 which_conv = which_conv
 which_conv
 
 theta
 theta
 
 phi
 phi
 
 g
 g
 
 o
 o
 
 gamma = P(torch.tensor(0.0), requires_grad=True)
 Learnable gain parameter.
 

Detailed Description

Attention

Definition at line 777 of file ieagan.py.

Constructor & Destructor Documentation

◆ __init__()

__init__ ( self,
ch,
which_conv = SNConv2d )

Constructor.

Definition at line 781 of file ieagan.py.

781 def __init__(self, ch, which_conv=SNConv2d):
782 super(Attention, self).__init__()
783
784 self.ch = ch
785
786 self.which_conv = which_conv
787
788 self.theta = self.which_conv(
789 self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False
790 )
791
792 self.phi = self.which_conv(
793 self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False
794 )
795
796 self.g = self.which_conv(
797 self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False
798 )
799
800 self.o = self.which_conv(
801 self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False
802 )
803
804 self.gamma = P(torch.tensor(0.0), requires_grad=True)
805

Member Function Documentation

◆ forward()

forward ( self,
x )

forward

Definition at line 807 of file ieagan.py.

807 def forward(self, x):
808 # Apply convs
809 theta = self.theta(x)
810 phi = F.max_pool2d(self.phi(x), [2, 2])
811 g = F.max_pool2d(self.g(x), [2, 2])
812 # Perform reshapes
813 theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3])
814 phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4)
815 g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4)
816 # Matmul and softmax to get attention maps
817 beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
818 # Attention map times g path
819 o = self.o(
820 torch.bmm(g, beta.transpose(1, 2)).view(
821 -1, self.ch // 2, x.shape[2], x.shape[3]
822 )
823 )
824 return self.gamma * o + x
825
826

Member Data Documentation

◆ ch

ch = ch

Channel multiplier.

Definition at line 784 of file ieagan.py.

◆ g

g
Initial value:
= self.which_conv(
self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False
)

g

Definition at line 796 of file ieagan.py.

◆ gamma

gamma = P(torch.tensor(0.0), requires_grad=True)

Learnable gain parameter.

Definition at line 804 of file ieagan.py.

◆ o

o
Initial value:
= self.which_conv(
self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False
)

o

Definition at line 800 of file ieagan.py.

◆ phi

phi
Initial value:
= self.which_conv(
self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False
)

phi

Definition at line 792 of file ieagan.py.

◆ theta

theta
Initial value:
= self.which_conv(
self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False
)

theta

Definition at line 788 of file ieagan.py.

◆ which_conv

which_conv = which_conv

which_conv

Definition at line 786 of file ieagan.py.


The documentation for this class was generated from the following file: