Belle II Software development
Model Class Reference

Class for the ResNet generator model. More...

Inheritance diagram for Model:

Public Member Functions

 __init__ (self)
 Constructor to create a new model instance.
 
 forward (self, z)
 Function to perform a forward pass.
 

Public Attributes

 fc = nn.Linear(96, 49152)
 Fully-connected layer.
 
 blocks
 Sequence of residual block layers.
 
 norm = nn.BatchNorm2d(16)
 Batch normalization layer.
 
 conv = nn.Conv2d(16, 1, 3, 1, 1)
 Convolutional layer.
 

Detailed Description

Class for the ResNet generator model.

ResNet generator model.

Definition at line 86 of file resnet.py.

Constructor & Destructor Documentation

◆ __init__()

__init__ ( self)

Constructor to create a new model instance.

Definition at line 91 of file resnet.py.

91 def __init__(self):
92 super().__init__()
93 # fully-connected inputs
94 self.fc = nn.Linear(96, 49152)
95 # stack of residual blocks
96 self.blocks = nn.ModuleList(
97 [
98 # (256, 8, 24)
99 ResidualBlock(256, 256),
100 # (256, 16, 48)
101 ResidualBlock(256, 128),
102 # (128, 32, 96)
103 ResidualBlock(128, 64),
104 # (64, 64, 192)
105 ResidualBlock(64, 32),
106 # (32, 128, 384)
107 ResidualBlock(32, 16),
108 # (16, 256, 768)
109 ]
110 )
111 # norm and conv outputs
112 self.norm = nn.BatchNorm2d(16)
113 self.conv = nn.Conv2d(16, 1, 3, 1, 1)
114

Member Function Documentation

◆ forward()

forward ( self,
z )

Function to perform a forward pass.

Compute the model output for a given input.

Definition at line 117 of file resnet.py.

117 def forward(self, z):
118 """Compute the model output for a given input."""
119 z = self.fc(z)
120 z = z.view(-1, 256, 8, 24)
121 for block in self.blocks:
122 z = block(z)
123 z = self.norm(z)
124 z.relu_()
125 z = self.conv(z)
126 return z.tanh_()
127

Member Data Documentation

◆ blocks

blocks
Initial value:
= nn.ModuleList(
[
# (256, 8, 24)
ResidualBlock(256, 256),
# (256, 16, 48)
ResidualBlock(256, 128),
# (128, 32, 96)
ResidualBlock(128, 64),
# (64, 64, 192)
ResidualBlock(64, 32),
# (32, 128, 384)
ResidualBlock(32, 16),
# (16, 256, 768)
]
)

Sequence of residual block layers.

Definition at line 96 of file resnet.py.

◆ conv

conv = nn.Conv2d(16, 1, 3, 1, 1)

Convolutional layer.

Definition at line 113 of file resnet.py.

◆ fc

fc = nn.Linear(96, 49152)

Fully-connected layer.

Definition at line 94 of file resnet.py.

◆ norm

norm = nn.BatchNorm2d(16)

Batch normalization layer.

Definition at line 112 of file resnet.py.


The documentation for this class was generated from the following file: