Belle II Software  release-08-01-10
convnet.py
1 
8 """
9 This module implements the ConvNet generator model.
10 """
11 
12 import torch
13 import torch.nn as nn
14 
15 
16 
18 class Model(nn.Module):
19  """ConvNet generator model."""
20 
21 
23  def __init__(self):
24  super().__init__()
25  self.fcfc = nn.Linear(96, 98304)
26  self.featuresfeatures = nn.Sequential(
27  # format: (channels, height, width)
28  # (512, 8, 24)
29  nn.Conv2d(512, 512, 5, 1, 2),
30  nn.BatchNorm2d(512),
31  nn.ReLU(inplace=True),
32  nn.Upsample(scale_factor=2, mode="nearest"),
33  # (512, 16, 48)
34  nn.Conv2d(512, 256, 5, 1, 2),
35  nn.BatchNorm2d(256),
36  nn.ReLU(inplace=True),
37  nn.Upsample(scale_factor=2, mode="nearest"),
38  # (256, 32, 96)
39  nn.Conv2d(256, 128, 5, 1, 2),
40  nn.BatchNorm2d(128),
41  nn.ReLU(inplace=True),
42  nn.Upsample(scale_factor=2, mode="nearest"),
43  # (128, 64, 192)
44  nn.Conv2d(128, 64, 5, 1, 2),
45  nn.BatchNorm2d(64),
46  nn.ReLU(inplace=True),
47  nn.Upsample(scale_factor=2, mode="nearest"),
48  # (64, 128, 384)
49  nn.Conv2d(64, 32, 5, 1, 2),
50  nn.BatchNorm2d(32),
51  nn.ReLU(inplace=True),
52  nn.Upsample(scale_factor=2, mode="nearest"),
53  # (32, 256, 768)
54  nn.Conv2d(32, 1, 5, 1, 2),
55  # (1, 256, 768)
56  )
57 
58 
60  def forward(self, z):
61  """Compute the model output for a given input."""
62  return self.featuresfeatures(self.fcfc(z).view(-1, 512, 8, 24)).tanh_()
63 
64 
67 
68 
71 
72 
73 
76 def generate(model):
77  """Produce one pseudo-random image for each PXD module
78  using the ConvNet generator model.
79  """
80  # infer the device that is in use
81  device = next(model.parameters()).device
82  # without computing gradients
83  with torch.no_grad():
84  # initialize the model input(s)
85  z = torch.randn(40, 96, device=device)
86  # evaluate the model output and crop
87  x = model(z)[:, 0, 3:-3, :]
88  # delete the reference(s) to model input(s)
89  del z
90  # always transfer to the CPU
91  x = x.cpu()
92  # apply the inverse logarithmic transformation
93  x = x.mul_(0.5).add_(0.5).clamp_(0.0, 1.0)
94  x = torch.pow(256.0, x).sub_(1.0).clamp_(0.0, 255.0)
95  # convert to unsigned 8-bit integer data type
96  return x.to(torch.uint8)
Class for the ConvNet generator model.
Definition: convnet.py:18
features
Sequential composite layer.
Definition: convnet.py:26
def forward(self, z)
Function to perform a forward pass.
Definition: convnet.py:60
def __init__(self)
Constructor to create a new model instance.
Definition: convnet.py:23