Belle II Software light-2406-ragdoll
flavorTagger.py
1#!/usr/bin/env python3
2
3
10
11# ************* Flavor Tagging ************
12# * This script is needed to train *
13# * and to test the flavor tagger. *
14# ********************************************
15
16from basf2 import B2INFO, B2FATAL, B2WARNING
17import basf2
18import basf2_mva
19import inspect
20import modularAnalysis as ma
21import variables
22from variables import utils
23import os
24import glob
25import b2bii
26import collections
27
28
29def getBelleOrBelle2():
30 """
31 Gets the global ModeCode.
32 """
33 if b2bii.isB2BII():
34 return 'Belle'
35 else:
36 return 'Belle2'
37
38
39def setInteractionWithDatabase(downloadFromDatabaseIfNotFound=False, uploadToDatabaseAfterTraining=False):
40 """
41 Sets the interaction with the database: download trained weight files or upload weight files after training.
42 """
43
44 global downloadFlag
45 global uploadFlag
46
47 downloadFlag = downloadFromDatabaseIfNotFound
48 uploadFlag = uploadToDatabaseAfterTraining
49
50
51# Default list of aliases that should be used to save the flavor tagging information using VariablesToNtuple
52flavor_tagging = ['FBDT_qrCombined', 'FANN_qrCombined', 'qrMC', 'mcFlavorOfOtherB', 'qrGNN',
53 'qpElectron', 'hasTrueTargetElectron', 'isRightCategoryElectron',
54 'qpIntermediateElectron', 'hasTrueTargetIntermediateElectron', 'isRightCategoryIntermediateElectron',
55 'qpMuon', 'hasTrueTargetMuon', 'isRightCategoryMuon',
56 'qpIntermediateMuon', 'hasTrueTargetIntermediateMuon', 'isRightCategoryIntermediateMuon',
57 'qpKinLepton', 'hasTrueTargetKinLepton', 'isRightCategoryKinLepton',
58 'qpIntermediateKinLepton', 'hasTrueTargetIntermediateKinLepton', 'isRightCategoryIntermediateKinLepton',
59 'qpKaon', 'hasTrueTargetKaon', 'isRightCategoryKaon',
60 'qpSlowPion', 'hasTrueTargetSlowPion', 'isRightCategorySlowPion',
61 'qpFastHadron', 'hasTrueTargetFastHadron', 'isRightCategoryFastHadron',
62 'qpLambda', 'hasTrueTargetLambda', 'isRightCategoryLambda',
63 'qpFSC', 'hasTrueTargetFSC', 'isRightCategoryFSC',
64 'qpMaximumPstar', 'hasTrueTargetMaximumPstar', 'isRightCategoryMaximumPstar',
65 'qpKaonPion', 'hasTrueTargetKaonPion', 'isRightCategoryKaonPion']
66
67
68def add_default_FlavorTagger_aliases():
69 """
70 This function adds the default aliases for flavor tagging variables
71 and defines the collection of flavor tagging variables.
72 """
73
74 variables.variables.addAlias('FBDT_qrCombined', 'qrOutput(FBDT)')
75 variables.variables.addAlias('FANN_qrCombined', 'qrOutput(FANN)')
76 variables.variables.addAlias('qrMC', 'isRelatedRestOfEventB0Flavor')
77
78 variables.variables.addAlias('qrGNN', 'extraInfo(qrGNN)')
79
80 for iCategory in AvailableCategories:
81 aliasForQp = 'qp' + iCategory
82 aliasForTrueTarget = 'hasTrueTarget' + iCategory
83 aliasForIsRightCategory = 'isRightCategory' + iCategory
84 variables.variables.addAlias(aliasForQp, 'qpCategory(' + iCategory + ')')
85 variables.variables.addAlias(aliasForTrueTarget, 'hasTrueTargets(' + iCategory + ')')
86 variables.variables.addAlias(aliasForIsRightCategory, 'isTrueFTCategory(' + iCategory + ')')
87
88 utils.add_collection(flavor_tagging, 'flavor_tagging')
89
90
91def set_FlavorTagger_pid_aliases():
92 """
93 This function adds the pid aliases needed by the flavor tagger.
94 """
95 variables.variables.addAlias('eid_TOP', 'pidPairProbabilityExpert(11, 211, TOP)')
96 variables.variables.addAlias('eid_ARICH', 'pidPairProbabilityExpert(11, 211, ARICH)')
97 variables.variables.addAlias('eid_ECL', 'pidPairProbabilityExpert(11, 211, ECL)')
98
99 variables.variables.addAlias('muid_TOP', 'pidPairProbabilityExpert(13, 211, TOP)')
100 variables.variables.addAlias('muid_ARICH', 'pidPairProbabilityExpert(13, 211, ARICH)')
101 variables.variables.addAlias('muid_KLM', 'pidPairProbabilityExpert(13, 211, KLM)')
102
103 variables.variables.addAlias('piid_TOP', 'pidPairProbabilityExpert(211, 321, TOP)')
104 variables.variables.addAlias('piid_ARICH', 'pidPairProbabilityExpert(211, 321, ARICH)')
105
106 variables.variables.addAlias('Kid_TOP', 'pidPairProbabilityExpert(321, 211, TOP)')
107 variables.variables.addAlias('Kid_ARICH', 'pidPairProbabilityExpert(321, 211, ARICH)')
108
109 if getBelleOrBelle2() == "Belle":
110 variables.variables.addAlias('eid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, CDC, SVD), 0.5)')
111 variables.variables.addAlias('muid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, CDC, SVD), 0.5)')
112 variables.variables.addAlias('piid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, CDC, SVD), 0.5)')
113 variables.variables.addAlias('pi_vs_edEdxid', 'ifNANgiveX(pidPairProbabilityExpert(211, 11, CDC, SVD), 0.5)')
114 variables.variables.addAlias('Kid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, CDC, SVD), 0.5)')
115 else:
116 variables.variables.addAlias('eid_dEdx', 'pidPairProbabilityExpert(11, 211, CDC)')
117 variables.variables.addAlias('muid_dEdx', 'pidPairProbabilityExpert(13, 211, CDC)')
118 variables.variables.addAlias('piid_dEdx', 'pidPairProbabilityExpert(211, 321, CDC)')
119 variables.variables.addAlias('pi_vs_edEdxid', 'pidPairProbabilityExpert(211, 11, CDC)')
120 variables.variables.addAlias('Kid_dEdx', 'pidPairProbabilityExpert(321, 211, CDC)')
121
122
123def set_FlavorTagger_pid_aliases_legacy():
124 """
125 This function adds the pid aliases needed by the flavor tagger trained for MC13.
126 """
127 variables.variables.addAlias('eid_TOP', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, TOP), 0.5)')
128 variables.variables.addAlias('eid_ARICH', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, ARICH), 0.5)')
129 variables.variables.addAlias('eid_ECL', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, ECL), 0.5)')
130
131 variables.variables.addAlias('muid_TOP', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, TOP), 0.5)')
132 variables.variables.addAlias('muid_ARICH', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, ARICH), 0.5)')
133 variables.variables.addAlias('muid_KLM', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, KLM), 0.5)')
134
135 variables.variables.addAlias('piid_TOP', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, TOP), 0.5)')
136 variables.variables.addAlias('piid_ARICH', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, ARICH), 0.5)')
137
138 variables.variables.addAlias('Kid_TOP', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, TOP), 0.5)')
139 variables.variables.addAlias('Kid_ARICH', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, ARICH), 0.5)')
140
141 if getBelleOrBelle2() == "Belle":
142 variables.variables.addAlias('eid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, CDC, SVD), 0.5)')
143 variables.variables.addAlias('muid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, CDC, SVD), 0.5)')
144 variables.variables.addAlias('piid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, CDC, SVD), 0.5)')
145 variables.variables.addAlias('pi_vs_edEdxid', 'ifNANgiveX(pidPairProbabilityExpert(211, 11, CDC, SVD), 0.5)')
146 variables.variables.addAlias('Kid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, CDC, SVD), 0.5)')
147 else:
148 # Removed SVD PID for Belle II MC and data as it is absent in release 4.
149 variables.variables.addAlias('eid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, CDC), 0.5)')
150 variables.variables.addAlias('muid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, CDC), 0.5)')
151 variables.variables.addAlias('piid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, CDC), 0.5)')
152 variables.variables.addAlias('pi_vs_edEdxid', 'ifNANgiveX(pidPairProbabilityExpert(211, 11, CDC), 0.5)')
153 variables.variables.addAlias('Kid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, CDC), 0.5)')
154
155
156def set_GNNFlavorTagger_aliases(categories):
157 """
158 This function adds aliases for the GNN-based flavor tagger.
159 """
160
161 # will be used for target variable 0:B0bar, 1:B0
162 variables.variables.addAlias('qrCombined_bit', '(qrCombined+1)/2')
163 alias_list = ['qrCombined_bit']
164
165 var_dict = {
166 # position
167 'dx': 'dx',
168 'dy': 'dy',
169 'dz': 'dz',
170 # mask
171 'E': 'E',
172 # charge,
173 'charge': 'charge',
174 # feature
175 'px_c': 'px*charge',
176 'py_c': 'py*charge',
177 'pz_c': 'pz*charge',
178 'electronID_c': 'electronID*charge',
179 'muonID_c': 'muonID*charge',
180 'pionID_c': 'pionID*charge',
181 'kaonID_c': 'kaonID*charge',
182 'protonID_c': 'protonID*charge',
183 'deuteronID_c': 'deuteronID*charge',
184 'electronID_noSVD_noTOP_c': 'electronID_noSVD_noTOP*charge',
185 }
186
187 # 16 charged particles are used at most
188 for rank in range(1, 17):
189
190 for cat in categories:
191 listName = AvailableCategories[cat].particleList
192 varName = f'QpTrack({listName}, isRightCategory({cat}), isRightCategory({cat}))'
193
194 varWithRank = f'ifNANgiveX(getVariableByRank(pi+:inRoe, FT_p, {varName}, {rank}), 0)'
195 aliasWithRank = f'{cat}_rank{rank}'
196
197 variables.variables.addAlias(aliasWithRank, varWithRank)
198 alias_list.append(aliasWithRank)
199
200 for alias, var in var_dict.items():
201 varWithRank = f'ifNANgiveX(getVariableByRank(pi+:inRoe, FT_p, {var}, {rank}), 0)'
202 aliasWithRank = f'{alias}_rank{rank}'
203
204 variables.variables.addAlias(aliasWithRank, varWithRank)
205 alias_list.append(aliasWithRank)
206
207 return alias_list
208
209
210def setInputVariablesWithMask(maskName='all'):
211 """
212 Set aliases for input variables with ROE mask.
213 """
214 variables.variables.addAlias('pMissTag_withMask', 'pMissTag('+maskName+')')
215 variables.variables.addAlias('cosTPTO_withMask', 'cosTPTO('+maskName+')')
216 variables.variables.addAlias('ptTracksRoe_withMask', 'ptTracksRoe('+maskName+')')
217 variables.variables.addAlias('pt2TracksRoe_withMask', 'pt2TracksRoe('+maskName+')')
218 variables.variables.addAlias('ptTracksRoe_withMask', 'ptTracksRoe('+maskName+')')
219
220
221def getFastBDTCategories():
222 '''
223 Helper function for getting the FastBDT categories.
224 It's necessary for removing top-level ROOT imports. '''
225 fastBDTCategories = basf2_mva.FastBDTOptions()
226 fastBDTCategories.m_nTrees = 500
227 fastBDTCategories.m_nCuts = 8
228 fastBDTCategories.m_nLevels = 3
229 fastBDTCategories.m_shrinkage = 0.10
230 fastBDTCategories.m_randRatio = 0.5
231 return fastBDTCategories
232
233
234def getFastBDTCombiner():
235 '''
236 Helper function for getting the FastBDT combiner.
237 It's necessary for removing top-level ROOT imports.
238 '''
239 fastBDTCombiner = basf2_mva.FastBDTOptions()
240 fastBDTCombiner.m_nTrees = 500
241 fastBDTCombiner.m_nCuts = 8
242 fastBDTCombiner.m_nLevels = 3
243 fastBDTCombiner.m_shrinkage = 0.10
244 fastBDTCombiner.m_randRatio = 0.5
245 return fastBDTCombiner
246
247
248def getMlpFANNCombiner():
249 '''
250 Helper function for getting the MLP FANN combiner.
251 It's necessary for removing top-level ROOT imports.
252 '''
253 mlpFANNCombiner = basf2_mva.FANNOptions()
254 mlpFANNCombiner.m_max_epochs = 10000
255 mlpFANNCombiner.m_hidden_layers_architecture = "3*N"
256 mlpFANNCombiner.m_hidden_activiation_function = "FANN_SIGMOID_SYMMETRIC"
257 mlpFANNCombiner.m_output_activiation_function = "FANN_SIGMOID_SYMMETRIC"
258 mlpFANNCombiner.m_error_function = "FANN_ERRORFUNC_LINEAR"
259 mlpFANNCombiner.m_training_method = "FANN_TRAIN_RPROP"
260 mlpFANNCombiner.m_validation_fraction = 0.5
261 mlpFANNCombiner.m_random_seeds = 10
262 mlpFANNCombiner.m_test_rate = 500
263 mlpFANNCombiner.m_number_of_threads = 8
264 mlpFANNCombiner.m_scale_features = True
265 mlpFANNCombiner.m_scale_target = False
266 # mlpFANNCombiner.m_scale_target = True
267 return mlpFANNCombiner
268
269
270# SignalFraction: FBDT feature
271# For smooth output set to -1, this will break the calibration.
272# For correct calibration set to -2, leads to peaky combiner output.
273signalFraction = -2
274
275# Maximal number of events to train each method
276maxEventsNumber = 0 # 0 takes all the sampled events. The number in the past was 500000
277
278
279FTCategoryParameters = collections.namedtuple('FTCategoryParameters',
280 ['particleList', 'trackName', 'eventName', 'variableName', 'code'])
281# Definition of all available categories, 'standard category name':
282# ['ParticleList', 'trackLevel category name', 'eventLevel category name',
283# 'combinerLevel variable name', 'category code']
284AvailableCategories = {
285 'Electron':
286 FTCategoryParameters('e+:inRoe', 'Electron', 'Electron',
287 'QpOf(e+:inRoe, isRightCategory(Electron), isRightCategory(Electron))',
288 0),
289 'IntermediateElectron':
290 FTCategoryParameters('e+:inRoe', 'IntermediateElectron', 'IntermediateElectron',
291 'QpOf(e+:inRoe, isRightCategory(IntermediateElectron), isRightCategory(IntermediateElectron))',
292 1),
293 'Muon':
294 FTCategoryParameters('mu+:inRoe', 'Muon', 'Muon',
295 'QpOf(mu+:inRoe, isRightCategory(Muon), isRightCategory(Muon))',
296 2),
297 'IntermediateMuon':
298 FTCategoryParameters('mu+:inRoe', 'IntermediateMuon', 'IntermediateMuon',
299 'QpOf(mu+:inRoe, isRightCategory(IntermediateMuon), isRightCategory(IntermediateMuon))',
300 3),
301 'KinLepton':
302 FTCategoryParameters('mu+:inRoe', 'KinLepton', 'KinLepton',
303 'QpOf(mu+:inRoe, isRightCategory(KinLepton), isRightCategory(KinLepton))',
304 4),
305 'IntermediateKinLepton':
306 FTCategoryParameters('mu+:inRoe', 'IntermediateKinLepton', 'IntermediateKinLepton',
307 'QpOf(mu+:inRoe, isRightCategory(IntermediateKinLepton), isRightCategory(IntermediateKinLepton))',
308 5),
309 'Kaon':
310 FTCategoryParameters('K+:inRoe', 'Kaon', 'Kaon',
311 'weightedQpOf(K+:inRoe, isRightCategory(Kaon), isRightCategory(Kaon))',
312 6),
313 'SlowPion':
314 FTCategoryParameters('pi+:inRoe', 'SlowPion', 'SlowPion',
315 'QpOf(pi+:inRoe, isRightCategory(SlowPion), isRightCategory(SlowPion))',
316 7),
317 'FastHadron':
318 FTCategoryParameters('pi+:inRoe', 'FastHadron', 'FastHadron',
319 'QpOf(pi+:inRoe, isRightCategory(FastHadron), isRightCategory(FastHadron))',
320 8),
321 'Lambda':
322 FTCategoryParameters('Lambda0:inRoe', 'Lambda', 'Lambda',
323 'weightedQpOf(Lambda0:inRoe, isRightCategory(Lambda), isRightCategory(Lambda))',
324 9),
325 'FSC':
326 FTCategoryParameters('pi+:inRoe', 'SlowPion', 'FSC',
327 'QpOf(pi+:inRoe, isRightCategory(FSC), isRightCategory(SlowPion))',
328 10),
329 'MaximumPstar':
330 FTCategoryParameters('pi+:inRoe', 'MaximumPstar', 'MaximumPstar',
331 'QpOf(pi+:inRoe, isRightCategory(MaximumPstar), isRightCategory(MaximumPstar))',
332 11),
333 'KaonPion':
334 FTCategoryParameters('K+:inRoe', 'Kaon', 'KaonPion',
335 'QpOf(K+:inRoe, isRightCategory(KaonPion), isRightCategory(Kaon))',
336 12),
337}
338
339
340def getTrainingVariables(category=None):
341 """
342 Helper function to get training variables.
343
344 NOTE: This function is not called the Expert mode. It is not necessary to be consistent with variables list of weight files.
345 """
346
347 KId = {'Belle': 'ifNANgiveX(atcPIDBelle(3,2), 0.5)', 'Belle2': 'kaonID'}
348 muId = {'Belle': 'muIDBelle', 'Belle2': 'muonID'}
349 eId = {'Belle': 'eIDBelle', 'Belle2': 'electronID'}
350
351 variables = []
352 if category == 'Electron' or category == 'IntermediateElectron':
353 variables = ['useCMSFrame(p)',
354 'useCMSFrame(pt)',
355 'p',
356 'pt',
357 'cosTheta',
358 eId[getBelleOrBelle2()],
359 'eid_TOP',
360 'eid_ARICH',
361 'eid_ECL',
362 'BtagToWBosonVariables(recoilMassSqrd)',
363 'BtagToWBosonVariables(pMissCMS)',
364 'BtagToWBosonVariables(cosThetaMissCMS)',
365 'BtagToWBosonVariables(EW90)',
366 'cosTPTO',
367 'chiProb',
368 ]
369 if getBelleOrBelle2() == "Belle":
370 variables.append('eid_dEdx')
371 variables.append('ImpactXY')
372 variables.append('distance')
373
374 elif category == 'Muon' or category == 'IntermediateMuon':
375 variables = ['useCMSFrame(p)',
376 'useCMSFrame(pt)',
377 'p',
378 'pt',
379 'cosTheta',
380 muId[getBelleOrBelle2()],
381 'muid_TOP',
382 'muid_ARICH',
383 'muid_KLM',
384 'BtagToWBosonVariables(recoilMassSqrd)',
385 'BtagToWBosonVariables(pMissCMS)',
386 'BtagToWBosonVariables(cosThetaMissCMS)',
387 'BtagToWBosonVariables(EW90)',
388 'cosTPTO',
389 ]
390 if getBelleOrBelle2() == "Belle":
391 variables.append('muid_dEdx')
392 variables.append('ImpactXY')
393 variables.append('distance')
394 variables.append('chiProb')
395
396 elif category == 'KinLepton' or category == 'IntermediateKinLepton':
397 variables = ['useCMSFrame(p)',
398 'useCMSFrame(pt)',
399 'p',
400 'pt',
401 'cosTheta',
402 muId[getBelleOrBelle2()],
403 'muid_TOP',
404 'muid_ARICH',
405 'muid_KLM',
406 eId[getBelleOrBelle2()],
407 'eid_TOP',
408 'eid_ARICH',
409 'eid_ECL',
410 'BtagToWBosonVariables(recoilMassSqrd)',
411 'BtagToWBosonVariables(pMissCMS)',
412 'BtagToWBosonVariables(cosThetaMissCMS)',
413 'BtagToWBosonVariables(EW90)',
414 'cosTPTO',
415 ]
416 if getBelleOrBelle2() == "Belle":
417 variables.append('eid_dEdx')
418 variables.append('muid_dEdx')
419 variables.append('ImpactXY')
420 variables.append('distance')
421 variables.append('chiProb')
422
423 elif category == 'Kaon':
424 variables = ['useCMSFrame(p)',
425 'useCMSFrame(pt)',
426 'p',
427 'pt',
428 'cosTheta',
429 KId[getBelleOrBelle2()],
430 'Kid_dEdx',
431 'Kid_TOP',
432 'Kid_ARICH',
433 'NumberOfKShortsInRoe',
434 'ptTracksRoe',
435 'BtagToWBosonVariables(recoilMassSqrd)',
436 'BtagToWBosonVariables(pMissCMS)',
437 'BtagToWBosonVariables(cosThetaMissCMS)',
438 'BtagToWBosonVariables(EW90)',
439 'cosTPTO',
440 'chiProb',
441 ]
442 if getBelleOrBelle2() == "Belle":
443 variables.append('ImpactXY')
444 variables.append('distance')
445
446 elif category == 'SlowPion':
447 variables = ['useCMSFrame(p)',
448 'useCMSFrame(pt)',
449 'cosTheta',
450 'p',
451 'pt',
452 'pionID',
453 'piid_TOP',
454 'piid_ARICH',
455 'pi_vs_edEdxid',
456 KId[getBelleOrBelle2()],
457 'Kid_dEdx',
458 'Kid_TOP',
459 'Kid_ARICH',
460 'NumberOfKShortsInRoe',
461 'ptTracksRoe',
462 eId[getBelleOrBelle2()],
463 'BtagToWBosonVariables(recoilMassSqrd)',
464 'BtagToWBosonVariables(EW90)',
465 'BtagToWBosonVariables(cosThetaMissCMS)',
466 'BtagToWBosonVariables(pMissCMS)',
467 'cosTPTO',
468 ]
469 if getBelleOrBelle2() == "Belle":
470 variables.append('piid_dEdx')
471 variables.append('ImpactXY')
472 variables.append('distance')
473 variables.append('chiProb')
474
475 elif category == 'FastHadron':
476 variables = ['useCMSFrame(p)',
477 'useCMSFrame(pt)',
478 'cosTheta',
479 'p',
480 'pt',
481 'pionID',
482 'piid_dEdx',
483 'piid_TOP',
484 'piid_ARICH',
485 'pi_vs_edEdxid',
486 KId[getBelleOrBelle2()],
487 'Kid_dEdx',
488 'Kid_TOP',
489 'Kid_ARICH',
490 'NumberOfKShortsInRoe',
491 'ptTracksRoe',
492 eId[getBelleOrBelle2()],
493 'BtagToWBosonVariables(recoilMassSqrd)',
494 'BtagToWBosonVariables(EW90)',
495 'BtagToWBosonVariables(cosThetaMissCMS)',
496 'cosTPTO',
497 ]
498 if getBelleOrBelle2() == "Belle":
499 variables.append('BtagToWBosonVariables(pMissCMS)')
500 variables.append('ImpactXY')
501 variables.append('distance')
502 variables.append('chiProb')
503
504 elif category == 'Lambda':
505 variables = ['lambdaFlavor',
506 'NumberOfKShortsInRoe',
507 'M',
508 'cosAngleBetweenMomentumAndVertexVector',
509 'lambdaZError',
510 'daughter(0,p)',
511 'daughter(0,useCMSFrame(p))',
512 'daughter(1,p)',
513 'daughter(1,useCMSFrame(p))',
514 'useCMSFrame(p)',
515 'p',
516 'chiProb',
517 ]
518 if getBelleOrBelle2() == "Belle2":
519 variables.append('daughter(1,protonID)') # protonID always 0 in B2BII check in future
520 variables.append('daughter(0,pionID)') # not very powerful in B2BII
521 else:
522 variables.append('distance')
523
524 elif category == 'MaximumPstar':
525 variables = ['useCMSFrame(p)',
526 'useCMSFrame(pt)',
527 'p',
528 'pt',
529 'cosTPTO',
530 ]
531 if getBelleOrBelle2() == "Belle2":
532 variables.append('ImpactXY')
533 variables.append('distance')
534
535 elif category == 'FSC':
536 variables = ['useCMSFrame(p)',
537 'cosTPTO',
538 KId[getBelleOrBelle2()],
539 'FSCVariables(pFastCMS)',
540 'FSCVariables(cosSlowFast)',
541 'FSCVariables(cosTPTOFast)',
542 'FSCVariables(SlowFastHaveOpositeCharges)',
543 ]
544 elif category == 'KaonPion':
545 variables = ['extraInfo(isRightCategory(Kaon))',
546 'HighestProbInCat(pi+:inRoe, isRightCategory(SlowPion))',
547 'KaonPionVariables(cosKaonPion)',
548 'KaonPionVariables(HaveOpositeCharges)',
549 KId[getBelleOrBelle2()]
550 ]
551
552 return variables
553
554
555def FillParticleLists(maskName='all', categories=None, path=None):
556 """
557 Fills the particle Lists for all categories.
558 """
559
560 from vertex import kFit
561 readyParticleLists = []
562
563 if categories is None:
564 categories = []
565
566 trackCut = 'isInRestOfEvent > 0.5 and passesROEMask(' + maskName + ') > 0.5 and p >= 0'
567
568 for category in categories:
569 particleList = AvailableCategories[category].particleList
570
571 if particleList in readyParticleLists:
572 continue
573
574 # Select particles in ROE for different categories according to mass hypothesis.
575 if particleList == 'Lambda0:inRoe':
576 if 'pi+:inRoe' not in readyParticleLists:
577 ma.fillParticleList('pi+:inRoe', trackCut, path=path)
578 readyParticleLists.append('pi+:inRoe')
579
580 ma.fillParticleList('p+:inRoe', trackCut, path=path)
581 ma.reconstructDecay(particleList + ' -> pi-:inRoe p+:inRoe', '1.00<=M<=1.23', False, path=path)
582 kFit(particleList, 0.01, path=path)
583 ma.matchMCTruth(particleList, path=path)
584 readyParticleLists.append(particleList)
585
586 else:
587 # Filling particle list for actual category
588 ma.fillParticleList(particleList, trackCut, path=path)
589 readyParticleLists.append(particleList)
590
591 # Additional particleLists for K_S0
592 if getBelleOrBelle2() == 'Belle':
593 ma.cutAndCopyList('K_S0:inRoe', 'K_S0:mdst', 'extraInfo(ksnbStandard) == 1 and isInRestOfEvent == 1', path=path)
594 else:
595 if 'pi+:inRoe' not in readyParticleLists:
596 ma.fillParticleList('pi+:inRoe', trackCut, path=path)
597 ma.reconstructDecay('K_S0:inRoe -> pi+:inRoe pi-:inRoe', '0.40<=M<=0.60', False, path=path)
598 kFit('K_S0:inRoe', 0.01, path=path)
599
600 # Apply BDT-based LID
601 if getBelleOrBelle2() == 'Belle2':
602 default_list_for_lid_BDT = ['e+:inRoe', 'mu+:inRoe']
603 list_for_lid_BDT = []
604
605 for particleList in default_list_for_lid_BDT:
606 if particleList in readyParticleLists:
607 list_for_lid_BDT.append(particleList)
608
609 if list_for_lid_BDT: # empty check
610 ma.applyChargedPidMVA(particleLists=list_for_lid_BDT, path=path,
611 trainingMode=0, # binary
612 binaryHypoPDGCodes=(11, 211)) # e vs pi
613 ma.applyChargedPidMVA(particleLists=list_for_lid_BDT, path=path,
614 trainingMode=0, # binary
615 binaryHypoPDGCodes=(13, 211)) # mu vs pi
616 ma.applyChargedPidMVA(particleLists=list_for_lid_BDT, path=path,
617 trainingMode=1) # Multiclass
618
619
620def eventLevel(mode='Expert', weightFiles='B2JpsiKs_mu', categories=None, path=None):
621 """
622 Samples data for training or tests all categories all categories at event level.
623 """
624
625 from basf2 import create_path
626 from basf2 import register_module
627
628 B2INFO('EVENT LEVEL')
629
630 ReadyMethods = 0
631
632 # Each category has its own Path in order to be skipped if the corresponding particle list is empty
633 identifiersExtraInfosDict = dict()
634 identifiersExtraInfosKaonPion = []
635
636 if categories is None:
637 categories = []
638
639 for category in categories:
640 particleList = AvailableCategories[category].particleList
641
642 methodPrefixEventLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'EventLevel' + category + 'FBDT'
643 identifierEventLevel = methodPrefixEventLevel
644 targetVariable = 'isRightCategory(' + category + ')'
645 extraInfoName = targetVariable
646
647 if mode == 'Expert':
648
649 if downloadFlag or useOnlyLocalFlag:
650 identifierEventLevel = filesDirectory + '/' + methodPrefixEventLevel + '_1.root'
651
652 if downloadFlag:
653 if not os.path.isfile(identifierEventLevel):
654 basf2_mva.download(methodPrefixEventLevel, identifierEventLevel)
655 if not os.path.isfile(identifierEventLevel):
656 B2FATAL('Flavor Tagger: Weight file ' + identifierEventLevel +
657 ' was not downloaded from Database. Please check the buildOrRevision name. Stopped')
658
659 if useOnlyLocalFlag:
660 if not os.path.isfile(identifierEventLevel):
661 B2FATAL('Flavor Tagger: ' + particleList + ' Eventlevel was not trained. Weight file ' +
662 identifierEventLevel + ' was not found. Stopped')
663
664 B2INFO('flavorTagger: MVAExpert ' + methodPrefixEventLevel + ' ready.')
665
666 elif mode == 'Sampler':
667
668 identifierEventLevel = filesDirectory + '/' + methodPrefixEventLevel + '_1.root'
669 if os.path.isfile(identifierEventLevel):
670 B2INFO('flavorTagger: MVAExpert ' + methodPrefixEventLevel + ' ready.')
671
672 if 'KaonPion' in categories:
673 methodPrefixEventLevelKaonPion = "FlavorTagger_" + getBelleOrBelle2() + \
674 "_" + weightFiles + 'EventLevelKaonPionFBDT'
675 identifierEventLevelKaonPion = filesDirectory + '/' + methodPrefixEventLevelKaonPion + '_1.root'
676 if not os.path.isfile(identifierEventLevelKaonPion):
677 # Slow Pion and Kaon categories are used if Kaon-Pion is lacking for
678 # sampling. The others are not needed and skipped
679 if category != "SlowPion" and category != "Kaon":
680 continue
681
682 if mode == 'Expert' or (mode == 'Sampler' and os.path.isfile(identifierEventLevel)):
683
684 B2INFO('flavorTagger: Applying MVAExpert ' + methodPrefixEventLevel + '.')
685
686 if category == 'KaonPion':
687 identifiersExtraInfosKaonPion.append((extraInfoName, identifierEventLevel))
688 elif particleList not in identifiersExtraInfosDict:
689 identifiersExtraInfosDict[particleList] = [(extraInfoName, identifierEventLevel)]
690 else:
691 identifiersExtraInfosDict[particleList].append((extraInfoName, identifierEventLevel))
692
693 ReadyMethods += 1
694
695 # Each category has its own Path in order to be skipped if the corresponding particle list is empty
696 for particleList in identifiersExtraInfosDict:
697 eventLevelPath = create_path()
698 SkipEmptyParticleList = register_module("SkimFilter")
699 SkipEmptyParticleList.set_name('SkimFilter_EventLevel_' + particleList)
700 SkipEmptyParticleList.param('particleLists', particleList)
701 SkipEmptyParticleList.if_true(eventLevelPath, basf2.AfterConditionPath.CONTINUE)
702 path.add_module(SkipEmptyParticleList)
703
704 mvaMultipleExperts = register_module('MVAMultipleExperts')
705 mvaMultipleExperts.set_name('MVAMultipleExperts_EventLevel_' + particleList)
706 mvaMultipleExperts.param('listNames', [particleList])
707 mvaMultipleExperts.param('extraInfoNames', [row[0] for row in identifiersExtraInfosDict[particleList]])
708 mvaMultipleExperts.param('signalFraction', signalFraction)
709 mvaMultipleExperts.param('identifiers', [row[1] for row in identifiersExtraInfosDict[particleList]])
710 eventLevelPath.add_module(mvaMultipleExperts)
711
712 if 'KaonPion' in categories and len(identifiersExtraInfosKaonPion) != 0:
713 eventLevelKaonPionPath = create_path()
714 SkipEmptyParticleList = register_module("SkimFilter")
715 SkipEmptyParticleList.set_name('SkimFilter_' + 'K+:inRoe')
716 SkipEmptyParticleList.param('particleLists', 'K+:inRoe')
717 SkipEmptyParticleList.if_true(eventLevelKaonPionPath, basf2.AfterConditionPath.CONTINUE)
718 path.add_module(SkipEmptyParticleList)
719
720 mvaExpertKaonPion = register_module("MVAExpert")
721 mvaExpertKaonPion.set_name('MVAExpert_KaonPion_' + 'K+:inRoe')
722 mvaExpertKaonPion.param('listNames', ['K+:inRoe'])
723 mvaExpertKaonPion.param('extraInfoName', identifiersExtraInfosKaonPion[0][0])
724 mvaExpertKaonPion.param('signalFraction', signalFraction)
725 mvaExpertKaonPion.param('identifier', identifiersExtraInfosKaonPion[0][1])
726
727 eventLevelKaonPionPath.add_module(mvaExpertKaonPion)
728
729 if mode == 'Sampler':
730
731 for category in categories:
732 particleList = AvailableCategories[category].particleList
733
734 methodPrefixEventLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'EventLevel' + category + 'FBDT'
735 identifierEventLevel = filesDirectory + '/' + methodPrefixEventLevel + '_1.root'
736 targetVariable = 'isRightCategory(' + category + ')'
737
738 if not os.path.isfile(identifierEventLevel):
739
740 if category == 'KaonPion':
741 methodPrefixEventLevelSlowPion = "FlavorTagger_" + getBelleOrBelle2() + \
742 "_" + weightFiles + 'EventLevelSlowPionFBDT'
743 identifierEventLevelSlowPion = filesDirectory + '/' + methodPrefixEventLevelSlowPion + '_1.root'
744 if not os.path.isfile(identifierEventLevelSlowPion):
745 B2INFO("Flavor Tagger: event level weight file for the Slow Pion category is absent." +
746 "It is required to sample the training information for the KaonPion category." +
747 "An additional sampling step will be needed after the following training step.")
748 continue
749
750 B2INFO('flavorTagger: file ' + filesDirectory + '/' +
751 methodPrefixEventLevel + "sampled" + fileId + '.root will be saved.')
752
753 ma.applyCuts(particleList, 'isRightCategory(mcAssociated) > 0', path)
754 eventLevelpath = create_path()
755 SkipEmptyParticleList = register_module("SkimFilter")
756 SkipEmptyParticleList.set_name('SkimFilter_EventLevel' + category)
757 SkipEmptyParticleList.param('particleLists', particleList)
758 SkipEmptyParticleList.if_true(eventLevelpath, basf2.AfterConditionPath.CONTINUE)
759 path.add_module(SkipEmptyParticleList)
760
761 ntuple = register_module('VariablesToNtuple')
762 ntuple.param('fileName', filesDirectory + '/' + methodPrefixEventLevel + "sampled" + fileId + ".root")
763 ntuple.param('treeName', methodPrefixEventLevel + "_tree")
764 variablesToBeSaved = getTrainingVariables(category) + [targetVariable, 'ancestorHasWhichFlavor',
765 'isSignal', 'mcPDG', 'mcErrors', 'genMotherPDG',
766 'nMCMatches', 'B0mcErrors']
767 if category != 'KaonPion' and category != 'FSC':
768 variablesToBeSaved = variablesToBeSaved + \
769 ['extraInfo(isRightTrack(' + category + '))',
770 'hasHighestProbInCat(' + particleList + ', isRightTrack(' + category + '))']
771 ntuple.param('variables', variablesToBeSaved)
772 ntuple.param('particleList', particleList)
773 eventLevelpath.add_module(ntuple)
774
775 if ReadyMethods != len(categories):
776 return False
777 else:
778 return True
779
780
781def eventLevelTeacher(weightFiles='B2JpsiKs_mu', categories=None):
782 """
783 Trains all categories at event level.
784 """
785
786 B2INFO('EVENT LEVEL TEACHER')
787
788 ReadyMethods = 0
789
790 if categories is None:
791 categories = []
792
793 for category in categories:
794 methodPrefixEventLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'EventLevel' + category + 'FBDT'
795 targetVariable = 'isRightCategory(' + category + ')'
796 weightFile = filesDirectory + '/' + methodPrefixEventLevel + "_1.root"
797
798 if os.path.isfile(weightFile):
799 ReadyMethods += 1
800 continue
801
802 sampledFilesList = glob.glob(filesDirectory + '/' + methodPrefixEventLevel + 'sampled*.root')
803 if len(sampledFilesList) == 0:
804 B2INFO('flavorTagger: eventLevelTeacher did not find any ' + methodPrefixEventLevel +
805 ".root" + ' file. Please run the flavorTagger in "Sampler" mode afterwards.')
806
807 else:
808 B2INFO('flavorTagger: MVA Teacher training' + methodPrefixEventLevel + ' .')
809 trainingOptionsEventLevel = basf2_mva.GeneralOptions()
810 trainingOptionsEventLevel.m_datafiles = basf2_mva.vector(*sampledFilesList)
811 trainingOptionsEventLevel.m_treename = methodPrefixEventLevel + "_tree"
812 trainingOptionsEventLevel.m_identifier = weightFile
813 trainingOptionsEventLevel.m_variables = basf2_mva.vector(*getTrainingVariables(category))
814 trainingOptionsEventLevel.m_target_variable = targetVariable
815 trainingOptionsEventLevel.m_max_events = maxEventsNumber
816
817 basf2_mva.teacher(trainingOptionsEventLevel, getFastBDTCategories())
818
819 if uploadFlag:
820 basf2_mva.upload(weightFile, methodPrefixEventLevel)
821
822 if ReadyMethods != len(categories):
823 return False
824 else:
825 return True
826
827
828def combinerLevel(mode='Expert', weightFiles='B2JpsiKs_mu', categories=None,
829 variablesCombinerLevel=None, categoriesCombinationCode=None, path=None):
830 """
831 Samples the input data or tests the combiner according to the selected categories.
832 """
833
834 B2INFO('COMBINER LEVEL')
835
836 if categories is None:
837 categories = []
838 if variablesCombinerLevel is None:
839 variablesCombinerLevel = []
840
841 B2INFO("Flavor Tagger: Required Combiner for Categories:")
842 for category in categories:
843 B2INFO(category)
844
845 B2INFO("Flavor Tagger: which corresponds to a weight file with categories combination code " + categoriesCombinationCode)
846
847 methodPrefixCombinerLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'Combiner' \
848 + categoriesCombinationCode
849
850 if mode == 'Sampler':
851
852 if os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + '_1.root') or \
853 os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + '_1.root'):
854 B2FATAL('flavorTagger: File' + methodPrefixCombinerLevel + 'FBDT' + "_1.root" + ' or ' + methodPrefixCombinerLevel +
855 'FANN' + '_1.root found. Please run the "Expert" mode or delete the file if a new sampling is desired.')
856
857 B2INFO('flavorTagger: Sampling Data on Combiner Level. File' +
858 methodPrefixCombinerLevel + ".root" + ' will be saved')
859
860 ntuple = basf2.register_module('VariablesToNtuple')
861 ntuple.param('fileName', filesDirectory + '/' + methodPrefixCombinerLevel + "sampled" + fileId + ".root")
862 ntuple.param('treeName', methodPrefixCombinerLevel + 'FBDT' + "_tree")
863 ntuple.param('variables', variablesCombinerLevel + ['qrCombined'])
864 ntuple.param('particleList', "")
865 path.add_module(ntuple)
866
867 if mode == 'Expert':
868
869 # Check if weight files are ready
870 if TMVAfbdt:
871 identifierFBDT = methodPrefixCombinerLevel + 'FBDT'
872 if downloadFlag or useOnlyLocalFlag:
873 identifierFBDT = filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + '_1.root'
874
875 if downloadFlag:
876 if not os.path.isfile(identifierFBDT):
877 basf2_mva.download(methodPrefixCombinerLevel + 'FBDT', identifierFBDT)
878 if not os.path.isfile(identifierFBDT):
879 B2FATAL('Flavor Tagger: Weight file ' + identifierFBDT +
880 ' was not downloaded from Database. Please check the buildOrRevision name. Stopped')
881
882 if useOnlyLocalFlag:
883 if not os.path.isfile(identifierFBDT):
884 B2FATAL('flavorTagger: Combinerlevel FastBDT was not trained with this combination of categories.' +
885 ' Weight file ' + identifierFBDT + ' not found. Stopped')
886
887 B2INFO('flavorTagger: Ready to be used with weightFile ' + methodPrefixCombinerLevel + 'FBDT' + '_1.root')
888
889 if FANNmlp:
890 identifierFANN = methodPrefixCombinerLevel + 'FANN'
891 if downloadFlag or useOnlyLocalFlag:
892 identifierFANN = filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + '_1.root'
893
894 if downloadFlag:
895 if not os.path.isfile(identifierFANN):
896 basf2_mva.download(methodPrefixCombinerLevel + 'FANN', identifierFANN)
897 if not os.path.isfile(identifierFANN):
898 B2FATAL('Flavor Tagger: Weight file ' + identifierFANN +
899 ' was not downloaded from Database. Please check the buildOrRevision name. Stopped')
900 if useOnlyLocalFlag:
901 if not os.path.isfile(identifierFANN):
902 B2FATAL('flavorTagger: Combinerlevel FANNMLP was not trained with this combination of categories. ' +
903 ' Weight file ' + identifierFANN + ' not found. Stopped')
904
905 B2INFO('flavorTagger: Ready to be used with weightFile ' + methodPrefixCombinerLevel + 'FANN' + '_1.root')
906
907 # At this stage, all necessary weight files should be ready.
908 # Call MVAExpert or MVAMultipleExperts module.
909 if TMVAfbdt and not FANNmlp:
910 B2INFO('flavorTagger: Apply FBDTMethod ' + methodPrefixCombinerLevel + 'FBDT')
911 path.add_module('MVAExpert', listNames=[], extraInfoName='qrCombined' + 'FBDT', signalFraction=signalFraction,
912 identifier=identifierFBDT)
913
914 if FANNmlp and not TMVAfbdt:
915 B2INFO('flavorTagger: Apply FANNMethod on combiner level')
916 path.add_module('MVAExpert', listNames=[], extraInfoName='qrCombined' + 'FANN', signalFraction=signalFraction,
917 identifier=identifierFANN)
918
919 if FANNmlp and TMVAfbdt:
920 B2INFO('flavorTagger: Apply FANNMethod and FBDTMethod on combiner level')
921 mvaMultipleExperts = basf2.register_module('MVAMultipleExperts')
922 mvaMultipleExperts.set_name('MVAMultipleExperts_Combiners')
923 mvaMultipleExperts.param('listNames', [])
924 mvaMultipleExperts.param('extraInfoNames', ['qrCombined' + 'FBDT', 'qrCombined' + 'FANN'])
925 mvaMultipleExperts.param('signalFraction', signalFraction)
926 mvaMultipleExperts.param('identifiers', [identifierFBDT, identifierFANN])
927 path.add_module(mvaMultipleExperts)
928
929
930def combinerLevelTeacher(weightFiles='B2JpsiKs_mu', variablesCombinerLevel=None,
931 categoriesCombinationCode=None):
932 """
933 Trains the combiner according to the selected categories.
934 """
935
936 B2INFO('COMBINER LEVEL TEACHER')
937
938 if variablesCombinerLevel is None:
939 variablesCombinerLevel = []
940
941 methodPrefixCombinerLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'Combiner' \
942 + categoriesCombinationCode
943
944 sampledFilesList = glob.glob(filesDirectory + '/' + methodPrefixCombinerLevel + 'sampled*.root')
945 if len(sampledFilesList) == 0:
946 B2FATAL('FlavorTagger: combinerLevelTeacher did not find any ' +
947 methodPrefixCombinerLevel + 'sampled*.root file. Please run the flavorTagger in "Sampler" mode.')
948
949 if TMVAfbdt:
950
951 if not os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + '_1.root'):
952
953 B2INFO('flavorTagger: MVA Teacher training a FastBDT on Combiner Level')
954
955 trainingOptionsCombinerLevel = basf2_mva.GeneralOptions()
956 trainingOptionsCombinerLevel.m_datafiles = basf2_mva.vector(*sampledFilesList)
957 trainingOptionsCombinerLevel.m_treename = methodPrefixCombinerLevel + 'FBDT' + "_tree"
958 trainingOptionsCombinerLevel.m_identifier = filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + "_1.root"
959 trainingOptionsCombinerLevel.m_variables = basf2_mva.vector(*variablesCombinerLevel)
960 trainingOptionsCombinerLevel.m_target_variable = 'qrCombined'
961 trainingOptionsCombinerLevel.m_max_events = maxEventsNumber
962
963 basf2_mva.teacher(trainingOptionsCombinerLevel, getFastBDTCombiner())
964
965 if uploadFlag:
966 basf2_mva.upload(filesDirectory + '/' + methodPrefixCombinerLevel +
967 'FBDT' + "_1.root", methodPrefixCombinerLevel + 'FBDT')
968
969 elif FANNmlp and not os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + '_1.root'):
970
971 B2INFO('flavorTagger: Combinerlevel FBDT was already trained with this combination of categories. Weight file ' +
972 methodPrefixCombinerLevel + 'FBDT' + '_1.root has been found.')
973
974 else:
975 B2FATAL('flavorTagger: Combinerlevel was already trained with this combination of categories. Weight files ' +
976 methodPrefixCombinerLevel + 'FBDT' + '_1.root and ' +
977 methodPrefixCombinerLevel + 'FANN' + '_1.root has been found. Please use the "Expert" mode')
978
979 if FANNmlp:
980
981 if not os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + '_1.root'):
982
983 B2INFO('flavorTagger: MVA Teacher training a FANN MLP on Combiner Level')
984
985 trainingOptionsCombinerLevel = basf2_mva.GeneralOptions()
986 trainingOptionsCombinerLevel.m_datafiles = basf2_mva.vector(*sampledFilesList)
987 trainingOptionsCombinerLevel.m_treename = methodPrefixCombinerLevel + 'FBDT' + "_tree"
988 trainingOptionsCombinerLevel.m_identifier = filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + "_1.root"
989 trainingOptionsCombinerLevel.m_variables = basf2_mva.vector(*variablesCombinerLevel)
990 trainingOptionsCombinerLevel.m_target_variable = 'qrCombined'
991 trainingOptionsCombinerLevel.m_max_events = maxEventsNumber
992
993 basf2_mva.teacher(trainingOptionsCombinerLevel, getMlpFANNCombiner())
994
995 if uploadFlag:
996 basf2_mva.upload(filesDirectory + '/' + methodPrefixCombinerLevel +
997 'FANN' + "_1.root", methodPrefixCombinerLevel + 'FANN')
998
999 elif TMVAfbdt and not os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + '_1.root'):
1000
1001 B2INFO('flavorTagger: Combinerlevel FBDT was already trained with this combination of categories. Weight file ' +
1002 methodPrefixCombinerLevel + 'FANN' + '_1.config has been found.')
1003
1004 else:
1005 B2FATAL('flavorTagger: Combinerlevel was already trained with this combination of categories. Weight files ' +
1006 methodPrefixCombinerLevel + 'FBDT' + '_1.root and ' +
1007 methodPrefixCombinerLevel + 'FANN' + '_1.root has been found. Please use the "Expert" mode')
1008
1009
1010def getEventLevelParticleLists(categories=None):
1011
1012 if categories is None:
1013 categories = []
1014
1015 eventLevelParticleLists = []
1016
1017 for category in categories:
1018 ftCategory = AvailableCategories[category]
1019 event_tuple = (ftCategory.particleList, ftCategory.eventName, ftCategory.variableName)
1020
1021 if event_tuple not in eventLevelParticleLists:
1022 eventLevelParticleLists.append(event_tuple)
1023 else:
1024 B2FATAL('Flavor Tagger: ' + category + ' has been already given')
1025
1026 return eventLevelParticleLists
1027
1028
1029def flavorTagger(
1030 particleLists=None,
1031 mode='Expert',
1032 weightFiles='B2nunubarBGx1',
1033 workingDirectory='.',
1034 combinerMethods=['TMVA-FBDT'],
1035 categories=[
1036 'Electron',
1037 'IntermediateElectron',
1038 'Muon',
1039 'IntermediateMuon',
1040 'KinLepton',
1041 'IntermediateKinLepton',
1042 'Kaon',
1043 'SlowPion',
1044 'FastHadron',
1045 'Lambda',
1046 'FSC',
1047 'MaximumPstar',
1048 'KaonPion'],
1049 maskName='FTDefaultMask',
1050 saveCategoriesInfo=True,
1051 useOnlyLocalWeightFiles=False,
1052 downloadFromDatabaseIfNotFound=False,
1053 uploadToDatabaseAfterTraining=False,
1054 samplerFileId='',
1055 prefix='MC15ri_light-2207-bengal_0',
1056 useGNN=False,
1057 identifierGNN='GFlaT_MC15ri_light_2303_iriomote_0',
1058 path=None,
1059):
1060 """
1061 Defines the whole flavor tagging process for each selected Rest of Event (ROE) built in the steering file.
1062 The flavor is predicted by Multivariate Methods trained with Variables and MetaVariables which use
1063 Tracks, ECL- and KLMClusters from the corresponding RestOfEvent dataobject.
1064 This module can be used to sample the training information, to train and/or to test the flavorTagger.
1065
1066 @param particleLists The ROEs for flavor tagging are selected from the given particle lists.
1067 @param mode The available modes are
1068 ``Expert`` (default), ``Sampler``, and ``Teacher``. In the ``Expert`` mode
1069 Flavor Tagging is applied to the analysis,. In the ``Sampler`` mode you save
1070 save the variables for training. In the ``Teacher`` mode the FlavorTagger is
1071 trained, for this step you do not reconstruct any particle or do any analysis,
1072 you just run the flavorTagger alone.
1073 @param weightFiles Weight files name. Default=
1074 ``B2nunubarBGx1`` (official weight files). If the user self
1075 wants to train the FlavorTagger, the weightfiles name should correspond to the
1076 analysed CP channel in order to avoid confusions. The default name
1077 ``B2nunubarBGx1`` corresponds to
1078 :math:`B^0_{\\rm sig}\\to \\nu \\overline{\\nu}`.
1079 and ``B2JpsiKs_muBGx1`` to
1080 :math:`B^0_{\\rm sig}\\to J/\\psi (\\to \\mu^+ \\mu^-) K_s (\\to \\pi^+ \\pi^-)`.
1081 BGx1 stays for events simulated with background.
1082 @param workingDirectory Path to the directory containing the FlavorTagging/ folder.
1083 @param combinerMethods MVAs for the combiner: ``TMVA-FBDT` (default).
1084 ``FANN-MLP`` is available only with ``prefix=''`` (MC13 weight files).
1085 @param categories Categories used for flavor tagging. By default all are used.
1086 @param maskName Gets ROE particles from a specified ROE mask.
1087 ``FTDefaultMask`` (default): tentative mask definition that will be created
1088 automatically. The definition is as follows:
1089
1090 - Track (pion): thetaInCDCAcceptance and dr<1 and abs(dz)<3
1091 - ECL-cluster (gamma): thetaInCDCAcceptance and clusterNHits>1.5 and \
1092 [[clusterReg==1 and E>0.08] or [clusterReg==2 and E>0.03] or \
1093 [clusterReg==3 and E>0.06]] \
1094 (Same as gamma:pi0eff30_May2020 and gamma:pi0eff40_May2020)
1095
1096 ``all``: all ROE particles are used.
1097 Or one can give any mask name defined before calling this function.
1098 @param saveCategoriesInfo Sets to save information of individual categories.
1099 @param useOnlyLocalWeightFiles [Expert] Uses only locally saved weight files.
1100 @param downloadFromDatabaseIfNotFound [Expert] Weight files are downloaded from
1101 the conditions database if not available in workingDirectory.
1102 @param uploadToDatabaseAfterTraining [Expert] For librarians only: uploads weight files to localdb after training.
1103 @param samplerFileId Identifier to parallelize
1104 sampling. Only used in ``Sampler`` mode. If you are training by yourself and
1105 want to parallelize the sampling, you can run several sampling scripts in
1106 parallel. By changing this parameter you will not overwrite an older sample.
1107 @param prefix Prefix of weight files.
1108 ``MC15ri_light-2207-bengal_0`` (default): Weight files trained for MC15ri samples.
1109 ``''``: Weight files trained for MC13 samples.
1110 @param useGNN Use GNN-based Flavor Tagger in addition with FastBDT-based one.
1111 Please specify the weight file with the option ``identifierGNN``.
1112 [Expert] In the sampler mode, training files for GNN-based Flavor Tagger is produced.
1113 @param identifierGNN The name of weight file of the GNN-based Flavor Tagger.
1114 [Expert] Multiple identifiers can be given with list(str).
1115 @param path Modules are added to this path
1116
1117 """
1118
1119 if (not isinstance(particleLists, list)):
1120 particleLists = [particleLists] # in case user inputs a particle list as string
1121
1122 if mode != 'Sampler' and mode != 'Teacher' and mode != 'Expert':
1123 B2FATAL('flavorTagger: Wrong mode given: The available modes are "Sampler", "Teacher" or "Expert"')
1124
1125 if len(categories) != len(set(categories)):
1126 dup = [cat for cat in set(categories) if categories.count(cat) > 1]
1127 B2WARNING('Flavor Tagger: There are duplicate elements in the given categories list. '
1128 << 'The following duplicate elements are removed; ' << ', '.join(dup))
1129 categories = list(set(categories))
1130
1131 if len(categories) < 2:
1132 B2FATAL('Flavor Tagger: Invalid amount of categories. At least two are needed.')
1133 B2FATAL(
1134 'Flavor Tagger: Possible categories are "Electron", "IntermediateElectron", "Muon", "IntermediateMuon", '
1135 '"KinLepton", "IntermediateKinLepton", "Kaon", "SlowPion", "FastHadron",'
1136 '"Lambda", "FSC", "MaximumPstar" or "KaonPion" ')
1137
1138 for category in categories:
1139 if category not in AvailableCategories:
1140 B2FATAL('Flavor Tagger: ' + category + ' is not a valid category name given')
1141 B2FATAL('Flavor Tagger: Available categories are "Electron", "IntermediateElectron", '
1142 '"Muon", "IntermediateMuon", "KinLepton", "IntermediateKinLepton", "Kaon", "SlowPion", "FastHadron", '
1143 '"Lambda", "FSC", "MaximumPstar" or "KaonPion" ')
1144
1145 if mode == 'Expert' and useGNN and identifierGNN == '':
1146 B2FATAL('Please specify the name of the weight file with ``identifierGNN``')
1147
1148 # Directory where the weights of the trained Methods are saved
1149 # workingDirectory = os.environ['BELLE2_LOCAL_DIR'] + '/analysis/data'
1150
1151 basf2.find_file(workingDirectory)
1152
1153 global filesDirectory
1154 filesDirectory = workingDirectory + '/FlavorTagging/TrainedMethods'
1155
1156 if mode == 'Sampler' or (mode == 'Expert' and downloadFromDatabaseIfNotFound):
1157 if not basf2.find_file(workingDirectory + '/FlavorTagging', silent=True):
1158 os.mkdir(workingDirectory + '/FlavorTagging')
1159 os.mkdir(workingDirectory + '/FlavorTagging/TrainedMethods')
1160 elif not basf2.find_file(workingDirectory + '/FlavorTagging/TrainedMethods', silent=True):
1161 os.mkdir(workingDirectory + '/FlavorTagging/TrainedMethods')
1162 filesDirectory = workingDirectory + '/FlavorTagging/TrainedMethods'
1163
1164 if len(combinerMethods) < 1 or len(combinerMethods) > 2:
1165 B2FATAL('flavorTagger: Invalid list of combinerMethods. The available methods are "TMVA-FBDT" and "FANN-MLP"')
1166
1167 global FANNmlp
1168 global TMVAfbdt
1169
1170 FANNmlp = False
1171 TMVAfbdt = False
1172
1173 for method in combinerMethods:
1174 if method == 'TMVA-FBDT':
1175 TMVAfbdt = True
1176 elif method == 'FANN-MLP':
1177 FANNmlp = True
1178 else:
1179 B2FATAL('flavorTagger: Invalid list of combinerMethods. The available methods are "TMVA-FBDT" and "FANN-MLP"')
1180
1181 global fileId
1182 fileId = samplerFileId
1183
1184 global useOnlyLocalFlag
1185 useOnlyLocalFlag = useOnlyLocalWeightFiles
1186
1187 B2INFO('*** FLAVOR TAGGING ***')
1188 B2INFO(' ')
1189 B2INFO(' Working directory is: ' + filesDirectory)
1190 B2INFO(' ')
1191
1192 setInteractionWithDatabase(downloadFromDatabaseIfNotFound, uploadToDatabaseAfterTraining)
1193
1194 if prefix == '':
1195 set_FlavorTagger_pid_aliases_legacy()
1196 else:
1197 set_FlavorTagger_pid_aliases()
1198
1199 alias_list_for_GNN = []
1200 if useGNN:
1201 alias_list_for_GNN = set_GNNFlavorTagger_aliases(categories)
1202
1203 setInputVariablesWithMask()
1204 if prefix != '':
1205 weightFiles = prefix + '_' + weightFiles
1206
1207 # Create configuration lists and code-name for given category's list
1208 trackLevelParticleLists = []
1209 eventLevelParticleLists = []
1210 variablesCombinerLevel = []
1211 categoriesCombination = []
1212 categoriesCombinationCode = 'CatCode'
1213 for category in categories:
1214 ftCategory = AvailableCategories[category]
1215
1216 track_tuple = (ftCategory.particleList, ftCategory.trackName)
1217 event_tuple = (ftCategory.particleList, ftCategory.eventName, ftCategory.variableName)
1218
1219 if track_tuple not in trackLevelParticleLists and category != 'MaximumPstar':
1220 trackLevelParticleLists.append(track_tuple)
1221
1222 if event_tuple not in eventLevelParticleLists:
1223 eventLevelParticleLists.append(event_tuple)
1224 variablesCombinerLevel.append(ftCategory.variableName)
1225 categoriesCombination.append(ftCategory.code)
1226 else:
1227 B2FATAL('Flavor Tagger: ' + category + ' has been already given')
1228
1229 for code in sorted(categoriesCombination):
1230 categoriesCombinationCode = categoriesCombinationCode + f'{int(code):02}'
1231
1232 # Create default ROE-mask
1233 if maskName == 'FTDefaultMask':
1234 FTDefaultMask = (
1235 'FTDefaultMask',
1236 'thetaInCDCAcceptance and dr<1 and abs(dz)<3',
1237 'thetaInCDCAcceptance and clusterNHits>1.5 and [[E>0.08 and clusterReg==1] or [E>0.03 and clusterReg==2] or \
1238 [E>0.06 and clusterReg==3]]')
1239 for name in particleLists:
1240 ma.appendROEMasks(list_name=name, mask_tuples=[FTDefaultMask], path=path)
1241
1242 # Start ROE-routine
1243 roe_path = basf2.create_path()
1244 deadEndPath = basf2.create_path()
1245
1246 if mode == 'Sampler':
1247 # Events containing ROE without B-Meson (but not empty) are discarded for training
1248 ma.signalSideParticleListsFilter(
1249 particleLists,
1250 'nROE_Charged(' + maskName + ', 0) > 0 and abs(qrCombined) == 1',
1251 roe_path,
1252 deadEndPath)
1253
1254 FillParticleLists(maskName, categories, roe_path)
1255
1256 if useGNN:
1257 if eventLevel('Expert', weightFiles, categories, roe_path):
1258
1259 ma.rankByHighest('pi+:inRoe', 'p', numBest=0, allowMultiRank=False,
1260 outputVariable='FT_p_rank', overwriteRank=True, path=roe_path)
1261 ma.fillParticleListFromDummy('vpho:dummy', path=roe_path)
1262 ma.variablesToNtuple('vpho:dummy',
1263 alias_list_for_GNN,
1264 treename='tree',
1265 filename=f'{filesDirectory}/FlavorTagger_GNN_sampled{fileId}.root',
1266 signalSideParticleList=particleLists[0],
1267 path=roe_path)
1268
1269 else:
1270 if eventLevel(mode, weightFiles, categories, roe_path):
1271 combinerLevel(mode, weightFiles, categories, variablesCombinerLevel, categoriesCombinationCode, roe_path)
1272
1273 path.for_each('RestOfEvent', 'RestOfEvents', roe_path)
1274
1275 elif mode == 'Expert':
1276 # If trigger returns 1 jump into empty path skipping further modules in roe_path
1277 # run filter with no cut first to get rid of ROEs that are missing the mask of the signal particle
1278 ma.signalSideParticleListsFilter(particleLists, 'nROE_Charged(' + maskName + ', 0) > 0', roe_path, deadEndPath)
1279
1280 # Initialization of flavorTaggerInfo dataObject needs to be done in the main path
1281 flavorTaggerInfoBuilder = basf2.register_module('FlavorTaggerInfoBuilder')
1282 path.add_module(flavorTaggerInfoBuilder)
1283
1284 FillParticleLists(maskName, categories, roe_path)
1285
1286 if eventLevel(mode, weightFiles, categories, roe_path):
1287 combinerLevel(mode, weightFiles, categories, variablesCombinerLevel, categoriesCombinationCode, roe_path)
1288
1289 flavorTaggerInfoFiller = basf2.register_module('FlavorTaggerInfoFiller')
1290 flavorTaggerInfoFiller.param('trackLevelParticleLists', trackLevelParticleLists)
1291 flavorTaggerInfoFiller.param('eventLevelParticleLists', eventLevelParticleLists)
1292 flavorTaggerInfoFiller.param('TMVAfbdt', TMVAfbdt)
1293 flavorTaggerInfoFiller.param('FANNmlp', FANNmlp)
1294 flavorTaggerInfoFiller.param('qpCategories', saveCategoriesInfo)
1295 flavorTaggerInfoFiller.param('istrueCategories', saveCategoriesInfo)
1296 flavorTaggerInfoFiller.param('targetProb', False)
1297 flavorTaggerInfoFiller.param('trackPointers', False)
1298 roe_path.add_module(flavorTaggerInfoFiller) # Add FlavorTag Info filler to roe_path
1299 add_default_FlavorTagger_aliases()
1300
1301 if useGNN:
1302 ma.rankByHighest('pi+:inRoe', 'p', numBest=0, allowMultiRank=False,
1303 outputVariable='FT_p_rank', overwriteRank=True, path=roe_path)
1304 ma.fillParticleListFromDummy('vpho:dummy', path=roe_path)
1305
1306 if isinstance(identifierGNN, str):
1307 roe_path.add_module('MVAExpert',
1308 listNames='vpho:dummy',
1309 extraInfoName='qrGNN_raw', # the range of qrGNN_raw is [0,1]
1310 identifier=identifierGNN)
1311
1312 ma.variableToSignalSideExtraInfo('vpho:dummy', {'extraInfo(qrGNN_raw)*2-1': 'qrGNN'},
1313 path=roe_path)
1314
1315 elif isinstance(identifierGNN, list):
1316 identifierGNN = list(set(identifierGNN))
1317
1318 extraInfoNames = [f'qrGNN_{i_id}' for i_id in identifierGNN]
1319 roe_path.add_module('MVAMultipleExperts',
1320 listNames='vpho:dummy',
1321 extraInfoNames=extraInfoNames,
1322 identifiers=identifierGNN)
1323
1324 extraInfoDict = {}
1325 for extraInfoName in extraInfoNames:
1326 extraInfoDict[f'extraInfo({extraInfoName})*2-1'] = extraInfoName
1327 variables.variables.addAlias(extraInfoName, f'extraInfo({extraInfoName})')
1328
1329 ma.variableToSignalSideExtraInfo('vpho:dummy', extraInfoDict,
1330 path=roe_path)
1331
1332 path.for_each('RestOfEvent', 'RestOfEvents', roe_path)
1333
1334 elif mode == 'Teacher':
1335 if eventLevelTeacher(weightFiles, categories):
1336 combinerLevelTeacher(weightFiles, variablesCombinerLevel, categoriesCombinationCode)
1337
1338
1339if __name__ == '__main__':
1340
1341 desc_list = []
1342
1343 function = globals()["flavorTagger"]
1344 signature = inspect.formatargspec(*inspect.getfullargspec(function))
1345 desc_list.append((function.__name__, signature + '\n' + function.__doc__))
1346
1347 from terminal_utils import Pager
1348 from basf2.utils import pretty_print_description_list
1349 with Pager('Flavor Tagger function accepts the following arguments:'):
1350 pretty_print_description_list(desc_list)
1351
def isB2BII()
Definition: b2bii.py:14