16 from basf2
import B2INFO, B2FATAL, B2WARNING
20 import modularAnalysis
as ma
22 from variables
import utils
29 def getBelleOrBelle2():
31 Gets the global ModeCode.
39 def setInteractionWithDatabase(downloadFromDatabaseIfNotFound=False, uploadToDatabaseAfterTraining=False):
41 Sets the interaction with the database: download trained weight files or upload weight files after training.
47 downloadFlag = downloadFromDatabaseIfNotFound
48 uploadFlag = uploadToDatabaseAfterTraining
52 flavor_tagging = [
'FBDT_qrCombined',
'FANN_qrCombined',
'qrMC',
'mcFlavorOfOtherB',
'qrGNN',
53 'qpElectron',
'hasTrueTargetElectron',
'isRightCategoryElectron',
54 'qpIntermediateElectron',
'hasTrueTargetIntermediateElectron',
'isRightCategoryIntermediateElectron',
55 'qpMuon',
'hasTrueTargetMuon',
'isRightCategoryMuon',
56 'qpIntermediateMuon',
'hasTrueTargetIntermediateMuon',
'isRightCategoryIntermediateMuon',
57 'qpKinLepton',
'hasTrueTargetKinLepton',
'isRightCategoryKinLepton',
58 'qpIntermediateKinLepton',
'hasTrueTargetIntermediateKinLepton',
'isRightCategoryIntermediateKinLepton',
59 'qpKaon',
'hasTrueTargetKaon',
'isRightCategoryKaon',
60 'qpSlowPion',
'hasTrueTargetSlowPion',
'isRightCategorySlowPion',
61 'qpFastHadron',
'hasTrueTargetFastHadron',
'isRightCategoryFastHadron',
62 'qpLambda',
'hasTrueTargetLambda',
'isRightCategoryLambda',
63 'qpFSC',
'hasTrueTargetFSC',
'isRightCategoryFSC',
64 'qpMaximumPstar',
'hasTrueTargetMaximumPstar',
'isRightCategoryMaximumPstar',
65 'qpKaonPion',
'hasTrueTargetKaonPion',
'isRightCategoryKaonPion']
68 def add_default_FlavorTagger_aliases():
70 This function adds the default aliases for flavor tagging variables
71 and defines the collection of flavor tagging variables.
74 variables.variables.addAlias(
'FBDT_qrCombined',
'qrOutput(FBDT)')
75 variables.variables.addAlias(
'FANN_qrCombined',
'qrOutput(FANN)')
76 variables.variables.addAlias(
'qrMC',
'isRelatedRestOfEventB0Flavor')
78 variables.variables.addAlias(
'qrGNN',
'extraInfo(qrGNN)')
80 for iCategory
in AvailableCategories:
81 aliasForQp =
'qp' + iCategory
82 aliasForTrueTarget =
'hasTrueTarget' + iCategory
83 aliasForIsRightCategory =
'isRightCategory' + iCategory
84 variables.variables.addAlias(aliasForQp,
'qpCategory(' + iCategory +
')')
85 variables.variables.addAlias(aliasForTrueTarget,
'hasTrueTargets(' + iCategory +
')')
86 variables.variables.addAlias(aliasForIsRightCategory,
'isTrueFTCategory(' + iCategory +
')')
88 utils.add_collection(flavor_tagging,
'flavor_tagging')
91 def set_FlavorTagger_pid_aliases():
93 This function adds the pid aliases needed by the flavor tagger.
95 variables.variables.addAlias(
'eid_TOP',
'pidPairProbabilityExpert(11, 211, TOP)')
96 variables.variables.addAlias(
'eid_ARICH',
'pidPairProbabilityExpert(11, 211, ARICH)')
97 variables.variables.addAlias(
'eid_ECL',
'pidPairProbabilityExpert(11, 211, ECL)')
99 variables.variables.addAlias(
'muid_TOP',
'pidPairProbabilityExpert(13, 211, TOP)')
100 variables.variables.addAlias(
'muid_ARICH',
'pidPairProbabilityExpert(13, 211, ARICH)')
101 variables.variables.addAlias(
'muid_KLM',
'pidPairProbabilityExpert(13, 211, KLM)')
103 variables.variables.addAlias(
'piid_TOP',
'pidPairProbabilityExpert(211, 321, TOP)')
104 variables.variables.addAlias(
'piid_ARICH',
'pidPairProbabilityExpert(211, 321, ARICH)')
106 variables.variables.addAlias(
'Kid_TOP',
'pidPairProbabilityExpert(321, 211, TOP)')
107 variables.variables.addAlias(
'Kid_ARICH',
'pidPairProbabilityExpert(321, 211, ARICH)')
109 if getBelleOrBelle2() ==
"Belle":
110 variables.variables.addAlias(
'eid_dEdx',
'ifNANgiveX(pidPairProbabilityExpert(11, 211, CDC, SVD), 0.5)')
111 variables.variables.addAlias(
'muid_dEdx',
'ifNANgiveX(pidPairProbabilityExpert(13, 211, CDC, SVD), 0.5)')
112 variables.variables.addAlias(
'piid_dEdx',
'ifNANgiveX(pidPairProbabilityExpert(211, 321, CDC, SVD), 0.5)')
113 variables.variables.addAlias(
'pi_vs_edEdxid',
'ifNANgiveX(pidPairProbabilityExpert(211, 11, CDC, SVD), 0.5)')
114 variables.variables.addAlias(
'Kid_dEdx',
'ifNANgiveX(pidPairProbabilityExpert(321, 211, CDC, SVD), 0.5)')
116 variables.variables.addAlias(
'eid_dEdx',
'pidPairProbabilityExpert(11, 211, CDC)')
117 variables.variables.addAlias(
'muid_dEdx',
'pidPairProbabilityExpert(13, 211, CDC)')
118 variables.variables.addAlias(
'piid_dEdx',
'pidPairProbabilityExpert(211, 321, CDC)')
119 variables.variables.addAlias(
'pi_vs_edEdxid',
'pidPairProbabilityExpert(211, 11, CDC)')
120 variables.variables.addAlias(
'Kid_dEdx',
'pidPairProbabilityExpert(321, 211, CDC)')
123 def set_FlavorTagger_pid_aliases_legacy():
125 This function adds the pid aliases needed by the flavor tagger trained for MC13.
127 variables.variables.addAlias(
'eid_TOP',
'ifNANgiveX(pidPairProbabilityExpert(11, 211, TOP), 0.5)')
128 variables.variables.addAlias(
'eid_ARICH',
'ifNANgiveX(pidPairProbabilityExpert(11, 211, ARICH), 0.5)')
129 variables.variables.addAlias(
'eid_ECL',
'ifNANgiveX(pidPairProbabilityExpert(11, 211, ECL), 0.5)')
131 variables.variables.addAlias(
'muid_TOP',
'ifNANgiveX(pidPairProbabilityExpert(13, 211, TOP), 0.5)')
132 variables.variables.addAlias(
'muid_ARICH',
'ifNANgiveX(pidPairProbabilityExpert(13, 211, ARICH), 0.5)')
133 variables.variables.addAlias(
'muid_KLM',
'ifNANgiveX(pidPairProbabilityExpert(13, 211, KLM), 0.5)')
135 variables.variables.addAlias(
'piid_TOP',
'ifNANgiveX(pidPairProbabilityExpert(211, 321, TOP), 0.5)')
136 variables.variables.addAlias(
'piid_ARICH',
'ifNANgiveX(pidPairProbabilityExpert(211, 321, ARICH), 0.5)')
138 variables.variables.addAlias(
'Kid_TOP',
'ifNANgiveX(pidPairProbabilityExpert(321, 211, TOP), 0.5)')
139 variables.variables.addAlias(
'Kid_ARICH',
'ifNANgiveX(pidPairProbabilityExpert(321, 211, ARICH), 0.5)')
141 if getBelleOrBelle2() ==
"Belle":
142 variables.variables.addAlias(
'eid_dEdx',
'ifNANgiveX(pidPairProbabilityExpert(11, 211, CDC, SVD), 0.5)')
143 variables.variables.addAlias(
'muid_dEdx',
'ifNANgiveX(pidPairProbabilityExpert(13, 211, CDC, SVD), 0.5)')
144 variables.variables.addAlias(
'piid_dEdx',
'ifNANgiveX(pidPairProbabilityExpert(211, 321, CDC, SVD), 0.5)')
145 variables.variables.addAlias(
'pi_vs_edEdxid',
'ifNANgiveX(pidPairProbabilityExpert(211, 11, CDC, SVD), 0.5)')
146 variables.variables.addAlias(
'Kid_dEdx',
'ifNANgiveX(pidPairProbabilityExpert(321, 211, CDC, SVD), 0.5)')
149 variables.variables.addAlias(
'eid_dEdx',
'ifNANgiveX(pidPairProbabilityExpert(11, 211, CDC), 0.5)')
150 variables.variables.addAlias(
'muid_dEdx',
'ifNANgiveX(pidPairProbabilityExpert(13, 211, CDC), 0.5)')
151 variables.variables.addAlias(
'piid_dEdx',
'ifNANgiveX(pidPairProbabilityExpert(211, 321, CDC), 0.5)')
152 variables.variables.addAlias(
'pi_vs_edEdxid',
'ifNANgiveX(pidPairProbabilityExpert(211, 11, CDC), 0.5)')
153 variables.variables.addAlias(
'Kid_dEdx',
'ifNANgiveX(pidPairProbabilityExpert(321, 211, CDC), 0.5)')
156 def set_GNNFlavorTagger_aliases(categories):
158 This function adds aliases for the GNN-based flavor tagger.
162 variables.variables.addAlias(
'qrCombined_bit',
'(qrCombined+1)/2')
163 alias_list = [
'qrCombined_bit']
178 'electronID_c':
'electronID*charge',
179 'muonID_c':
'muonID*charge',
180 'pionID_c':
'pionID*charge',
181 'kaonID_c':
'kaonID*charge',
182 'protonID_c':
'protonID*charge',
183 'deuteronID_c':
'deuteronID*charge',
184 'electronID_noSVD_noTOP_c':
'electronID_noSVD_noTOP*charge',
188 for rank
in range(1, 17):
190 for cat
in categories:
191 listName = AvailableCategories[cat].particleList
192 varName = f
'QpTrack({listName}, isRightCategory({cat}), isRightCategory({cat}))'
194 varWithRank = f
'ifNANgiveX(getVariableByRank(pi+:inRoe, FT_p, {varName}, {rank}), 0)'
195 aliasWithRank = f
'{cat}_rank{rank}'
197 variables.variables.addAlias(aliasWithRank, varWithRank)
198 alias_list.append(aliasWithRank)
200 for alias, var
in var_dict.items():
201 varWithRank = f
'ifNANgiveX(getVariableByRank(pi+:inRoe, FT_p, {var}, {rank}), 0)'
202 aliasWithRank = f
'{alias}_rank{rank}'
204 variables.variables.addAlias(aliasWithRank, varWithRank)
205 alias_list.append(aliasWithRank)
210 def setInputVariablesWithMask(maskName='all'):
212 Set aliases for input variables with ROE mask.
214 variables.variables.addAlias(
'pMissTag_withMask',
'pMissTag('+maskName+
')')
215 variables.variables.addAlias(
'cosTPTO_withMask',
'cosTPTO('+maskName+
')')
216 variables.variables.addAlias(
'ptTracksRoe_withMask',
'ptTracksRoe('+maskName+
')')
217 variables.variables.addAlias(
'pt2TracksRoe_withMask',
'pt2TracksRoe('+maskName+
')')
218 variables.variables.addAlias(
'ptTracksRoe_withMask',
'ptTracksRoe('+maskName+
')')
221 def getFastBDTCategories():
223 Helper function for getting the FastBDT categories.
224 It's necessary for removing top-level ROOT imports.
226 fastBDTCategories = basf2_mva.FastBDTOptions()
227 fastBDTCategories.m_nTrees = 500
228 fastBDTCategories.m_nCuts = 8
229 fastBDTCategories.m_nLevels = 3
230 fastBDTCategories.m_shrinkage = 0.10
231 fastBDTCategories.m_randRatio = 0.5
232 return fastBDTCategories
235 def getFastBDTCombiner():
237 Helper function for getting the FastBDT combiner.
238 It's necessary for removing top-level ROOT imports.
240 fastBDTCombiner = basf2_mva.FastBDTOptions()
241 fastBDTCombiner.m_nTrees = 500
242 fastBDTCombiner.m_nCuts = 8
243 fastBDTCombiner.m_nLevels = 3
244 fastBDTCombiner.m_shrinkage = 0.10
245 fastBDTCombiner.m_randRatio = 0.5
246 return fastBDTCombiner
249 def getMlpFANNCombiner():
251 Helper function for getting the MLP FANN combiner.
252 It's necessary for removing top-level ROOT imports.
254 mlpFANNCombiner = basf2_mva.FANNOptions()
255 mlpFANNCombiner.m_max_epochs = 10000
256 mlpFANNCombiner.m_hidden_layers_architecture =
"3*N"
257 mlpFANNCombiner.m_hidden_activiation_function =
"FANN_SIGMOID_SYMMETRIC"
258 mlpFANNCombiner.m_output_activiation_function =
"FANN_SIGMOID_SYMMETRIC"
259 mlpFANNCombiner.m_error_function =
"FANN_ERRORFUNC_LINEAR"
260 mlpFANNCombiner.m_training_method =
"FANN_TRAIN_RPROP"
261 mlpFANNCombiner.m_validation_fraction = 0.5
262 mlpFANNCombiner.m_random_seeds = 10
263 mlpFANNCombiner.m_test_rate = 500
264 mlpFANNCombiner.m_number_of_threads = 8
265 mlpFANNCombiner.m_scale_features =
True
266 mlpFANNCombiner.m_scale_target =
False
268 return mlpFANNCombiner
280 FTCategoryParameters = collections.namedtuple(
'FTCategoryParameters',
281 [
'particleList',
'trackName',
'eventName',
'variableName',
'code'])
285 AvailableCategories = {
287 FTCategoryParameters(
'e+:inRoe',
'Electron',
'Electron',
288 'QpOf(e+:inRoe, isRightCategory(Electron), isRightCategory(Electron))',
290 'IntermediateElectron':
291 FTCategoryParameters(
'e+:inRoe',
'IntermediateElectron',
'IntermediateElectron',
292 'QpOf(e+:inRoe, isRightCategory(IntermediateElectron), isRightCategory(IntermediateElectron))',
295 FTCategoryParameters(
'mu+:inRoe',
'Muon',
'Muon',
296 'QpOf(mu+:inRoe, isRightCategory(Muon), isRightCategory(Muon))',
299 FTCategoryParameters(
'mu+:inRoe',
'IntermediateMuon',
'IntermediateMuon',
300 'QpOf(mu+:inRoe, isRightCategory(IntermediateMuon), isRightCategory(IntermediateMuon))',
303 FTCategoryParameters(
'mu+:inRoe',
'KinLepton',
'KinLepton',
304 'QpOf(mu+:inRoe, isRightCategory(KinLepton), isRightCategory(KinLepton))',
306 'IntermediateKinLepton':
307 FTCategoryParameters(
'mu+:inRoe',
'IntermediateKinLepton',
'IntermediateKinLepton',
308 'QpOf(mu+:inRoe, isRightCategory(IntermediateKinLepton), isRightCategory(IntermediateKinLepton))',
311 FTCategoryParameters(
'K+:inRoe',
'Kaon',
'Kaon',
312 'weightedQpOf(K+:inRoe, isRightCategory(Kaon), isRightCategory(Kaon))',
315 FTCategoryParameters(
'pi+:inRoe',
'SlowPion',
'SlowPion',
316 'QpOf(pi+:inRoe, isRightCategory(SlowPion), isRightCategory(SlowPion))',
319 FTCategoryParameters(
'pi+:inRoe',
'FastHadron',
'FastHadron',
320 'QpOf(pi+:inRoe, isRightCategory(FastHadron), isRightCategory(FastHadron))',
323 FTCategoryParameters(
'Lambda0:inRoe',
'Lambda',
'Lambda',
324 'weightedQpOf(Lambda0:inRoe, isRightCategory(Lambda), isRightCategory(Lambda))',
327 FTCategoryParameters(
'pi+:inRoe',
'SlowPion',
'FSC',
328 'QpOf(pi+:inRoe, isRightCategory(FSC), isRightCategory(SlowPion))',
331 FTCategoryParameters(
'pi+:inRoe',
'MaximumPstar',
'MaximumPstar',
332 'QpOf(pi+:inRoe, isRightCategory(MaximumPstar), isRightCategory(MaximumPstar))',
335 FTCategoryParameters(
'K+:inRoe',
'Kaon',
'KaonPion',
336 'QpOf(K+:inRoe, isRightCategory(KaonPion), isRightCategory(Kaon))',
341 def getTrainingVariables(category=None):
343 Helper function to get training variables.
345 NOTE: This function is not called the Expert mode. It is not necessary to be consistent with variables list of weight files.
348 KId = {
'Belle':
'ifNANgiveX(atcPIDBelle(3,2), 0.5)',
'Belle2':
'kaonID'}
349 muId = {
'Belle':
'muIDBelle',
'Belle2':
'muonID'}
350 eId = {
'Belle':
'eIDBelle',
'Belle2':
'electronID'}
353 if category ==
'Electron' or category ==
'IntermediateElectron':
354 variables = [
'useCMSFrame(p)',
359 eId[getBelleOrBelle2()],
363 'BtagToWBosonVariables(recoilMassSqrd)',
364 'BtagToWBosonVariables(pMissCMS)',
365 'BtagToWBosonVariables(cosThetaMissCMS)',
366 'BtagToWBosonVariables(EW90)',
370 if getBelleOrBelle2() ==
"Belle":
371 variables.append(
'eid_dEdx')
372 variables.append(
'ImpactXY')
373 variables.append(
'distance')
375 elif category ==
'Muon' or category ==
'IntermediateMuon':
376 variables = [
'useCMSFrame(p)',
381 muId[getBelleOrBelle2()],
385 'BtagToWBosonVariables(recoilMassSqrd)',
386 'BtagToWBosonVariables(pMissCMS)',
387 'BtagToWBosonVariables(cosThetaMissCMS)',
388 'BtagToWBosonVariables(EW90)',
391 if getBelleOrBelle2() ==
"Belle":
392 variables.append(
'muid_dEdx')
393 variables.append(
'ImpactXY')
394 variables.append(
'distance')
395 variables.append(
'chiProb')
397 elif category ==
'KinLepton' or category ==
'IntermediateKinLepton':
398 variables = [
'useCMSFrame(p)',
403 muId[getBelleOrBelle2()],
407 eId[getBelleOrBelle2()],
411 'BtagToWBosonVariables(recoilMassSqrd)',
412 'BtagToWBosonVariables(pMissCMS)',
413 'BtagToWBosonVariables(cosThetaMissCMS)',
414 'BtagToWBosonVariables(EW90)',
417 if getBelleOrBelle2() ==
"Belle":
418 variables.append(
'eid_dEdx')
419 variables.append(
'muid_dEdx')
420 variables.append(
'ImpactXY')
421 variables.append(
'distance')
422 variables.append(
'chiProb')
424 elif category ==
'Kaon':
425 variables = [
'useCMSFrame(p)',
430 KId[getBelleOrBelle2()],
434 'NumberOfKShortsInRoe',
436 'BtagToWBosonVariables(recoilMassSqrd)',
437 'BtagToWBosonVariables(pMissCMS)',
438 'BtagToWBosonVariables(cosThetaMissCMS)',
439 'BtagToWBosonVariables(EW90)',
443 if getBelleOrBelle2() ==
"Belle":
444 variables.append(
'ImpactXY')
445 variables.append(
'distance')
447 elif category ==
'SlowPion':
448 variables = [
'useCMSFrame(p)',
457 KId[getBelleOrBelle2()],
461 'NumberOfKShortsInRoe',
463 eId[getBelleOrBelle2()],
464 'BtagToWBosonVariables(recoilMassSqrd)',
465 'BtagToWBosonVariables(EW90)',
466 'BtagToWBosonVariables(cosThetaMissCMS)',
467 'BtagToWBosonVariables(pMissCMS)',
470 if getBelleOrBelle2() ==
"Belle":
471 variables.append(
'piid_dEdx')
472 variables.append(
'ImpactXY')
473 variables.append(
'distance')
474 variables.append(
'chiProb')
476 elif category ==
'FastHadron':
477 variables = [
'useCMSFrame(p)',
487 KId[getBelleOrBelle2()],
491 'NumberOfKShortsInRoe',
493 eId[getBelleOrBelle2()],
494 'BtagToWBosonVariables(recoilMassSqrd)',
495 'BtagToWBosonVariables(EW90)',
496 'BtagToWBosonVariables(cosThetaMissCMS)',
499 if getBelleOrBelle2() ==
"Belle":
500 variables.append(
'BtagToWBosonVariables(pMissCMS)')
501 variables.append(
'ImpactXY')
502 variables.append(
'distance')
503 variables.append(
'chiProb')
505 elif category ==
'Lambda':
506 variables = [
'lambdaFlavor',
507 'NumberOfKShortsInRoe',
509 'cosAngleBetweenMomentumAndVertexVector',
512 'daughter(0,useCMSFrame(p))',
514 'daughter(1,useCMSFrame(p))',
519 if getBelleOrBelle2() ==
"Belle2":
520 variables.append(
'daughter(1,protonID)')
521 variables.append(
'daughter(0,pionID)')
523 variables.append(
'distance')
525 elif category ==
'MaximumPstar':
526 variables = [
'useCMSFrame(p)',
532 if getBelleOrBelle2() ==
"Belle2":
533 variables.append(
'ImpactXY')
534 variables.append(
'distance')
536 elif category ==
'FSC':
537 variables = [
'useCMSFrame(p)',
539 KId[getBelleOrBelle2()],
540 'FSCVariables(pFastCMS)',
541 'FSCVariables(cosSlowFast)',
542 'FSCVariables(cosTPTOFast)',
543 'FSCVariables(SlowFastHaveOpositeCharges)',
545 elif category ==
'KaonPion':
546 variables = [
'extraInfo(isRightCategory(Kaon))',
547 'HighestProbInCat(pi+:inRoe, isRightCategory(SlowPion))',
548 'KaonPionVariables(cosKaonPion)',
549 'KaonPionVariables(HaveOpositeCharges)',
550 KId[getBelleOrBelle2()]
556 def FillParticleLists(maskName='all', categories=None, path=None):
558 Fills the particle Lists for all categories.
561 from vertex
import kFit
562 readyParticleLists = []
564 if categories
is None:
567 trackCut =
'isInRestOfEvent > 0.5 and passesROEMask(' + maskName +
') > 0.5 and p >= 0'
569 for category
in categories:
570 particleList = AvailableCategories[category].particleList
572 if particleList
in readyParticleLists:
576 if particleList ==
'Lambda0:inRoe':
577 if 'pi+:inRoe' not in readyParticleLists:
578 ma.fillParticleList(
'pi+:inRoe', trackCut, path=path)
579 readyParticleLists.append(
'pi+:inRoe')
581 ma.fillParticleList(
'p+:inRoe', trackCut, path=path)
582 ma.reconstructDecay(particleList +
' -> pi-:inRoe p+:inRoe',
'1.00<=M<=1.23',
False, path=path)
583 kFit(particleList, 0.01, path=path)
584 ma.matchMCTruth(particleList, path=path)
585 readyParticleLists.append(particleList)
589 ma.fillParticleList(particleList, trackCut, path=path)
590 readyParticleLists.append(particleList)
593 if getBelleOrBelle2() ==
'Belle':
594 ma.cutAndCopyList(
'K_S0:inRoe',
'K_S0:mdst',
'extraInfo(ksnbStandard) == 1 and isInRestOfEvent == 1', path=path)
596 if 'pi+:inRoe' not in readyParticleLists:
597 ma.fillParticleList(
'pi+:inRoe', trackCut, path=path)
598 ma.reconstructDecay(
'K_S0:inRoe -> pi+:inRoe pi-:inRoe',
'0.40<=M<=0.60',
False, path=path)
599 kFit(
'K_S0:inRoe', 0.01, path=path)
602 if getBelleOrBelle2() ==
'Belle2':
603 default_list_for_lid_BDT = [
'e+:inRoe',
'mu+:inRoe']
604 list_for_lid_BDT = []
606 for particleList
in default_list_for_lid_BDT:
607 if particleList
in readyParticleLists:
608 list_for_lid_BDT.append(particleList)
611 ma.applyChargedPidMVA(particleLists=list_for_lid_BDT, path=path,
613 binaryHypoPDGCodes=(11, 211))
614 ma.applyChargedPidMVA(particleLists=list_for_lid_BDT, path=path,
616 binaryHypoPDGCodes=(13, 211))
617 ma.applyChargedPidMVA(particleLists=list_for_lid_BDT, path=path,
621 def eventLevel(mode='Expert', weightFiles='B2JpsiKs_mu', categories=None, path=None):
623 Samples data for training or tests all categories all categories at event level.
626 from basf2
import create_path
627 from basf2
import register_module
629 B2INFO(
'EVENT LEVEL')
634 identifiersExtraInfosDict = dict()
635 identifiersExtraInfosKaonPion = []
637 if categories
is None:
640 for category
in categories:
641 particleList = AvailableCategories[category].particleList
643 methodPrefixEventLevel =
"FlavorTagger_" + getBelleOrBelle2() +
"_" + weightFiles +
'EventLevel' + category +
'FBDT'
644 identifierEventLevel = methodPrefixEventLevel
645 targetVariable =
'isRightCategory(' + category +
')'
646 extraInfoName = targetVariable
650 if downloadFlag
or useOnlyLocalFlag:
651 identifierEventLevel = filesDirectory +
'/' + methodPrefixEventLevel +
'_1.root'
654 if not os.path.isfile(identifierEventLevel):
655 basf2_mva.download(methodPrefixEventLevel, identifierEventLevel)
656 if not os.path.isfile(identifierEventLevel):
657 B2FATAL(
'Flavor Tagger: Weight file ' + identifierEventLevel +
658 ' was not downloaded from Database. Please check the buildOrRevision name. Stopped')
661 if not os.path.isfile(identifierEventLevel):
662 B2FATAL(
'Flavor Tagger: ' + particleList +
' Eventlevel was not trained. Weight file ' +
663 identifierEventLevel +
' was not found. Stopped')
665 B2INFO(
'flavorTagger: MVAExpert ' + methodPrefixEventLevel +
' ready.')
667 elif mode ==
'Sampler':
669 identifierEventLevel = filesDirectory +
'/' + methodPrefixEventLevel +
'_1.root'
670 if os.path.isfile(identifierEventLevel):
671 B2INFO(
'flavorTagger: MVAExpert ' + methodPrefixEventLevel +
' ready.')
673 if 'KaonPion' in categories:
674 methodPrefixEventLevelKaonPion =
"FlavorTagger_" + getBelleOrBelle2() + \
675 "_" + weightFiles +
'EventLevelKaonPionFBDT'
676 identifierEventLevelKaonPion = filesDirectory +
'/' + methodPrefixEventLevelKaonPion +
'_1.root'
677 if not os.path.isfile(identifierEventLevelKaonPion):
680 if category !=
"SlowPion" and category !=
"Kaon":
683 if mode ==
'Expert' or (mode ==
'Sampler' and os.path.isfile(identifierEventLevel)):
685 B2INFO(
'flavorTagger: Applying MVAExpert ' + methodPrefixEventLevel +
'.')
687 if category ==
'KaonPion':
688 identifiersExtraInfosKaonPion.append((extraInfoName, identifierEventLevel))
689 elif particleList
not in identifiersExtraInfosDict:
690 identifiersExtraInfosDict[particleList] = [(extraInfoName, identifierEventLevel)]
692 identifiersExtraInfosDict[particleList].append((extraInfoName, identifierEventLevel))
697 for particleList
in identifiersExtraInfosDict:
698 eventLevelPath = create_path()
699 SkipEmptyParticleList = register_module(
"SkimFilter")
700 SkipEmptyParticleList.set_name(
'SkimFilter_EventLevel_' + particleList)
701 SkipEmptyParticleList.param(
'particleLists', particleList)
702 SkipEmptyParticleList.if_true(eventLevelPath, basf2.AfterConditionPath.CONTINUE)
703 path.add_module(SkipEmptyParticleList)
705 mvaMultipleExperts = register_module(
'MVAMultipleExperts')
706 mvaMultipleExperts.set_name(
'MVAMultipleExperts_EventLevel_' + particleList)
707 mvaMultipleExperts.param(
'listNames', [particleList])
708 mvaMultipleExperts.param(
'extraInfoNames', [row[0]
for row
in identifiersExtraInfosDict[particleList]])
709 mvaMultipleExperts.param(
'signalFraction', signalFraction)
710 mvaMultipleExperts.param(
'identifiers', [row[1]
for row
in identifiersExtraInfosDict[particleList]])
711 eventLevelPath.add_module(mvaMultipleExperts)
713 if 'KaonPion' in categories
and len(identifiersExtraInfosKaonPion) != 0:
714 eventLevelKaonPionPath = create_path()
715 SkipEmptyParticleList = register_module(
"SkimFilter")
716 SkipEmptyParticleList.set_name(
'SkimFilter_' +
'K+:inRoe')
717 SkipEmptyParticleList.param(
'particleLists',
'K+:inRoe')
718 SkipEmptyParticleList.if_true(eventLevelKaonPionPath, basf2.AfterConditionPath.CONTINUE)
719 path.add_module(SkipEmptyParticleList)
721 mvaExpertKaonPion = register_module(
"MVAExpert")
722 mvaExpertKaonPion.set_name(
'MVAExpert_KaonPion_' +
'K+:inRoe')
723 mvaExpertKaonPion.param(
'listNames', [
'K+:inRoe'])
724 mvaExpertKaonPion.param(
'extraInfoName', identifiersExtraInfosKaonPion[0][0])
725 mvaExpertKaonPion.param(
'signalFraction', signalFraction)
726 mvaExpertKaonPion.param(
'identifier', identifiersExtraInfosKaonPion[0][1])
728 eventLevelKaonPionPath.add_module(mvaExpertKaonPion)
730 if mode ==
'Sampler':
732 for category
in categories:
733 particleList = AvailableCategories[category].particleList
735 methodPrefixEventLevel =
"FlavorTagger_" + getBelleOrBelle2() +
"_" + weightFiles +
'EventLevel' + category +
'FBDT'
736 identifierEventLevel = filesDirectory +
'/' + methodPrefixEventLevel +
'_1.root'
737 targetVariable =
'isRightCategory(' + category +
')'
739 if not os.path.isfile(identifierEventLevel):
741 if category ==
'KaonPion':
742 methodPrefixEventLevelSlowPion =
"FlavorTagger_" + getBelleOrBelle2() + \
743 "_" + weightFiles +
'EventLevelSlowPionFBDT'
744 identifierEventLevelSlowPion = filesDirectory +
'/' + methodPrefixEventLevelSlowPion +
'_1.root'
745 if not os.path.isfile(identifierEventLevelSlowPion):
746 B2INFO(
"Flavor Tagger: event level weight file for the Slow Pion category is absent." +
747 "It is required to sample the training information for the KaonPion category." +
748 "An additional sampling step will be needed after the following training step.")
751 B2INFO(
'flavorTagger: file ' + filesDirectory +
'/' +
752 methodPrefixEventLevel +
"sampled" + fileId +
'.root will be saved.')
754 ma.applyCuts(particleList,
'isRightCategory(mcAssociated) > 0', path)
755 eventLevelpath = create_path()
756 SkipEmptyParticleList = register_module(
"SkimFilter")
757 SkipEmptyParticleList.set_name(
'SkimFilter_EventLevel' + category)
758 SkipEmptyParticleList.param(
'particleLists', particleList)
759 SkipEmptyParticleList.if_true(eventLevelpath, basf2.AfterConditionPath.CONTINUE)
760 path.add_module(SkipEmptyParticleList)
762 ntuple = register_module(
'VariablesToNtuple')
763 ntuple.param(
'fileName', filesDirectory +
'/' + methodPrefixEventLevel +
"sampled" + fileId +
".root")
764 ntuple.param(
'treeName', methodPrefixEventLevel +
"_tree")
765 variablesToBeSaved = getTrainingVariables(category) + [targetVariable,
'ancestorHasWhichFlavor',
766 'isSignal',
'mcPDG',
'mcErrors',
'genMotherPDG',
767 'nMCMatches',
'B0mcErrors']
768 if category !=
'KaonPion' and category !=
'FSC':
769 variablesToBeSaved = variablesToBeSaved + \
770 [
'extraInfo(isRightTrack(' + category +
'))',
771 'hasHighestProbInCat(' + particleList +
', isRightTrack(' + category +
'))']
772 ntuple.param(
'variables', variablesToBeSaved)
773 ntuple.param(
'particleList', particleList)
774 eventLevelpath.add_module(ntuple)
776 if ReadyMethods != len(categories):
782 def eventLevelTeacher(weightFiles='B2JpsiKs_mu', categories=None):
784 Trains all categories at event level.
787 B2INFO(
'EVENT LEVEL TEACHER')
791 if categories
is None:
794 for category
in categories:
795 methodPrefixEventLevel =
"FlavorTagger_" + getBelleOrBelle2() +
"_" + weightFiles +
'EventLevel' + category +
'FBDT'
796 targetVariable =
'isRightCategory(' + category +
')'
797 weightFile = filesDirectory +
'/' + methodPrefixEventLevel +
"_1.root"
799 if os.path.isfile(weightFile):
803 sampledFilesList = glob.glob(filesDirectory +
'/' + methodPrefixEventLevel +
'sampled*.root')
804 if len(sampledFilesList) == 0:
805 B2INFO(
'flavorTagger: eventLevelTeacher did not find any ' + methodPrefixEventLevel +
806 ".root" +
' file. Please run the flavorTagger in "Sampler" mode afterwards.')
809 B2INFO(
'flavorTagger: MVA Teacher training' + methodPrefixEventLevel +
' .')
810 trainingOptionsEventLevel = basf2_mva.GeneralOptions()
811 trainingOptionsEventLevel.m_datafiles = basf2_mva.vector(*sampledFilesList)
812 trainingOptionsEventLevel.m_treename = methodPrefixEventLevel +
"_tree"
813 trainingOptionsEventLevel.m_identifier = weightFile
814 trainingOptionsEventLevel.m_variables = basf2_mva.vector(*getTrainingVariables(category))
815 trainingOptionsEventLevel.m_target_variable = targetVariable
816 trainingOptionsEventLevel.m_max_events = maxEventsNumber
818 basf2_mva.teacher(trainingOptionsEventLevel, getFastBDTCategories())
821 basf2_mva.upload(weightFile, methodPrefixEventLevel)
823 if ReadyMethods != len(categories):
829 def combinerLevel(mode='Expert', weightFiles='B2JpsiKs_mu', categories=None,
830 variablesCombinerLevel=None, categoriesCombinationCode=None, path=None):
832 Samples the input data or tests the combiner according to the selected categories.
835 B2INFO(
'COMBINER LEVEL')
837 if categories
is None:
839 if variablesCombinerLevel
is None:
840 variablesCombinerLevel = []
842 B2INFO(
"Flavor Tagger: Required Combiner for Categories:")
843 for category
in categories:
846 B2INFO(
"Flavor Tagger: which corresponds to a weight file with categories combination code " + categoriesCombinationCode)
848 methodPrefixCombinerLevel =
"FlavorTagger_" + getBelleOrBelle2() +
"_" + weightFiles +
'Combiner' \
849 + categoriesCombinationCode
851 if mode ==
'Sampler':
853 if os.path.isfile(filesDirectory +
'/' + methodPrefixCombinerLevel +
'FBDT' +
'_1.root')
or \
854 os.path.isfile(filesDirectory +
'/' + methodPrefixCombinerLevel +
'FANN' +
'_1.root'):
855 B2FATAL(
'flavorTagger: File' + methodPrefixCombinerLevel +
'FBDT' +
"_1.root" +
' or ' + methodPrefixCombinerLevel +
856 'FANN' +
'_1.root found. Please run the "Expert" mode or delete the file if a new sampling is desired.')
858 B2INFO(
'flavorTagger: Sampling Data on Combiner Level. File' +
859 methodPrefixCombinerLevel +
".root" +
' will be saved')
861 ntuple = basf2.register_module(
'VariablesToNtuple')
862 ntuple.param(
'fileName', filesDirectory +
'/' + methodPrefixCombinerLevel +
"sampled" + fileId +
".root")
863 ntuple.param(
'treeName', methodPrefixCombinerLevel +
'FBDT' +
"_tree")
864 ntuple.param(
'variables', variablesCombinerLevel + [
'qrCombined'])
865 ntuple.param(
'particleList',
"")
866 path.add_module(ntuple)
872 identifierFBDT = methodPrefixCombinerLevel +
'FBDT'
873 if downloadFlag
or useOnlyLocalFlag:
874 identifierFBDT = filesDirectory +
'/' + methodPrefixCombinerLevel +
'FBDT' +
'_1.root'
877 if not os.path.isfile(identifierFBDT):
878 basf2_mva.download(methodPrefixCombinerLevel +
'FBDT', identifierFBDT)
879 if not os.path.isfile(identifierFBDT):
880 B2FATAL(
'Flavor Tagger: Weight file ' + identifierFBDT +
881 ' was not downloaded from Database. Please check the buildOrRevision name. Stopped')
884 if not os.path.isfile(identifierFBDT):
885 B2FATAL(
'flavorTagger: Combinerlevel FastBDT was not trained with this combination of categories.' +
886 ' Weight file ' + identifierFBDT +
' not found. Stopped')
888 B2INFO(
'flavorTagger: Ready to be used with weightFile ' + methodPrefixCombinerLevel +
'FBDT' +
'_1.root')
891 identifierFANN = methodPrefixCombinerLevel +
'FANN'
892 if downloadFlag
or useOnlyLocalFlag:
893 identifierFANN = filesDirectory +
'/' + methodPrefixCombinerLevel +
'FANN' +
'_1.root'
896 if not os.path.isfile(identifierFANN):
897 basf2_mva.download(methodPrefixCombinerLevel +
'FANN', identifierFANN)
898 if not os.path.isfile(identifierFANN):
899 B2FATAL(
'Flavor Tagger: Weight file ' + identifierFANN +
900 ' was not downloaded from Database. Please check the buildOrRevision name. Stopped')
902 if not os.path.isfile(identifierFANN):
903 B2FATAL(
'flavorTagger: Combinerlevel FANNMLP was not trained with this combination of categories. ' +
904 ' Weight file ' + identifierFANN +
' not found. Stopped')
906 B2INFO(
'flavorTagger: Ready to be used with weightFile ' + methodPrefixCombinerLevel +
'FANN' +
'_1.root')
910 if TMVAfbdt
and not FANNmlp:
911 B2INFO(
'flavorTagger: Apply FBDTMethod ' + methodPrefixCombinerLevel +
'FBDT')
912 path.add_module(
'MVAExpert', listNames=[], extraInfoName=
'qrCombined' +
'FBDT', signalFraction=signalFraction,
913 identifier=identifierFBDT)
915 if FANNmlp
and not TMVAfbdt:
916 B2INFO(
'flavorTagger: Apply FANNMethod on combiner level')
917 path.add_module(
'MVAExpert', listNames=[], extraInfoName=
'qrCombined' +
'FANN', signalFraction=signalFraction,
918 identifier=identifierFANN)
920 if FANNmlp
and TMVAfbdt:
921 B2INFO(
'flavorTagger: Apply FANNMethod and FBDTMethod on combiner level')
922 mvaMultipleExperts = basf2.register_module(
'MVAMultipleExperts')
923 mvaMultipleExperts.set_name(
'MVAMultipleExperts_Combiners')
924 mvaMultipleExperts.param(
'listNames', [])
925 mvaMultipleExperts.param(
'extraInfoNames', [
'qrCombined' +
'FBDT',
'qrCombined' +
'FANN'])
926 mvaMultipleExperts.param(
'signalFraction', signalFraction)
927 mvaMultipleExperts.param(
'identifiers', [identifierFBDT, identifierFANN])
928 path.add_module(mvaMultipleExperts)
931 def combinerLevelTeacher(weightFiles='B2JpsiKs_mu', variablesCombinerLevel=None,
932 categoriesCombinationCode=None):
934 Trains the combiner according to the selected categories.
937 B2INFO(
'COMBINER LEVEL TEACHER')
939 if variablesCombinerLevel
is None:
940 variablesCombinerLevel = []
942 methodPrefixCombinerLevel =
"FlavorTagger_" + getBelleOrBelle2() +
"_" + weightFiles +
'Combiner' \
943 + categoriesCombinationCode
945 sampledFilesList = glob.glob(filesDirectory +
'/' + methodPrefixCombinerLevel +
'sampled*.root')
946 if len(sampledFilesList) == 0:
947 B2FATAL(
'FlavorTagger: combinerLevelTeacher did not find any ' +
948 methodPrefixCombinerLevel +
'sampled*.root file. Please run the flavorTagger in "Sampler" mode.')
952 if not os.path.isfile(filesDirectory +
'/' + methodPrefixCombinerLevel +
'FBDT' +
'_1.root'):
954 B2INFO(
'flavorTagger: MVA Teacher training a FastBDT on Combiner Level')
956 trainingOptionsCombinerLevel = basf2_mva.GeneralOptions()
957 trainingOptionsCombinerLevel.m_datafiles = basf2_mva.vector(*sampledFilesList)
958 trainingOptionsCombinerLevel.m_treename = methodPrefixCombinerLevel +
'FBDT' +
"_tree"
959 trainingOptionsCombinerLevel.m_identifier = filesDirectory +
'/' + methodPrefixCombinerLevel +
'FBDT' +
"_1.root"
960 trainingOptionsCombinerLevel.m_variables = basf2_mva.vector(*variablesCombinerLevel)
961 trainingOptionsCombinerLevel.m_target_variable =
'qrCombined'
962 trainingOptionsCombinerLevel.m_max_events = maxEventsNumber
964 basf2_mva.teacher(trainingOptionsCombinerLevel, getFastBDTCombiner())
967 basf2_mva.upload(filesDirectory +
'/' + methodPrefixCombinerLevel +
968 'FBDT' +
"_1.root", methodPrefixCombinerLevel +
'FBDT')
970 elif FANNmlp
and not os.path.isfile(filesDirectory +
'/' + methodPrefixCombinerLevel +
'FANN' +
'_1.root'):
972 B2INFO(
'flavorTagger: Combinerlevel FBDT was already trained with this combination of categories. Weight file ' +
973 methodPrefixCombinerLevel +
'FBDT' +
'_1.root has been found.')
976 B2FATAL(
'flavorTagger: Combinerlevel was already trained with this combination of categories. Weight files ' +
977 methodPrefixCombinerLevel +
'FBDT' +
'_1.root and ' +
978 methodPrefixCombinerLevel +
'FANN' +
'_1.root has been found. Please use the "Expert" mode')
982 if not os.path.isfile(filesDirectory +
'/' + methodPrefixCombinerLevel +
'FANN' +
'_1.root'):
984 B2INFO(
'flavorTagger: MVA Teacher training a FANN MLP on Combiner Level')
986 trainingOptionsCombinerLevel = basf2_mva.GeneralOptions()
987 trainingOptionsCombinerLevel.m_datafiles = basf2_mva.vector(*sampledFilesList)
988 trainingOptionsCombinerLevel.m_treename = methodPrefixCombinerLevel +
'FBDT' +
"_tree"
989 trainingOptionsCombinerLevel.m_identifier = filesDirectory +
'/' + methodPrefixCombinerLevel +
'FANN' +
"_1.root"
990 trainingOptionsCombinerLevel.m_variables = basf2_mva.vector(*variablesCombinerLevel)
991 trainingOptionsCombinerLevel.m_target_variable =
'qrCombined'
992 trainingOptionsCombinerLevel.m_max_events = maxEventsNumber
994 basf2_mva.teacher(trainingOptionsCombinerLevel, getMlpFANNCombiner())
997 basf2_mva.upload(filesDirectory +
'/' + methodPrefixCombinerLevel +
998 'FANN' +
"_1.root", methodPrefixCombinerLevel +
'FANN')
1000 elif TMVAfbdt
and not os.path.isfile(filesDirectory +
'/' + methodPrefixCombinerLevel +
'FBDT' +
'_1.root'):
1002 B2INFO(
'flavorTagger: Combinerlevel FBDT was already trained with this combination of categories. Weight file ' +
1003 methodPrefixCombinerLevel +
'FANN' +
'_1.config has been found.')
1006 B2FATAL(
'flavorTagger: Combinerlevel was already trained with this combination of categories. Weight files ' +
1007 methodPrefixCombinerLevel +
'FBDT' +
'_1.root and ' +
1008 methodPrefixCombinerLevel +
'FANN' +
'_1.root has been found. Please use the "Expert" mode')
1011 def getEventLevelParticleLists(categories=None):
1013 if categories
is None:
1016 eventLevelParticleLists = []
1018 for category
in categories:
1019 ftCategory = AvailableCategories[category]
1020 event_tuple = (ftCategory.particleList, ftCategory.eventName, ftCategory.variableName)
1022 if event_tuple
not in eventLevelParticleLists:
1023 eventLevelParticleLists.append(event_tuple)
1025 B2FATAL(
'Flavor Tagger: ' + category +
' has been already given')
1027 return eventLevelParticleLists
1033 weightFiles='B2nunubarBGx1',
1034 workingDirectory='.',
1035 combinerMethods=['TMVA-FBDT'],
1038 'IntermediateElectron',
1042 'IntermediateKinLepton',
1050 maskName='FTDefaultMask',
1051 saveCategoriesInfo=True,
1052 useOnlyLocalWeightFiles=False,
1053 downloadFromDatabaseIfNotFound=False,
1054 uploadToDatabaseAfterTraining=False,
1056 prefix='MC15ri_light-2207-bengal_0',
1058 identifierGNN='GFlaT_MC15ri_light_2303_iriomote_0',
1062 Defines the whole flavor tagging process for each selected Rest of Event (ROE) built in the steering file.
1063 The flavor is predicted by Multivariate Methods trained with Variables and MetaVariables which use
1064 Tracks, ECL- and KLMClusters from the corresponding RestOfEvent dataobject.
1065 This module can be used to sample the training information, to train and/or to test the flavorTagger.
1067 @param particleLists The ROEs for flavor tagging are selected from the given particle lists.
1068 @param mode The available modes are
1069 ``Expert`` (default), ``Sampler``, and ``Teacher``. In the ``Expert`` mode
1070 Flavor Tagging is applied to the analysis,. In the ``Sampler`` mode you save
1071 save the variables for training. In the ``Teacher`` mode the FlavorTagger is
1072 trained, for this step you do not reconstruct any particle or do any analysis,
1073 you just run the flavorTagger alone.
1074 @param weightFiles Weight files name. Default=
1075 ``B2nunubarBGx1`` (official weight files). If the user self
1076 wants to train the FlavorTagger, the weightfiles name should correspond to the
1077 analysed CP channel in order to avoid confusions. The default name
1078 ``B2nunubarBGx1`` corresponds to
1079 :math:`B^0_{\\rm sig}\\to \\nu \\overline{\\nu}`.
1080 and ``B2JpsiKs_muBGx1`` to
1081 :math:`B^0_{\\rm sig}\\to J/\\psi (\\to \\mu^+ \\mu^-) K_s (\\to \\pi^+ \\pi^-)`.
1082 BGx1 stays for events simulated with background.
1083 @param workingDirectory Path to the directory containing the FlavorTagging/ folder.
1084 @param combinerMethods MVAs for the combiner: ``TMVA-FBDT` (default).
1085 ``FANN-MLP`` is available only with ``prefix=''`` (MC13 weight files).
1086 @param categories Categories used for flavor tagging. By default all are used.
1087 @param maskName Gets ROE particles from a specified ROE mask.
1088 ``FTDefaultMask`` (default): tentative mask definition that will be created
1089 automatically. The definition is as follows:
1091 - Track (pion): thetaInCDCAcceptance and dr<1 and abs(dz)<3
1092 - ECL-cluster (gamma): thetaInCDCAcceptance and clusterNHits>1.5 and \
1093 [[clusterReg==1 and E>0.08] or [clusterReg==2 and E>0.03] or \
1094 [clusterReg==3 and E>0.06]] \
1095 (Same as gamma:pi0eff30_May2020 and gamma:pi0eff40_May2020)
1097 ``all``: all ROE particles are used.
1098 Or one can give any mask name defined before calling this function.
1099 @param saveCategoriesInfo Sets to save information of individual categories.
1100 @param useOnlyLocalWeightFiles [Expert] Uses only locally saved weight files.
1101 @param downloadFromDatabaseIfNotFound [Expert] Weight files are downloaded from
1102 the conditions database if not available in workingDirectory.
1103 @param uploadToDatabaseAfterTraining [Expert] For librarians only: uploads weight files to localdb after training.
1104 @param samplerFileId Identifier to paralellize
1105 sampling. Only used in ``Sampler`` mode. If you are training by yourself and
1106 want to parallelize the sampling, you can run several sampling scripts in
1107 parallel. By changing this parameter you will not overwrite an older sample.
1108 @param prefix Prefix of weight files.
1109 ``MC15ri_light-2207-bengal_0`` (default): Weight files trained for MC15ri samples.
1110 ``''``: Weight files trained for MC13 samples.
1111 @param useGNN Use GNN-based Flavor Tagger in addition with FastBDT-based one.
1112 Please specify the weight file with the option ``identifierGNN``.
1113 [Expert] In the sampler mode, training files for GNN-based Flavor Tagger is produced.
1114 @param identifierGNN The name of weight file of the GNN-based Flavor Tagger.
1115 [Expert] Multiple identifiers can be given with list(str).
1116 @param path Modules are added to this path
1120 if (
not isinstance(particleLists, list)):
1121 particleLists = [particleLists]
1123 if mode !=
'Sampler' and mode !=
'Teacher' and mode !=
'Expert':
1124 B2FATAL(
'flavorTagger: Wrong mode given: The available modes are "Sampler", "Teacher" or "Expert"')
1126 if len(categories) != len(set(categories)):
1127 dup = [cat
for cat
in set(categories)
if categories.count(cat) > 1]
1128 B2WARNING(
'Flavor Tagger: There are duplicate elements in the given categories list. '
1129 <<
'The following duplicate elements are removed; ' <<
', '.join(dup))
1130 categories = list(set(categories))
1132 if len(categories) < 2:
1133 B2FATAL(
'Flavor Tagger: Invalid amount of categories. At least two are needed.')
1135 'Flavor Tagger: Possible categories are "Electron", "IntermediateElectron", "Muon", "IntermediateMuon", '
1136 '"KinLepton", "IntermediateKinLepton", "Kaon", "SlowPion", "FastHadron",'
1137 '"Lambda", "FSC", "MaximumPstar" or "KaonPion" ')
1139 for category
in categories:
1140 if category
not in AvailableCategories:
1141 B2FATAL(
'Flavor Tagger: ' + category +
' is not a valid category name given')
1142 B2FATAL(
'Flavor Tagger: Available categories are "Electron", "IntermediateElectron", '
1143 '"Muon", "IntermediateMuon", "KinLepton", "IntermediateKinLepton", "Kaon", "SlowPion", "FastHadron", '
1144 '"Lambda", "FSC", "MaximumPstar" or "KaonPion" ')
1146 if mode ==
'Expert' and useGNN
and identifierGNN ==
'':
1147 B2FATAL(
'Please specify the name of the weight file with ``identifierGNN``')
1152 basf2.find_file(workingDirectory)
1154 global filesDirectory
1155 filesDirectory = workingDirectory +
'/FlavorTagging/TrainedMethods'
1157 if mode ==
'Sampler' or (mode ==
'Expert' and downloadFromDatabaseIfNotFound):
1158 if not basf2.find_file(workingDirectory +
'/FlavorTagging', silent=
True):
1159 os.mkdir(workingDirectory +
'/FlavorTagging')
1160 os.mkdir(workingDirectory +
'/FlavorTagging/TrainedMethods')
1161 elif not basf2.find_file(workingDirectory +
'/FlavorTagging/TrainedMethods', silent=
True):
1162 os.mkdir(workingDirectory +
'/FlavorTagging/TrainedMethods')
1163 filesDirectory = workingDirectory +
'/FlavorTagging/TrainedMethods'
1165 if len(combinerMethods) < 1
or len(combinerMethods) > 2:
1166 B2FATAL(
'flavorTagger: Invalid list of combinerMethods. The available methods are "TMVA-FBDT" and "FANN-MLP"')
1174 for method
in combinerMethods:
1175 if method ==
'TMVA-FBDT':
1177 elif method ==
'FANN-MLP':
1180 B2FATAL(
'flavorTagger: Invalid list of combinerMethods. The available methods are "TMVA-FBDT" and "FANN-MLP"')
1183 fileId = samplerFileId
1185 global useOnlyLocalFlag
1186 useOnlyLocalFlag = useOnlyLocalWeightFiles
1188 B2INFO(
'*** FLAVOR TAGGING ***')
1190 B2INFO(
' Working directory is: ' + filesDirectory)
1193 setInteractionWithDatabase(downloadFromDatabaseIfNotFound, uploadToDatabaseAfterTraining)
1196 set_FlavorTagger_pid_aliases_legacy()
1198 set_FlavorTagger_pid_aliases()
1200 alias_list_for_GNN = []
1202 alias_list_for_GNN = set_GNNFlavorTagger_aliases(categories)
1204 setInputVariablesWithMask()
1206 weightFiles = prefix +
'_' + weightFiles
1209 trackLevelParticleLists = []
1210 eventLevelParticleLists = []
1211 variablesCombinerLevel = []
1212 categoriesCombination = []
1213 categoriesCombinationCode =
'CatCode'
1214 for category
in categories:
1215 ftCategory = AvailableCategories[category]
1217 track_tuple = (ftCategory.particleList, ftCategory.trackName)
1218 event_tuple = (ftCategory.particleList, ftCategory.eventName, ftCategory.variableName)
1220 if track_tuple
not in trackLevelParticleLists
and category !=
'MaximumPstar':
1221 trackLevelParticleLists.append(track_tuple)
1223 if event_tuple
not in eventLevelParticleLists:
1224 eventLevelParticleLists.append(event_tuple)
1225 variablesCombinerLevel.append(ftCategory.variableName)
1226 categoriesCombination.append(ftCategory.code)
1228 B2FATAL(
'Flavor Tagger: ' + category +
' has been already given')
1230 for code
in sorted(categoriesCombination):
1231 categoriesCombinationCode = categoriesCombinationCode +
'%02d' % code
1234 if maskName ==
'FTDefaultMask':
1237 'thetaInCDCAcceptance and dr<1 and abs(dz)<3',
1238 'thetaInCDCAcceptance and clusterNHits>1.5 and [[E>0.08 and clusterReg==1] or [E>0.03 and clusterReg==2] or \
1239 [E>0.06 and clusterReg==3]]')
1240 for name
in particleLists:
1241 ma.appendROEMasks(list_name=name, mask_tuples=[FTDefaultMask], path=path)
1244 roe_path = basf2.create_path()
1245 deadEndPath = basf2.create_path()
1247 if mode ==
'Sampler':
1249 ma.signalSideParticleListsFilter(
1251 'nROE_Charged(' + maskName +
', 0) > 0 and abs(qrCombined) == 1',
1255 FillParticleLists(maskName, categories, roe_path)
1258 if eventLevel(
'Expert', weightFiles, categories, roe_path):
1260 ma.rankByHighest(
'pi+:inRoe',
'p', numBest=0, allowMultiRank=
False,
1261 outputVariable=
'FT_p_rank', overwriteRank=
True, path=roe_path)
1262 ma.fillParticleListFromDummy(
'vpho:dummy', path=roe_path)
1263 ma.variablesToNtuple(
'vpho:dummy',
1266 filename=f
'{filesDirectory}/FlavorTagger_GNN_sampled{fileId}.root',
1267 signalSideParticleList=particleLists[0],
1271 if eventLevel(mode, weightFiles, categories, roe_path):
1272 combinerLevel(mode, weightFiles, categories, variablesCombinerLevel, categoriesCombinationCode, roe_path)
1274 path.for_each(
'RestOfEvent',
'RestOfEvents', roe_path)
1276 elif mode ==
'Expert':
1279 ma.signalSideParticleListsFilter(particleLists,
'nROE_Charged(' + maskName +
', 0) > 0', roe_path, deadEndPath)
1282 flavorTaggerInfoBuilder = basf2.register_module(
'FlavorTaggerInfoBuilder')
1283 path.add_module(flavorTaggerInfoBuilder)
1285 FillParticleLists(maskName, categories, roe_path)
1287 if eventLevel(mode, weightFiles, categories, roe_path):
1288 combinerLevel(mode, weightFiles, categories, variablesCombinerLevel, categoriesCombinationCode, roe_path)
1290 flavorTaggerInfoFiller = basf2.register_module(
'FlavorTaggerInfoFiller')
1291 flavorTaggerInfoFiller.param(
'trackLevelParticleLists', trackLevelParticleLists)
1292 flavorTaggerInfoFiller.param(
'eventLevelParticleLists', eventLevelParticleLists)
1293 flavorTaggerInfoFiller.param(
'TMVAfbdt', TMVAfbdt)
1294 flavorTaggerInfoFiller.param(
'FANNmlp', FANNmlp)
1295 flavorTaggerInfoFiller.param(
'qpCategories', saveCategoriesInfo)
1296 flavorTaggerInfoFiller.param(
'istrueCategories', saveCategoriesInfo)
1297 flavorTaggerInfoFiller.param(
'targetProb',
False)
1298 flavorTaggerInfoFiller.param(
'trackPointers',
False)
1299 roe_path.add_module(flavorTaggerInfoFiller)
1300 add_default_FlavorTagger_aliases()
1303 ma.rankByHighest(
'pi+:inRoe',
'p', numBest=0, allowMultiRank=
False,
1304 outputVariable=
'FT_p_rank', overwriteRank=
True, path=roe_path)
1305 ma.fillParticleListFromDummy(
'vpho:dummy', path=roe_path)
1307 if isinstance(identifierGNN, str):
1308 roe_path.add_module(
'MVAExpert',
1309 listNames=
'vpho:dummy',
1310 extraInfoName=
'qrGNN_raw',
1311 identifier=identifierGNN)
1313 ma.variableToSignalSideExtraInfo(
'vpho:dummy', {
'extraInfo(qrGNN_raw)*2-1':
'qrGNN'},
1316 elif isinstance(identifierGNN, list):
1317 identifierGNN = list(set(identifierGNN))
1319 extraInfoNames = [f
'qrGNN_{i_id}' for i_id
in identifierGNN]
1320 roe_path.add_module(
'MVAMultipleExperts',
1321 listNames=
'vpho:dummy',
1322 extraInfoNames=extraInfoNames,
1323 identifiers=identifierGNN)
1326 for extraInfoName
in extraInfoNames:
1327 extraInfoDict[f
'extraInfo({extraInfoName})*2-1'] = extraInfoName
1328 variables.variables.addAlias(extraInfoName, f
'extraInfo({extraInfoName})')
1330 ma.variableToSignalSideExtraInfo(
'vpho:dummy', extraInfoDict,
1333 path.for_each(
'RestOfEvent',
'RestOfEvents', roe_path)
1335 elif mode ==
'Teacher':
1336 if eventLevelTeacher(weightFiles, categories):
1337 combinerLevelTeacher(weightFiles, variablesCombinerLevel, categoriesCombinationCode)
1340 if __name__ ==
'__main__':
1344 function = globals()[
"flavorTagger"]
1345 signature = inspect.formatargspec(*inspect.getfullargspec(function))
1346 desc_list.append((function.__name__, signature +
'\n' + function.__doc__))
1348 from terminal_utils
import Pager
1349 from basf2.utils import pretty_print_description_list
1350 with Pager(
'Flavor Tagger function accepts the following arguments:'):
1351 pretty_print_description_list(desc_list)