Belle II Software development
CBAM_attention Class Reference
Inheritance diagram for CBAM_attention:

Public Member Functions

def __init__ (self, channels, which_conv=SNConv2d, reduction=8, attention_kernel_size=3)
 Constructor.
 
def forward (self, x)
 forward
 

Public Attributes

 avg_pool
 average pooling
 
 max_pool
 max pooling
 
 fc1
 fcl
 
 relu
 relu
 
 fc2
 f2c
 
 sigmoid_channel
 sigmoid channel
 
 conv_after_concat
 convolution after concatenation
 
 sigmoid_spatial
 sigmoid_spatial
 

Detailed Description

CBAM attention

Definition at line 683 of file ieagan.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  channels,
  which_conv = SNConv2d,
  reduction = 8,
  attention_kernel_size = 3 
)

Constructor.

Definition at line 687 of file ieagan.py.

693 ):
694 super(CBAM_attention, self).__init__()
695
696 self.avg_pool = nn.AdaptiveAvgPool2d(1)
697
698 self.max_pool = nn.AdaptiveMaxPool2d(1)
699
700 self.fc1 = which_conv(
701 channels, channels // reduction, kernel_size=1, padding=0
702 )
703
704 self.relu = nn.ReLU(inplace=True)
705
706 self.fc2 = which_conv(
707 channels // reduction, channels, kernel_size=1, padding=0
708 )
709
710 self.sigmoid_channel = nn.Sigmoid()
711
712 self.conv_after_concat = which_conv(
713 2,
714 1,
715 kernel_size=attention_kernel_size,
716 stride=1,
717 padding=attention_kernel_size // 2,
718 )
719
720 self.sigmoid_spatial = nn.Sigmoid()
721

Member Function Documentation

◆ forward()

def forward (   self,
  x 
)

forward

Definition at line 723 of file ieagan.py.

723 def forward(self, x):
724 # Channel attention module
725 module_input = x
726 avg = self.avg_pool(x)
727 mx = self.max_pool(x)
728 avg = self.fc1(avg)
729 mx = self.fc1(mx)
730 avg = self.relu(avg)
731 mx = self.relu(mx)
732 avg = self.fc2(avg)
733 mx = self.fc2(mx)
734 x = avg + mx
735 x = self.sigmoid_channel(x)
736 # Spatial attention module
737 x = module_input * x
738 module_input = x
739 # b, c, h, w = x.size()
740 avg = torch.mean(x, 1, True)
741 mx, _ = torch.max(x, 1, True)
742 x = torch.cat((avg, mx), 1)
743 x = self.conv_after_concat(x)
744 x = self.sigmoid_spatial(x)
745 x = module_input * x
746 return x
747
748

Member Data Documentation

◆ avg_pool

avg_pool

average pooling

Definition at line 696 of file ieagan.py.

◆ conv_after_concat

conv_after_concat

convolution after concatenation

Definition at line 712 of file ieagan.py.

◆ fc1

fc1

fcl

Definition at line 700 of file ieagan.py.

◆ fc2

fc2

f2c

Definition at line 706 of file ieagan.py.

◆ max_pool

max_pool

max pooling

Definition at line 698 of file ieagan.py.

◆ relu

relu

relu

Definition at line 704 of file ieagan.py.

◆ sigmoid_channel

sigmoid_channel

sigmoid channel

Definition at line 710 of file ieagan.py.

◆ sigmoid_spatial

sigmoid_spatial

sigmoid_spatial

Definition at line 720 of file ieagan.py.


The documentation for this class was generated from the following file: