Belle II Software development
resnet.py
1
8"""
9This module implements the ResNet generator model.
10"""
11
12import torch
13import torch.nn as nn
14import torch.nn.functional as F
15
16
17
19class ResidualBlock(nn.Module):
20 """Residual block layer."""
21
22
24 def __init__(self, ninput, noutput, upsample=True):
25 super().__init__()
26 self.upsample = upsample
27 # shortcut branch
28 self.conv = None
29 if upsample or (ninput != noutput):
30 self.conv = nn.Conv2d(ninput, noutput, 1, 1, 0)
31 # residual branch
32 self.norm1 = nn.BatchNorm2d(ninput)
33 self.conv1 = nn.Conv2d(ninput, noutput, 3, 1, 1)
34 self.norm2 = nn.BatchNorm2d(noutput)
35 self.conv2 = nn.Conv2d(noutput, noutput, 3, 1, 1)
36
37
39 def forward(self, x):
40 """Compute the layer output for a given input."""
41 # residual branch
42 h = x
43 h = self.norm1(h)
44 h.relu_()
45 if self.upsample:
46 h = F.interpolate(h, mode="nearest", scale_factor=2)
47 h = self.conv1(h)
48 h = self.norm2(h)
49 h.relu_()
50 h = self.conv2(h)
51 # shortcut branch
52 if self.upsample:
53 x = F.interpolate(x, mode="nearest", scale_factor=2)
54 if self.conv:
55 x = self.conv(x)
56 # return sum of both
57 return h + x
58
59
62
63
66
67
70
71
74
75
78
79
82
83
84
86class Model(nn.Module):
87 """ResNet generator model."""
88
89
91 def __init__(self):
92 super().__init__()
93 # fully-connected inputs
94 self.fc = nn.Linear(96, 49152)
95 # stack of residual blocks
96 self.blocks = nn.ModuleList(
97 [
98 # (256, 8, 24)
99 ResidualBlock(256, 256),
100 # (256, 16, 48)
101 ResidualBlock(256, 128),
102 # (128, 32, 96)
103 ResidualBlock(128, 64),
104 # (64, 64, 192)
105 ResidualBlock(64, 32),
106 # (32, 128, 384)
107 ResidualBlock(32, 16),
108 # (16, 256, 768)
109 ]
110 )
111 # norm and conv outputs
112 self.norm = nn.BatchNorm2d(16)
113 self.conv = nn.Conv2d(16, 1, 3, 1, 1)
114
115
117 def forward(self, z):
118 """Compute the model output for a given input."""
119 z = self.fc(z)
120 z = z.view(-1, 256, 8, 24)
121 for block in self.blocks:
122 z = block(z)
123 z = self.norm(z)
124 z.relu_()
125 z = self.conv(z)
126 return z.tanh_()
127
128
131
132
135
136
139
140
143
144
145
148def generate(model):
149 """Produce one pseudo-random image for each PXD module
150 using the ResNet generator model.
151 """
152 # infer the device that is in use
153 device = next(model.parameters()).device
154 # without computing gradients
155 with torch.no_grad():
156 # initialize the model input(s)
157 z = torch.randn(40, 96, device=device)
158 # evaluate the model output and crop
159 x = model(z)[:, 0, 3:-3, :]
160 # delete the reference(s) to model input(s)
161 del z
162 # always transfer to the CPU
163 x = x.cpu()
164 # apply the inverse logarithmic transformation
165 x = x.mul_(0.5).add_(0.5).clamp_(0.0, 1.0)
166 x = torch.pow(256.0, x).sub_(1.0).clamp_(0.0, 255.0)
167 # convert to unsigned 8-bit integer data type
168 return x.to(torch.uint8)
Class for the ResNet generator model.
Definition: resnet.py:86
norm
Batch normalization layer.
Definition: resnet.py:112
blocks
Sequence of residual block layers.
Definition: resnet.py:96
def forward(self, z)
Function to perform a forward pass.
Definition: resnet.py:117
def __init__(self)
Constructor to create a new model instance.
Definition: resnet.py:91
Class for the residual block layer.
Definition: resnet.py:19
conv
Convolutional layer in the shortcut branch.
Definition: resnet.py:28
conv1
First convolutional layer in the residual branch.
Definition: resnet.py:33
def forward(self, x)
Function to perform a forward pass.
Definition: resnet.py:39
conv2
Second convolutional layer in the residual branch.
Definition: resnet.py:35
def __init__(self, ninput, noutput, upsample=True)
Constructor to create a new residual block layer.
Definition: resnet.py:24
norm1
First batch normalization layer in the residual branch.
Definition: resnet.py:32
upsample
Whether to double the height and width of input.
Definition: resnet.py:26
norm2
Second batch normalization layer in the residual branch.
Definition: resnet.py:34