Belle II Software development
cnn_pid_conv_net.py
1# !/usr/bin/env python3
2
3
10
11# @cond
12
13import torch
14import numpy as np
15import torch.nn as nn
16import torch.nn.functional as F
17
18
19def findConv2dOutShape(
20 H_in,
21 W_in,
22 conv,
23 pool
24):
25 """ Find proper height and width of an image in each convolution step.
26
27 Inputs:
28
29 @param H_in: Height of the image (in this case 7)
30 @param H_in: Width of the image (in this case 7)
31 @param conv: convolutional layer
32 @param pool: maxpooling
33
34 Outputs:
35 - H_out: Height of the image after convolution
36 - W_out: Width of the image after convolution
37 """
38 kernel_size = conv.kernel_size
39 stride = conv.stride
40 padding = conv.padding
41 dilation = conv.dilation
42
43 H_out = np.floor(
44 (H_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
45 W_out = np.floor(
46 (W_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
47
48 if pool:
49 H_out /= pool
50 W_out /= pool
51
52 return(int(H_out),
53 int(W_out))
54
55
56class ConvNet(nn.Module):
57 """ ConvNet generator model """
58
59 def __init__(
60 self,
61 params
62 ):
63 """ Constructor to create a new model instance.
64
65 fc represents fully-connected layer
66 conv represents convolutional layer
67 """
68 super().__init__()
69 C_in, H_in, W_in = params['input_shape']
70 num_emb_theta = params['num_emb_theta']
71 dim_emb_theta = params['dim_emb_theta']
72 num_emb_phi = params['num_emb_phi']
73 dim_emb_phi = params['dim_emb_phi']
74 num_ext_input = params['num_ext_input']
75 init_f = params['initial_filters']
76 num_fc1 = params['num_fc1']
77 num_classes = params['num_classes']
78 self.dropout_rate = params['dropout_rate']
79 C_in_array = np.array(
80 [True]
81 )
82 count_C_in = np.count_nonzero(C_in_array)
83
84 self.emb_theta = nn.Embedding(num_emb_theta, dim_emb_theta)
85 self.emb_phi = nn.Embedding(num_emb_phi, dim_emb_phi)
86
87 self.conv1 = nn.Conv2d(C_in, init_f, kernel_size=3, padding=1, stride=1)
88 h, w = findConv2dOutShape(H_in, W_in, self.conv1, pool=1)
89
90 self.num_flatten = h * w * init_f
91
92 self.fc1 = nn.Linear(
93 self.num_flatten * count_C_in + num_ext_input + dim_emb_theta + dim_emb_phi,
94 num_fc1)
95
96 self.fc2 = nn.Linear(num_fc1, num_classes)
97
98 def forward(
99 self,
100 energy,
101 theta_input,
102 phi_input,
103 pt
104 ):
105 """ Function to perform a forward pass.
106
107 It computes the model output for a given input.
108 """
109 x1 = F.relu(self.conv1(energy))
110 x1 = x1.view(-1, self.num_flatten)
111
112 pt = torch.reshape(pt, (1, 1))
113 x = torch.cat(
114 (x1,
115 pt,
116 self.emb_theta(theta_input),
117 self.emb_phi(phi_input)),
118 dim=1
119 )
120
121 x = F.relu(self.fc1(x))
122 x = F.dropout(x, self.dropout_rate, training=self.training)
123 output = self.fc2(x)
124
125 return(output)
126
127# @endcond