11""" This script includes the core of 'CNN_PID_ECL' module
20import torch.nn.functional
as F
25from .cnn_pid_conv_net
import ConvNet
26from .cnn_pid_cluster_image
import ClusterImage
29class CNN_PID_ECL(b2.Module):
30 """ Extracts CNN value for an image
32 A model based on convolutional neural network was trained
33 on pions and muons
' pixel images together with 3 more information which were added after convolutional layer. This module uses
34 a pretrained model which is used
as inference on pions
and muons
37 This module works under the following conditions:
38 1. The extrapolated tracks are inside ECL Barrel based
39 on ThetaId of central pixel.
40 2. Transverse momentum (pt) of extrapolated tracks are
43 The goal
is better muon
and pion separation
in the ECL barrel
for
47 from ROOT
import Belle2
55 thetaId_range=(13, 58),
60 self.image_length = image_length
61 self.threshold = threshold
62 self.thetaId_range = thetaId_range
63 self.pt_range = pt_range
66 self.device = torch.device(
'cpu')
69 """ Initialize necessary arrays and/or objects """
71 self.eclCalDigits = self.Belle2.PyStoreArray(
'ECLCalDigits')
72 self.mapping = self.Belle2.PyStoreObj(
'ECLCellIdMapping')
73 self.eclCnnMuon = self.Belle2.PyStoreArray(self.Belle2.ECLCNNPid.Class())
74 self.eclCnnMuon.registerInDataStore()
75 self.tracks = self.Belle2.PyStoreArray(
'Tracks')
76 self.tracks.registerRelationTo(self.eclCnnMuon)
79 """ Read CNN payloads which include weights and biases """
82 self.model_plus = self.read_model()
85 self.model_minus = self.read_model()
90 This function goes through the particleList and find eclcaldigit of
91 an extrapolated track. Then it extracts the energy
for a 7x7 pixel image
92 around the entering position of the track. In addition it extracts
93 ThetaId
and PhiId of central pixel
in a 7x7 image together
with Pt of
94 the track. Then, it calculates the CNN values
for an extrapolated being
98 variable_muon = 'cnn_pid_ecl_muon'
100 for track
in self.tracks:
102 fit_result = track.getTrackFitResultWithClosestMass(self.Belle2.Const.pion)
106 charge = fit_result.getChargeSign()
110 extHit_dict = self.getExtCell(track)
112 if (track
and extHit_dict):
113 maxCellId = extHit_dict[
'cellid']
114 self.pt = extHit_dict[
'pt']
115 self.pt = np.array([self.pt])
118 model = self.model_minus
120 model = self.model_plus
122 self.extThetaId = self.mapping.getCellIdToThetaId(maxCellId)
123 self.extThetaId = np.array([self.extThetaId])
124 self.extPhiId = self.mapping.getCellIdToPhiId(maxCellId)
125 self.extPhiId = np.array([self.extPhiId])
128 b2.B2WARNING(
'maxCellId is less 0.')
133 if (self.extThetaId > self.thetaId_range[0]
and
134 self.extThetaId < self.thetaId_range[1]
and
135 self.pt >= self.pt_range[0]
and
136 self.pt <= self.pt_range[1]):
139 neighbours = self.mapping.getCellIdToNeighbour7(maxCellId)
141 for posid
in range(self.image_length ** 2):
142 if posid < neighbours.size():
143 neighbourid = neighbours[posid]
145 storearraypos = self.mapping.getCellIdToStoreArray(neighbourid)
147 if storearraypos >= 0:
148 energy = self.eclCalDigits[storearraypos].getEnergy()
149 energy_list.append(energy)
151 self.energy_array = np.array(energy_list)
153 prob_CNN_dict = self.extract_cnn_value(model)
154 prob_CNN_muon = prob_CNN_dict[
'cnn_muon']
156 eclCnnMuon = self.eclCnnMuon.appendNew()
157 eclCnnMuon.setEclCnnMuon(prob_CNN_muon)
158 track.addRelationTo(eclCnnMuon)
160 b2.B2DEBUG(22, f
'{variable_muon}: {prob_CNN_muon}')
162 b2.B2DEBUG(22,
'Track is either outside ECL Barrel or Pt is outside [0.2, 1.0] GeV/c. No CNN value.')
164 def getExtCell(self, track):
165 """ Extract cellId and pt of an extrapolated track
167 The output is dictionary which has cellId
and pt.
170 myDetID = self.Belle2.Const.EDetector.ECL
171 hypothesis = self.Belle2.Const.pion
172 pdgCode = abs(hypothesis.getPDGCode())
174 extHits = track.getRelationsTo('ExtHits')
176 for extHit
in extHits:
177 if abs(extHit.getPdgCode()) != pdgCode:
179 if extHit.getDetectorID() != myDetID:
181 if extHit.getStatus() != self.Belle2.ExtHitStatus.EXT_EXIT:
183 if extHit.isBackwardPropagated():
185 copyid = extHit.getCopyID()
190 px = extHit.getMomentum().X()
191 py = extHit.getMomentum().Y()
192 pt = math.sqrt(px**2 + py**2)
201 def extract_cnn_value(self, model):
202 """ Extract CNN values for an extrapolated track
204 The output of this function is dictionary
205 which includes two probabilities:
207 cnn_pion: Probability of an extrapolated track being pion
208 cnn_muon: Probability of an extrapolated track being muon
210 NOTE: cnn_pion
and cnn_muon are floats.
213 test_loader = self.prepare_images()
216 with torch.no_grad():
217 for energy, theta, phi, pt
in test_loader:
219 energy = energy.to(self.device)
220 theta = theta.to(self.device)
221 phi = phi.to(self.device)
222 pt = pt.to(self.device)
224 output = model(energy, theta, phi, pt)
225 output = F.softmax(output, dim=1)
227 prob_CNN_pion = output[0][0].item()
228 prob_CNN_muon = output[0][1].item()
231 'cnn_pion': prob_CNN_pion,
232 'cnn_muon': prob_CNN_muon,
237 def prepare_images(self):
240 A dictionary is passed to a function (ClusterImage)
241 in order to prepare proper format
for CNN inputs.
245 'image_length': self.image_length,
246 'energy_image': self.energy_array,
247 'thetaId': self.extThetaId,
248 'phiId': self.extPhiId,
250 'threshold': self.threshold
253 dataset = ClusterImage(params_image)
254 infer_loader = DataLoader(dataset, shuffle=
False)
258 def model_cnn_name(self):
259 """ Create CNN model name
261 The outputs of this function are:
263 2. Models parameters which is important
for
264 initializing ConvNet.
266 Models parameters should be exactly the same
as the CNN model
270 model_name = f'ECLCNNPID_charge_{self.charge}'
273 'input_shape': (1, self.image_length, self.image_length),
274 'initial_filters': 64,
286 return(model_name, params_model)
288 def read_model(self):
289 """ Load the CNN model
291 This function receives model's name and
292 CNN parameters, then reads .pt file which
293 includes weights and biases
for CNN.
296 model_name, params_model = self.model_cnn_name()
297 model = ConvNet(params_model)
298 model = model.to(self.device)
300 payload = f'{model_name}.pt'
301 accessor = self.Belle2.DBAccessorBase(
302 self.Belle2.DBStoreEntry.c_RawFile, payload,
True)
303 checkpoint = accessor.getFilename()
305 model.load_state_dict(torch.load(checkpoint))