Belle II Software development
flavorTagger.py
1#!/usr/bin/env python3
2
3
10
11# ************* Flavor Tagging ************
12# * This script is needed to train *
13# * and to test the flavor tagger. *
14# ********************************************
15
16from basf2 import B2INFO, B2FATAL, B2WARNING
17import basf2
18import basf2_mva
19import modularAnalysis as ma
20import variables
21from variables import utils
22import os
23import glob
24import b2bii
25import collections
26
27
28def getBelleOrBelle2():
29 """
30 Gets the global ModeCode.
31 """
32 if b2bii.isB2BII():
33 return 'Belle'
34 else:
35 return 'Belle2'
36
37
38def setInteractionWithDatabase(downloadFromDatabaseIfNotFound=False, uploadToDatabaseAfterTraining=False):
39 """
40 Sets the interaction with the database: download trained weight files or upload weight files after training.
41 """
42
43 global downloadFlag
44 global uploadFlag
45
46 downloadFlag = downloadFromDatabaseIfNotFound
47 uploadFlag = uploadToDatabaseAfterTraining
48
49
50# Default list of aliases that should be used to save the flavor tagging information using VariablesToNtuple
51flavor_tagging = ['FBDT_qrCombined', 'FANN_qrCombined', 'qrMC', 'mcFlavorOfOtherB', 'qrGNN',
52 'qpElectron', 'hasTrueTargetElectron', 'isRightCategoryElectron',
53 'qpIntermediateElectron', 'hasTrueTargetIntermediateElectron', 'isRightCategoryIntermediateElectron',
54 'qpMuon', 'hasTrueTargetMuon', 'isRightCategoryMuon',
55 'qpIntermediateMuon', 'hasTrueTargetIntermediateMuon', 'isRightCategoryIntermediateMuon',
56 'qpKinLepton', 'hasTrueTargetKinLepton', 'isRightCategoryKinLepton',
57 'qpIntermediateKinLepton', 'hasTrueTargetIntermediateKinLepton', 'isRightCategoryIntermediateKinLepton',
58 'qpKaon', 'hasTrueTargetKaon', 'isRightCategoryKaon',
59 'qpSlowPion', 'hasTrueTargetSlowPion', 'isRightCategorySlowPion',
60 'qpFastHadron', 'hasTrueTargetFastHadron', 'isRightCategoryFastHadron',
61 'qpLambda', 'hasTrueTargetLambda', 'isRightCategoryLambda',
62 'qpFSC', 'hasTrueTargetFSC', 'isRightCategoryFSC',
63 'qpMaximumPstar', 'hasTrueTargetMaximumPstar', 'isRightCategoryMaximumPstar',
64 'qpKaonPion', 'hasTrueTargetKaonPion', 'isRightCategoryKaonPion']
65
66
67def add_default_FlavorTagger_aliases():
68 """
69 This function adds the default aliases for flavor tagging variables
70 and defines the collection of flavor tagging variables.
71 """
72
73 variables.variables.addAlias('FBDT_qrCombined', 'qrOutput(FBDT)')
74 variables.variables.addAlias('FANN_qrCombined', 'qrOutput(FANN)')
75 variables.variables.addAlias('qrMC', 'isRelatedRestOfEventB0Flavor')
76
77 variables.variables.addAlias('qrGNN', 'extraInfo(qrGNN)')
78
79 for iCategory in AvailableCategories:
80 aliasForQp = 'qp' + iCategory
81 aliasForTrueTarget = 'hasTrueTarget' + iCategory
82 aliasForIsRightCategory = 'isRightCategory' + iCategory
83 variables.variables.addAlias(aliasForQp, 'qpCategory(' + iCategory + ')')
84 variables.variables.addAlias(aliasForTrueTarget, 'hasTrueTargets(' + iCategory + ')')
85 variables.variables.addAlias(aliasForIsRightCategory, 'isTrueFTCategory(' + iCategory + ')')
86
87 utils.add_collection(flavor_tagging, 'flavor_tagging')
88
89
90def set_FlavorTagger_pid_aliases():
91 """
92 This function adds the pid aliases needed by the flavor tagger.
93 """
94 variables.variables.addAlias('eid_TOP', 'pidPairProbabilityExpert(11, 211, TOP)')
95 variables.variables.addAlias('eid_ARICH', 'pidPairProbabilityExpert(11, 211, ARICH)')
96 variables.variables.addAlias('eid_ECL', 'pidPairProbabilityExpert(11, 211, ECL)')
97
98 variables.variables.addAlias('muid_TOP', 'pidPairProbabilityExpert(13, 211, TOP)')
99 variables.variables.addAlias('muid_ARICH', 'pidPairProbabilityExpert(13, 211, ARICH)')
100 variables.variables.addAlias('muid_KLM', 'pidPairProbabilityExpert(13, 211, KLM)')
101
102 variables.variables.addAlias('piid_TOP', 'pidPairProbabilityExpert(211, 321, TOP)')
103 variables.variables.addAlias('piid_ARICH', 'pidPairProbabilityExpert(211, 321, ARICH)')
104
105 variables.variables.addAlias('Kid_TOP', 'pidPairProbabilityExpert(321, 211, TOP)')
106 variables.variables.addAlias('Kid_ARICH', 'pidPairProbabilityExpert(321, 211, ARICH)')
107
108 if getBelleOrBelle2() == "Belle":
109 variables.variables.addAlias('eid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, CDC, SVD), 0.5)')
110 variables.variables.addAlias('muid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, CDC, SVD), 0.5)')
111 variables.variables.addAlias('piid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, CDC, SVD), 0.5)')
112 variables.variables.addAlias('pi_vs_edEdxid', 'ifNANgiveX(pidPairProbabilityExpert(211, 11, CDC, SVD), 0.5)')
113 variables.variables.addAlias('Kid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, CDC, SVD), 0.5)')
114 else:
115 variables.variables.addAlias('eid_dEdx', 'pidPairProbabilityExpert(11, 211, CDC)')
116 variables.variables.addAlias('muid_dEdx', 'pidPairProbabilityExpert(13, 211, CDC)')
117 variables.variables.addAlias('piid_dEdx', 'pidPairProbabilityExpert(211, 321, CDC)')
118 variables.variables.addAlias('pi_vs_edEdxid', 'pidPairProbabilityExpert(211, 11, CDC)')
119 variables.variables.addAlias('Kid_dEdx', 'pidPairProbabilityExpert(321, 211, CDC)')
120
121
122def set_FlavorTagger_pid_aliases_legacy():
123 """
124 This function adds the pid aliases needed by the flavor tagger trained for MC13.
125 """
126 variables.variables.addAlias('eid_TOP', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, TOP), 0.5)')
127 variables.variables.addAlias('eid_ARICH', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, ARICH), 0.5)')
128 variables.variables.addAlias('eid_ECL', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, ECL), 0.5)')
129
130 variables.variables.addAlias('muid_TOP', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, TOP), 0.5)')
131 variables.variables.addAlias('muid_ARICH', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, ARICH), 0.5)')
132 variables.variables.addAlias('muid_KLM', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, KLM), 0.5)')
133
134 variables.variables.addAlias('piid_TOP', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, TOP), 0.5)')
135 variables.variables.addAlias('piid_ARICH', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, ARICH), 0.5)')
136
137 variables.variables.addAlias('Kid_TOP', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, TOP), 0.5)')
138 variables.variables.addAlias('Kid_ARICH', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, ARICH), 0.5)')
139
140 if getBelleOrBelle2() == "Belle":
141 variables.variables.addAlias('eid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, CDC, SVD), 0.5)')
142 variables.variables.addAlias('muid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, CDC, SVD), 0.5)')
143 variables.variables.addAlias('piid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, CDC, SVD), 0.5)')
144 variables.variables.addAlias('pi_vs_edEdxid', 'ifNANgiveX(pidPairProbabilityExpert(211, 11, CDC, SVD), 0.5)')
145 variables.variables.addAlias('Kid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, CDC, SVD), 0.5)')
146 else:
147 # Removed SVD PID for Belle II MC and data as it is absent in release 4.
148 variables.variables.addAlias('eid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, CDC), 0.5)')
149 variables.variables.addAlias('muid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, CDC), 0.5)')
150 variables.variables.addAlias('piid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, CDC), 0.5)')
151 variables.variables.addAlias('pi_vs_edEdxid', 'ifNANgiveX(pidPairProbabilityExpert(211, 11, CDC), 0.5)')
152 variables.variables.addAlias('Kid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, CDC), 0.5)')
153
154
155def set_GNNFlavorTagger_aliases(categories):
156 """
157 This function adds aliases for the GNN-based flavor tagger.
158 """
159
160 # will be used for target variable 0:B0bar, 1:B0
161 variables.variables.addAlias('qrCombined_bit', '(qrCombined+1)/2')
162 alias_list = ['qrCombined_bit']
163
164 var_dict = {
165 # position
166 'dx': 'dx',
167 'dy': 'dy',
168 'dz': 'dz',
169 # mask
170 'E': 'E',
171 # charge,
172 'charge': 'charge',
173 # feature
174 'px_c': 'px*charge',
175 'py_c': 'py*charge',
176 'pz_c': 'pz*charge',
177 'electronID_c': 'electronID*charge',
178 'muonID_c': 'muonID*charge',
179 'pionID_c': 'pionID*charge',
180 'kaonID_c': 'kaonID*charge',
181 'protonID_c': 'protonID*charge',
182 'deuteronID_c': 'deuteronID*charge',
183 'electronID_noSVD_noTOP_c': 'electronID_noSVD_noTOP*charge',
184 }
185
186 # 16 charged particles are used at most
187 for rank in range(1, 17):
188
189 for cat in categories:
190 listName = AvailableCategories[cat].particleList
191 varName = f'QpTrack({listName}, isRightCategory({cat}), isRightCategory({cat}))'
192
193 varWithRank = f'ifNANgiveX(getVariableByRank(pi+:inRoe, FT_p, {varName}, {rank}), 0)'
194 aliasWithRank = f'{cat}_rank{rank}'
195
196 variables.variables.addAlias(aliasWithRank, varWithRank)
197 alias_list.append(aliasWithRank)
198
199 for alias, var in var_dict.items():
200 varWithRank = f'ifNANgiveX(getVariableByRank(pi+:inRoe, FT_p, {var}, {rank}), 0)'
201 aliasWithRank = f'{alias}_rank{rank}'
202
203 variables.variables.addAlias(aliasWithRank, varWithRank)
204 alias_list.append(aliasWithRank)
205
206 return alias_list
207
208
209def setInputVariablesWithMask(maskName='all'):
210 """
211 Set aliases for input variables with ROE mask.
212 """
213 variables.variables.addAlias('pMissTag_withMask', 'pMissTag('+maskName+')')
214 variables.variables.addAlias('cosTPTO_withMask', 'cosTPTO('+maskName+')')
215 variables.variables.addAlias('ptTracksRoe_withMask', 'ptTracksRoe('+maskName+')')
216 variables.variables.addAlias('pt2TracksRoe_withMask', 'pt2TracksRoe('+maskName+')')
217 variables.variables.addAlias('ptTracksRoe_withMask', 'ptTracksRoe('+maskName+')')
218
219
220def getFastBDTCategories():
221 '''
222 Helper function for getting the FastBDT categories.
223 It's necessary for removing top-level ROOT imports. '''
224 fastBDTCategories = basf2_mva.FastBDTOptions()
225 fastBDTCategories.m_nTrees = 500
226 fastBDTCategories.m_nCuts = 8
227 fastBDTCategories.m_nLevels = 3
228 fastBDTCategories.m_shrinkage = 0.10
229 fastBDTCategories.m_randRatio = 0.5
230 return fastBDTCategories
231
232
233def getFastBDTCombiner():
234 '''
235 Helper function for getting the FastBDT combiner.
236 It's necessary for removing top-level ROOT imports.
237 '''
238 fastBDTCombiner = basf2_mva.FastBDTOptions()
239 fastBDTCombiner.m_nTrees = 500
240 fastBDTCombiner.m_nCuts = 8
241 fastBDTCombiner.m_nLevels = 3
242 fastBDTCombiner.m_shrinkage = 0.10
243 fastBDTCombiner.m_randRatio = 0.5
244 return fastBDTCombiner
245
246
247def getMlpFANNCombiner():
248 '''
249 Helper function for getting the MLP FANN combiner.
250 It's necessary for removing top-level ROOT imports.
251 '''
252 mlpFANNCombiner = basf2_mva.FANNOptions()
253 mlpFANNCombiner.m_max_epochs = 10000
254 mlpFANNCombiner.m_hidden_layers_architecture = "3*N"
255 mlpFANNCombiner.m_hidden_activiation_function = "FANN_SIGMOID_SYMMETRIC"
256 mlpFANNCombiner.m_output_activiation_function = "FANN_SIGMOID_SYMMETRIC"
257 mlpFANNCombiner.m_error_function = "FANN_ERRORFUNC_LINEAR"
258 mlpFANNCombiner.m_training_method = "FANN_TRAIN_RPROP"
259 mlpFANNCombiner.m_validation_fraction = 0.5
260 mlpFANNCombiner.m_random_seeds = 10
261 mlpFANNCombiner.m_test_rate = 500
262 mlpFANNCombiner.m_number_of_threads = 8
263 mlpFANNCombiner.m_scale_features = True
264 mlpFANNCombiner.m_scale_target = False
265 # mlpFANNCombiner.m_scale_target = True
266 return mlpFANNCombiner
267
268
269# SignalFraction: FBDT feature
270# For smooth output set to -1, this will break the calibration.
271# For correct calibration set to -2, leads to peaky combiner output.
272signalFraction = -2
273
274# Maximal number of events to train each method
275maxEventsNumber = 0 # 0 takes all the sampled events. The number in the past was 500000
276
277
278FTCategoryParameters = collections.namedtuple('FTCategoryParameters',
279 ['particleList', 'trackName', 'eventName', 'variableName', 'code'])
280# Definition of all available categories, 'standard category name':
281# ['ParticleList', 'trackLevel category name', 'eventLevel category name',
282# 'combinerLevel variable name', 'category code']
283AvailableCategories = {
284 'Electron':
285 FTCategoryParameters('e+:inRoe', 'Electron', 'Electron',
286 'QpOf(e+:inRoe, isRightCategory(Electron), isRightCategory(Electron))',
287 0),
288 'IntermediateElectron':
289 FTCategoryParameters('e+:inRoe', 'IntermediateElectron', 'IntermediateElectron',
290 'QpOf(e+:inRoe, isRightCategory(IntermediateElectron), isRightCategory(IntermediateElectron))',
291 1),
292 'Muon':
293 FTCategoryParameters('mu+:inRoe', 'Muon', 'Muon',
294 'QpOf(mu+:inRoe, isRightCategory(Muon), isRightCategory(Muon))',
295 2),
296 'IntermediateMuon':
297 FTCategoryParameters('mu+:inRoe', 'IntermediateMuon', 'IntermediateMuon',
298 'QpOf(mu+:inRoe, isRightCategory(IntermediateMuon), isRightCategory(IntermediateMuon))',
299 3),
300 'KinLepton':
301 FTCategoryParameters('mu+:inRoe', 'KinLepton', 'KinLepton',
302 'QpOf(mu+:inRoe, isRightCategory(KinLepton), isRightCategory(KinLepton))',
303 4),
304 'IntermediateKinLepton':
305 FTCategoryParameters('mu+:inRoe', 'IntermediateKinLepton', 'IntermediateKinLepton',
306 'QpOf(mu+:inRoe, isRightCategory(IntermediateKinLepton), isRightCategory(IntermediateKinLepton))',
307 5),
308 'Kaon':
309 FTCategoryParameters('K+:inRoe', 'Kaon', 'Kaon',
310 'weightedQpOf(K+:inRoe, isRightCategory(Kaon), isRightCategory(Kaon))',
311 6),
312 'SlowPion':
313 FTCategoryParameters('pi+:inRoe', 'SlowPion', 'SlowPion',
314 'QpOf(pi+:inRoe, isRightCategory(SlowPion), isRightCategory(SlowPion))',
315 7),
316 'FastHadron':
317 FTCategoryParameters('pi+:inRoe', 'FastHadron', 'FastHadron',
318 'QpOf(pi+:inRoe, isRightCategory(FastHadron), isRightCategory(FastHadron))',
319 8),
320 'Lambda':
321 FTCategoryParameters('Lambda0:inRoe', 'Lambda', 'Lambda',
322 'weightedQpOf(Lambda0:inRoe, isRightCategory(Lambda), isRightCategory(Lambda))',
323 9),
324 'FSC':
325 FTCategoryParameters('pi+:inRoe', 'SlowPion', 'FSC',
326 'QpOf(pi+:inRoe, isRightCategory(FSC), isRightCategory(SlowPion))',
327 10),
328 'MaximumPstar':
329 FTCategoryParameters('pi+:inRoe', 'MaximumPstar', 'MaximumPstar',
330 'QpOf(pi+:inRoe, isRightCategory(MaximumPstar), isRightCategory(MaximumPstar))',
331 11),
332 'KaonPion':
333 FTCategoryParameters('K+:inRoe', 'Kaon', 'KaonPion',
334 'QpOf(K+:inRoe, isRightCategory(KaonPion), isRightCategory(Kaon))',
335 12),
336}
337
338
339def getTrainingVariables(category=None):
340 """
341 Helper function to get training variables.
342
343 NOTE: This function is not called the Expert mode. It is not necessary to be consistent with variables list of weight files.
344 """
345
346 KId = {'Belle': 'ifNANgiveX(atcPIDBelle(3,2), 0.5)', 'Belle2': 'kaonID'}
347 muId = {'Belle': 'muIDBelle', 'Belle2': 'muonID'}
348 eId = {'Belle': 'eIDBelle', 'Belle2': 'electronID'}
349
350 variables = []
351 if category == 'Electron' or category == 'IntermediateElectron':
352 variables = ['useCMSFrame(p)',
353 'useCMSFrame(pt)',
354 'p',
355 'pt',
356 'cosTheta',
357 eId[getBelleOrBelle2()],
358 'eid_TOP',
359 'eid_ARICH',
360 'eid_ECL',
361 'BtagToWBosonVariables(recoilMassSqrd)',
362 'BtagToWBosonVariables(pMissCMS)',
363 'BtagToWBosonVariables(cosThetaMissCMS)',
364 'BtagToWBosonVariables(EW90)',
365 'cosTPTO',
366 'chiProb',
367 ]
368 if getBelleOrBelle2() == "Belle":
369 variables.append('eid_dEdx')
370 variables.append('ImpactXY')
371 variables.append('distance')
372
373 elif category == 'Muon' or category == 'IntermediateMuon':
374 variables = ['useCMSFrame(p)',
375 'useCMSFrame(pt)',
376 'p',
377 'pt',
378 'cosTheta',
379 muId[getBelleOrBelle2()],
380 'muid_TOP',
381 'muid_ARICH',
382 'muid_KLM',
383 'BtagToWBosonVariables(recoilMassSqrd)',
384 'BtagToWBosonVariables(pMissCMS)',
385 'BtagToWBosonVariables(cosThetaMissCMS)',
386 'BtagToWBosonVariables(EW90)',
387 'cosTPTO',
388 ]
389 if getBelleOrBelle2() == "Belle":
390 variables.append('muid_dEdx')
391 variables.append('ImpactXY')
392 variables.append('distance')
393 variables.append('chiProb')
394
395 elif category == 'KinLepton' or category == 'IntermediateKinLepton':
396 variables = ['useCMSFrame(p)',
397 'useCMSFrame(pt)',
398 'p',
399 'pt',
400 'cosTheta',
401 muId[getBelleOrBelle2()],
402 'muid_TOP',
403 'muid_ARICH',
404 'muid_KLM',
405 eId[getBelleOrBelle2()],
406 'eid_TOP',
407 'eid_ARICH',
408 'eid_ECL',
409 'BtagToWBosonVariables(recoilMassSqrd)',
410 'BtagToWBosonVariables(pMissCMS)',
411 'BtagToWBosonVariables(cosThetaMissCMS)',
412 'BtagToWBosonVariables(EW90)',
413 'cosTPTO',
414 ]
415 if getBelleOrBelle2() == "Belle":
416 variables.append('eid_dEdx')
417 variables.append('muid_dEdx')
418 variables.append('ImpactXY')
419 variables.append('distance')
420 variables.append('chiProb')
421
422 elif category == 'Kaon':
423 variables = ['useCMSFrame(p)',
424 'useCMSFrame(pt)',
425 'p',
426 'pt',
427 'cosTheta',
428 KId[getBelleOrBelle2()],
429 'Kid_dEdx',
430 'Kid_TOP',
431 'Kid_ARICH',
432 'NumberOfKShortsInRoe',
433 'ptTracksRoe',
434 'BtagToWBosonVariables(recoilMassSqrd)',
435 'BtagToWBosonVariables(pMissCMS)',
436 'BtagToWBosonVariables(cosThetaMissCMS)',
437 'BtagToWBosonVariables(EW90)',
438 'cosTPTO',
439 'chiProb',
440 ]
441 if getBelleOrBelle2() == "Belle":
442 variables.append('ImpactXY')
443 variables.append('distance')
444
445 elif category == 'SlowPion':
446 variables = ['useCMSFrame(p)',
447 'useCMSFrame(pt)',
448 'cosTheta',
449 'p',
450 'pt',
451 'pionID',
452 'piid_TOP',
453 'piid_ARICH',
454 'pi_vs_edEdxid',
455 KId[getBelleOrBelle2()],
456 'Kid_dEdx',
457 'Kid_TOP',
458 'Kid_ARICH',
459 'NumberOfKShortsInRoe',
460 'ptTracksRoe',
461 eId[getBelleOrBelle2()],
462 'BtagToWBosonVariables(recoilMassSqrd)',
463 'BtagToWBosonVariables(EW90)',
464 'BtagToWBosonVariables(cosThetaMissCMS)',
465 'BtagToWBosonVariables(pMissCMS)',
466 'cosTPTO',
467 ]
468 if getBelleOrBelle2() == "Belle":
469 variables.append('piid_dEdx')
470 variables.append('ImpactXY')
471 variables.append('distance')
472 variables.append('chiProb')
473
474 elif category == 'FastHadron':
475 variables = ['useCMSFrame(p)',
476 'useCMSFrame(pt)',
477 'cosTheta',
478 'p',
479 'pt',
480 'pionID',
481 'piid_dEdx',
482 'piid_TOP',
483 'piid_ARICH',
484 'pi_vs_edEdxid',
485 KId[getBelleOrBelle2()],
486 'Kid_dEdx',
487 'Kid_TOP',
488 'Kid_ARICH',
489 'NumberOfKShortsInRoe',
490 'ptTracksRoe',
491 eId[getBelleOrBelle2()],
492 'BtagToWBosonVariables(recoilMassSqrd)',
493 'BtagToWBosonVariables(EW90)',
494 'BtagToWBosonVariables(cosThetaMissCMS)',
495 'cosTPTO',
496 ]
497 if getBelleOrBelle2() == "Belle":
498 variables.append('BtagToWBosonVariables(pMissCMS)')
499 variables.append('ImpactXY')
500 variables.append('distance')
501 variables.append('chiProb')
502
503 elif category == 'Lambda':
504 variables = ['lambdaFlavor',
505 'NumberOfKShortsInRoe',
506 'M',
507 'cosAngleBetweenMomentumAndVertexVector',
508 'lambdaZError',
509 'daughter(0,p)',
510 'daughter(0,useCMSFrame(p))',
511 'daughter(1,p)',
512 'daughter(1,useCMSFrame(p))',
513 'useCMSFrame(p)',
514 'p',
515 'chiProb',
516 ]
517 if getBelleOrBelle2() == "Belle2":
518 variables.append('daughter(1,protonID)') # protonID always 0 in B2BII check in future
519 variables.append('daughter(0,pionID)') # not very powerful in B2BII
520 else:
521 variables.append('distance')
522
523 elif category == 'MaximumPstar':
524 variables = ['useCMSFrame(p)',
525 'useCMSFrame(pt)',
526 'p',
527 'pt',
528 'cosTPTO',
529 ]
530 if getBelleOrBelle2() == "Belle2":
531 variables.append('ImpactXY')
532 variables.append('distance')
533
534 elif category == 'FSC':
535 variables = ['useCMSFrame(p)',
536 'cosTPTO',
537 KId[getBelleOrBelle2()],
538 'FSCVariables(pFastCMS)',
539 'FSCVariables(cosSlowFast)',
540 'FSCVariables(cosTPTOFast)',
541 'FSCVariables(SlowFastHaveOpositeCharges)',
542 ]
543 elif category == 'KaonPion':
544 variables = ['extraInfo(isRightCategory(Kaon))',
545 'HighestProbInCat(pi+:inRoe, isRightCategory(SlowPion))',
546 'KaonPionVariables(cosKaonPion)',
547 'KaonPionVariables(HaveOpositeCharges)',
548 KId[getBelleOrBelle2()]
549 ]
550
551 return variables
552
553
554def FillParticleLists(maskName='all', categories=None, path=None):
555 """
556 Fills the particle Lists for all categories.
557 """
558
559 from vertex import kFit
560 readyParticleLists = []
561
562 if categories is None:
563 categories = []
564
565 trackCut = 'isInRestOfEvent > 0.5 and passesROEMask(' + maskName + ') > 0.5 and p >= 0'
566
567 for category in categories:
568 particleList = AvailableCategories[category].particleList
569
570 if particleList in readyParticleLists:
571 continue
572
573 # Select particles in ROE for different categories according to mass hypothesis.
574 if particleList == 'Lambda0:inRoe':
575 if 'pi+:inRoe' not in readyParticleLists:
576 ma.fillParticleList('pi+:inRoe', trackCut, path=path)
577 readyParticleLists.append('pi+:inRoe')
578
579 ma.fillParticleList('p+:inRoe', trackCut, path=path)
580 ma.reconstructDecay(particleList + ' -> pi-:inRoe p+:inRoe', '1.00<=M<=1.23', False, path=path)
581 kFit(particleList, 0.01, path=path)
582 ma.matchMCTruth(particleList, path=path)
583 readyParticleLists.append(particleList)
584
585 else:
586 # Filling particle list for actual category
587 ma.fillParticleList(particleList, trackCut, path=path)
588 readyParticleLists.append(particleList)
589
590 # Additional particleLists for K_S0
591 if getBelleOrBelle2() == 'Belle':
592 ma.cutAndCopyList('K_S0:inRoe', 'K_S0:mdst', 'extraInfo(ksnbStandard) == 1 and isInRestOfEvent == 1', path=path)
593 else:
594 if 'pi+:inRoe' not in readyParticleLists:
595 ma.fillParticleList('pi+:inRoe', trackCut, path=path)
596 ma.reconstructDecay('K_S0:inRoe -> pi+:inRoe pi-:inRoe', '0.40<=M<=0.60', False, path=path)
597 kFit('K_S0:inRoe', 0.01, path=path)
598
599 # Apply BDT-based LID
600 if getBelleOrBelle2() == 'Belle2':
601 default_list_for_lid_BDT = ['e+:inRoe', 'mu+:inRoe']
602 list_for_lid_BDT = []
603
604 for particleList in default_list_for_lid_BDT:
605 if particleList in readyParticleLists:
606 list_for_lid_BDT.append(particleList)
607
608 if list_for_lid_BDT: # empty check
609 ma.applyChargedPidMVA(particleLists=list_for_lid_BDT, path=path,
610 trainingMode=0, # binary
611 binaryHypoPDGCodes=(11, 211)) # e vs pi
612 ma.applyChargedPidMVA(particleLists=list_for_lid_BDT, path=path,
613 trainingMode=0, # binary
614 binaryHypoPDGCodes=(13, 211)) # mu vs pi
615 ma.applyChargedPidMVA(particleLists=list_for_lid_BDT, path=path,
616 trainingMode=1) # Multiclass
617
618
619def eventLevel(mode='Expert', weightFiles='B2JpsiKs_mu', categories=None, path=None):
620 """
621 Samples data for training or tests all categories all categories at event level.
622 """
623
624 from basf2 import create_path
625 from basf2 import register_module
626
627 B2INFO('EVENT LEVEL')
628
629 ReadyMethods = 0
630
631 # Each category has its own Path in order to be skipped if the corresponding particle list is empty
632 identifiersExtraInfosDict = dict()
633 identifiersExtraInfosKaonPion = []
634
635 if categories is None:
636 categories = []
637
638 for category in categories:
639 particleList = AvailableCategories[category].particleList
640
641 methodPrefixEventLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'EventLevel' + category + 'FBDT'
642 identifierEventLevel = methodPrefixEventLevel
643 targetVariable = 'isRightCategory(' + category + ')'
644 extraInfoName = targetVariable
645
646 if mode == 'Expert':
647
648 if downloadFlag or useOnlyLocalFlag:
649 identifierEventLevel = filesDirectory + '/' + methodPrefixEventLevel + '_1.root'
650
651 if downloadFlag:
652 if not os.path.isfile(identifierEventLevel):
653 basf2_mva.download(methodPrefixEventLevel, identifierEventLevel)
654 if not os.path.isfile(identifierEventLevel):
655 B2FATAL('Flavor Tagger: Weight file ' + identifierEventLevel +
656 ' was not downloaded from Database. Please check the buildOrRevision name. Stopped')
657
658 if useOnlyLocalFlag:
659 if not os.path.isfile(identifierEventLevel):
660 B2FATAL('Flavor Tagger: ' + particleList + ' Eventlevel was not trained. Weight file ' +
661 identifierEventLevel + ' was not found. Stopped')
662
663 B2INFO('flavorTagger: MVAExpert ' + methodPrefixEventLevel + ' ready.')
664
665 elif mode == 'Sampler':
666
667 identifierEventLevel = filesDirectory + '/' + methodPrefixEventLevel + '_1.root'
668 if os.path.isfile(identifierEventLevel):
669 B2INFO('flavorTagger: MVAExpert ' + methodPrefixEventLevel + ' ready.')
670
671 if 'KaonPion' in categories:
672 methodPrefixEventLevelKaonPion = "FlavorTagger_" + getBelleOrBelle2() + \
673 "_" + weightFiles + 'EventLevelKaonPionFBDT'
674 identifierEventLevelKaonPion = filesDirectory + '/' + methodPrefixEventLevelKaonPion + '_1.root'
675 if not os.path.isfile(identifierEventLevelKaonPion):
676 # Slow Pion and Kaon categories are used if Kaon-Pion is lacking for
677 # sampling. The others are not needed and skipped
678 if category != "SlowPion" and category != "Kaon":
679 continue
680
681 if mode == 'Expert' or (mode == 'Sampler' and os.path.isfile(identifierEventLevel)):
682
683 B2INFO('flavorTagger: Applying MVAExpert ' + methodPrefixEventLevel + '.')
684
685 if category == 'KaonPion':
686 identifiersExtraInfosKaonPion.append((extraInfoName, identifierEventLevel))
687 elif particleList not in identifiersExtraInfosDict:
688 identifiersExtraInfosDict[particleList] = [(extraInfoName, identifierEventLevel)]
689 else:
690 identifiersExtraInfosDict[particleList].append((extraInfoName, identifierEventLevel))
691
692 ReadyMethods += 1
693
694 # Each category has its own Path in order to be skipped if the corresponding particle list is empty
695 for particleList in identifiersExtraInfosDict:
696 eventLevelPath = create_path()
697 SkipEmptyParticleList = register_module("SkimFilter")
698 SkipEmptyParticleList.set_name('SkimFilter_EventLevel_' + particleList)
699 SkipEmptyParticleList.param('particleLists', particleList)
700 SkipEmptyParticleList.if_true(eventLevelPath, basf2.AfterConditionPath.CONTINUE)
701 path.add_module(SkipEmptyParticleList)
702
703 mvaMultipleExperts = register_module('MVAMultipleExperts')
704 mvaMultipleExperts.set_name('MVAMultipleExperts_EventLevel_' + particleList)
705 mvaMultipleExperts.param('listNames', [particleList])
706 mvaMultipleExperts.param('extraInfoNames', [row[0] for row in identifiersExtraInfosDict[particleList]])
707 mvaMultipleExperts.param('signalFraction', signalFraction)
708 mvaMultipleExperts.param('identifiers', [row[1] for row in identifiersExtraInfosDict[particleList]])
709 eventLevelPath.add_module(mvaMultipleExperts)
710
711 if 'KaonPion' in categories and len(identifiersExtraInfosKaonPion) != 0:
712 eventLevelKaonPionPath = create_path()
713 SkipEmptyParticleList = register_module("SkimFilter")
714 SkipEmptyParticleList.set_name('SkimFilter_' + 'K+:inRoe')
715 SkipEmptyParticleList.param('particleLists', 'K+:inRoe')
716 SkipEmptyParticleList.if_true(eventLevelKaonPionPath, basf2.AfterConditionPath.CONTINUE)
717 path.add_module(SkipEmptyParticleList)
718
719 mvaExpertKaonPion = register_module("MVAExpert")
720 mvaExpertKaonPion.set_name('MVAExpert_KaonPion_' + 'K+:inRoe')
721 mvaExpertKaonPion.param('listNames', ['K+:inRoe'])
722 mvaExpertKaonPion.param('extraInfoName', identifiersExtraInfosKaonPion[0][0])
723 mvaExpertKaonPion.param('signalFraction', signalFraction)
724 mvaExpertKaonPion.param('identifier', identifiersExtraInfosKaonPion[0][1])
725
726 eventLevelKaonPionPath.add_module(mvaExpertKaonPion)
727
728 if mode == 'Sampler':
729
730 for category in categories:
731 particleList = AvailableCategories[category].particleList
732
733 methodPrefixEventLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'EventLevel' + category + 'FBDT'
734 identifierEventLevel = filesDirectory + '/' + methodPrefixEventLevel + '_1.root'
735 targetVariable = 'isRightCategory(' + category + ')'
736
737 if not os.path.isfile(identifierEventLevel):
738
739 if category == 'KaonPion':
740 methodPrefixEventLevelSlowPion = "FlavorTagger_" + getBelleOrBelle2() + \
741 "_" + weightFiles + 'EventLevelSlowPionFBDT'
742 identifierEventLevelSlowPion = filesDirectory + '/' + methodPrefixEventLevelSlowPion + '_1.root'
743 if not os.path.isfile(identifierEventLevelSlowPion):
744 B2INFO("Flavor Tagger: event level weight file for the Slow Pion category is absent." +
745 "It is required to sample the training information for the KaonPion category." +
746 "An additional sampling step will be needed after the following training step.")
747 continue
748
749 B2INFO('flavorTagger: file ' + filesDirectory + '/' +
750 methodPrefixEventLevel + "sampled" + fileId + '.root will be saved.')
751
752 ma.applyCuts(particleList, 'isRightCategory(mcAssociated) > 0', path)
753 eventLevelpath = create_path()
754 SkipEmptyParticleList = register_module("SkimFilter")
755 SkipEmptyParticleList.set_name('SkimFilter_EventLevel' + category)
756 SkipEmptyParticleList.param('particleLists', particleList)
757 SkipEmptyParticleList.if_true(eventLevelpath, basf2.AfterConditionPath.CONTINUE)
758 path.add_module(SkipEmptyParticleList)
759
760 ntuple = register_module('VariablesToNtuple')
761 ntuple.param('fileName', filesDirectory + '/' + methodPrefixEventLevel + "sampled" + fileId + ".root")
762 ntuple.param('treeName', methodPrefixEventLevel + "_tree")
763 variablesToBeSaved = getTrainingVariables(category) + [targetVariable, 'ancestorHasWhichFlavor',
764 'isSignal', 'mcPDG', 'mcErrors', 'genMotherPDG',
765 'nMCMatches', 'B0mcErrors']
766 if category != 'KaonPion' and category != 'FSC':
767 variablesToBeSaved = variablesToBeSaved + \
768 ['extraInfo(isRightTrack(' + category + '))',
769 'hasHighestProbInCat(' + particleList + ', isRightTrack(' + category + '))']
770 ntuple.param('variables', variablesToBeSaved)
771 ntuple.param('particleList', particleList)
772 eventLevelpath.add_module(ntuple)
773
774 if ReadyMethods != len(categories):
775 return False
776 else:
777 return True
778
779
780def eventLevelTeacher(weightFiles='B2JpsiKs_mu', categories=None):
781 """
782 Trains all categories at event level.
783 """
784
785 B2INFO('EVENT LEVEL TEACHER')
786
787 ReadyMethods = 0
788
789 if categories is None:
790 categories = []
791
792 for category in categories:
793 methodPrefixEventLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'EventLevel' + category + 'FBDT'
794 targetVariable = 'isRightCategory(' + category + ')'
795 weightFile = filesDirectory + '/' + methodPrefixEventLevel + "_1.root"
796
797 if os.path.isfile(weightFile):
798 ReadyMethods += 1
799 continue
800
801 sampledFilesList = glob.glob(filesDirectory + '/' + methodPrefixEventLevel + 'sampled*.root')
802 if len(sampledFilesList) == 0:
803 B2INFO('flavorTagger: eventLevelTeacher did not find any ' + methodPrefixEventLevel +
804 ".root" + ' file. Please run the flavorTagger in "Sampler" mode afterwards.')
805
806 else:
807 B2INFO('flavorTagger: MVA Teacher training' + methodPrefixEventLevel + ' .')
808 trainingOptionsEventLevel = basf2_mva.GeneralOptions()
809 trainingOptionsEventLevel.m_datafiles = basf2_mva.vector(*sampledFilesList)
810 trainingOptionsEventLevel.m_treename = methodPrefixEventLevel + "_tree"
811 trainingOptionsEventLevel.m_identifier = weightFile
812 trainingOptionsEventLevel.m_variables = basf2_mva.vector(*getTrainingVariables(category))
813 trainingOptionsEventLevel.m_target_variable = targetVariable
814 trainingOptionsEventLevel.m_max_events = maxEventsNumber
815
816 basf2_mva.teacher(trainingOptionsEventLevel, getFastBDTCategories())
817
818 if uploadFlag:
819 basf2_mva.upload(weightFile, methodPrefixEventLevel)
820
821 if ReadyMethods != len(categories):
822 return False
823 else:
824 return True
825
826
827def combinerLevel(mode='Expert', weightFiles='B2JpsiKs_mu', categories=None,
828 variablesCombinerLevel=None, categoriesCombinationCode=None, path=None):
829 """
830 Samples the input data or tests the combiner according to the selected categories.
831 """
832
833 B2INFO('COMBINER LEVEL')
834
835 if categories is None:
836 categories = []
837 if variablesCombinerLevel is None:
838 variablesCombinerLevel = []
839
840 B2INFO("Flavor Tagger: Required Combiner for Categories:")
841 for category in categories:
842 B2INFO(category)
843
844 B2INFO("Flavor Tagger: which corresponds to a weight file with categories combination code " + categoriesCombinationCode)
845
846 methodPrefixCombinerLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'Combiner' \
847 + categoriesCombinationCode
848
849 if mode == 'Sampler':
850
851 if os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + '_1.root') or \
852 os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + '_1.root'):
853 B2FATAL('flavorTagger: File' + methodPrefixCombinerLevel + 'FBDT' + "_1.root" + ' or ' + methodPrefixCombinerLevel +
854 'FANN' + '_1.root found. Please run the "Expert" mode or delete the file if a new sampling is desired.')
855
856 B2INFO('flavorTagger: Sampling Data on Combiner Level. File' +
857 methodPrefixCombinerLevel + ".root" + ' will be saved')
858
859 ntuple = basf2.register_module('VariablesToNtuple')
860 ntuple.param('fileName', filesDirectory + '/' + methodPrefixCombinerLevel + "sampled" + fileId + ".root")
861 ntuple.param('treeName', methodPrefixCombinerLevel + 'FBDT' + "_tree")
862 ntuple.param('variables', variablesCombinerLevel + ['qrCombined'])
863 ntuple.param('particleList', "")
864 path.add_module(ntuple)
865
866 if mode == 'Expert':
867
868 # Check if weight files are ready
869 if TMVAfbdt:
870 identifierFBDT = methodPrefixCombinerLevel + 'FBDT'
871 if downloadFlag or useOnlyLocalFlag:
872 identifierFBDT = filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + '_1.root'
873
874 if downloadFlag:
875 if not os.path.isfile(identifierFBDT):
876 basf2_mva.download(methodPrefixCombinerLevel + 'FBDT', identifierFBDT)
877 if not os.path.isfile(identifierFBDT):
878 B2FATAL('Flavor Tagger: Weight file ' + identifierFBDT +
879 ' was not downloaded from Database. Please check the buildOrRevision name. Stopped')
880
881 if useOnlyLocalFlag:
882 if not os.path.isfile(identifierFBDT):
883 B2FATAL('flavorTagger: Combinerlevel FastBDT was not trained with this combination of categories.' +
884 ' Weight file ' + identifierFBDT + ' not found. Stopped')
885
886 B2INFO('flavorTagger: Ready to be used with weightFile ' + methodPrefixCombinerLevel + 'FBDT' + '_1.root')
887
888 if FANNmlp:
889 identifierFANN = methodPrefixCombinerLevel + 'FANN'
890 if downloadFlag or useOnlyLocalFlag:
891 identifierFANN = filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + '_1.root'
892
893 if downloadFlag:
894 if not os.path.isfile(identifierFANN):
895 basf2_mva.download(methodPrefixCombinerLevel + 'FANN', identifierFANN)
896 if not os.path.isfile(identifierFANN):
897 B2FATAL('Flavor Tagger: Weight file ' + identifierFANN +
898 ' was not downloaded from Database. Please check the buildOrRevision name. Stopped')
899 if useOnlyLocalFlag:
900 if not os.path.isfile(identifierFANN):
901 B2FATAL('flavorTagger: Combinerlevel FANNMLP was not trained with this combination of categories. ' +
902 ' Weight file ' + identifierFANN + ' not found. Stopped')
903
904 B2INFO('flavorTagger: Ready to be used with weightFile ' + methodPrefixCombinerLevel + 'FANN' + '_1.root')
905
906 # At this stage, all necessary weight files should be ready.
907 # Call MVAExpert or MVAMultipleExperts module.
908 if TMVAfbdt and not FANNmlp:
909 B2INFO('flavorTagger: Apply FBDTMethod ' + methodPrefixCombinerLevel + 'FBDT')
910 path.add_module('MVAExpert', listNames=[], extraInfoName='qrCombined' + 'FBDT', signalFraction=signalFraction,
911 identifier=identifierFBDT)
912
913 if FANNmlp and not TMVAfbdt:
914 B2INFO('flavorTagger: Apply FANNMethod on combiner level')
915 path.add_module('MVAExpert', listNames=[], extraInfoName='qrCombined' + 'FANN', signalFraction=signalFraction,
916 identifier=identifierFANN)
917
918 if FANNmlp and TMVAfbdt:
919 B2INFO('flavorTagger: Apply FANNMethod and FBDTMethod on combiner level')
920 mvaMultipleExperts = basf2.register_module('MVAMultipleExperts')
921 mvaMultipleExperts.set_name('MVAMultipleExperts_Combiners')
922 mvaMultipleExperts.param('listNames', [])
923 mvaMultipleExperts.param('extraInfoNames', ['qrCombined' + 'FBDT', 'qrCombined' + 'FANN'])
924 mvaMultipleExperts.param('signalFraction', signalFraction)
925 mvaMultipleExperts.param('identifiers', [identifierFBDT, identifierFANN])
926 path.add_module(mvaMultipleExperts)
927
928
929def combinerLevelTeacher(weightFiles='B2JpsiKs_mu', variablesCombinerLevel=None,
930 categoriesCombinationCode=None):
931 """
932 Trains the combiner according to the selected categories.
933 """
934
935 B2INFO('COMBINER LEVEL TEACHER')
936
937 if variablesCombinerLevel is None:
938 variablesCombinerLevel = []
939
940 methodPrefixCombinerLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'Combiner' \
941 + categoriesCombinationCode
942
943 sampledFilesList = glob.glob(filesDirectory + '/' + methodPrefixCombinerLevel + 'sampled*.root')
944 if len(sampledFilesList) == 0:
945 B2FATAL('FlavorTagger: combinerLevelTeacher did not find any ' +
946 methodPrefixCombinerLevel + 'sampled*.root file. Please run the flavorTagger in "Sampler" mode.')
947
948 if TMVAfbdt:
949
950 if not os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + '_1.root'):
951
952 B2INFO('flavorTagger: MVA Teacher training a FastBDT on Combiner Level')
953
954 trainingOptionsCombinerLevel = basf2_mva.GeneralOptions()
955 trainingOptionsCombinerLevel.m_datafiles = basf2_mva.vector(*sampledFilesList)
956 trainingOptionsCombinerLevel.m_treename = methodPrefixCombinerLevel + 'FBDT' + "_tree"
957 trainingOptionsCombinerLevel.m_identifier = filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + "_1.root"
958 trainingOptionsCombinerLevel.m_variables = basf2_mva.vector(*variablesCombinerLevel)
959 trainingOptionsCombinerLevel.m_target_variable = 'qrCombined'
960 trainingOptionsCombinerLevel.m_max_events = maxEventsNumber
961
962 basf2_mva.teacher(trainingOptionsCombinerLevel, getFastBDTCombiner())
963
964 if uploadFlag:
965 basf2_mva.upload(filesDirectory + '/' + methodPrefixCombinerLevel +
966 'FBDT' + "_1.root", methodPrefixCombinerLevel + 'FBDT')
967
968 elif FANNmlp and not os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + '_1.root'):
969
970 B2INFO('flavorTagger: Combinerlevel FBDT was already trained with this combination of categories. Weight file ' +
971 methodPrefixCombinerLevel + 'FBDT' + '_1.root has been found.')
972
973 else:
974 B2FATAL('flavorTagger: Combinerlevel was already trained with this combination of categories. Weight files ' +
975 methodPrefixCombinerLevel + 'FBDT' + '_1.root and ' +
976 methodPrefixCombinerLevel + 'FANN' + '_1.root has been found. Please use the "Expert" mode')
977
978 if FANNmlp:
979
980 if not os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + '_1.root'):
981
982 B2INFO('flavorTagger: MVA Teacher training a FANN MLP on Combiner Level')
983
984 trainingOptionsCombinerLevel = basf2_mva.GeneralOptions()
985 trainingOptionsCombinerLevel.m_datafiles = basf2_mva.vector(*sampledFilesList)
986 trainingOptionsCombinerLevel.m_treename = methodPrefixCombinerLevel + 'FBDT' + "_tree"
987 trainingOptionsCombinerLevel.m_identifier = filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + "_1.root"
988 trainingOptionsCombinerLevel.m_variables = basf2_mva.vector(*variablesCombinerLevel)
989 trainingOptionsCombinerLevel.m_target_variable = 'qrCombined'
990 trainingOptionsCombinerLevel.m_max_events = maxEventsNumber
991
992 basf2_mva.teacher(trainingOptionsCombinerLevel, getMlpFANNCombiner())
993
994 if uploadFlag:
995 basf2_mva.upload(filesDirectory + '/' + methodPrefixCombinerLevel +
996 'FANN' + "_1.root", methodPrefixCombinerLevel + 'FANN')
997
998 elif TMVAfbdt and not os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + '_1.root'):
999
1000 B2INFO('flavorTagger: Combinerlevel FBDT was already trained with this combination of categories. Weight file ' +
1001 methodPrefixCombinerLevel + 'FANN' + '_1.config has been found.')
1002
1003 else:
1004 B2FATAL('flavorTagger: Combinerlevel was already trained with this combination of categories. Weight files ' +
1005 methodPrefixCombinerLevel + 'FBDT' + '_1.root and ' +
1006 methodPrefixCombinerLevel + 'FANN' + '_1.root has been found. Please use the "Expert" mode')
1007
1008
1009def getEventLevelParticleLists(categories=None):
1010
1011 if categories is None:
1012 categories = []
1013
1014 eventLevelParticleLists = []
1015
1016 for category in categories:
1017 ftCategory = AvailableCategories[category]
1018 event_tuple = (ftCategory.particleList, ftCategory.eventName, ftCategory.variableName)
1019
1020 if event_tuple not in eventLevelParticleLists:
1021 eventLevelParticleLists.append(event_tuple)
1022 else:
1023 B2FATAL('Flavor Tagger: ' + category + ' has been already given')
1024
1025 return eventLevelParticleLists
1026
1027
1028def flavorTagger(
1029 particleLists=None,
1030 mode='Expert',
1031 weightFiles='B2nunubarBGx1',
1032 workingDirectory='.',
1033 combinerMethods=['TMVA-FBDT'],
1034 categories=[
1035 'Electron',
1036 'IntermediateElectron',
1037 'Muon',
1038 'IntermediateMuon',
1039 'KinLepton',
1040 'IntermediateKinLepton',
1041 'Kaon',
1042 'SlowPion',
1043 'FastHadron',
1044 'Lambda',
1045 'FSC',
1046 'MaximumPstar',
1047 'KaonPion'],
1048 maskName='FTDefaultMask',
1049 saveCategoriesInfo=True,
1050 useOnlyLocalWeightFiles=False,
1051 downloadFromDatabaseIfNotFound=False,
1052 uploadToDatabaseAfterTraining=False,
1053 samplerFileId='',
1054 prefix='MC15ri_light-2207-bengal_0',
1055 useGNN=False,
1056 identifierGNN='GFlaT_MC15ri_light_2303_iriomote_0',
1057 path=None,
1058):
1059 """
1060 Defines the whole flavor tagging process for each selected Rest of Event (ROE) built in the steering file.
1061 The flavor is predicted by Multivariate Methods trained with Variables and MetaVariables which use
1062 Tracks, ECL- and KLMClusters from the corresponding RestOfEvent dataobject.
1063 This module can be used to sample the training information, to train and/or to test the flavorTagger.
1064
1065 @param particleLists The ROEs for flavor tagging are selected from the given particle lists.
1066 @param mode The available modes are
1067 ``Expert`` (default), ``Sampler``, and ``Teacher``. In the ``Expert`` mode
1068 Flavor Tagging is applied to the analysis,. In the ``Sampler`` mode you save
1069 save the variables for training. In the ``Teacher`` mode the FlavorTagger is
1070 trained, for this step you do not reconstruct any particle or do any analysis,
1071 you just run the flavorTagger alone.
1072 @param weightFiles Weight files name. Default=
1073 ``B2nunubarBGx1`` (official weight files). If the user self
1074 wants to train the FlavorTagger, the weightfiles name should correspond to the
1075 analysed CP channel in order to avoid confusions. The default name
1076 ``B2nunubarBGx1`` corresponds to
1077 :math:`B^0_{\\rm sig}\\to \\nu \\overline{\\nu}`.
1078 and ``B2JpsiKs_muBGx1`` to
1079 :math:`B^0_{\\rm sig}\\to J/\\psi (\\to \\mu^+ \\mu^-) K_s (\\to \\pi^+ \\pi^-)`.
1080 BGx1 stays for events simulated with background.
1081 @param workingDirectory Path to the directory containing the FlavorTagging/ folder.
1082 @param combinerMethods MVAs for the combiner: ``TMVA-FBDT` (default).
1083 ``FANN-MLP`` is available only with ``prefix=''`` (MC13 weight files).
1084 @param categories Categories used for flavor tagging. By default all are used.
1085 @param maskName Gets ROE particles from a specified ROE mask.
1086 ``FTDefaultMask`` (default): tentative mask definition that will be created
1087 automatically. The definition is as follows:
1088
1089 - Track (pion): thetaInCDCAcceptance and dr<1 and abs(dz)<3
1090 - ECL-cluster (gamma): thetaInCDCAcceptance and clusterNHits>1.5 and \
1091 [[clusterReg==1 and E>0.08] or [clusterReg==2 and E>0.03] or \
1092 [clusterReg==3 and E>0.06]] \
1093 (Same as gamma:pi0eff30_May2020 and gamma:pi0eff40_May2020)
1094
1095 ``all``: all ROE particles are used.
1096 Or one can give any mask name defined before calling this function.
1097 @param saveCategoriesInfo Sets to save information of individual categories.
1098 @param useOnlyLocalWeightFiles [Expert] Uses only locally saved weight files.
1099 @param downloadFromDatabaseIfNotFound [Expert] Weight files are downloaded from
1100 the conditions database if not available in workingDirectory.
1101 @param uploadToDatabaseAfterTraining [Expert] For librarians only: uploads weight files to localdb after training.
1102 @param samplerFileId Identifier to parallelize
1103 sampling. Only used in ``Sampler`` mode. If you are training by yourself and
1104 want to parallelize the sampling, you can run several sampling scripts in
1105 parallel. By changing this parameter you will not overwrite an older sample.
1106 @param prefix Prefix of weight files.
1107 ``MC15ri_light-2207-bengal_0`` (default): Weight files trained for MC15ri samples.
1108 ``''``: Weight files trained for MC13 samples.
1109 @param useGNN Use GNN-based Flavor Tagger in addition with FastBDT-based one.
1110 Please specify the weight file with the option ``identifierGNN``.
1111 [Expert] In the sampler mode, training files for GNN-based Flavor Tagger is produced.
1112 @param identifierGNN The name of weight file of the GNN-based Flavor Tagger.
1113 [Expert] Multiple identifiers can be given with list(str).
1114 @param path Modules are added to this path
1115
1116 """
1117
1118 if (not isinstance(particleLists, list)):
1119 particleLists = [particleLists] # in case user inputs a particle list as string
1120
1121 if mode != 'Sampler' and mode != 'Teacher' and mode != 'Expert':
1122 B2FATAL('flavorTagger: Wrong mode given: The available modes are "Sampler", "Teacher" or "Expert"')
1123
1124 if len(categories) != len(set(categories)):
1125 dup = [cat for cat in set(categories) if categories.count(cat) > 1]
1126 B2WARNING('Flavor Tagger: There are duplicate elements in the given categories list. '
1127 << 'The following duplicate elements are removed; ' << ', '.join(dup))
1128 categories = list(set(categories))
1129
1130 if len(categories) < 2:
1131 B2FATAL('Flavor Tagger: Invalid amount of categories. At least two are needed.')
1132 B2FATAL(
1133 'Flavor Tagger: Possible categories are "Electron", "IntermediateElectron", "Muon", "IntermediateMuon", '
1134 '"KinLepton", "IntermediateKinLepton", "Kaon", "SlowPion", "FastHadron",'
1135 '"Lambda", "FSC", "MaximumPstar" or "KaonPion" ')
1136
1137 for category in categories:
1138 if category not in AvailableCategories:
1139 B2FATAL('Flavor Tagger: ' + category + ' is not a valid category name given')
1140 B2FATAL('Flavor Tagger: Available categories are "Electron", "IntermediateElectron", '
1141 '"Muon", "IntermediateMuon", "KinLepton", "IntermediateKinLepton", "Kaon", "SlowPion", "FastHadron", '
1142 '"Lambda", "FSC", "MaximumPstar" or "KaonPion" ')
1143
1144 if mode == 'Expert' and useGNN and identifierGNN == '':
1145 B2FATAL('Please specify the name of the weight file with ``identifierGNN``')
1146
1147 # Directory where the weights of the trained Methods are saved
1148 # workingDirectory = os.environ['BELLE2_LOCAL_DIR'] + '/analysis/data'
1149
1150 basf2.find_file(workingDirectory)
1151
1152 global filesDirectory
1153 filesDirectory = workingDirectory + '/FlavorTagging/TrainedMethods'
1154
1155 if mode == 'Sampler' or (mode == 'Expert' and downloadFromDatabaseIfNotFound):
1156 if not basf2.find_file(workingDirectory + '/FlavorTagging', silent=True):
1157 os.mkdir(workingDirectory + '/FlavorTagging')
1158 os.mkdir(workingDirectory + '/FlavorTagging/TrainedMethods')
1159 elif not basf2.find_file(workingDirectory + '/FlavorTagging/TrainedMethods', silent=True):
1160 os.mkdir(workingDirectory + '/FlavorTagging/TrainedMethods')
1161 filesDirectory = workingDirectory + '/FlavorTagging/TrainedMethods'
1162
1163 if len(combinerMethods) < 1 or len(combinerMethods) > 2:
1164 B2FATAL('flavorTagger: Invalid list of combinerMethods. The available methods are "TMVA-FBDT" and "FANN-MLP"')
1165
1166 global FANNmlp
1167 global TMVAfbdt
1168
1169 FANNmlp = False
1170 TMVAfbdt = False
1171
1172 for method in combinerMethods:
1173 if method == 'TMVA-FBDT':
1174 TMVAfbdt = True
1175 elif method == 'FANN-MLP':
1176 FANNmlp = True
1177 else:
1178 B2FATAL('flavorTagger: Invalid list of combinerMethods. The available methods are "TMVA-FBDT" and "FANN-MLP"')
1179
1180 global fileId
1181 fileId = samplerFileId
1182
1183 global useOnlyLocalFlag
1184 useOnlyLocalFlag = useOnlyLocalWeightFiles
1185
1186 B2INFO('*** FLAVOR TAGGING ***')
1187 B2INFO(' ')
1188 B2INFO(' Working directory is: ' + filesDirectory)
1189 B2INFO(' ')
1190
1191 setInteractionWithDatabase(downloadFromDatabaseIfNotFound, uploadToDatabaseAfterTraining)
1192
1193 if prefix == '':
1194 set_FlavorTagger_pid_aliases_legacy()
1195 else:
1196 set_FlavorTagger_pid_aliases()
1197
1198 alias_list_for_GNN = []
1199 if useGNN:
1200 alias_list_for_GNN = set_GNNFlavorTagger_aliases(categories)
1201
1202 setInputVariablesWithMask()
1203 if prefix != '':
1204 weightFiles = prefix + '_' + weightFiles
1205
1206 # Create configuration lists and code-name for given category's list
1207 trackLevelParticleLists = []
1208 eventLevelParticleLists = []
1209 variablesCombinerLevel = []
1210 categoriesCombination = []
1211 categoriesCombinationCode = 'CatCode'
1212 for category in categories:
1213 ftCategory = AvailableCategories[category]
1214
1215 track_tuple = (ftCategory.particleList, ftCategory.trackName)
1216 event_tuple = (ftCategory.particleList, ftCategory.eventName, ftCategory.variableName)
1217
1218 if track_tuple not in trackLevelParticleLists and category != 'MaximumPstar':
1219 trackLevelParticleLists.append(track_tuple)
1220
1221 if event_tuple not in eventLevelParticleLists:
1222 eventLevelParticleLists.append(event_tuple)
1223 variablesCombinerLevel.append(ftCategory.variableName)
1224 categoriesCombination.append(ftCategory.code)
1225 else:
1226 B2FATAL('Flavor Tagger: ' + category + ' has been already given')
1227
1228 for code in sorted(categoriesCombination):
1229 categoriesCombinationCode = categoriesCombinationCode + f'{int(code):02}'
1230
1231 # Create default ROE-mask
1232 if maskName == 'FTDefaultMask':
1233 FTDefaultMask = (
1234 'FTDefaultMask',
1235 'thetaInCDCAcceptance and dr<1 and abs(dz)<3',
1236 'thetaInCDCAcceptance and clusterNHits>1.5 and [[E>0.08 and clusterReg==1] or [E>0.03 and clusterReg==2] or \
1237 [E>0.06 and clusterReg==3]]')
1238 for name in particleLists:
1239 ma.appendROEMasks(list_name=name, mask_tuples=[FTDefaultMask], path=path)
1240
1241 # Start ROE-routine
1242 roe_path = basf2.create_path()
1243 deadEndPath = basf2.create_path()
1244
1245 if mode == 'Sampler':
1246 # Events containing ROE without B-Meson (but not empty) are discarded for training
1247 ma.signalSideParticleListsFilter(
1248 particleLists,
1249 'nROE_Charged(' + maskName + ', 0) > 0 and abs(qrCombined) == 1',
1250 roe_path,
1251 deadEndPath)
1252
1253 FillParticleLists(maskName, categories, roe_path)
1254
1255 if useGNN:
1256 if eventLevel('Expert', weightFiles, categories, roe_path):
1257
1258 ma.rankByHighest('pi+:inRoe', 'p', numBest=0, allowMultiRank=False,
1259 outputVariable='FT_p_rank', overwriteRank=True, path=roe_path)
1260 ma.fillParticleListFromDummy('vpho:dummy', path=roe_path)
1261 ma.variablesToNtuple('vpho:dummy',
1262 alias_list_for_GNN,
1263 treename='tree',
1264 filename=f'{filesDirectory}/FlavorTagger_GNN_sampled{fileId}.root',
1265 signalSideParticleList=particleLists[0],
1266 path=roe_path)
1267
1268 else:
1269 if eventLevel(mode, weightFiles, categories, roe_path):
1270 combinerLevel(mode, weightFiles, categories, variablesCombinerLevel, categoriesCombinationCode, roe_path)
1271
1272 path.for_each('RestOfEvent', 'RestOfEvents', roe_path)
1273
1274 elif mode == 'Expert':
1275 # If trigger returns 1 jump into empty path skipping further modules in roe_path
1276 # run filter with no cut first to get rid of ROEs that are missing the mask of the signal particle
1277 ma.signalSideParticleListsFilter(particleLists, 'nROE_Charged(' + maskName + ', 0) > 0', roe_path, deadEndPath)
1278
1279 # Initialization of flavorTaggerInfo dataObject needs to be done in the main path
1280 flavorTaggerInfoBuilder = basf2.register_module('FlavorTaggerInfoBuilder')
1281 path.add_module(flavorTaggerInfoBuilder)
1282
1283 FillParticleLists(maskName, categories, roe_path)
1284
1285 if eventLevel(mode, weightFiles, categories, roe_path):
1286 combinerLevel(mode, weightFiles, categories, variablesCombinerLevel, categoriesCombinationCode, roe_path)
1287
1288 flavorTaggerInfoFiller = basf2.register_module('FlavorTaggerInfoFiller')
1289 flavorTaggerInfoFiller.param('trackLevelParticleLists', trackLevelParticleLists)
1290 flavorTaggerInfoFiller.param('eventLevelParticleLists', eventLevelParticleLists)
1291 flavorTaggerInfoFiller.param('TMVAfbdt', TMVAfbdt)
1292 flavorTaggerInfoFiller.param('FANNmlp', FANNmlp)
1293 flavorTaggerInfoFiller.param('qpCategories', saveCategoriesInfo)
1294 flavorTaggerInfoFiller.param('istrueCategories', saveCategoriesInfo)
1295 flavorTaggerInfoFiller.param('targetProb', False)
1296 flavorTaggerInfoFiller.param('trackPointers', False)
1297 roe_path.add_module(flavorTaggerInfoFiller) # Add FlavorTag Info filler to roe_path
1298 add_default_FlavorTagger_aliases()
1299
1300 if useGNN:
1301 ma.rankByHighest('pi+:inRoe', 'p', numBest=0, allowMultiRank=False,
1302 outputVariable='FT_p_rank', overwriteRank=True, path=roe_path)
1303 ma.fillParticleListFromDummy('vpho:dummy', path=roe_path)
1304
1305 if isinstance(identifierGNN, str):
1306 roe_path.add_module('MVAExpert',
1307 listNames='vpho:dummy',
1308 extraInfoName='qrGNN_raw', # the range of qrGNN_raw is [0,1]
1309 identifier=identifierGNN)
1310
1311 ma.variableToSignalSideExtraInfo('vpho:dummy', {'extraInfo(qrGNN_raw)*2-1': 'qrGNN'},
1312 path=roe_path)
1313
1314 elif isinstance(identifierGNN, list):
1315 identifierGNN = list(set(identifierGNN))
1316
1317 extraInfoNames = [f'qrGNN_{i_id}' for i_id in identifierGNN]
1318 roe_path.add_module('MVAMultipleExperts',
1319 listNames='vpho:dummy',
1320 extraInfoNames=extraInfoNames,
1321 identifiers=identifierGNN)
1322
1323 extraInfoDict = {}
1324 for extraInfoName in extraInfoNames:
1325 extraInfoDict[f'extraInfo({extraInfoName})*2-1'] = extraInfoName
1326 variables.variables.addAlias(extraInfoName, f'extraInfo({extraInfoName})')
1327
1328 ma.variableToSignalSideExtraInfo('vpho:dummy', extraInfoDict,
1329 path=roe_path)
1330
1331 path.for_each('RestOfEvent', 'RestOfEvents', roe_path)
1332
1333 elif mode == 'Teacher':
1334 if eventLevelTeacher(weightFiles, categories):
1335 combinerLevelTeacher(weightFiles, variablesCombinerLevel, categoriesCombinationCode)
1336
1337
1338if __name__ == '__main__':
1339 from basf2.utils import pretty_print_module
1340 pretty_print_module(__name__, "flavorTagger")
1341
def isB2BII()
Definition: b2bii.py:14