Belle II Software  release-08-01-10
testVXDTFRelatedModules.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
11 
12 import basf2 as b2
13 import time
14 
15 from VXDTF.setup_modules import (setup_gfTCtoSPTCConverters,
16  setup_spCreatorPXD,
17  setup_spCreatorSVD,
18  setup_sp2thConnector,
19  setup_qualityEstimators)
20 
21 from VXDTF.setup_modules_ml import add_fbdtclassifier_training, add_ml_threehitfilters
22 
23 # ################
24 # rootInputFileName = "seed4nEv100000pGun1_1T.root"
25 # rootInputFileName = "seed14nEv100000pGun1_1T.root" # phi 0-90, theta 60-85, pT 100-145 MeV/c, PDG 13
26 # rootInputFileName = "seed13nEv100000pGun1_2T.root" # phi 0-90, theta 60-85, pT 100-145 MeV/c, PDG 13
27 # rootInputFileName = "evtGenseed6nEv100000.root" #evtGenSinglePassTest-TrainSample. skipCluster = False
28 # rootInputFileName = "evtGenseed5nEv10000.root" #testSample. skipCluster = False
29 # rootInputFileName = "evtGenseed7nEv10000.root" # testSample. skipCluster = True
30 # rootInputFileName = "evtGenseed8nEv200000.root" # trainSample. skipCluster = True
31 # rootInputFileName = "evtGenseed8nEv100000.root" #evtGenSinglePassTest-TrainSample. skipCluster = True (TODO)
32 
33 # rootInputFileName = "seed11nEv100pGun1_1T.root" # test- and TrainSample
34 # rootInputFileName = "seed12nEv200pGun1_2T.root" # test- and TrainSample 0-90° phi, 60-85° Theta
35 rootInputFileName = "seed12345nEv1000pGun1_20T.root" # test- and TrainSample 0-90° phi, 60-85° Theta
36 rootInputFileName = "MyRootFile.root" # test- and TrainSample 0-90° phi, 60-85° Theta
37 # rootInputFileName = "TestFile.root"
38 
39 # file name into which the segNetAnalize stores its stuff
40 fitType = 'circleFit' # currently supported: 'random' and 'circleFit'
41 # fitType = 'random' # currently supported: 'random' and 'circleFit'
42 setFilterType = 'hopfield' # currently supported: 'greedy' and 'hopfield'
43 
44 usePXD = False
45 useDisplay = False
46 newTrain = False # if true, rawSecMap-Data is collected. IF false, new TF will be executed
47 printNetworks = False # creates graphs for each network for each event if true
48 useOldTFinstead = False # if true, the old vxdtf is used instead of the new one.
49 oldTFNoSubsetSelection = True # if true, the old vxdtf does not its hopfield-part, which allows using new modules for that
50 ignoreDeadTCs = True # if true, the TrackFinderVXDAnalizer will not add dead TCs to the efficiencies
51 bypassCA = False # if true, no CA will be used but the BasicPathFinder instead...
52 
53 activateSegNetAnalizer = False # only needed when studying FastBDT-behavior
54 
55 doStrictSeeds = False # if true, a smaller amount of TCs are created from the same segment-tree
56 doNewSubsetSelection = True # if true, then NewSubsetSelection will be executed
57 doVirtualIPRemovalb4Fit = True # if false, the vIP willbe removed after the fit
58 doEventSummary = True # if true TFVXDAnalizer will produce event-wise results
59 switchFiltersOff = False # if true, SegNetProducer does not apply any filters of the sectorMap.
60 
61 trainFBDT = False # with the current settings: collects samples but does not train a FastBDT!
62 useFBDT = False # use the ML Filter for creating the SegmentNetwork instead of the SectorMap filters
63 
64 segNetAnaRFN = 'SegNetAnalyzer_SM_train.root'
65 fbdtSamplesFN = 'FBDTClassifier_samples_train_10k.dat'
66 fbdtFN = 'FBDTClassifier_1000_3.dat'
67 if useFBDT:
68  cNetworks = int(2)
69 else:
70  cNetworks = int(3)
71 
72 # Important parameters:
73 
74 # tempStringList = rootInputFileName.split('nEv', 1)
75 # stringInitialValue = tempStringList[0].split("seed", 1)
76 # print("found seed: " + stringInitialValue[1])
77 # numEvents = 20 # WARNING has to be identical with the value named in rootInputFileName!
78 # comment Martin: this doesn't seem to be true. With
79 # initialValue = int(stringInitialValue[1])
80 initialValue = 0
81 
82 b2.set_log_level(b2.LogLevel.ERROR)
83 b2.set_random_seed(initialValue)
84 
85 trainerVXDTFLogLevel = b2.LogLevel.INFO
86 trainerVXDTFDebugLevel = 10
87 
88 TFlogLevel = b2.LogLevel.INFO
89 TFDebugLevel = 1
90 
91 CAlogLevel = b2.LogLevel.DEBUG
92 CADebugLevel = 1
93 
94 AnalizerlogLevel = b2.LogLevel.INFO
95 AnalizerDebugLevel = 1
96 
97 # acceptedRawSecMapFiles = ['lowTestRedesign_202608818.root']
98 if (initialValue == 2):
99  print("chosen initialvalue 2! " + rootInputFileName)
100  acceptedRawSecMapFiles = ['lowTestRedesign_1373026662.root'] # 42
101 elif (initialValue == 0):
102  print("chosen initialvalue 0! " + rootInputFileName)
103  acceptedRawSecMapFiles = ['lowTestRedesign.root'] # 23
104 elif (initialValue == 3):
105  print("chosen initialvalue 3! " + rootInputFileName)
106  acceptedRawSecMapFiles = ['lowTestRedesign_202608818.root'] # 23
107 elif (initialValue == 4):
108  print("chosen initialvalue 4!! " + rootInputFileName)
109  acceptedRawSecMapFiles = ['lowTestRedesign_293660864.root'] # 24
110 # elif (initialValue == 5):
111  # print("chosen initialvalue 5! (with background)" + rootInputFileName)
112  # acceptedRawSecMapFiles = ['lowTestRedesign_753986291.root'] # 25
113 elif (initialValue == 5):
114  print("chosen initialvalue 5! " + rootInputFileName)
115  acceptedRawSecMapFiles = ['lowTestRedesign_1120112796.root'] # lowTestRedesign_1120112796.root
116 elif (initialValue == 6):
117  print("chosen initialvalue 6! " + rootInputFileName)
118  acceptedRawSecMapFiles = ['lowTestRedesign_1120112796.root']
119  # acceptedRawSecMapFiles = ['lowTestRedesign_1667035383.root'] # 26 - single track, single event raw data
120 elif (initialValue == 7):
121  print("chosen initialvalue 7! (skipCluster-setting=True) " + rootInputFileName)
122  acceptedRawSecMapFiles = ['lowTestRedesign_1332084337.root'] # 27 - single track, single event raw data
123 elif (initialValue == 8):
124  print("chosen initialvalue 8! (skipCluster-setting=True): 200k evtGen events " + rootInputFileName)
125  acceptedRawSecMapFiles = ['lowTestRedesign_1332084337.root'] # 28 - single track, single event raw data
126 elif (initialValue == 11):
127  print("chosen initialvalue 11! (skipCluster-setting=True): 100 pGun events " + rootInputFileName)
128  acceptedRawSecMapFiles = ['lowTestRedesign_1017144726.root'] # 28 - single track, single event raw data
129 elif (initialValue == 12):
130  print("chosen initialvalue 12! (skipCluster-setting=True): 200 pGun events " + rootInputFileName)
131  acceptedRawSecMapFiles = ['lowTestRedesign_1196763558.root'] # 28 - single track, single event raw data
132 elif (initialValue == 13):
133  print("chosen initialvalue 13! (skipCluster-setting=True): 100k pGun events " + rootInputFileName)
134  acceptedRawSecMapFiles = ['lowTestRedesign_1874442389.root'] # 28 - single track, single event raw data
135 # elif (initialValue == 14):
136 # print("chosen initialvalue 14! (skipCluster-setting=True): 100k pGun events " + rootInputFileName)
137 # acceptedRawSecMapFiles = ['lowTestRedesign_1054912153.root'] # 28 - single track, single event raw data
138 elif (initialValue == 57):
139  print("chosen initialvalue 57! setup remark: train: 10k events, 10 tracks per event, theta 60-85°, phi 0-360°, pT 100-145MeV.")
140  acceptedRawSecMapFiles = ['lowTestRedesign_779994078.root'] # 55 - single track, single event raw data
141 elif (initialValue == 12345):
142  print("chosen initialvalue 12345! some dummy setup!")
143  acceptedRawSecMapFiles = ['lowTestRedesign_349397772.root']
144 else:
145  print("ERROR! no valid initialvalue chosen!")
146  acceptedRawSecMapFiles = [""]
147  time.sleep(60)
148 
149 print('')
150 time.sleep(5) # sleep for 5 seconds
151 print('')
152 
153 rootInputM = b2.register_module('RootInput')
154 rootInputM.param('inputFileName', rootInputFileName)
155 
156 # rootInputM.param('skipNEvents', int(10000))
157 
158 eventinfoprinter = b2.register_module('EventInfoPrinter')
159 
160 
161 gearbox = b2.register_module('Gearbox')
162 
163 secMapBootStrap = b2.register_module('SectorMapBootstrap')
164 secMapBootStrap.param('ReadSectorMap', False)
165 secMapBootStrap.param('WriteSectorMap', True)
166 
167 evtStepSize = 1
168 if newTrain:
169  newSecMapTrainerBase = b2.register_module('SecMapTrainerBase')
170  newSecMapTrainerBase.logging.log_level = trainerVXDTFLogLevel
171  newSecMapTrainerBase.logging.debug_level = trainerVXDTFDebugLevel
172  newSecMapTrainerBase.param('spTCarrayName', 'checkedSPTCs')
173  newSecMapTrainerBase.param('allowTraining', True)
174 
175  evtStepSize = 100
176 else:
177  merger = b2.register_module('RawSecMapMerger')
178  merger.logging.log_level = trainerVXDTFLogLevel
179  merger.logging.debug_level = trainerVXDTFDebugLevel
180  merger.param('rootFileNames', acceptedRawSecMapFiles)
181  # merger.param('spTCarrayName', 'checkedSPTCs')
182 
183 if useOldTFinstead:
184  evtStepSize = 100
185 
186 geometry = b2.register_module('Geometry')
187 geometry.param('components', ['BeamPipe', 'MagneticFieldConstant4LimitedRSVD',
188  'PXD', 'SVD'])
189 
190 
191 eventCounter = b2.register_module('EventCounter')
192 eventCounter.logging.log_level = b2.LogLevel.INFO
193 eventCounter.param('stepSize', evtStepSize)
194 
195 if useOldTFinstead:
196  tuneValue = 0.06
197  secSetup = [
198  'shiftedL3IssueTestSVDStd-moreThan400MeV_SVD',
199  'shiftedL3IssueTestSVDStd-100to400MeV_SVD',
200  'shiftedL3IssueTestSVDStd-25to100MeV_SVD']
201  if usePXD:
202  secSetup = \
203  ['shiftedL3IssueTestVXDStd-moreThan400MeV_PXDSVD',
204  'shiftedL3IssueTestVXDStd-100to400MeV_PXDSVD',
205  'shiftedL3IssueTestVXDStd-25to100MeV_PXDSVD'
206  ]
207  tuneValue = 0.22
208  vxdtf = b2.register_module('VXDTF') # VXDTF TFRedesign
209  vxdtf.logging.log_level = b2.LogLevel.DEBUG
210  vxdtf.logging.debug_level = 1
211  vxdtf.param('sectorSetup', secSetup)
212  vxdtf.param('GFTrackCandidatesColName', 'caTracks')
213  vxdtf.param('tuneCutoffs', tuneValue)
214  vxdtf.param('displayCollector', 2)
215  if oldTFNoSubsetSelection:
216  vxdtf.param('filterOverlappingTCs', 'none') # shall provide overlapping TCs
217  # vxdtf.param('useTimeSeedAsQI', True) # hack for storing QIs in TimeSeed-Variable for genfit::TrackCand
218 
219  oldAnalyzer = b2.register_module('TFAnalizer')
220  oldAnalyzer.logging.log_level = b2.LogLevel.INFO
221  oldAnalyzer.param('printExtentialAnalysisData', False)
222  oldAnalyzer.param('caTCname', 'caTracks')
223  oldAnalyzer.param('acceptedTCname', 'VXDTFoldAcceptedTCS')
224  oldAnalyzer.param('lostTCname', 'VXDTFoldLostTCS')
225 
226  # TCConverter, genfit -> SPTC
227  trackCandConverter = b2.register_module('GFTC2SPTCConverter')
228  trackCandConverter.logging.log_level = b2.LogLevel.WARNING
229  trackCandConverter.param('genfitTCName', 'caTracks')
230  trackCandConverter.param('SpacePointTCName', 'caSPTCs')
231  trackCandConverter.param('NoSingleClusterSVDSP', 'nosingleSP')
232  trackCandConverter.param('PXDClusterSP', 'pxdOnly')
233  trackCandConverter.param('checkNoSingleSVDSP', True)
234  trackCandConverter.param('checkTrueHits', False)
235  trackCandConverter.param('useSingleClusterSP', False)
236  trackCandConverter.param('skipCluster', True)
237 else:
238  segNetProducer = b2.register_module('SegmentNetworkProducer')
239  segNetProducer.param('CreateNeworks', cNetworks)
240  segNetProducer.param('NetworkOutputName', 'test2Hits')
241  segNetProducer.param('printNetworks', printNetworks)
242  segNetProducer.param('allFiltersOff', switchFiltersOff)
243  # segNetProducer.param('SpacePointsArrayNames', ['pxdOnly', 'nosingleSP'])
244  segNetProducer.param('SpacePointsArrayNames', ['nosingleSP_relTH'])
245  segNetProducer.logging.log_level = TFlogLevel
246  segNetProducer.logging.debug_level = TFDebugLevel
247 
248  if activateSegNetAnalizer:
249  segNetAnalyzer = b2.register_module('SegmentNetworkAnalyzer')
250  segNetAnalyzer.param('networkInputName', 'test2Hits')
251  segNetAnalyzer.param('rootFileName', segNetAnaRFN)
252  segNetAnalyzer.logging.log_level = b2.LogLevel.INFO
253  segNetAnalyzer.logging.debug_level = 100
254 
255  if bypassCA:
256  cellOmat = b2.register_module('TrackFinderVXDBasicPathFinder')
257  else:
258  cellOmat = b2.register_module('TrackFinderVXDCellOMat')
259  cellOmat.param('printNetworks', printNetworks)
260  cellOmat.param('SpacePointTrackCandArrayName', 'caSPTCs')
261  cellOmat.param('NetworkName', 'test2Hits')
262  cellOmat.param('removeVirtualIP', False)
263  cellOmat.param('strictSeeding', doStrictSeeds)
264  cellOmat.logging.log_level = CAlogLevel
265  cellOmat.logging.debug_level = CADebugLevel
266 
267 
268 print("spot 10")
269 
270 vxdAnal = b2.register_module('TrackFinderVXDAnalizer')
271 vxdAnal.param('referenceTCname', 'SPTracks')
272 vxdAnal.param('testTCname', 'caSPTCs')
273 vxdAnal.param('purityThreshold', 0.7)
274 vxdAnal.param('ignoreDeadTCs', ignoreDeadTCs)
275 vxdAnal.param('doEventSummary', doEventSummary)
276 vxdAnal.logging.log_level = AnalizerlogLevel
277 vxdAnal.logging.debug_level = AnalizerDebugLevel
278 
279 print("spot 11")
280 if newTrain:
281  b2.log_to_file('testRedesign' + str(initialValue) + '.log', append=False)
282 else:
283  b2.log_to_file('testsegNetExecute' + str(initialValue) + '.log', append=False)
284 # Create paths
285 main = b2.create_path()
286 
287 
288 main.add_module(rootInputM)
289 main.add_module(eventinfoprinter)
290 main.add_module(gearbox)
291 main.add_module(geometry)
292 main.add_module(eventCounter)
293 main.add_module(secMapBootStrap)
294 
295 setup_spCreatorSVD(path=main, nameOutput='nosingleSP', createSingleClusterSPs=False, logLevel=b2.LogLevel.INFO)
296 # needed since 2gftc-converter does not work without it
297 setup_spCreatorPXD(path=main, nameOutput='pxdOnly', logLevel=b2.LogLevel.INFO)
298 setup_gfTCtoSPTCConverters(
299  path=main,
300  pxdSPs='pxdOnly',
301  svdSPs='nosingleSP',
302  gfTCinput='mcTracks',
303  sptcOutput='checkedSPTCs',
304  usePXD=usePXD,
305  logLevel=b2.LogLevel.WARNING)
306 
307 vIPRemover = b2.register_module('SPTCvirtualIPRemover')
308 vIPRemover.param('maxTCLengthForVIPKeeping', 0) # want to remove virtualIP for any track length
309 vIPRemover.param('tcArrayName', 'caSPTCs')
310 
311 # connect all SpacePoints to all possible TrueHits and store them in a new
312 # StoreArray (to not interfere with the SpacePoints of the reference
313 # TrackCands)
314 setup_sp2thConnector(main, 'pxdOnly', 'nosingleSP', '_relTH', True, b2.LogLevel.ERROR, 1)
315 if newTrain:
316  main.add_module(newSecMapTrainerBase)
317 else:
318  if useOldTFinstead:
319  main.add_module(b2.register_module('SetupGenfitExtrapolation'))
320  main.add_module(vxdtf)
321  main.add_module(oldAnalyzer)
322  main.add_module(trackCandConverter)
323  else:
324  main.add_module(merger)
325  main.add_module(segNetProducer)
326  if trainFBDT: # collect in this step
327  add_fbdtclassifier_training(main, 'test2Hits', 'FBDTClassifier.dat', False, True,
328  False, fbdtSamplesFN, 100, 3, 0.15, 0.5, b2.LogLevel.DEBUG, 10)
329  if useFBDT: # apply the filters
330  add_ml_threehitfilters(main, 'test2Hits', fbdtFN, 0.989351, True)
331  if activateSegNetAnalizer:
332  main.add_module(segNetAnalyzer)
333  main.add_module(cellOmat)
334 
335  if doVirtualIPRemovalb4Fit:
336  main.add_module(vIPRemover)
337 
338  setup_qualityEstimators(main, fitType, 'caSPTCs', b2.LogLevel.INFO, 1)
339  # setup_qualityEstimators(main, fitType, 'SPTracks', LogLevel.DEBUG, 1)
340 
341  if doVirtualIPRemovalb4Fit is False:
342  main.add_module(vIPRemover)
343 
344  if doNewSubsetSelection:
345 
346  tcNetworkProducer = b2.register_module('SPTCNetworkProducer')
347  tcNetworkProducer.param('tcArrayName', 'caSPTCs')
348  tcNetworkProducer.param('tcNetworkName', 'tcNetwork')
349  main.add_module(tcNetworkProducer)
350 
351  tsEvaluator = b2.register_module('TrackSetEvaluatorHopfieldNN')
352  tsEvaluator.logging.log_level = b2.LogLevel.DEBUG
353  tsEvaluator.logging.debug_level = 3
354  tsEvaluator.param('tcArrayName', 'caSPTCs')
355  tsEvaluator.param('tcNetworkName', 'tcNetwork')
356  main.add_module(tsEvaluator)
357 
358  svdOverlapResolver = b2.register_module('SVDOverlapResolver')
359  svdOverlapResolver.param('NameSpacePointTrackCands', 'caSPTCs')
360  svdOverlapResolver.param('resolveMethod', 'greedy')
361  # svdOverlapResolver.param('resolveMethod', 'hopfield')
362  svdOverlapResolver.logging.log_level = b2.LogLevel.DEBUG
363 
364  main.add_module(vxdAnal)
365 
366 if useDisplay:
367  display = b2.register_module('Display')
368  display.param('showAllPrimaries', True)
369  main.add_module(display)
370 
371 # Process events
372 b2.process(main)
373 
374 print(b2.statistics)