Belle II Software development
nntd.py
1
8import basf2
9from ROOT import Belle2
10from ROOT import TVector3
11from ROOT.Math import XYZVector
12import numpy as np
13import pickle
14import os
15
16
17class nntd(basf2.Module):
18 '''
19 This class represents a dataset.
20 '''
21 version = 2 # changes, when form of self.array changes
22 maxtracks = 100 # max number of tracks per event to be stored
23 # dict to store the content for each entry in a track vector
24 varnum = {}
25 varnum["recoz"] = [0, r'$Z_{Reco}$', r'$[cm]$']
26 varnum["recotheta"] = [1, r'$\theta_{Reco}$', r'$[°]$']
27 varnum["recophi"] = [2, r'$\phi_{Reco}$', r'$[°]$']
28 varnum["recopt"] = [3, r'$P_{t, Reco}$', r'$[GeV]$']
29 varnum["recop"] = [4, r'$P_{Reco}$', r'$[GeV]$']
30 varnum["neuroz"] = [5, r'$Z_{Neuro}$', r'$[cm]$']
31 varnum["neurotheta"] = [6, r'$\theta_{Neuro}$', r'$[°]$']
32 varnum["neurophi"] = [7, r'$\phi_{Neuro}$', r'$[°]$']
33 varnum["neuropt"] = [8, r'$P_{Neuro}$', r'$[GeV]$']
34 varnum["neurop"] = [9, r'$P_{t, Neuro}$', r'$[GeV]$']
35 varnum["neuroval"] = [10, r'Validity', '']
36 varnum["neuroqual"] = [11, r'Quality', '']
37 varnum["neurots"] = [12, r'TSVector', '']
38 varnum["neuroexp"] = [13, r'Expert Number', '']
39 varnum["neurodriftth"] = [14, r'Driftthreshold', '']
40 varnum["neuroquad"] = [15, r'Quadrant', '']
41 varnum["neurofp"] = [16, r'Fastestpriority Eventtime', 'clocks']
42 varnum["neuroetf"] = [17, r'ETF Eventtime', 'clocks']
43 varnum["twodphi"] = [18, r'$\phi_{2D}$', r'$[°]$']
44 varnum["twodpt"] = [19, r'$P_{t, 2D}$', r'$[GeV]$']
45 varnum["twodfot"] = [20, r'FoundOldTrack', '']
46 varnum["hwneuroz"] = [21, r'$Z_{HWNeuro}$', r'$[cm]$']
47 varnum["hwneurotheta"] = [22, r'$\theta_{HWNeuro}$', r'$[°]$']
48 varnum["hwneurophi"] = [23, r'$\phi_{HWNeuro}$', r'$[°]$']
49 varnum["hwneuropt"] = [24, r'$P_{t, HWNeuro}$', r'$[GeV]$']
50 varnum["hwneurop"] = [25, r'$P_{HWNeuro}$', r'$[GeV]$']
51 varnum["hwneuroval"] = [26, r'Validity', '']
52 varnum["hwneuroqual"] = [27, r'Quality', '']
53 varnum["hwneurots"] = [28, r'TSVector', '']
54 varnum["hwneuroexp"] = [29, r'Expert Number', '']
55 varnum["hwneurodriftth"] = [30, r'Driftthreshold', '']
56 varnum["hwneuroquad"] = [31, r'Quadrant', '']
57 varnum["hwneurofp"] = [32, r'Fastestpriority Eventtime', 'clocks']
58 varnum["hwneuroetf"] = [33, r'ETF Eventtime', 'clocks']
59 varnum["swneuroz"] = [34, r'$Z_{SWNeuro}$', r'$[cm]$']
60 varnum["swneurotheta"] = [35, r'$\theta_{SWNeuro}$', r'$[°]$']
61 varnum["swneurophi"] = [36, r'$\phi_{SWNeuro}$', r'$[°]$']
62 varnum["swneuropt"] = [37, r'$P_{t, SWNeuro}$', r'$[GeV]$']
63 varnum["swneurop"] = [38, r'$P_{SWNeuro}$', r'$[GeV]$']
64 varnum["swneuroval"] = [39, r'Validity', '']
65 varnum["swneuroqual"] = [40, r'Quality', '']
66 varnum["swneurots"] = [41, r'TSVector', '']
67 varnum["swneuroexp"] = [42, r'Expert Number', '']
68 varnum["swneurodriftth"] = [43, r'Driftthreshold', '']
69 varnum["swneuroquad"] = [44, r'Quadrant', '']
70 varnum["swneurofp"] = [45, r'Fastestpriority Eventtime', 'clocks']
71 varnum["swneuroetf"] = [46, r'ETF Eventtime', 'clocks']
72 varnum["swtwodphi"] = [47, r'$\phi_{SW2D}$', r'$[°]$']
73 varnum["swtwodpt"] = [48, r'$P_{t, SW2D}$', r'$[GeV]$']
74 varnum["swtwodfot"] = [49, r'FoundOldTrack', '']
75 varnum["neuroats"] = [50, r'NumberOfAxials', '']
76 varnum["hwneuroats"] = [51, r'NumberOfAxials', '']
77 varnum["swneuroats"] = [52, r'NumberOfAxials', '']
78 varnum["neuroetfcc"] = [53, r'ETF Eventtime from CC', 'clocks']
79 varnum["neurohwtime"] = [54, r'Reconstructed HW Eventtime', 'clocks']
80 varnum["hwneuroetfcc"] = [55, r'ETF Eventtime', 'clocks']
81 varnum["hwneurohwtime"] = [56, r'Reconstructed HW Eventtime', 'clocks']
82 nonelist = []
83 for x in varnum:
84 nonelist.append(None)
85
86 def param(self, params):
87 for key, value in params.items():
88 setattr(self, key, value)
89
90 def initialize(self):
91 # TODO:
92 # check if folder is present or create it
93 # initialize all plots somehow
94 # initialize filters somehow, so they can be looped over in the evetn function
95 # setup histograms
96 self.data = None # np.array([[[]]])
97 self.eventlist = []
98 self.networkname = "unspecified net"
99 self.dataname = "unspecified runs"
100 # TODO
101 # # dict of plots, which should be plotted during the processing and updated every 5000 events.
102 # self.plotdict = {}
103 self.recotracksname = "RecoTracks" # recotracksname
104 # if not hasattr(self, "neurotracksname"):
105 self.neurotracksname = "TSimNeuroTracks" # "TRGCDCNeuroTracks" # neurotracksname
106 self.hwneurotracksname = "CDCTriggerNeuroTracks" # "TRGCDCNeuroTracks" # neurotracksname
107 self.swneurotracksname = "TRGCDCNeuroTracks" # neurotracksname
108 self.twodtracksname = "CDCTriggerNNInput2DFinderTracks" # "TRGCDC2DFinderTracks" # twodtracksname
109 self.swtwodtracksname = "TRGCDC2DFinderTracks" # twodtracksname
110 self.etfname = "CDCTriggerNeuroETFT0"
111 self.tsname = "CDCTriggerNNInputSegmentHits"
112
113 # storearrays
115 try:
117 except ValueError:
118 self.neurotracks = None
119 try:
121 except ValueError:
122 self.hwneurotracks = None
123 try:
125 except ValueError:
126 self.swneurotracks = None
127 try:
129 except ValueError:
130 self.twodtracks = None
131 try:
133 except ValueError:
134 self.swtwodtracks = None
135 try:
136 self.ts = Belle2.PyStoreArray(self.tsname)
137 except ValueError:
138 self.ts = None
139 try:
140 self.etf = Belle2.PyStoreObj(self.etfname)
141 except ValueError:
142 self.etf = None
143
144 self.varnumvarnum = nntd.varnum
145
146 self.debuglist = []
147# if not self.networkname: self.networknamne = "default"
148# if not self.filename: self.filename = "default.pkl"
149
150 def costotheta(self, x):
151 if isinstance(x, list):
152 ret = []
153 for y in x:
154 ret.append(self.costotheta(y))
155 return ret
156 else:
157 ret = None
158 if not x:
159 return None
160 else:
161 if x < -1 or x > 1:
162 x = np.round(x)
163 return 180. / np.pi * np.arccos(x)
164
165 def getrecovalsold(self, evlist, fitres):
166 if fitres:
167 evlist[self.varnumvarnum["recoz"][0]] = fitres.getPosition().Z()
168 evlist[self.varnumvarnum["recotheta"][0]] = fitres.getMomentum().Theta() # self.costotheta(fitres.getMomentum().CosTheta())
169 evlist[self.varnumvarnum["recophi"][0]] = fitres.getMomentum().Phi()
170 evlist[self.varnumvarnum["recopt"][0]] = fitres.getTransverseMomentum()
171 evlist[self.varnumvarnum["recop"][0]] = fitres.getMomentum().R()
172 return evlist
173
174 def getrecovals(self, evlist, state):
175 if state:
176 evlist[self.varnumvarnum["recoz"][0]] = state.getPos().Z()
177 evlist[self.varnumvarnum["recotheta"][0]] = state.getMom().Theta() # self.costotheta(fitres.getMomentum().CosTheta())
178 evlist[self.varnumvarnum["recophi"][0]] = state.getMom().Phi()
179 evlist[self.varnumvarnum["recopt"][0]] = state.getMom().Pt()
180 evlist[self.varnumvarnum["recop"][0]] = state.getMomMag()
181 return evlist
182
183 def getneurovals(self, evlist, neuro, status=""):
184 pre = status
185 if neuro:
186
187 evlist[self.varnumvarnum[pre + "neuroz"][0]] = neuro.getZ0()
188 evlist[self.varnumvarnum[pre + "neurotheta"][0]] = self.costotheta(neuro.getCotTheta() / np.sqrt(1 + neuro.getCotTheta()**2))
189 evlist[self.varnumvarnum[pre + "neurophi"][0]] = neuro.getPhi0()
190 evlist[self.varnumvarnum[pre + "neuropt"][0]] = neuro.getPt()
191 evlist[self.varnumvarnum[pre + "neurop"][0]] = neuro.getPt()/np.sin(self.costotheta(neuro.getCotTheta() /
192 np.sqrt(1 + neuro.getCotTheta()**2)))
193 evlist[self.varnumvarnum[pre + "neuroval"][0]] = neuro.getValidStereoBit()
194 evlist[self.varnumvarnum[pre + "neuroqual"][0]] = neuro.getQualityVector()
195 evlist[self.varnumvarnum[pre + "neurots"][0]] = int("".join([str(x) for x in neuro.getTSVector()]))
196 xx = sum([int(i != 0) for i in neuro.getTSVector()][::2])
197 if xx is None:
198 xx = 0
199 evlist[self.varnumvarnum[pre + "neuroats"][0]] = xx
200 evlist[self.varnumvarnum[pre + "neuroexp"][0]] = neuro.getExpert()
201 evlist[self.varnumvarnum[pre + "neurodriftth"][0]] = int("".join([str(int(x)) for x in neuro.getDriftThreshold()]))
202 evlist[self.varnumvarnum[pre + "neuroquad"][0]] = neuro.getQuadrant()
203 fpt = 9999
204 for ts in neuro.getRelationsTo(self.tsname):
205 if ts.priorityTime() < fpt:
206 fpt = ts.priorityTime()
207 if self.etf.hasBinnedEventT0(Belle2.Const.CDC):
208 eft = self.etf.getBinnedEventT0(Belle2.Const.CDC)
209 else:
210 eft = None
211
212 # overwrite the etf temporarily with the etfcc
213
214 evlist[self.varnumvarnum[pre + "neurofp"][0]] = fpt
215 evlist[self.varnumvarnum[pre + "neuroetf"][0]] = eft
216 if pre != "sw":
217 evlist[self.varnumvarnum[pre + "neuroetfcc"][0]] = neuro.getETF_unpacked()
218 evlist[self.varnumvarnum[pre + "neurohwtime"][0]] = neuro.getETF_recalced()
219 return evlist
220
221 def gettwodvals(self, evlist, twod):
222 if twod:
223 evlist[self.varnumvarnum["twodphi"][0]] = twod.getPhi0()
224 evlist[self.varnumvarnum["twodpt"][0]] = twod.getPt()
225 # evlist[self.varnum["twodfot"][0]] = int(twod.getFoundOldTrack())
226 return evlist
227
228 def getswtwodvals(self, evlist, twod):
229 if twod:
230 evlist[self.varnumvarnum["swtwodphi"][0]] = twod.getPhi0()
231 evlist[self.varnumvarnum["swtwodpt"][0]] = twod.getPt()
232 # evlist[self.varnum["twodfot"][0]] = int(twod.getFoundOldTrack())
233 return evlist
234
235 def event(self):
236 # TODO: update the plots every nth time
237 # if self.showplots != 0:
238 # if eventnumber % self.showplots = 0:
239 # show plots
240
241 # loop over events
242 event = []
243 for reco in self.recotracks:
244 track = reco.getRelatedFrom("Tracks")
245
246 fitres = None
247 state = None
248
249 # method should be either 'old' for the old method or anything else for the new one
250 method = 'old'
251
252 if method == 'old':
253 # # old way: ########################################################################
254
255 if not track:
256 print("no track found for recotrack")
257 continue
258 whishPdg = 211 # pion
259 fitres = track.getTrackFitResultWithClosestMass(Belle2.Const.ChargedStable(whishPdg))
260 if not fitres:
261 continue
262 else:
263 # # new way: ########################################################################
264
265 reps = reco.getRepresentations()
266 irep = 0
267 for irep, rep in enumerate(reps):
268 if not reco.wasFitSuccessful(rep):
269 continue
270 try:
271 state = reco.getMeasuredStateOnPlaneClosestTo(XYZVector(0, 0, 0), rep)
272 rep.extrapolateToLine(state, TVector3(0, 0, -1000), TVector3(0, 0, 2000))
273 except BaseException:
274 continue
275 if not state:
276 continue
277
278
279
280 neuro = reco.getRelatedTo(self.neurotracksname)
281 event.append(self.nonelist.copy())
282 try:
283 neuro = reco.getRelatedTo(self.neurotracksname)
284 except BaseException:
285 neuro = None
286 try:
287 swneuro = reco.getRelatedTo(self.swneurotracksname)
288 except BaseException:
289 swneuro = None
290 try:
291 hwneuro = neuro.getRelatedFrom(self.hwneurotracksname)
292 except BaseException:
293 hwneuro = None
294 try:
295 twod = reco.getRelatedTo(self.twodtracksname)
296 except BaseException:
297 twod = None
298 if method == 'old':
299 event[-1] = self.getrecovalsold(event[-1], fitres)
300 else:
301 event[-1] = self.getrecovals(event[-1], state)
302 event[-1] = self.getneurovals(event[-1], neuro)
303 event[-1] = self.gettwodvals(event[-1], twod)
304 event[-1] = self.getneurovals(event[-1], hwneuro, status="hw")
305 event[-1] = self.getneurovals(event[-1], swneuro, status="sw")
306
307 for neuro in self.neurotracks:
308 # print("neuroloop")
309 # print(len(neuro.getRelationsFrom(self.recotracksname)))
310 if len(neuro.getRelationsFrom(self.recotracksname)) > 0:
311 # this track is already stored in a recoline
312 # print("skipping...")
313 continue
314 event.append(self.nonelist.copy())
315 try:
316 twod = reco.getRelatedTo(self.twodtracksname)
317 except BaseException:
318 twod = None
319 try:
320 hwneuro = neuro.getRelatedFrom(self.hwneurotracksname)
321 except BaseException:
322 hwneuro = None
323 event[-1] = self.getneurovals(event[-1], neuro)
324 event[-1] = self.gettwodvals(event[-1], twod)
325 event[-1] = self.getneurovals(event[-1], hwneuro, status="hw")
326 for swneuro in self.swneurotracks:
327 # print("neuroloop")
328 # print(len(neuro.getRelationsFrom(self.recotracksname)))
329 if len(swneuro.getRelationsFrom(self.recotracksname)) > 0:
330 # this track is already stored in a recoline
331 # print("skipping...")
332 continue
333 event.append(self.nonelist.copy())
334 try:
335 swtwod = reco.getRelatedTo(self.swtwodtracksname)
336 except BaseException:
337 swtwod = None
338 event[-1] = self.getneurovals(event[-1], swneuro, status="sw")
339 event[-1] = self.getswtwodvals(event[-1], swtwod)
340 for twod in self.twodtracks:
341 # print("twodloop")
342 # print(len(twod.getRelationsFrom(self.neurotracksname)))
343 if len(twod.getRelationsFrom(self.neurotracksname)) > 0:
344 # print("skipping...")
345 # this track is already stored in a recoline or twodline
346 continue
347 event.append(self.nonelist.copy())
348 event[-1] = self.gettwodvals(event[-1], twod)
349 for swtwod in self.swtwodtracks:
350 # print("twodloop")
351 # print(len(twod.getRelationsFrom(self.neurotracksname)))
352 if len(swtwod.getRelationsFrom(self.swneurotracksname)) > 0:
353 # print("skipping...")
354 # this track is already stored in a recoline or twodline
355 continue
356 event.append(self.nonelist.copy())
357 event[-1] = self.getswtwodvals(event[-1], swtwod)
358
359 # attach an array for every event
360 if len(event) > self.maxtracks:
361 event = event[0:self.maxtracks]
362 elif len(event) < self.maxtracks:
363 for i in range(self.maxtracks - len(event)):
364 event.append(self.nonelist.copy())
365 self.eventlist.append(event)
366
367 def terminate(self):
368 # self.eventfilters()
369 # self.makearray(self.eventlist)
370 # convert eventlist to data array
371 # initialize histograms and fill them
372 # both save histograms to file and show them in the plots
373 self.save()
374
375 def save(self, filename=None, netname=None, dataname=None):
376 if not filename:
377 filename = self.filename
378 if not netname:
379 netname = self.netname
380 if not dataname:
381 dataname = self.dataname
382 # save the dataset as an array, the corresponding varnum,
383 # and a description about the dataset into a pickle file
384 savedict = {}
385 savedict["eventlist"] = self.eventlist
386 savedict["varnum"] = self.varnumvarnum
387 savedict["networkname"] = netname
388 savedict["dataname"] = dataname
389 savedict["version"] = nntd.version
390 f = open(filename, 'wb')
391 pickle.dump(savedict, f)
392 f.close()
393 print('file ' + filename + ' has been saved. ')
394
395 def loadmore(self, filenames):
396 # first, check the amount of events and limit them to NNTD_EVLIMIT
397 evlim = 0
398 evnumber = 0
399 skipev = 0
400 if "NNTD_EVLIMIT" in os.environ:
401 evlim = int(os.environ["NNTD_EVLIMIT"])
402 else:
403 evlim = 50000
404 for i, x in enumerate(filenames):
405 print("checking file: " + str(i) + "/" + str(len(filenames)))
406 f = open(x, 'rb')
407 evnumber += len(pickle.load(f)["eventlist"])
408 if evnumber > evlim:
409 print("total number of available events is " + str(evnumber))
410 skipev = int(evnumber/evlim)
411 print("Number of events more than " + str(evlim) + " only taking every " + str(skipev) + " event")
412 else:
413 skipev = 1
414 for x in filenames:
415 f = open(x, 'rb')
416 savedict = pickle.load(f)
417 f.close()
418 if self.version != savedict["version"]:
419 print("Error! loaded file was made with different version of nntd! exiting ... ")
420 exit()
421 self.networkname = savedict["networkname"]
422 if "dataname" in savedict:
423 self.dataname = savedict["dataname"]
424 templim = evlim-len(self.eventlist)
425 self.eventlist += savedict["eventlist"][::skipev][:templim]
426 self.varnumvarnum = savedict["varnum"]
427 print("Loaded file: " + x)
428 print("length of eventlist: " + str(len(self.eventlist)))
429 if evlim <= len(self.eventlist):
430 print("stop loading, maximum event number reached")
431 break
432 self.makearray(self.eventlist)
433 print("all files loaded, array.size: " + str(self.data.size) + ", array.shape: " + str(self.data.shape))
434
435 def load(self, filename):
436 # load a given pickle file
437 f = open(filename, 'rb')
438 savedict = pickle.load(f)
439 f.close()
440 if self.version != savedict["version"]:
441 print("Error! loaded file was made with different version of nntd! exiting ... ")
442 exit()
443 self.eventlist = savedict["eventlist"]
444 self.varnumvarnum = savedict["varnum"]
445 self.networkname = savedict["networkname"]
446 self.dataname = savedict["dataname"]
447 # self.eventfilters()
448 self.makearray(self.eventlist)
449
450 def makearray(self, evlist):
451 # TODO: apply filters
452 self.data = np.array(evlist)
Provides a type-safe way to pass members of the chargedStableSet set.
Definition: Const.h:589
A (simplified) python wrapper for StoreArray.
Definition: PyStoreArray.h:72
a (simplified) python wrapper for StoreObjPtr.
Definition: PyStoreObj.h:67
tsname
Definition: nntd.py:111
list nonelist
Definition: nntd.py:82
def getswtwodvals(self, evlist, twod)
Definition: nntd.py:228
dataname
Definition: nntd.py:99
def getneurovals(self, evlist, neuro, status="")
Definition: nntd.py:183
swtwodtracksname
Definition: nntd.py:109
neurotracksname
Definition: nntd.py:105
def getrecovals(self, evlist, state)
Definition: nntd.py:174
def save(self, filename=None, netname=None, dataname=None)
Definition: nntd.py:375
etfname
Definition: nntd.py:110
data
Definition: nntd.py:96
recotracksname
Definition: nntd.py:103
eventlist
Definition: nntd.py:97
swtwodtracks
Definition: nntd.py:132
networkname
Definition: nntd.py:98
neurotracks
Definition: nntd.py:116
swneurotracks
Definition: nntd.py:124
swneurotracksname
Definition: nntd.py:107
def makearray(self, evlist)
Definition: nntd.py:450
def costotheta(self, x)
Definition: nntd.py:150
int maxtracks
Definition: nntd.py:22
int version
Definition: nntd.py:21
def gettwodvals(self, evlist, twod)
Definition: nntd.py:221
twodtracksname
Definition: nntd.py:108
varnum
Definition: nntd.py:144
hwneurotracksname
Definition: nntd.py:106
debuglist
Definition: nntd.py:146
hwneurotracks
Definition: nntd.py:120
dict varnum
Definition: nntd.py:24
recotracks
Definition: nntd.py:114
def getrecovalsold(self, evlist, fitres)
Definition: nntd.py:165
twodtracks
Definition: nntd.py:128