Belle II Software development
CBAM_attention Class Reference
Inheritance diagram for CBAM_attention:

Public Member Functions

 __init__ (self, channels, which_conv=SNConv2d, reduction=8, attention_kernel_size=3)
 Constructor.
 
 forward (self, x)
 forward
 

Public Attributes

 avg_pool = nn.AdaptiveAvgPool2d(1)
 average pooling
 
 max_pool = nn.AdaptiveMaxPool2d(1)
 max pooling
 
 fc1
 fcl
 
 relu = nn.ReLU(inplace=True)
 relu
 
 fc2
 f2c
 
 sigmoid_channel = nn.Sigmoid()
 sigmoid channel
 
 conv_after_concat
 convolution after concatenation
 
 sigmoid_spatial = nn.Sigmoid()
 sigmoid_spatial
 

Detailed Description

CBAM attention

Definition at line 711 of file ieagan.py.

Constructor & Destructor Documentation

◆ __init__()

__init__ ( self,
channels,
which_conv = SNConv2d,
reduction = 8,
attention_kernel_size = 3 )

Constructor.

Definition at line 715 of file ieagan.py.

721 ):
722 super(CBAM_attention, self).__init__()
723
724 self.avg_pool = nn.AdaptiveAvgPool2d(1)
725
726 self.max_pool = nn.AdaptiveMaxPool2d(1)
727
728 self.fc1 = which_conv(
729 channels, channels // reduction, kernel_size=1, padding=0
730 )
731
732 self.relu = nn.ReLU(inplace=True)
733
734 self.fc2 = which_conv(
735 channels // reduction, channels, kernel_size=1, padding=0
736 )
737
738 self.sigmoid_channel = nn.Sigmoid()
739
740 self.conv_after_concat = which_conv(
741 2,
742 1,
743 kernel_size=attention_kernel_size,
744 stride=1,
745 padding=attention_kernel_size // 2,
746 )
747
748 self.sigmoid_spatial = nn.Sigmoid()
749

Member Function Documentation

◆ forward()

forward ( self,
x )

forward

Definition at line 751 of file ieagan.py.

751 def forward(self, x):
752 # Channel attention module
753 module_input = x
754 avg = self.avg_pool(x)
755 mx = self.max_pool(x)
756 avg = self.fc1(avg)
757 mx = self.fc1(mx)
758 avg = self.relu(avg)
759 mx = self.relu(mx)
760 avg = self.fc2(avg)
761 mx = self.fc2(mx)
762 x = avg + mx
763 x = self.sigmoid_channel(x)
764 # Spatial attention module
765 x = module_input * x
766 module_input = x
767 # b, c, h, w = x.size()
768 avg = torch.mean(x, 1, True)
769 mx, _ = torch.max(x, 1, True)
770 x = torch.cat((avg, mx), 1)
771 x = self.conv_after_concat(x)
772 x = self.sigmoid_spatial(x)
773 x = module_input * x
774 return x
775
776

Member Data Documentation

◆ avg_pool

avg_pool = nn.AdaptiveAvgPool2d(1)

average pooling

Definition at line 724 of file ieagan.py.

◆ conv_after_concat

conv_after_concat
Initial value:
= which_conv(
2,
1,
kernel_size=attention_kernel_size,
stride=1,
padding=attention_kernel_size // 2,
)

convolution after concatenation

Definition at line 740 of file ieagan.py.

◆ fc1

fc1
Initial value:
= which_conv(
channels, channels // reduction, kernel_size=1, padding=0
)

fcl

Definition at line 728 of file ieagan.py.

◆ fc2

fc2
Initial value:
= which_conv(
channels // reduction, channels, kernel_size=1, padding=0
)

f2c

Definition at line 734 of file ieagan.py.

◆ max_pool

max_pool = nn.AdaptiveMaxPool2d(1)

max pooling

Definition at line 726 of file ieagan.py.

◆ relu

relu = nn.ReLU(inplace=True)

relu

Definition at line 732 of file ieagan.py.

◆ sigmoid_channel

sigmoid_channel = nn.Sigmoid()

sigmoid channel

Definition at line 738 of file ieagan.py.

◆ sigmoid_spatial

sigmoid_spatial = nn.Sigmoid()

sigmoid_spatial

Definition at line 748 of file ieagan.py.


The documentation for this class was generated from the following file: