Belle II Software  release-08-01-10
nntd.py
1 import basf2
2 from ROOT import Belle2
3 from ROOT import TVector3
4 from ROOT.Math import XYZVector
5 import numpy as np
6 import pickle
7 import os
8 
9 
10 class nntd(basf2.Module):
11  '''
12  This class represents a dataset.
13  '''
14  version = 2 # changes, when form of self.array changes
15  maxtracks = 100 # max number of tracks per event to be stored
16  # dict to store the content for each entry in a track vector
17  varnum = {}
18  varnum["recoz"] = [0, r'$Z_{Reco}$', r'$[cm]$']
19  varnum["recotheta"] = [1, r'$\theta_{Reco}$', r'$[°]$']
20  varnum["recophi"] = [2, r'$\phi_{Reco}$', r'$[°]$']
21  varnum["recopt"] = [3, r'$P_{t, Reco}$', r'$[GeV]$']
22  varnum["recop"] = [4, r'$P_{Reco}$', r'$[GeV]$']
23  varnum["neuroz"] = [5, r'$Z_{Neuro}$', r'$[cm]$']
24  varnum["neurotheta"] = [6, r'$\theta_{Neuro}$', r'$[°]$']
25  varnum["neurophi"] = [7, r'$\phi_{Neuro}$', r'$[°]$']
26  varnum["neuropt"] = [8, r'$P_{Neuro}$', r'$[GeV]$']
27  varnum["neurop"] = [9, r'$P_{t, Neuro}$', r'$[GeV]$']
28  varnum["neuroval"] = [10, r'Validity', '']
29  varnum["neuroqual"] = [11, r'Quality', '']
30  varnum["neurots"] = [12, r'TSVector', '']
31  varnum["neuroexp"] = [13, r'Expert Number', '']
32  varnum["neurodriftth"] = [14, r'Driftthreshold', '']
33  varnum["neuroquad"] = [15, r'Quadrant', '']
34  varnum["neurofp"] = [16, r'Fastestpriority Eventtime', 'clocks']
35  varnum["neuroetf"] = [17, r'ETF Eventtime', 'clocks']
36  varnum["twodphi"] = [18, r'$\phi_{2D}$', r'$[°]$']
37  varnum["twodpt"] = [19, r'$P_{t, 2D}$', r'$[GeV]$']
38  varnum["twodfot"] = [20, r'FoundOldTrack', '']
39  varnum["hwneuroz"] = [21, r'$Z_{HWNeuro}$', r'$[cm]$']
40  varnum["hwneurotheta"] = [22, r'$\theta_{HWNeuro}$', r'$[°]$']
41  varnum["hwneurophi"] = [23, r'$\phi_{HWNeuro}$', r'$[°]$']
42  varnum["hwneuropt"] = [24, r'$P_{t, HWNeuro}$', r'$[GeV]$']
43  varnum["hwneurop"] = [25, r'$P_{HWNeuro}$', r'$[GeV]$']
44  varnum["hwneuroval"] = [26, r'Validity', '']
45  varnum["hwneuroqual"] = [27, r'Quality', '']
46  varnum["hwneurots"] = [28, r'TSVector', '']
47  varnum["hwneuroexp"] = [29, r'Expert Number', '']
48  varnum["hwneurodriftth"] = [30, r'Driftthreshold', '']
49  varnum["hwneuroquad"] = [31, r'Quadrant', '']
50  varnum["hwneurofp"] = [32, r'Fastestpriority Eventtime', 'clocks']
51  varnum["hwneuroetf"] = [33, r'ETF Eventtime', 'clocks']
52  varnum["swneuroz"] = [34, r'$Z_{SWNeuro}$', r'$[cm]$']
53  varnum["swneurotheta"] = [35, r'$\theta_{SWNeuro}$', r'$[°]$']
54  varnum["swneurophi"] = [36, r'$\phi_{SWNeuro}$', r'$[°]$']
55  varnum["swneuropt"] = [37, r'$P_{t, SWNeuro}$', r'$[GeV]$']
56  varnum["swneurop"] = [38, r'$P_{SWNeuro}$', r'$[GeV]$']
57  varnum["swneuroval"] = [39, r'Validity', '']
58  varnum["swneuroqual"] = [40, r'Quality', '']
59  varnum["swneurots"] = [41, r'TSVector', '']
60  varnum["swneuroexp"] = [42, r'Expert Number', '']
61  varnum["swneurodriftth"] = [43, r'Driftthreshold', '']
62  varnum["swneuroquad"] = [44, r'Quadrant', '']
63  varnum["swneurofp"] = [45, r'Fastestpriority Eventtime', 'clocks']
64  varnum["swneuroetf"] = [46, r'ETF Eventtime', 'clocks']
65  varnum["swtwodphi"] = [47, r'$\phi_{SW2D}$', r'$[°]$']
66  varnum["swtwodpt"] = [48, r'$P_{t, SW2D}$', r'$[GeV]$']
67  varnum["swtwodfot"] = [49, r'FoundOldTrack', '']
68  varnum["neuroats"] = [50, r'NumberOfAxials', '']
69  varnum["hwneuroats"] = [51, r'NumberOfAxials', '']
70  varnum["swneuroats"] = [52, r'NumberOfAxials', '']
71  varnum["neuroetfcc"] = [53, r'ETF Eventtime from CC', 'clocks']
72  varnum["neurohwtime"] = [54, r'Reconstructed HW Eventtime', 'clocks']
73  varnum["hwneuroetfcc"] = [55, r'ETF Eventtime', 'clocks']
74  varnum["hwneurohwtime"] = [56, r'Reconstructed HW Eventtime', 'clocks']
75  nonelist = []
76  for x in varnum:
77  nonelist.append(None)
78 
79  def param(self, params):
80  for key, value in params.items():
81  setattr(self, key, value)
82 
83  def initialize(self):
84  # TODO:
85  # check if folder is present or create it
86  # initialize all plots somehow
87  # initialize filters somehow, so they can be looped over in the evetn function
88  # setup histograms
89  self.datadata = None # np.array([[[]]])
90  self.eventlisteventlist = []
91  self.networknamenetworkname = "unspecified net"
92  self.datanamedataname = "unspecified runs"
93  # TODO
94  # # dict of plots, which should be plotted during the processing and updated every 5000 events.
95  # self.plotdict = {}
96  self.recotracksnamerecotracksname = "RecoTracks" # recotracksname
97  # if not hasattr(self, "neurotracksname"):
98  self.neurotracksnameneurotracksname = "TSimNeuroTracks" # "TRGCDCNeuroTracks" # neurotracksname
99  self.hwneurotracksnamehwneurotracksname = "CDCTriggerNeuroTracks" # "TRGCDCNeuroTracks" # neurotracksname
100  self.swneurotracksnameswneurotracksname = "TRGCDCNeuroTracks" # neurotracksname
101  self.twodtracksnametwodtracksname = "CDCTriggerNNInput2DFinderTracks" # "TRGCDC2DFinderTracks" # twodtracksname
102  self.swtwodtracksnameswtwodtracksname = "TRGCDC2DFinderTracks" # twodtracksname
103  self.etfnameetfname = "CDCTriggerNeuroETFT0"
104  self.tsnametsname = "CDCTriggerNNInputSegmentHits"
105 
106  # storearrays
107  self.recotracksrecotracks = Belle2.PyStoreArray(self.recotracksnamerecotracksname)
108  try:
109  self.neurotracksneurotracks = Belle2.PyStoreArray(self.neurotracksnameneurotracksname)
110  except ValueError:
111  self.neurotracksneurotracks = None
112  try:
113  self.hwneurotrackshwneurotracks = Belle2.PyStoreArray(self.hwneurotracksnamehwneurotracksname)
114  except ValueError:
115  self.hwneurotrackshwneurotracks = None
116  try:
117  self.swneurotracksswneurotracks = Belle2.PyStoreArray(self.swneurotracksnameswneurotracksname)
118  except ValueError:
119  self.swneurotracksswneurotracks = None
120  try:
121  self.twodtrackstwodtracks = Belle2.PyStoreArray(self.twodtracksnametwodtracksname)
122  except ValueError:
123  self.twodtrackstwodtracks = None
124  try:
125  self.swtwodtracksswtwodtracks = Belle2.PyStoreArray(self.swtwodtracksnameswtwodtracksname)
126  except ValueError:
127  self.swtwodtracksswtwodtracks = None
128  try:
129  self.tsts = Belle2.PyStoreArray(self.tsnametsname)
130  except ValueError:
131  self.tsts = None
132  try:
133  self.etfetf = Belle2.PyStoreObj(self.etfnameetfname)
134  except ValueError:
135  self.etfetf = None
136 
137  self.varnumvarnumvarnum = nntd.varnum
138 
139  self.debuglistdebuglist = []
140 # if not self.networkname: self.networknamne = "default"
141 # if not self.filename: self.filename = "default.pkl"
142 
143  def costotheta(self, x):
144  if isinstance(x, list):
145  ret = []
146  for y in x:
147  ret.append(self.costothetacostotheta(y))
148  return ret
149  else:
150  ret = None
151  if not x:
152  return None
153  else:
154  if x < -1 or x > 1:
155  x = np.round(x)
156  return 180. / np.pi * np.arccos(x)
157 
158  def getrecovalsold(self, evlist, fitres):
159  if fitres:
160  evlist[self.varnumvarnumvarnum["recoz"][0]] = fitres.getPosition().Z()
161  evlist[self.varnumvarnumvarnum["recotheta"][0]] = fitres.getMomentum().Theta() # self.costotheta(fitres.getMomentum().CosTheta())
162  evlist[self.varnumvarnumvarnum["recophi"][0]] = fitres.getMomentum().Phi()
163  evlist[self.varnumvarnumvarnum["recopt"][0]] = fitres.getTransverseMomentum()
164  evlist[self.varnumvarnumvarnum["recop"][0]] = fitres.getMomentum().R()
165  return evlist
166 
167  def getrecovals(self, evlist, state):
168  if state:
169  evlist[self.varnumvarnumvarnum["recoz"][0]] = state.getPos().Z()
170  evlist[self.varnumvarnumvarnum["recotheta"][0]] = state.getMom().Theta() # self.costotheta(fitres.getMomentum().CosTheta())
171  evlist[self.varnumvarnumvarnum["recophi"][0]] = state.getMom().Phi()
172  evlist[self.varnumvarnumvarnum["recopt"][0]] = state.getMom().Pt()
173  evlist[self.varnumvarnumvarnum["recop"][0]] = state.getMomMag()
174  return evlist
175 
176  def getneurovals(self, evlist, neuro, status=""):
177  pre = status
178  if neuro:
179 
180  evlist[self.varnumvarnumvarnum[pre + "neuroz"][0]] = neuro.getZ0()
181  evlist[self.varnumvarnumvarnum[pre + "neurotheta"][0]] = self.costothetacostotheta(neuro.getCotTheta() / np.sqrt(1 + neuro.getCotTheta()**2))
182  evlist[self.varnumvarnumvarnum[pre + "neurophi"][0]] = neuro.getPhi0()
183  evlist[self.varnumvarnumvarnum[pre + "neuropt"][0]] = neuro.getPt()
184  evlist[self.varnumvarnumvarnum[pre + "neurop"][0]] = neuro.getPt()/np.sin(self.costothetacostotheta(neuro.getCotTheta() /
185  np.sqrt(1 + neuro.getCotTheta()**2)))
186  evlist[self.varnumvarnumvarnum[pre + "neuroval"][0]] = neuro.getValidStereoBit()
187  evlist[self.varnumvarnumvarnum[pre + "neuroqual"][0]] = neuro.getQualityVector()
188  evlist[self.varnumvarnumvarnum[pre + "neurots"][0]] = int("".join([str(x) for x in neuro.getTSVector()]))
189  xx = sum([int(i != 0) for i in neuro.getTSVector()][::2])
190  if xx is None:
191  xx = 0
192  evlist[self.varnumvarnumvarnum[pre + "neuroats"][0]] = xx
193  evlist[self.varnumvarnumvarnum[pre + "neuroexp"][0]] = neuro.getExpert()
194  evlist[self.varnumvarnumvarnum[pre + "neurodriftth"][0]] = int("".join([str(int(x)) for x in neuro.getDriftThreshold()]))
195  evlist[self.varnumvarnumvarnum[pre + "neuroquad"][0]] = neuro.getQuadrant()
196  fpt = 9999
197  for ts in neuro.getRelationsTo(self.tsnametsname):
198  if ts.priorityTime() < fpt:
199  fpt = ts.priorityTime()
200  if self.etfetf.hasBinnedEventT0(Belle2.Const.CDC):
201  eft = self.etfetf.getBinnedEventT0(Belle2.Const.CDC)
202  else:
203  eft = None
204 
205  # overwrite the etf temporarily with the etfcc
206 
207  evlist[self.varnumvarnumvarnum[pre + "neurofp"][0]] = fpt
208  evlist[self.varnumvarnumvarnum[pre + "neuroetf"][0]] = eft
209  if pre != "sw":
210  evlist[self.varnumvarnumvarnum[pre + "neuroetfcc"][0]] = neuro.getETF_unpacked()
211  evlist[self.varnumvarnumvarnum[pre + "neurohwtime"][0]] = neuro.getETF_recalced()
212  return evlist
213 
214  def gettwodvals(self, evlist, twod):
215  if twod:
216  evlist[self.varnumvarnumvarnum["twodphi"][0]] = twod.getPhi0()
217  evlist[self.varnumvarnumvarnum["twodpt"][0]] = twod.getPt()
218  # evlist[self.varnum["twodfot"][0]] = int(twod.getFoundOldTrack())
219  return evlist
220 
221  def getswtwodvals(self, evlist, twod):
222  if twod:
223  evlist[self.varnumvarnumvarnum["swtwodphi"][0]] = twod.getPhi0()
224  evlist[self.varnumvarnumvarnum["swtwodpt"][0]] = twod.getPt()
225  # evlist[self.varnum["twodfot"][0]] = int(twod.getFoundOldTrack())
226  return evlist
227 
228  def event(self):
229  # TODO: update the plots every nth time
230  # if self.showplots != 0:
231  # if eventnumber % self.showplots = 0:
232  # show plots
233 
234  # loop over events
235  event = []
236  for reco in self.recotracksrecotracks:
237  track = reco.getRelatedFrom("Tracks")
238 
239  fitres = None
240  state = None
241 
242  # method should be either 'old' for the old method or anything else for the new one
243  method = 'old'
244 
245  if method == 'old':
246  # # old way: ########################################################################
247 
248  if not track:
249  print("no track found for recotrack")
250  continue
251  whishPdg = 211 # pion
252  fitres = track.getTrackFitResultWithClosestMass(Belle2.Const.ChargedStable(whishPdg))
253  if not fitres:
254  continue
255  else:
256  # # new way: ########################################################################
257 
258  reps = reco.getRepresentations()
259  irep = 0
260  for irep, rep in enumerate(reps):
261  if not reco.wasFitSuccessful(rep):
262  continue
263  try:
264  state = reco.getMeasuredStateOnPlaneClosestTo(XYZVector(0, 0, 0), rep)
265  rep.extrapolateToLine(state, TVector3(0, 0, -1000), TVector3(0, 0, 2000))
266  except BaseException:
267  continue
268  if not state:
269  continue
270 
271 
272 
273  neuro = reco.getRelatedTo(self.neurotracksnameneurotracksname)
274  event.append(self.nonelistnonelist.copy())
275  try:
276  neuro = reco.getRelatedTo(self.neurotracksnameneurotracksname)
277  except BaseException:
278  neuro = None
279  try:
280  swneuro = reco.getRelatedTo(self.swneurotracksnameswneurotracksname)
281  except BaseException:
282  swneuro = None
283  try:
284  hwneuro = neuro.getRelatedFrom(self.hwneurotracksnamehwneurotracksname)
285  except BaseException:
286  hwneuro = None
287  try:
288  twod = reco.getRelatedTo(self.twodtracksnametwodtracksname)
289  except BaseException:
290  twod = None
291  if method == 'old':
292  event[-1] = self.getrecovalsoldgetrecovalsold(event[-1], fitres)
293  else:
294  event[-1] = self.getrecovalsgetrecovals(event[-1], state)
295  event[-1] = self.getneurovalsgetneurovals(event[-1], neuro)
296  event[-1] = self.gettwodvalsgettwodvals(event[-1], twod)
297  event[-1] = self.getneurovalsgetneurovals(event[-1], hwneuro, status="hw")
298  event[-1] = self.getneurovalsgetneurovals(event[-1], swneuro, status="sw")
299 
300  for neuro in self.neurotracksneurotracks:
301  # print("neuroloop")
302  # print(len(neuro.getRelationsFrom(self.recotracksname)))
303  if len(neuro.getRelationsFrom(self.recotracksnamerecotracksname)) > 0:
304  # this track is already stored in a recoline
305  # print("skipping...")
306  continue
307  event.append(self.nonelistnonelist.copy())
308  try:
309  twod = reco.getRelatedTo(self.twodtracksnametwodtracksname)
310  except BaseException:
311  twod = None
312  try:
313  hwneuro = neuro.getRelatedFrom(self.hwneurotracksnamehwneurotracksname)
314  except BaseException:
315  hwneuro = None
316  event[-1] = self.getneurovalsgetneurovals(event[-1], neuro)
317  event[-1] = self.gettwodvalsgettwodvals(event[-1], twod)
318  event[-1] = self.getneurovalsgetneurovals(event[-1], hwneuro, status="hw")
319  for swneuro in self.swneurotracksswneurotracks:
320  # print("neuroloop")
321  # print(len(neuro.getRelationsFrom(self.recotracksname)))
322  if len(swneuro.getRelationsFrom(self.recotracksnamerecotracksname)) > 0:
323  # this track is already stored in a recoline
324  # print("skipping...")
325  continue
326  event.append(self.nonelistnonelist.copy())
327  try:
328  swtwod = reco.getRelatedTo(self.swtwodtracksnameswtwodtracksname)
329  except BaseException:
330  swtwod = None
331  event[-1] = self.getneurovalsgetneurovals(event[-1], swneuro, status="sw")
332  event[-1] = self.getswtwodvalsgetswtwodvals(event[-1], swtwod)
333  for twod in self.twodtrackstwodtracks:
334  # print("twodloop")
335  # print(len(twod.getRelationsFrom(self.neurotracksname)))
336  if len(twod.getRelationsFrom(self.neurotracksnameneurotracksname)) > 0:
337  # print("skipping...")
338  # this track is already stored in a recoline or twodline
339  continue
340  event.append(self.nonelistnonelist.copy())
341  event[-1] = self.gettwodvalsgettwodvals(event[-1], twod)
342  for swtwod in self.swtwodtracksswtwodtracks:
343  # print("twodloop")
344  # print(len(twod.getRelationsFrom(self.neurotracksname)))
345  if len(swtwod.getRelationsFrom(self.swneurotracksnameswneurotracksname)) > 0:
346  # print("skipping...")
347  # this track is already stored in a recoline or twodline
348  continue
349  event.append(self.nonelistnonelist.copy())
350  event[-1] = self.getswtwodvalsgetswtwodvals(event[-1], swtwod)
351 
352  # attach an array for every event
353  if len(event) > self.maxtracksmaxtracks:
354  event = event[0:self.maxtracksmaxtracks]
355  elif len(event) < self.maxtracksmaxtracks:
356  for i in range(self.maxtracksmaxtracks - len(event)):
357  event.append(self.nonelistnonelist.copy())
358  self.eventlisteventlist.append(event)
359 
360  def terminate(self):
361  # self.eventfilters()
362  # self.makearray(self.eventlist)
363  # convert eventlist to data array
364  # initialize histograms and fill them
365  # both save histograms to file and show them in the plots
366  self.savesave()
367 
368  def save(self, filename=None, netname=None, dataname=None):
369  if not filename:
370  filename = self.filename
371  if not netname:
372  netname = self.netname
373  if not dataname:
374  dataname = self.datanamedataname
375  # save the dataset as an array, the corresponding varnum,
376  # and a description about the dataset into a pickle file
377  savedict = {}
378  savedict["eventlist"] = self.eventlisteventlist
379  savedict["varnum"] = self.varnumvarnumvarnum
380  savedict["networkname"] = netname
381  savedict["dataname"] = dataname
382  savedict["version"] = nntd.version
383  f = open(filename, 'wb')
384  pickle.dump(savedict, f)
385  f.close()
386  print('file ' + filename + ' has been saved. ')
387 
388  def loadmore(self, filenames):
389  # first, check the amount of events and limit them to NNTD_EVLIMIT
390  evlim = 0
391  evnumber = 0
392  skipev = 0
393  if "NNTD_EVLIMIT" in os.environ:
394  evlim = int(os.environ["NNTD_EVLIMIT"])
395  else:
396  evlim = 50000
397  for i, x in enumerate(filenames):
398  print("checking file: " + str(i) + "/" + str(len(filenames)))
399  f = open(x, 'rb')
400  evnumber += len(pickle.load(f)["eventlist"])
401  if evnumber > evlim:
402  print("total number of available events is " + str(evnumber))
403  skipev = int(evnumber/evlim)
404  print("Number of events more than " + str(evlim) + " only taking every " + str(skipev) + " event")
405  else:
406  skipev = 1
407  for x in filenames:
408  f = open(x, 'rb')
409  savedict = pickle.load(f)
410  f.close()
411  if self.versionversion != savedict["version"]:
412  print("Error! loaded file was made with different version of nntd! exiting ... ")
413  exit()
414  self.networknamenetworkname = savedict["networkname"]
415  if "dataname" in savedict:
416  self.datanamedataname = savedict["dataname"]
417  templim = evlim-len(self.eventlisteventlist)
418  self.eventlisteventlist += savedict["eventlist"][::skipev][:templim]
419  self.varnumvarnumvarnum = savedict["varnum"]
420  print("Loaded file: " + x)
421  print("length of eventlist: " + str(len(self.eventlisteventlist)))
422  if evlim <= len(self.eventlisteventlist):
423  print("stop loading, maximum event number reached")
424  break
425  self.makearraymakearray(self.eventlisteventlist)
426  print("all files loaded, array.size: " + str(self.datadata.size) + ", array.shape: " + str(self.datadata.shape))
427 
428  def load(self, filename):
429  # load a given pickle file
430  f = open(filename, 'rb')
431  savedict = pickle.load(f)
432  f.close()
433  if self.versionversion != savedict["version"]:
434  print("Error! loaded file was made with different version of nntd! exiting ... ")
435  exit()
436  self.eventlisteventlist = savedict["eventlist"]
437  self.varnumvarnumvarnum = savedict["varnum"]
438  self.networknamenetworkname = savedict["networkname"]
439  self.datanamedataname = savedict["dataname"]
440  # self.eventfilters()
441  self.makearraymakearray(self.eventlisteventlist)
442 
443  def makearray(self, evlist):
444  # TODO: apply filters
445  self.datadata = np.array(evlist)
Provides a type-safe way to pass members of the chargedStableSet set.
Definition: Const.h:580
A (simplified) python wrapper for StoreArray.
Definition: PyStoreArray.h:72
a (simplified) python wrapper for StoreObjPtr.
Definition: PyStoreObj.h:67
tsname
Definition: nntd.py:104
list nonelist
Definition: nntd.py:75
def getswtwodvals(self, evlist, twod)
Definition: nntd.py:221
dataname
Definition: nntd.py:92
dictionary varnum
Definition: nntd.py:17
def getneurovals(self, evlist, neuro, status="")
Definition: nntd.py:176
swtwodtracksname
Definition: nntd.py:102
neurotracksname
Definition: nntd.py:98
def getrecovals(self, evlist, state)
Definition: nntd.py:167
def save(self, filename=None, netname=None, dataname=None)
Definition: nntd.py:368
etfname
Definition: nntd.py:103
data
Definition: nntd.py:89
recotracksname
Definition: nntd.py:96
eventlist
Definition: nntd.py:90
swtwodtracks
Definition: nntd.py:125
networkname
Definition: nntd.py:91
neurotracks
Definition: nntd.py:109
swneurotracks
Definition: nntd.py:117
swneurotracksname
Definition: nntd.py:100
def makearray(self, evlist)
Definition: nntd.py:443
def costotheta(self, x)
Definition: nntd.py:143
int maxtracks
Definition: nntd.py:15
int version
Definition: nntd.py:14
def gettwodvals(self, evlist, twod)
Definition: nntd.py:214
twodtracksname
Definition: nntd.py:101
varnum
Definition: nntd.py:137
hwneurotracksname
Definition: nntd.py:99
debuglist
Definition: nntd.py:139
hwneurotracks
Definition: nntd.py:113
recotracks
Definition: nntd.py:107
def getrecovalsold(self, evlist, fitres)
Definition: nntd.py:158
twodtracks
Definition: nntd.py:121