Belle II Software  release-08-01-10
resnet.py
1 
8 """
9 This module implements the ResNet generator model.
10 """
11 
12 import torch
13 import torch.nn as nn
14 import torch.nn.functional as F
15 
16 
17 
19 class ResidualBlock(nn.Module):
20  """Residual block layer."""
21 
22 
24  def __init__(self, ninput, noutput, upsample=True):
25  super().__init__()
26  self.upsampleupsample = upsample
27  # shortcut branch
28  self.convconv = None
29  if upsample or (ninput != noutput):
30  self.convconv = nn.Conv2d(ninput, noutput, 1, 1, 0)
31  # residual branch
32  self.norm1norm1 = nn.BatchNorm2d(ninput)
33  self.conv1conv1 = nn.Conv2d(ninput, noutput, 3, 1, 1)
34  self.norm2norm2 = nn.BatchNorm2d(noutput)
35  self.conv2conv2 = nn.Conv2d(noutput, noutput, 3, 1, 1)
36 
37 
39  def forward(self, x):
40  """Compute the layer output for a given input."""
41  # residual branch
42  h = x
43  h = self.norm1norm1(h)
44  h.relu_()
45  if self.upsampleupsample:
46  h = F.interpolate(h, mode="nearest", scale_factor=2)
47  h = self.conv1conv1(h)
48  h = self.norm2norm2(h)
49  h.relu_()
50  h = self.conv2conv2(h)
51  # shortcut branch
52  if self.upsampleupsample:
53  x = F.interpolate(x, mode="nearest", scale_factor=2)
54  if self.convconv:
55  x = self.convconv(x)
56  # return sum of both
57  return h + x
58 
59 
62 
63 
66 
67 
70 
71 
74 
75 
78 
79 
82 
83 
84 
86 class Model(nn.Module):
87  """ResNet generator model."""
88 
89 
91  def __init__(self):
92  super().__init__()
93  # fully-connected inputs
94  self.fcfc = nn.Linear(96, 49152)
95  # stack of residual blocks
96  self.blocksblocks = nn.ModuleList(
97  [
98  # (256, 8, 24)
99  ResidualBlock(256, 256),
100  # (256, 16, 48)
101  ResidualBlock(256, 128),
102  # (128, 32, 96)
103  ResidualBlock(128, 64),
104  # (64, 64, 192)
105  ResidualBlock(64, 32),
106  # (32, 128, 384)
107  ResidualBlock(32, 16),
108  # (16, 256, 768)
109  ]
110  )
111  # norm and conv outputs
112  self.normnorm = nn.BatchNorm2d(16)
113  self.convconv = nn.Conv2d(16, 1, 3, 1, 1)
114 
115 
117  def forward(self, z):
118  """Compute the model output for a given input."""
119  z = self.fcfc(z)
120  z = z.view(-1, 256, 8, 24)
121  for block in self.blocksblocks:
122  z = block(z)
123  z = self.normnorm(z)
124  z.relu_()
125  z = self.convconv(z)
126  return z.tanh_()
127 
128 
131 
132 
135 
136 
139 
140 
143 
144 
145 
148 def generate(model):
149  """Produce one pseudo-random image for each PXD module
150  using the ResNet generator model.
151  """
152  # infer the device that is in use
153  device = next(model.parameters()).device
154  # without computing gradients
155  with torch.no_grad():
156  # initialize the model input(s)
157  z = torch.randn(40, 96, device=device)
158  # evaluate the model output and crop
159  x = model(z)[:, 0, 3:-3, :]
160  # delete the reference(s) to model input(s)
161  del z
162  # always transfer to the CPU
163  x = x.cpu()
164  # apply the inverse logarithmic transformation
165  x = x.mul_(0.5).add_(0.5).clamp_(0.0, 1.0)
166  x = torch.pow(256.0, x).sub_(1.0).clamp_(0.0, 255.0)
167  # convert to unsigned 8-bit integer data type
168  return x.to(torch.uint8)
Class for the ResNet generator model.
Definition: resnet.py:86
norm
Batch normalization layer.
Definition: resnet.py:112
blocks
Sequence of residual block layers.
Definition: resnet.py:96
def forward(self, z)
Function to perform a forward pass.
Definition: resnet.py:117
def __init__(self)
Constructor to create a new model instance.
Definition: resnet.py:91
Class for the residual block layer.
Definition: resnet.py:19
conv
Convolutional layer in the shortcut branch.
Definition: resnet.py:28
conv1
First convolutional layer in the residual branch.
Definition: resnet.py:33
def forward(self, x)
Function to perform a forward pass.
Definition: resnet.py:39
conv2
Second convolutional layer in the residual branch.
Definition: resnet.py:35
def __init__(self, ninput, noutput, upsample=True)
Constructor to create a new residual block layer.
Definition: resnet.py:24
norm1
First batch normalization layer in the residual branch.
Definition: resnet.py:32
upsample
Whether to double the height and width of input.
Definition: resnet.py:26
norm2
Second batch normalization layer in the residual branch.
Definition: resnet.py:34