16import torch.nn.functional
as F
19def findConv2dOutShape(
25 """ Find proper height and width of an image in each convolution step.
29 @param H_in: Height of the image (
in this case 7)
30 @param H_in: Width of the image (
in this case 7)
31 @param conv: convolutional layer
32 @param pool: maxpooling
35 - H_out: Height of the image after convolution
36 - W_out: Width of the image after convolution
38 kernel_size = conv.kernel_size
40 padding = conv.padding
41 dilation = conv.dilation
44 (H_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
46 (W_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
56class ConvNet(nn.Module):
57 """ ConvNet generator model """
63 """ Constructor to create a new model instance.
65 fc represents fully-connected layer
66 conv represents convolutional layer
69 C_in, H_in, W_in = params['input_shape']
70 num_emb_theta = params[
'num_emb_theta']
71 dim_emb_theta = params[
'dim_emb_theta']
72 num_emb_phi = params[
'num_emb_phi']
73 dim_emb_phi = params[
'dim_emb_phi']
74 num_ext_input = params[
'num_ext_input']
75 init_f = params[
'initial_filters']
76 num_fc1 = params[
'num_fc1']
77 num_classes = params[
'num_classes']
78 self.dropout_rate = params[
'dropout_rate']
79 C_in_array = np.array(
82 count_C_in = np.count_nonzero(C_in_array)
84 self.emb_theta = nn.Embedding(num_emb_theta, dim_emb_theta)
85 self.emb_phi = nn.Embedding(num_emb_phi, dim_emb_phi)
87 self.conv1 = nn.Conv2d(C_in, init_f, kernel_size=3, padding=1, stride=1)
88 h, w = findConv2dOutShape(H_in, W_in, self.conv1, pool=1)
90 self.num_flatten = h * w * init_f
93 self.num_flatten * count_C_in + num_ext_input + dim_emb_theta + dim_emb_phi,
96 self.fc2 = nn.Linear(num_fc1, num_classes)
105 """ Function to perform a forward pass.
107 It computes the model output for a given input.
109 x1 = F.relu(self.conv1(energy))
110 x1 = x1.view(-1, self.num_flatten)
112 pt = torch.reshape(pt, (1, 1))
116 self.emb_theta(theta_input),
117 self.emb_phi(phi_input)),
121 x = F.relu(self.fc1(x))
122 x = F.dropout(x, self.dropout_rate, training=self.training)