Belle II Software  release-08-01-10
cnn_pid_ecl_module.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
11 
12 """ This script includes the core of 'CNN_PID_ECL' module
13 
14 """
15 
16 # @cond
17 
18 import math
19 import numpy as np
20 import torch
21 import torch.nn.functional as F
22 from torch.utils.data import DataLoader
23 
24 import basf2 as b2
25 
26 from .cnn_pid_conv_net import ConvNet
27 from .cnn_pid_cluster_image import ClusterImage
28 
29 
30 class CNN_PID_ECL(b2.Module):
31  """ Extracts CNN value for an image
32 
33  A model based on convolutional neural network was trained
34  on pions and muons' pixel images together with 3 more information
35  which were added after convolutional layer. This module uses
36  a pretrained model which is used as inference on pions and muons
37  in the particle list.
38 
39  This module works under the following conditions:
40  1. The extrapolated tracks are inside ECL Barrel based
41  on ThetaId of central pixel.
42  2. Transverse momentum (pt) of extrapolated tracks are
43  in [0.2, 1.0] range.
44 
45  The goal is better muon and pion separation in the ECL barrel for
46  low pt tracks.
47  """
48 
49  from ROOT import Belle2
50  import pdg
51 
52  def __init__(
53  self,
54  path,
55  image_length=7,
56  threshold=0.001,
57  thetaId_range=(13, 58),
58  pt_range=(0.2, 1.0)
59  ):
60  super().__init__()
61 
62  self.image_length = image_length
63  self.threshold = threshold
64  self.thetaId_range = thetaId_range
65  self.pt_range = pt_range
66  self.path = path
67 
68  self.device = torch.device('cpu')
69 
70  def initialize(self):
71  """ Initialize necessary arrays and/or objects """
72 
73  self.eclCalDigits = self.Belle2.PyStoreArray('ECLCalDigits')
74  self.mapping = self.Belle2.PyStoreObj('ECLCellIdMapping')
75  self.eclCnnMuon = self.Belle2.PyStoreArray(self.Belle2.ECLCNNPid.Class())
76  self.eclCnnMuon.registerInDataStore()
77  self.tracks = self.Belle2.PyStoreArray('Tracks')
78  self.tracks.registerRelationTo(self.eclCnnMuon)
79 
80  def beginRun(self):
81  """ Read CNN payloads which include weights and biases """
82 
83  self.charge = 'plus'
84  self.model_plus = self.read_model()
85 
86  self.charge = 'minus'
87  self.model_minus = self.read_model()
88 
89  def event(self):
90  """ Event processing
91 
92  This function goes through the particleList and find eclcaldigit of
93  an extrapolated track. Then it extracts the energy for a 7x7 pixel image
94  around the entering position of the track. In addition it extracts
95  ThetaId and PhiId of central pixel in a 7x7 image together with Pt of
96  the track. Then, it calculates the CNN values for an extrapolated being
97  muon- or pion-like.
98  """
99 
100  variable_muon = 'cnn_pid_ecl_muon'
101 
102  for track in self.tracks:
103 
104  fit_result = track.getTrackFitResultWithClosestMass(self.Belle2.Const.pion)
105  if not fit_result:
106  continue
107 
108  charge = fit_result.getChargeSign()
109  if charge == 0:
110  continue
111 
112  extHit_dict = self.getExtCell(track)
113 
114  if (track and extHit_dict):
115  maxCellId = extHit_dict['cellid']
116  self.pt = extHit_dict['pt']
117  self.pt = np.array([self.pt])
118 
119  if charge == -1.0:
120  model = self.model_minus
121  elif charge == 1.0:
122  model = self.model_plus
123 
124  self.extThetaId = self.mapping.getCellIdToThetaId(maxCellId)
125  self.extThetaId = np.array([self.extThetaId])
126  self.extPhiId = self.mapping.getCellIdToPhiId(maxCellId)
127  self.extPhiId = np.array([self.extPhiId])
128 
129  if maxCellId < 0:
130  b2.B2WARNING('maxCellId is less 0.')
131  else:
132  # Since CNN can only predict PID of the tracks inside the barrel,
133  # there are two hard-coded numbers in the __init__ (13 and 58),
134  # representing the thetaID limit.
135  if (self.extThetaId > self.thetaId_range[0] and
136  self.extThetaId < self.thetaId_range[1] and
137  self.pt >= self.pt_range[0] and
138  self.pt <= self.pt_range[1]):
139 
140  energy_list = []
141  neighbours = self.mapping.getCellIdToNeighbour7(maxCellId)
142 
143  for posid in range(self.image_length ** 2):
144  if posid < neighbours.size():
145  neighbourid = neighbours[posid]
146 
147  storearraypos = self.mapping.getCellIdToStoreArray(neighbourid)
148  energy = 0.0
149  if storearraypos >= 0:
150  energy = self.eclCalDigits[storearraypos].getEnergy()
151  energy_list.append(energy)
152 
153  self.energy_array = np.array(energy_list)
154 
155  prob_CNN_dict = self.extract_cnn_value(model)
156  prob_CNN_muon = prob_CNN_dict['cnn_muon']
157 
158  eclCnnMuon = self.eclCnnMuon.appendNew()
159  eclCnnMuon.setEclCnnMuon(prob_CNN_muon)
160  track.addRelationTo(eclCnnMuon)
161 
162  b2.B2DEBUG(22, f'{variable_muon}: {prob_CNN_muon}')
163  else:
164  b2.B2DEBUG(22, 'Track is either outside ECL Barrel or Pt is outside [0.2, 1.0] GeV/c. No CNN value.')
165 
166  def getExtCell(self, track):
167  """ Extract cellId and pt of an extrapolated track
168 
169  The output is dictionary which has cellId and pt.
170  """
171 
172  myDetID = self.Belle2.Const.EDetector.ECL
173  hypothesis = self.Belle2.Const.pion
174  pdgCode = abs(hypothesis.getPDGCode())
175 
176  extHits = track.getRelationsTo('ExtHits')
177 
178  for extHit in extHits:
179  if abs(extHit.getPdgCode()) != pdgCode:
180  continue
181  if extHit.getDetectorID() != myDetID:
182  continue
183  if extHit.getStatus() != self.Belle2.ExtHitStatus.EXT_EXIT:
184  continue
185  if extHit.isBackwardPropagated():
186  continue
187  copyid = extHit.getCopyID()
188  if copyid == -1:
189  continue
190  cellid = copyid + 1
191 
192  px = extHit.getMomentum().X()
193  py = extHit.getMomentum().Y()
194  pt = math.sqrt(px**2 + py**2)
195 
196  extHit_dict = {
197  'cellid': cellid,
198  'pt': pt,
199  }
200 
201  return(extHit_dict)
202 
203  def extract_cnn_value(self, model):
204  """ Extract CNN values for an extrapolated track
205 
206  The output of this function is dictionary
207  which includes two probabilities:
208 
209  cnn_pion: Probability of an extrapolated track being pion
210  cnn_muon: Probability of an extrapolated track being muon
211 
212  NOTE: cnn_pion and cnn_muon are floats.
213  """
214 
215  test_loader = self.prepare_images()
216  model.eval()
217 
218  with torch.no_grad():
219  for energy, theta, phi, pt in test_loader:
220 
221  energy = energy.to(self.device)
222  theta = theta.to(self.device)
223  phi = phi.to(self.device)
224  pt = pt.to(self.device)
225 
226  output = model(energy, theta, phi, pt)
227  output = F.softmax(output, dim=1)
228 
229  prob_CNN_pion = output[0][0].item()
230  prob_CNN_muon = output[0][1].item()
231 
232  prob_dict = {
233  'cnn_pion': prob_CNN_pion,
234  'cnn_muon': prob_CNN_muon,
235  }
236 
237  return(prob_dict)
238 
239  def prepare_images(self):
240  """ Prepare images
241 
242  A dictionary is passed to a function (ClusterImage)
243  in order to prepare proper format for CNN inputs.
244  """
245 
246  params_image = {
247  'image_length': self.image_length,
248  'energy_image': self.energy_array,
249  'thetaId': self.extThetaId,
250  'phiId': self.extPhiId,
251  'pt': self.pt,
252  'threshold': self.threshold
253  }
254 
255  dataset = ClusterImage(params_image)
256  infer_loader = DataLoader(dataset, shuffle=False)
257 
258  return(infer_loader)
259 
260  def model_cnn_name(self):
261  """ Create CNN model name
262 
263  The outputs of this function are:
264  1. CNN model's name
265  2. Models parameters which is important for
266  initializing ConvNet.
267 
268  Models parameters should be exactly the same as the CNN model
269  which was trained.
270  """
271 
272  model_name = f'ECLCNNPID_charge_{self.charge}'
273 
274  params_model = {
275  'input_shape': (1, self.image_length, self.image_length),
276  'initial_filters': 64,
277  'num_emb_theta': 44,
278  'dim_emb_theta': 22,
279  'num_emb_phi': 144,
280  'dim_emb_phi': 18,
281  'num_ext_input': 1,
282  'num_fc1': 128,
283  'dropout_rate': 0.1,
284  'num_classes': 2,
285  'energy': True
286  }
287 
288  return(model_name, params_model)
289 
290  def read_model(self):
291  """ Load the CNN model
292 
293  This function receives model's name and
294  CNN parameters, then reads .pt file which
295  includes weights and biases for CNN.
296  """
297 
298  model_name, params_model = self.model_cnn_name()
299  model = ConvNet(params_model)
300  model = model.to(self.device)
301 
302  payload = f'{model_name}.pt'
303  accessor = self.Belle2.DBAccessorBase(
304  self.Belle2.DBStoreEntry.c_RawFile, payload, True)
305  checkpoint = accessor.getFilename()
306 
307  model.load_state_dict(torch.load(checkpoint))
308 
309  return(model)
310 
311 # @endcond