Belle II Software development
Model Class Reference

Class for the ResNet generator model. More...

Inheritance diagram for Model:

Public Member Functions

def __init__ (self)
 Constructor to create a new model instance.
 
def forward (self, z)
 Function to perform a forward pass.
 

Public Attributes

 fc
 Fully-connected layer.
 
 blocks
 Sequence of residual block layers.
 
 norm
 Batch normalization layer.
 
 conv
 Convolutional layer.
 

Detailed Description

Class for the ResNet generator model.

ResNet generator model.

Definition at line 86 of file resnet.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self)

Constructor to create a new model instance.

Definition at line 91 of file resnet.py.

91 def __init__(self):
92 super().__init__()
93 # fully-connected inputs
94 self.fc = nn.Linear(96, 49152)
95 # stack of residual blocks
96 self.blocks = nn.ModuleList(
97 [
98 # (256, 8, 24)
99 ResidualBlock(256, 256),
100 # (256, 16, 48)
101 ResidualBlock(256, 128),
102 # (128, 32, 96)
103 ResidualBlock(128, 64),
104 # (64, 64, 192)
105 ResidualBlock(64, 32),
106 # (32, 128, 384)
107 ResidualBlock(32, 16),
108 # (16, 256, 768)
109 ]
110 )
111 # norm and conv outputs
112 self.norm = nn.BatchNorm2d(16)
113 self.conv = nn.Conv2d(16, 1, 3, 1, 1)
114

Member Function Documentation

◆ forward()

def forward (   self,
  z 
)

Function to perform a forward pass.

Compute the model output for a given input.

Definition at line 117 of file resnet.py.

117 def forward(self, z):
118 """Compute the model output for a given input."""
119 z = self.fc(z)
120 z = z.view(-1, 256, 8, 24)
121 for block in self.blocks:
122 z = block(z)
123 z = self.norm(z)
124 z.relu_()
125 z = self.conv(z)
126 return z.tanh_()
127

Member Data Documentation

◆ blocks

blocks

Sequence of residual block layers.

Definition at line 96 of file resnet.py.

◆ conv

conv

Convolutional layer.

Definition at line 113 of file resnet.py.

◆ fc

fc

Fully-connected layer.

Definition at line 94 of file resnet.py.

◆ norm

norm

Batch normalization layer.

Definition at line 112 of file resnet.py.


The documentation for this class was generated from the following file: