Belle II Software development
ResidualBlock Class Reference

Class for the residual block layer. More...

Inheritance diagram for ResidualBlock:

Public Member Functions

def __init__ (self, ninput, noutput, upsample=True)
 Constructor to create a new residual block layer.
 
def forward (self, x)
 Function to perform a forward pass.
 

Public Attributes

 upsample
 Whether to double the height and width of input.
 
 conv
 Convolutional layer in the shortcut branch.
 
 norm1
 First batch normalization layer in the residual branch.
 
 conv1
 First convolutional layer in the residual branch.
 
 norm2
 Second batch normalization layer in the residual branch.
 
 conv2
 Second convolutional layer in the residual branch.
 

Detailed Description

Class for the residual block layer.

Residual block layer.

Definition at line 19 of file resnet.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  ninput,
  noutput,
  upsample = True 
)

Constructor to create a new residual block layer.

Definition at line 24 of file resnet.py.

24 def __init__(self, ninput, noutput, upsample=True):
25 super().__init__()
26 self.upsample = upsample
27 # shortcut branch
28 self.conv = None
29 if upsample or (ninput != noutput):
30 self.conv = nn.Conv2d(ninput, noutput, 1, 1, 0)
31 # residual branch
32 self.norm1 = nn.BatchNorm2d(ninput)
33 self.conv1 = nn.Conv2d(ninput, noutput, 3, 1, 1)
34 self.norm2 = nn.BatchNorm2d(noutput)
35 self.conv2 = nn.Conv2d(noutput, noutput, 3, 1, 1)
36

Member Function Documentation

◆ forward()

def forward (   self,
  x 
)

Function to perform a forward pass.

Compute the layer output for a given input.

Definition at line 39 of file resnet.py.

39 def forward(self, x):
40 """Compute the layer output for a given input."""
41 # residual branch
42 h = x
43 h = self.norm1(h)
44 h.relu_()
45 if self.upsample:
46 h = F.interpolate(h, mode="nearest", scale_factor=2)
47 h = self.conv1(h)
48 h = self.norm2(h)
49 h.relu_()
50 h = self.conv2(h)
51 # shortcut branch
52 if self.upsample:
53 x = F.interpolate(x, mode="nearest", scale_factor=2)
54 if self.conv:
55 x = self.conv(x)
56 # return sum of both
57 return h + x
58

Member Data Documentation

◆ conv

conv

Convolutional layer in the shortcut branch.

Definition at line 28 of file resnet.py.

◆ conv1

conv1

First convolutional layer in the residual branch.

Definition at line 33 of file resnet.py.

◆ conv2

conv2

Second convolutional layer in the residual branch.

Definition at line 35 of file resnet.py.

◆ norm1

norm1

First batch normalization layer in the residual branch.

Definition at line 32 of file resnet.py.

◆ norm2

norm2

Second batch normalization layer in the residual branch.

Definition at line 34 of file resnet.py.

◆ upsample

upsample

Whether to double the height and width of input.

Definition at line 26 of file resnet.py.


The documentation for this class was generated from the following file: