15 from VXDTF.setup_modules
import (setup_gfTCtoSPTCConverters,
19 setup_qualityEstimators)
21 from VXDTF.setup_modules_ml
import add_fbdtclassifier_training, add_ml_threehitfilters
35 rootInputFileName =
"seed12345nEv1000pGun1_20T.root"
36 rootInputFileName =
"MyRootFile.root"
42 setFilterType =
'hopfield'
48 useOldTFinstead =
False
49 oldTFNoSubsetSelection =
True
53 activateSegNetAnalizer =
False
56 doNewSubsetSelection =
True
57 doVirtualIPRemovalb4Fit =
True
59 switchFiltersOff =
False
64 segNetAnaRFN =
'SegNetAnalyzer_SM_train.root'
65 fbdtSamplesFN =
'FBDTClassifier_samples_train_10k.dat'
66 fbdtFN =
'FBDTClassifier_1000_3.dat'
82 b2.set_log_level(b2.LogLevel.ERROR)
83 b2.set_random_seed(initialValue)
85 trainerVXDTFLogLevel = b2.LogLevel.INFO
86 trainerVXDTFDebugLevel = 10
88 TFlogLevel = b2.LogLevel.INFO
91 CAlogLevel = b2.LogLevel.DEBUG
94 AnalizerlogLevel = b2.LogLevel.INFO
95 AnalizerDebugLevel = 1
98 if (initialValue == 2):
99 print(
"chosen initialvalue 2! " + rootInputFileName)
100 acceptedRawSecMapFiles = [
'lowTestRedesign_1373026662.root']
101 elif (initialValue == 0):
102 print(
"chosen initialvalue 0! " + rootInputFileName)
103 acceptedRawSecMapFiles = [
'lowTestRedesign.root']
104 elif (initialValue == 3):
105 print(
"chosen initialvalue 3! " + rootInputFileName)
106 acceptedRawSecMapFiles = [
'lowTestRedesign_202608818.root']
107 elif (initialValue == 4):
108 print(
"chosen initialvalue 4!! " + rootInputFileName)
109 acceptedRawSecMapFiles = [
'lowTestRedesign_293660864.root']
113 elif (initialValue == 5):
114 print(
"chosen initialvalue 5! " + rootInputFileName)
115 acceptedRawSecMapFiles = [
'lowTestRedesign_1120112796.root']
116 elif (initialValue == 6):
117 print(
"chosen initialvalue 6! " + rootInputFileName)
118 acceptedRawSecMapFiles = [
'lowTestRedesign_1120112796.root']
120 elif (initialValue == 7):
121 print(
"chosen initialvalue 7! (skipCluster-setting=True) " + rootInputFileName)
122 acceptedRawSecMapFiles = [
'lowTestRedesign_1332084337.root']
123 elif (initialValue == 8):
124 print(
"chosen initialvalue 8! (skipCluster-setting=True): 200k evtGen events " + rootInputFileName)
125 acceptedRawSecMapFiles = [
'lowTestRedesign_1332084337.root']
126 elif (initialValue == 11):
127 print(
"chosen initialvalue 11! (skipCluster-setting=True): 100 pGun events " + rootInputFileName)
128 acceptedRawSecMapFiles = [
'lowTestRedesign_1017144726.root']
129 elif (initialValue == 12):
130 print(
"chosen initialvalue 12! (skipCluster-setting=True): 200 pGun events " + rootInputFileName)
131 acceptedRawSecMapFiles = [
'lowTestRedesign_1196763558.root']
132 elif (initialValue == 13):
133 print(
"chosen initialvalue 13! (skipCluster-setting=True): 100k pGun events " + rootInputFileName)
134 acceptedRawSecMapFiles = [
'lowTestRedesign_1874442389.root']
138 elif (initialValue == 57):
139 print(
"chosen initialvalue 57! setup remark: train: 10k events, 10 tracks per event, theta 60-85°, phi 0-360°, pT 100-145MeV.")
140 acceptedRawSecMapFiles = [
'lowTestRedesign_779994078.root']
141 elif (initialValue == 12345):
142 print(
"chosen initialvalue 12345! some dummy setup!")
143 acceptedRawSecMapFiles = [
'lowTestRedesign_349397772.root']
145 print(
"ERROR! no valid initialvalue chosen!")
146 acceptedRawSecMapFiles = [
""]
153 rootInputM = b2.register_module(
'RootInput')
154 rootInputM.param(
'inputFileName', rootInputFileName)
158 eventinfoprinter = b2.register_module(
'EventInfoPrinter')
161 gearbox = b2.register_module(
'Gearbox')
163 secMapBootStrap = b2.register_module(
'SectorMapBootstrap')
164 secMapBootStrap.param(
'ReadSectorMap',
False)
165 secMapBootStrap.param(
'WriteSectorMap',
True)
169 newSecMapTrainerBase = b2.register_module(
'SecMapTrainerBase')
170 newSecMapTrainerBase.logging.log_level = trainerVXDTFLogLevel
171 newSecMapTrainerBase.logging.debug_level = trainerVXDTFDebugLevel
172 newSecMapTrainerBase.param(
'spTCarrayName',
'checkedSPTCs')
173 newSecMapTrainerBase.param(
'allowTraining',
True)
177 merger = b2.register_module(
'RawSecMapMerger')
178 merger.logging.log_level = trainerVXDTFLogLevel
179 merger.logging.debug_level = trainerVXDTFDebugLevel
180 merger.param(
'rootFileNames', acceptedRawSecMapFiles)
186 geometry = b2.register_module(
'Geometry')
187 geometry.param(
'components', [
'BeamPipe',
'MagneticFieldConstant4LimitedRSVD',
191 eventCounter = b2.register_module(
'EventCounter')
192 eventCounter.logging.log_level = b2.LogLevel.INFO
193 eventCounter.param(
'stepSize', evtStepSize)
198 'shiftedL3IssueTestSVDStd-moreThan400MeV_SVD',
199 'shiftedL3IssueTestSVDStd-100to400MeV_SVD',
200 'shiftedL3IssueTestSVDStd-25to100MeV_SVD']
203 [
'shiftedL3IssueTestVXDStd-moreThan400MeV_PXDSVD',
204 'shiftedL3IssueTestVXDStd-100to400MeV_PXDSVD',
205 'shiftedL3IssueTestVXDStd-25to100MeV_PXDSVD'
208 vxdtf = b2.register_module(
'VXDTF')
209 vxdtf.logging.log_level = b2.LogLevel.DEBUG
210 vxdtf.logging.debug_level = 1
211 vxdtf.param(
'sectorSetup', secSetup)
212 vxdtf.param(
'GFTrackCandidatesColName',
'caTracks')
213 vxdtf.param(
'tuneCutoffs', tuneValue)
214 vxdtf.param(
'displayCollector', 2)
215 if oldTFNoSubsetSelection:
216 vxdtf.param(
'filterOverlappingTCs',
'none')
219 oldAnalyzer = b2.register_module(
'TFAnalizer')
220 oldAnalyzer.logging.log_level = b2.LogLevel.INFO
221 oldAnalyzer.param(
'printExtentialAnalysisData',
False)
222 oldAnalyzer.param(
'caTCname',
'caTracks')
223 oldAnalyzer.param(
'acceptedTCname',
'VXDTFoldAcceptedTCS')
224 oldAnalyzer.param(
'lostTCname',
'VXDTFoldLostTCS')
227 trackCandConverter = b2.register_module(
'GFTC2SPTCConverter')
228 trackCandConverter.logging.log_level = b2.LogLevel.WARNING
229 trackCandConverter.param(
'genfitTCName',
'caTracks')
230 trackCandConverter.param(
'SpacePointTCName',
'caSPTCs')
231 trackCandConverter.param(
'NoSingleClusterSVDSP',
'nosingleSP')
232 trackCandConverter.param(
'PXDClusterSP',
'pxdOnly')
233 trackCandConverter.param(
'checkNoSingleSVDSP',
True)
234 trackCandConverter.param(
'checkTrueHits',
False)
235 trackCandConverter.param(
'useSingleClusterSP',
False)
236 trackCandConverter.param(
'skipCluster',
True)
238 segNetProducer = b2.register_module(
'SegmentNetworkProducer')
239 segNetProducer.param(
'CreateNeworks', cNetworks)
240 segNetProducer.param(
'NetworkOutputName',
'test2Hits')
241 segNetProducer.param(
'printNetworks', printNetworks)
242 segNetProducer.param(
'allFiltersOff', switchFiltersOff)
244 segNetProducer.param(
'SpacePointsArrayNames', [
'nosingleSP_relTH'])
245 segNetProducer.logging.log_level = TFlogLevel
246 segNetProducer.logging.debug_level = TFDebugLevel
248 if activateSegNetAnalizer:
249 segNetAnalyzer = b2.register_module(
'SegmentNetworkAnalyzer')
250 segNetAnalyzer.param(
'networkInputName',
'test2Hits')
251 segNetAnalyzer.param(
'rootFileName', segNetAnaRFN)
252 segNetAnalyzer.logging.log_level = b2.LogLevel.INFO
253 segNetAnalyzer.logging.debug_level = 100
256 cellOmat = b2.register_module(
'TrackFinderVXDBasicPathFinder')
258 cellOmat = b2.register_module(
'TrackFinderVXDCellOMat')
259 cellOmat.param(
'printNetworks', printNetworks)
260 cellOmat.param(
'SpacePointTrackCandArrayName',
'caSPTCs')
261 cellOmat.param(
'NetworkName',
'test2Hits')
262 cellOmat.param(
'removeVirtualIP',
False)
263 cellOmat.param(
'strictSeeding', doStrictSeeds)
264 cellOmat.logging.log_level = CAlogLevel
265 cellOmat.logging.debug_level = CADebugLevel
270 vxdAnal = b2.register_module(
'TrackFinderVXDAnalizer')
271 vxdAnal.param(
'referenceTCname',
'SPTracks')
272 vxdAnal.param(
'testTCname',
'caSPTCs')
273 vxdAnal.param(
'purityThreshold', 0.7)
274 vxdAnal.param(
'ignoreDeadTCs', ignoreDeadTCs)
275 vxdAnal.param(
'doEventSummary', doEventSummary)
276 vxdAnal.logging.log_level = AnalizerlogLevel
277 vxdAnal.logging.debug_level = AnalizerDebugLevel
281 b2.log_to_file(
'testRedesign' + str(initialValue) +
'.log', append=
False)
283 b2.log_to_file(
'testsegNetExecute' + str(initialValue) +
'.log', append=
False)
285 main = b2.create_path()
288 main.add_module(rootInputM)
289 main.add_module(eventinfoprinter)
290 main.add_module(gearbox)
291 main.add_module(geometry)
292 main.add_module(eventCounter)
293 main.add_module(secMapBootStrap)
295 setup_spCreatorSVD(path=main, nameOutput=
'nosingleSP', createSingleClusterSPs=
False, logLevel=b2.LogLevel.INFO)
297 setup_spCreatorPXD(path=main, nameOutput=
'pxdOnly', logLevel=b2.LogLevel.INFO)
298 setup_gfTCtoSPTCConverters(
302 gfTCinput=
'mcTracks',
303 sptcOutput=
'checkedSPTCs',
305 logLevel=b2.LogLevel.WARNING)
307 vIPRemover = b2.register_module(
'SPTCvirtualIPRemover')
308 vIPRemover.param(
'maxTCLengthForVIPKeeping', 0)
309 vIPRemover.param(
'tcArrayName',
'caSPTCs')
314 setup_sp2thConnector(main,
'pxdOnly',
'nosingleSP',
'_relTH',
True, b2.LogLevel.ERROR, 1)
316 main.add_module(newSecMapTrainerBase)
319 main.add_module(b2.register_module(
'SetupGenfitExtrapolation'))
320 main.add_module(vxdtf)
321 main.add_module(oldAnalyzer)
322 main.add_module(trackCandConverter)
324 main.add_module(merger)
325 main.add_module(segNetProducer)
327 add_fbdtclassifier_training(main,
'test2Hits',
'FBDTClassifier.dat',
False,
True,
328 False, fbdtSamplesFN, 100, 3, 0.15, 0.5, b2.LogLevel.DEBUG, 10)
330 add_ml_threehitfilters(main,
'test2Hits', fbdtFN, 0.989351,
True)
331 if activateSegNetAnalizer:
332 main.add_module(segNetAnalyzer)
333 main.add_module(cellOmat)
335 if doVirtualIPRemovalb4Fit:
336 main.add_module(vIPRemover)
338 setup_qualityEstimators(main, fitType,
'caSPTCs', b2.LogLevel.INFO, 1)
341 if doVirtualIPRemovalb4Fit
is False:
342 main.add_module(vIPRemover)
344 if doNewSubsetSelection:
346 tcNetworkProducer = b2.register_module(
'SPTCNetworkProducer')
347 tcNetworkProducer.param(
'tcArrayName',
'caSPTCs')
348 tcNetworkProducer.param(
'tcNetworkName',
'tcNetwork')
349 main.add_module(tcNetworkProducer)
351 tsEvaluator = b2.register_module(
'TrackSetEvaluatorHopfieldNN')
352 tsEvaluator.logging.log_level = b2.LogLevel.DEBUG
353 tsEvaluator.logging.debug_level = 3
354 tsEvaluator.param(
'tcArrayName',
'caSPTCs')
355 tsEvaluator.param(
'tcNetworkName',
'tcNetwork')
356 main.add_module(tsEvaluator)
358 svdOverlapResolver = b2.register_module(
'SVDOverlapResolver')
359 svdOverlapResolver.param(
'NameSpacePointTrackCands',
'caSPTCs')
360 svdOverlapResolver.param(
'resolveMethod',
'greedy')
362 svdOverlapResolver.logging.log_level = b2.LogLevel.DEBUG
364 main.add_module(vxdAnal)
367 display = b2.register_module(
'Display')
368 display.param(
'showAllPrimaries',
True)
369 main.add_module(display)