9 This module implements the ResNet generator model.
14 import torch.nn.functional
as F
20 """Residual block layer."""
24 def __init__(self, ninput, noutput, upsample=True):
29 if upsample
or (ninput != noutput):
30 self.
convconv = nn.Conv2d(ninput, noutput, 1, 1, 0)
32 self.
norm1norm1 = nn.BatchNorm2d(ninput)
33 self.
conv1conv1 = nn.Conv2d(ninput, noutput, 3, 1, 1)
34 self.
norm2norm2 = nn.BatchNorm2d(noutput)
35 self.
conv2conv2 = nn.Conv2d(noutput, noutput, 3, 1, 1)
40 """Compute the layer output for a given input."""
43 h = self.
norm1norm1(h)
46 h = F.interpolate(h, mode=
"nearest", scale_factor=2)
47 h = self.
conv1conv1(h)
48 h = self.
norm2norm2(h)
50 h = self.
conv2conv2(h)
53 x = F.interpolate(x, mode=
"nearest", scale_factor=2)
87 """ResNet generator model."""
94 self.
fcfc = nn.Linear(96, 49152)
112 self.
normnorm = nn.BatchNorm2d(16)
113 self.
convconv = nn.Conv2d(16, 1, 3, 1, 1)
118 """Compute the model output for a given input."""
120 z = z.view(-1, 256, 8, 24)
121 for block
in self.
blocksblocks:
149 """Produce one pseudo-random image for each PXD module
150 using the ResNet generator model.
153 device = next(model.parameters()).device
155 with torch.no_grad():
157 z = torch.randn(40, 96, device=device)
159 x = model(z)[:, 0, 3:-3, :]
165 x = x.mul_(0.5).add_(0.5).clamp_(0.0, 1.0)
166 x = torch.pow(256.0, x).sub_(1.0).clamp_(0.0, 255.0)
168 return x.to(torch.uint8)
Class for the ResNet generator model.
norm
Batch normalization layer.
blocks
Sequence of residual block layers.
def forward(self, z)
Function to perform a forward pass.
def __init__(self)
Constructor to create a new model instance.
Class for the residual block layer.
conv
Convolutional layer in the shortcut branch.
conv1
First convolutional layer in the residual branch.
def forward(self, x)
Function to perform a forward pass.
conv2
Second convolutional layer in the residual branch.
def __init__(self, ninput, noutput, upsample=True)
Constructor to create a new residual block layer.
norm1
First batch normalization layer in the residual branch.
upsample
Whether to double the height and width of input.
norm2
Second batch normalization layer in the residual branch.