Belle II Software  release-08-01-10
cnn_pid_conv_net.py
1 # !/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
11 
12 # @cond
13 
14 import torch
15 import numpy as np
16 import torch.nn as nn
17 import torch.nn.functional as F
18 
19 
20 def findConv2dOutShape(
21  H_in,
22  W_in,
23  conv,
24  pool
25 ):
26  """ Find proper height and width of an image in each convolution step.
27 
28  Inputs:
29 
30  @param H_in: Height of the image (in this case 7)
31  @param H_in: Width of the image (in this case 7)
32  @param conv: convolutional layer
33  @param pool: maxpooling
34 
35  Outputs:
36  - H_out: Height of the image after convolution
37  - W_out: Width of the image after convolution
38  """
39  kernel_size = conv.kernel_size
40  stride = conv.stride
41  padding = conv.padding
42  dilation = conv.dilation
43 
44  H_out = np.floor(
45  (H_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
46  W_out = np.floor(
47  (W_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
48 
49  if pool:
50  H_out /= pool
51  W_out /= pool
52 
53  return(int(H_out),
54  int(W_out))
55 
56 
57 class ConvNet(nn.Module):
58  """ ConvNet generator model """
59 
60  def __init__(
61  self,
62  params
63  ):
64  """ Constructor to create a new model instance.
65 
66  fc represents fully-connected layer
67  conv represents convolutional layer
68  """
69  super(ConvNet, self).__init__()
70  C_in, H_in, W_in = params['input_shape']
71  num_emb_theta = params['num_emb_theta']
72  dim_emb_theta = params['dim_emb_theta']
73  num_emb_phi = params['num_emb_phi']
74  dim_emb_phi = params['dim_emb_phi']
75  num_ext_input = params['num_ext_input']
76  init_f = params['initial_filters']
77  num_fc1 = params['num_fc1']
78  num_classes = params['num_classes']
79  self.dropout_rate = params['dropout_rate']
80  C_in_array = np.array(
81  [True]
82  )
83  count_C_in = np.count_nonzero(C_in_array)
84 
85  self.emb_theta = nn.Embedding(num_emb_theta, dim_emb_theta)
86  self.emb_phi = nn.Embedding(num_emb_phi, dim_emb_phi)
87 
88  self.conv1 = nn.Conv2d(C_in, init_f, kernel_size=3, padding=1, stride=1)
89  h, w = findConv2dOutShape(H_in, W_in, self.conv1, pool=1)
90 
91  self.num_flatten = h * w * init_f
92 
93  self.fc1 = nn.Linear(
94  self.num_flatten * count_C_in + num_ext_input + dim_emb_theta + dim_emb_phi,
95  num_fc1)
96 
97  self.fc2 = nn.Linear(num_fc1, num_classes)
98 
99  def forward(
100  self,
101  energy,
102  theta_input,
103  phi_input,
104  pt
105  ):
106  """ Function to perform a forward pass.
107 
108  It computes the model output for a given input.
109  """
110  x1 = F.relu(self.conv1(energy))
111  x1 = x1.view(-1, self.num_flatten)
112 
113  pt = torch.reshape(pt, (1, 1))
114  x = torch.cat(
115  (x1,
116  pt,
117  self.emb_theta(theta_input),
118  self.emb_phi(phi_input)),
119  dim=1
120  )
121 
122  x = F.relu(self.fc1(x))
123  x = F.dropout(x, self.dropout_rate, training=self.training)
124  output = self.fc2(x)
125 
126  return(output)
127 
128 # @endcond