Belle II Software development
cnn_pid_ecl_module.py
1#!/usr/bin/env python3
2
3
10
11""" This script includes the core of 'CNN_PID_ECL' module
12
13"""
14
15# @cond
16
17import math
18import numpy as np
19import torch
20import torch.nn.functional as F
21from torch.utils.data import DataLoader
22
23import basf2 as b2
24
25from .cnn_pid_conv_net import ConvNet
26from .cnn_pid_cluster_image import ClusterImage
27
28
29class CNN_PID_ECL(b2.Module):
30 """ Extracts CNN value for an image
31
32 A model based on convolutional neural network was trained
33 on pions and muons' pixel images together with 3 more information which were added after convolutional layer. This module uses
34 a pretrained model which is used as inference on pions and muons
35 in the particle list.
36
37 This module works under the following conditions:
38 1. The extrapolated tracks are inside ECL Barrel based
39 on ThetaId of central pixel.
40 2. Transverse momentum (pt) of extrapolated tracks are
41 in [0.2, 1.0] range.
42
43 The goal is better muon and pion separation in the ECL barrel for
44 low pt tracks.
45 """
46
47 from ROOT import Belle2
48 import pdg
49
50 def __init__(
51 self,
52 path,
53 image_length=7,
54 threshold=0.001,
55 thetaId_range=(13, 58),
56 pt_range=(0.2, 1.0)
57 ):
58 super().__init__()
59
60 self.image_length = image_length
61 self.threshold = threshold
62 self.thetaId_range = thetaId_range
63 self.pt_range = pt_range
64 self.path = path
65
66 self.device = torch.device('cpu')
67
68 def initialize(self):
69 """ Initialize necessary arrays and/or objects """
70
71 self.eclCalDigits = self.Belle2.PyStoreArray('ECLCalDigits')
72 self.mapping = self.Belle2.PyStoreObj('ECLCellIdMapping')
73 self.eclCnnMuon = self.Belle2.PyStoreArray(self.Belle2.ECLCNNPid.Class())
74 self.eclCnnMuon.registerInDataStore()
75 self.tracks = self.Belle2.PyStoreArray('Tracks')
76 self.tracks.registerRelationTo(self.eclCnnMuon)
77
78 def beginRun(self):
79 """ Read CNN payloads which include weights and biases """
80
81 self.charge = 'plus'
82 self.model_plus = self.read_model()
83
84 self.charge = 'minus'
85 self.model_minus = self.read_model()
86
87 def event(self):
88 """ Event processing
89
90 This function goes through the particleList and find eclcaldigit of
91 an extrapolated track. Then it extracts the energy for a 7x7 pixel image
92 around the entering position of the track. In addition it extracts
93 ThetaId and PhiId of central pixel in a 7x7 image together with Pt of
94 the track. Then, it calculates the CNN values for an extrapolated being
95 muon- or pion-like.
96 """
97
98 variable_muon = 'cnn_pid_ecl_muon'
99
100 for track in self.tracks:
101
102 fit_result = track.getTrackFitResultWithClosestMass(self.Belle2.Const.pion)
103 if not fit_result:
104 continue
105
106 charge = fit_result.getChargeSign()
107 if charge == 0:
108 continue
109
110 extHit_dict = self.getExtCell(track)
111
112 if (track and extHit_dict):
113 maxCellId = extHit_dict['cellid']
114 self.pt = extHit_dict['pt']
115 self.pt = np.array([self.pt])
116
117 if charge == -1.0:
118 model = self.model_minus
119 elif charge == 1.0:
120 model = self.model_plus
121
122 self.extThetaId = self.mapping.getCellIdToThetaId(maxCellId)
123 self.extThetaId = np.array([self.extThetaId])
124 self.extPhiId = self.mapping.getCellIdToPhiId(maxCellId)
125 self.extPhiId = np.array([self.extPhiId])
126
127 if maxCellId < 0:
128 b2.B2WARNING('maxCellId is less 0.')
129 else:
130 # Since CNN can only predict PID of the tracks inside the barrel,
131 # there are two hard-coded numbers in the __init__ (13 and 58),
132 # representing the thetaID limit.
133 if (self.extThetaId > self.thetaId_range[0] and
134 self.extThetaId < self.thetaId_range[1] and
135 self.pt >= self.pt_range[0] and
136 self.pt <= self.pt_range[1]):
137
138 energy_list = []
139 neighbours = self.mapping.getCellIdToNeighbour7(maxCellId)
140
141 for posid in range(self.image_length ** 2):
142 if posid < neighbours.size():
143 neighbourid = neighbours[posid]
144
145 storearraypos = self.mapping.getCellIdToStoreArray(neighbourid)
146 energy = 0.0
147 if storearraypos >= 0:
148 energy = self.eclCalDigits[storearraypos].getEnergy()
149 energy_list.append(energy)
150
151 self.energy_array = np.array(energy_list)
152
153 prob_CNN_dict = self.extract_cnn_value(model)
154 prob_CNN_muon = prob_CNN_dict['cnn_muon']
155
156 eclCnnMuon = self.eclCnnMuon.appendNew()
157 eclCnnMuon.setEclCnnMuon(prob_CNN_muon)
158 track.addRelationTo(eclCnnMuon)
159
160 b2.B2DEBUG(22, f'{variable_muon}: {prob_CNN_muon}')
161 else:
162 b2.B2DEBUG(22, 'Track is either outside ECL Barrel or Pt is outside [0.2, 1.0] GeV/c. No CNN value.')
163
164 def getExtCell(self, track):
165 """ Extract cellId and pt of an extrapolated track
166
167 The output is dictionary which has cellId and pt.
168 """
169
170 myDetID = self.Belle2.Const.EDetector.ECL
171 hypothesis = self.Belle2.Const.pion
172 pdgCode = abs(hypothesis.getPDGCode())
173
174 extHits = track.getRelationsTo('ExtHits')
175
176 for extHit in extHits:
177 if abs(extHit.getPdgCode()) != pdgCode:
178 continue
179 if extHit.getDetectorID() != myDetID:
180 continue
181 if extHit.getStatus() != self.Belle2.ExtHitStatus.EXT_EXIT:
182 continue
183 if extHit.isBackwardPropagated():
184 continue
185 copyid = extHit.getCopyID()
186 if copyid == -1:
187 continue
188 cellid = copyid + 1
189
190 px = extHit.getMomentum().X()
191 py = extHit.getMomentum().Y()
192 pt = math.sqrt(px**2 + py**2)
193
194 extHit_dict = {
195 'cellid': cellid,
196 'pt': pt,
197 }
198
199 return(extHit_dict)
200
201 def extract_cnn_value(self, model):
202 """ Extract CNN values for an extrapolated track
203
204 The output of this function is dictionary
205 which includes two probabilities:
206
207 cnn_pion: Probability of an extrapolated track being pion
208 cnn_muon: Probability of an extrapolated track being muon
209
210 NOTE: cnn_pion and cnn_muon are floats.
211 """
212
213 test_loader = self.prepare_images()
214 model.eval()
215
216 with torch.no_grad():
217 for energy, theta, phi, pt in test_loader:
218
219 energy = energy.to(self.device)
220 theta = theta.to(self.device)
221 phi = phi.to(self.device)
222 pt = pt.to(self.device)
223
224 output = model(energy, theta, phi, pt)
225 output = F.softmax(output, dim=1)
226
227 prob_CNN_pion = output[0][0].item()
228 prob_CNN_muon = output[0][1].item()
229
230 prob_dict = {
231 'cnn_pion': prob_CNN_pion,
232 'cnn_muon': prob_CNN_muon,
233 }
234
235 return(prob_dict)
236
237 def prepare_images(self):
238 """ Prepare images
239
240 A dictionary is passed to a function (ClusterImage)
241 in order to prepare proper format for CNN inputs.
242 """
243
244 params_image = {
245 'image_length': self.image_length,
246 'energy_image': self.energy_array,
247 'thetaId': self.extThetaId,
248 'phiId': self.extPhiId,
249 'pt': self.pt,
250 'threshold': self.threshold
251 }
252
253 dataset = ClusterImage(params_image)
254 infer_loader = DataLoader(dataset, shuffle=False)
255
256 return(infer_loader)
257
258 def model_cnn_name(self):
259 """ Create CNN model name
260
261 The outputs of this function are:
262 1. CNN model's name
263 2. Models parameters which is important for
264 initializing ConvNet.
265
266 Models parameters should be exactly the same as the CNN model
267 which was trained.
268 """
269
270 model_name = f'ECLCNNPID_charge_{self.charge}'
271
272 params_model = {
273 'input_shape': (1, self.image_length, self.image_length),
274 'initial_filters': 64,
275 'num_emb_theta': 44,
276 'dim_emb_theta': 22,
277 'num_emb_phi': 144,
278 'dim_emb_phi': 18,
279 'num_ext_input': 1,
280 'num_fc1': 128,
281 'dropout_rate': 0.1,
282 'num_classes': 2,
283 'energy': True
284 }
285
286 return(model_name, params_model)
287
288 def read_model(self):
289 """ Load the CNN model
290
291 This function receives model's name and
292 CNN parameters, then reads .pt file which
293 includes weights and biases for CNN.
294 """
295
296 model_name, params_model = self.model_cnn_name()
297 model = ConvNet(params_model)
298 model = model.to(self.device)
299
300 payload = f'{model_name}.pt'
301 accessor = self.Belle2.DBAccessorBase(
302 self.Belle2.DBStoreEntry.c_RawFile, payload, True)
303 checkpoint = accessor.getFilename()
304
305 model.load_state_dict(torch.load(checkpoint))
306
307 return(model)
308
309# @endcond
310