Belle II Software development
convnet.py
1
8"""
9This module implements the ConvNet generator model.
10"""
11
12import torch
13import torch.nn as nn
14
15
16
18class Model(nn.Module):
19 """ConvNet generator model."""
20
21
23 def __init__(self):
24 super().__init__()
25 self.fc = nn.Linear(96, 98304)
26 self.features = nn.Sequential(
27 # format: (channels, height, width)
28 # (512, 8, 24)
29 nn.Conv2d(512, 512, 5, 1, 2),
30 nn.BatchNorm2d(512),
31 nn.ReLU(inplace=True),
32 nn.Upsample(scale_factor=2, mode="nearest"),
33 # (512, 16, 48)
34 nn.Conv2d(512, 256, 5, 1, 2),
35 nn.BatchNorm2d(256),
36 nn.ReLU(inplace=True),
37 nn.Upsample(scale_factor=2, mode="nearest"),
38 # (256, 32, 96)
39 nn.Conv2d(256, 128, 5, 1, 2),
40 nn.BatchNorm2d(128),
41 nn.ReLU(inplace=True),
42 nn.Upsample(scale_factor=2, mode="nearest"),
43 # (128, 64, 192)
44 nn.Conv2d(128, 64, 5, 1, 2),
45 nn.BatchNorm2d(64),
46 nn.ReLU(inplace=True),
47 nn.Upsample(scale_factor=2, mode="nearest"),
48 # (64, 128, 384)
49 nn.Conv2d(64, 32, 5, 1, 2),
50 nn.BatchNorm2d(32),
51 nn.ReLU(inplace=True),
52 nn.Upsample(scale_factor=2, mode="nearest"),
53 # (32, 256, 768)
54 nn.Conv2d(32, 1, 5, 1, 2),
55 # (1, 256, 768)
56 )
57
58
60 def forward(self, z):
61 """Compute the model output for a given input."""
62 return self.features(self.fc(z).view(-1, 512, 8, 24)).tanh_()
63
64
67
68
71
72
73
76def generate(model):
77 """Produce one pseudo-random image for each PXD module
78 using the ConvNet generator model.
79 """
80 # infer the device that is in use
81 device = next(model.parameters()).device
82 # without computing gradients
83 with torch.no_grad():
84 # initialize the model input(s)
85 z = torch.randn(40, 96, device=device)
86 # evaluate the model output and crop
87 x = model(z)[:, 0, 3:-3, :]
88 # delete the reference(s) to model input(s)
89 del z
90 # always transfer to the CPU
91 x = x.cpu()
92 # apply the inverse logarithmic transformation
93 x = x.mul_(0.5).add_(0.5).clamp_(0.0, 1.0)
94 x = torch.pow(256.0, x).sub_(1.0).clamp_(0.0, 255.0)
95 # convert to unsigned 8-bit integer data type
96 return x.to(torch.uint8)
Class for the ConvNet generator model.
Definition: convnet.py:18
features
Sequential composite layer.
Definition: convnet.py:26
def forward(self, z)
Function to perform a forward pass.
Definition: convnet.py:60
def __init__(self)
Constructor to create a new model instance.
Definition: convnet.py:23