12 """ This script includes the core of 'CNN_PID_ECL' module
21 import torch.nn.functional
as F
22 from torch.utils.data
import DataLoader
26 from .cnn_pid_conv_net
import ConvNet
27 from .cnn_pid_cluster_image
import ClusterImage
30 class CNN_PID_ECL(b2.Module):
31 """ Extracts CNN value for an image
33 A model based on convolutional neural network was trained
34 on pions and muons' pixel images together with 3 more information
35 which were added after convolutional layer. This module uses
36 a pretrained model which is used as inference on pions and muons
39 This module works under the following conditions:
40 1. The extrapolated tracks are inside ECL Barrel based
41 on ThetaId of central pixel.
42 2. Transverse momentum (pt) of extrapolated tracks are
45 The goal is better muon and pion separation in the ECL barrel for
49 from ROOT
import Belle2
57 thetaId_range=(13, 58),
62 self.image_length = image_length
63 self.threshold = threshold
64 self.thetaId_range = thetaId_range
65 self.pt_range = pt_range
68 self.device = torch.device(
'cpu')
71 """ Initialize necessary arrays and/or objects """
73 self.eclCalDigits = self.Belle2.PyStoreArray(
'ECLCalDigits')
74 self.mapping = self.Belle2.PyStoreObj(
'ECLCellIdMapping')
75 self.eclCnnMuon = self.Belle2.PyStoreArray(self.Belle2.ECLCNNPid.Class())
76 self.eclCnnMuon.registerInDataStore()
77 self.tracks = self.Belle2.PyStoreArray(
'Tracks')
78 self.tracks.registerRelationTo(self.eclCnnMuon)
81 """ Read CNN payloads which include weights and biases """
84 self.model_plus = self.read_model()
87 self.model_minus = self.read_model()
92 This function goes through the particleList and find eclcaldigit of
93 an extrapolated track. Then it extracts the energy for a 7x7 pixel image
94 around the entering position of the track. In addition it extracts
95 ThetaId and PhiId of central pixel in a 7x7 image together with Pt of
96 the track. Then, it calculates the CNN values for an extrapolated being
100 variable_muon =
'cnn_pid_ecl_muon'
102 for track
in self.tracks:
104 fit_result = track.getTrackFitResultWithClosestMass(self.Belle2.Const.pion)
108 charge = fit_result.getChargeSign()
112 extHit_dict = self.getExtCell(track)
114 if (track
and extHit_dict):
115 maxCellId = extHit_dict[
'cellid']
116 self.pt = extHit_dict[
'pt']
117 self.pt = np.array([self.pt])
120 model = self.model_minus
122 model = self.model_plus
124 self.extThetaId = self.mapping.getCellIdToThetaId(maxCellId)
125 self.extThetaId = np.array([self.extThetaId])
126 self.extPhiId = self.mapping.getCellIdToPhiId(maxCellId)
127 self.extPhiId = np.array([self.extPhiId])
130 b2.B2WARNING(
'maxCellId is less 0.')
135 if (self.extThetaId > self.thetaId_range[0]
and
136 self.extThetaId < self.thetaId_range[1]
and
137 self.pt >= self.pt_range[0]
and
138 self.pt <= self.pt_range[1]):
141 neighbours = self.mapping.getCellIdToNeighbour7(maxCellId)
143 for posid
in range(self.image_length ** 2):
144 if posid < neighbours.size():
145 neighbourid = neighbours[posid]
147 storearraypos = self.mapping.getCellIdToStoreArray(neighbourid)
149 if storearraypos >= 0:
150 energy = self.eclCalDigits[storearraypos].getEnergy()
151 energy_list.append(energy)
153 self.energy_array = np.array(energy_list)
155 prob_CNN_dict = self.extract_cnn_value(model)
156 prob_CNN_muon = prob_CNN_dict[
'cnn_muon']
158 eclCnnMuon = self.eclCnnMuon.appendNew()
159 eclCnnMuon.setEclCnnMuon(prob_CNN_muon)
160 track.addRelationTo(eclCnnMuon)
162 b2.B2DEBUG(22, f
'{variable_muon}: {prob_CNN_muon}')
164 b2.B2DEBUG(22,
'Track is either outside ECL Barrel or Pt is outside [0.2, 1.0] GeV/c. No CNN value.')
166 def getExtCell(self, track):
167 """ Extract cellId and pt of an extrapolated track
169 The output is dictionary which has cellId and pt.
172 myDetID = self.Belle2.Const.EDetector.ECL
173 hypothesis = self.Belle2.Const.pion
174 pdgCode = abs(hypothesis.getPDGCode())
176 extHits = track.getRelationsTo(
'ExtHits')
178 for extHit
in extHits:
179 if abs(extHit.getPdgCode()) != pdgCode:
181 if extHit.getDetectorID() != myDetID:
183 if extHit.getStatus() != self.Belle2.ExtHitStatus.EXT_EXIT:
185 if extHit.isBackwardPropagated():
187 copyid = extHit.getCopyID()
192 px = extHit.getMomentum().X()
193 py = extHit.getMomentum().Y()
194 pt = math.sqrt(px**2 + py**2)
203 def extract_cnn_value(self, model):
204 """ Extract CNN values for an extrapolated track
206 The output of this function is dictionary
207 which includes two probabilities:
209 cnn_pion: Probability of an extrapolated track being pion
210 cnn_muon: Probability of an extrapolated track being muon
212 NOTE: cnn_pion and cnn_muon are floats.
215 test_loader = self.prepare_images()
218 with torch.no_grad():
219 for energy, theta, phi, pt
in test_loader:
221 energy = energy.to(self.device)
222 theta = theta.to(self.device)
223 phi = phi.to(self.device)
224 pt = pt.to(self.device)
226 output = model(energy, theta, phi, pt)
227 output = F.softmax(output, dim=1)
229 prob_CNN_pion = output[0][0].item()
230 prob_CNN_muon = output[0][1].item()
233 'cnn_pion': prob_CNN_pion,
234 'cnn_muon': prob_CNN_muon,
239 def prepare_images(self):
242 A dictionary is passed to a function (ClusterImage)
243 in order to prepare proper format for CNN inputs.
247 'image_length': self.image_length,
248 'energy_image': self.energy_array,
249 'thetaId': self.extThetaId,
250 'phiId': self.extPhiId,
252 'threshold': self.threshold
255 dataset = ClusterImage(params_image)
256 infer_loader = DataLoader(dataset, shuffle=
False)
260 def model_cnn_name(self):
261 """ Create CNN model name
263 The outputs of this function are:
265 2. Models parameters which is important for
266 initializing ConvNet.
268 Models parameters should be exactly the same as the CNN model
272 model_name = f
'ECLCNNPID_charge_{self.charge}'
275 'input_shape': (1, self.image_length, self.image_length),
276 'initial_filters': 64,
288 return(model_name, params_model)
290 def read_model(self):
291 """ Load the CNN model
293 This function receives model's name and
294 CNN parameters, then reads .pt file which
295 includes weights and biases for CNN.
298 model_name, params_model = self.model_cnn_name()
299 model = ConvNet(params_model)
300 model = model.to(self.device)
302 payload = f
'{model_name}.pt'
303 accessor = self.Belle2.DBAccessorBase(
304 self.Belle2.DBStoreEntry.c_RawFile, payload,
True)
305 checkpoint = accessor.getFilename()
307 model.load_state_dict(torch.load(checkpoint))