17 import torch.nn.functional
as F
20 def findConv2dOutShape(
26 """ Find proper height and width of an image in each convolution step.
30 @param H_in: Height of the image (in this case 7)
31 @param H_in: Width of the image (in this case 7)
32 @param conv: convolutional layer
33 @param pool: maxpooling
36 - H_out: Height of the image after convolution
37 - W_out: Width of the image after convolution
39 kernel_size = conv.kernel_size
41 padding = conv.padding
42 dilation = conv.dilation
45 (H_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
47 (W_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
57 class ConvNet(nn.Module):
58 """ ConvNet generator model """
64 """ Constructor to create a new model instance.
66 fc represents fully-connected layer
67 conv represents convolutional layer
69 super(ConvNet, self).__init__()
70 C_in, H_in, W_in = params[
'input_shape']
71 num_emb_theta = params[
'num_emb_theta']
72 dim_emb_theta = params[
'dim_emb_theta']
73 num_emb_phi = params[
'num_emb_phi']
74 dim_emb_phi = params[
'dim_emb_phi']
75 num_ext_input = params[
'num_ext_input']
76 init_f = params[
'initial_filters']
77 num_fc1 = params[
'num_fc1']
78 num_classes = params[
'num_classes']
79 self.dropout_rate = params[
'dropout_rate']
80 C_in_array = np.array(
83 count_C_in = np.count_nonzero(C_in_array)
85 self.emb_theta = nn.Embedding(num_emb_theta, dim_emb_theta)
86 self.emb_phi = nn.Embedding(num_emb_phi, dim_emb_phi)
88 self.conv1 = nn.Conv2d(C_in, init_f, kernel_size=3, padding=1, stride=1)
89 h, w = findConv2dOutShape(H_in, W_in, self.conv1, pool=1)
91 self.num_flatten = h * w * init_f
94 self.num_flatten * count_C_in + num_ext_input + dim_emb_theta + dim_emb_phi,
97 self.fc2 = nn.Linear(num_fc1, num_classes)
106 """ Function to perform a forward pass.
108 It computes the model output for a given input.
110 x1 = F.relu(self.conv1(energy))
111 x1 = x1.view(-1, self.num_flatten)
113 pt = torch.reshape(pt, (1, 1))
117 self.emb_theta(theta_input),
118 self.emb_phi(phi_input)),
122 x = F.relu(self.fc1(x))
123 x = F.dropout(x, self.dropout_rate, training=self.training)