Belle II Software development
flavorTagger.py
1#!/usr/bin/env python3
2
3
10
11# ************* Flavor Tagging ************
12# * This script is needed to train *
13# * and to test the flavor tagger. *
14# ********************************************
15
16from basf2 import B2INFO, B2FATAL, B2WARNING
17import basf2
18import basf2_mva
19import modularAnalysis as ma
20import variables
21from variables import utils
22import os
23import glob
24import b2bii
25import collections
26
27
28def getBelleOrBelle2():
29 """
30 Gets the global ModeCode.
31 """
32 if b2bii.isB2BII():
33 return 'Belle'
34 else:
35 return 'Belle2'
36
37
38def setInteractionWithDatabase(downloadFromDatabaseIfNotFound=False, uploadToDatabaseAfterTraining=False):
39 """
40 Sets the interaction with the database: download trained weight files or upload weight files after training.
41 """
42
43 global downloadFlag
44 global uploadFlag
45
46 downloadFlag = downloadFromDatabaseIfNotFound
47 uploadFlag = uploadToDatabaseAfterTraining
48
49
50# Default list of aliases that should be used to save the flavor tagging information using VariablesToNtuple
51flavor_tagging = ['FBDT_qrCombined', 'FANN_qrCombined', 'qrMC', 'mcFlavorOfOtherB', 'qrGNN',
52 'qpElectron', 'hasTrueTargetElectron', 'isRightCategoryElectron',
53 'qpIntermediateElectron', 'hasTrueTargetIntermediateElectron', 'isRightCategoryIntermediateElectron',
54 'qpMuon', 'hasTrueTargetMuon', 'isRightCategoryMuon',
55 'qpIntermediateMuon', 'hasTrueTargetIntermediateMuon', 'isRightCategoryIntermediateMuon',
56 'qpKinLepton', 'hasTrueTargetKinLepton', 'isRightCategoryKinLepton',
57 'qpIntermediateKinLepton', 'hasTrueTargetIntermediateKinLepton', 'isRightCategoryIntermediateKinLepton',
58 'qpKaon', 'hasTrueTargetKaon', 'isRightCategoryKaon',
59 'qpSlowPion', 'hasTrueTargetSlowPion', 'isRightCategorySlowPion',
60 'qpFastHadron', 'hasTrueTargetFastHadron', 'isRightCategoryFastHadron',
61 'qpLambda', 'hasTrueTargetLambda', 'isRightCategoryLambda',
62 'qpFSC', 'hasTrueTargetFSC', 'isRightCategoryFSC',
63 'qpMaximumPstar', 'hasTrueTargetMaximumPstar', 'isRightCategoryMaximumPstar',
64 'qpKaonPion', 'hasTrueTargetKaonPion', 'isRightCategoryKaonPion']
65
66
67def add_default_FlavorTagger_aliases():
68 """
69 This function adds the default aliases for flavor tagging variables
70 and defines the collection of flavor tagging variables.
71 """
72
73 variables.variables.addAlias('FBDT_qrCombined', 'qrOutput(FBDT)')
74 variables.variables.addAlias('FANN_qrCombined', 'qrOutput(FANN)')
75 variables.variables.addAlias('qrMC', 'isRelatedRestOfEventB0Flavor')
76
77 variables.variables.addAlias('qrGNN', 'extraInfo(qrGNN)')
78
79 for iCategory in AvailableCategories:
80 aliasForQp = 'qp' + iCategory
81 aliasForTrueTarget = 'hasTrueTarget' + iCategory
82 aliasForIsRightCategory = 'isRightCategory' + iCategory
83 variables.variables.addAlias(aliasForQp, 'qpCategory(' + iCategory + ')')
84 variables.variables.addAlias(aliasForTrueTarget, 'hasTrueTargets(' + iCategory + ')')
85 variables.variables.addAlias(aliasForIsRightCategory, 'isTrueFTCategory(' + iCategory + ')')
86
87 utils.add_collection(flavor_tagging, 'flavor_tagging')
88
89
90def set_FlavorTagger_pid_aliases():
91 """
92 This function adds the pid aliases needed by the flavor tagger.
93 """
94 variables.variables.addAlias('eid_TOP', 'pidPairProbabilityExpert(11, 211, TOP)')
95 variables.variables.addAlias('eid_ARICH', 'pidPairProbabilityExpert(11, 211, ARICH)')
96 variables.variables.addAlias('eid_ECL', 'pidPairProbabilityExpert(11, 211, ECL)')
97
98 variables.variables.addAlias('muid_TOP', 'pidPairProbabilityExpert(13, 211, TOP)')
99 variables.variables.addAlias('muid_ARICH', 'pidPairProbabilityExpert(13, 211, ARICH)')
100 variables.variables.addAlias('muid_KLM', 'pidPairProbabilityExpert(13, 211, KLM)')
101
102 variables.variables.addAlias('piid_TOP', 'pidPairProbabilityExpert(211, 321, TOP)')
103 variables.variables.addAlias('piid_ARICH', 'pidPairProbabilityExpert(211, 321, ARICH)')
104
105 variables.variables.addAlias('Kid_TOP', 'pidPairProbabilityExpert(321, 211, TOP)')
106 variables.variables.addAlias('Kid_ARICH', 'pidPairProbabilityExpert(321, 211, ARICH)')
107
108 if getBelleOrBelle2() == "Belle":
109 variables.variables.addAlias('eid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, CDC, SVD), 0.5)')
110 variables.variables.addAlias('muid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, CDC, SVD), 0.5)')
111 variables.variables.addAlias('piid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, CDC, SVD), 0.5)')
112 variables.variables.addAlias('pi_vs_edEdxid', 'ifNANgiveX(pidPairProbabilityExpert(211, 11, CDC, SVD), 0.5)')
113 variables.variables.addAlias('Kid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, CDC, SVD), 0.5)')
114 else:
115 variables.variables.addAlias('eid_dEdx', 'pidPairProbabilityExpert(11, 211, CDC)')
116 variables.variables.addAlias('muid_dEdx', 'pidPairProbabilityExpert(13, 211, CDC)')
117 variables.variables.addAlias('piid_dEdx', 'pidPairProbabilityExpert(211, 321, CDC)')
118 variables.variables.addAlias('pi_vs_edEdxid', 'pidPairProbabilityExpert(211, 11, CDC)')
119 variables.variables.addAlias('Kid_dEdx', 'pidPairProbabilityExpert(321, 211, CDC)')
120
121
122def set_FlavorTagger_pid_aliases_legacy():
123 """
124 This function adds the pid aliases needed by the flavor tagger trained for MC13.
125 """
126 variables.variables.addAlias('eid_TOP', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, TOP), 0.5)')
127 variables.variables.addAlias('eid_ARICH', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, ARICH), 0.5)')
128 variables.variables.addAlias('eid_ECL', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, ECL), 0.5)')
129
130 variables.variables.addAlias('muid_TOP', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, TOP), 0.5)')
131 variables.variables.addAlias('muid_ARICH', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, ARICH), 0.5)')
132 variables.variables.addAlias('muid_KLM', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, KLM), 0.5)')
133
134 variables.variables.addAlias('piid_TOP', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, TOP), 0.5)')
135 variables.variables.addAlias('piid_ARICH', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, ARICH), 0.5)')
136
137 variables.variables.addAlias('Kid_TOP', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, TOP), 0.5)')
138 variables.variables.addAlias('Kid_ARICH', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, ARICH), 0.5)')
139
140 if getBelleOrBelle2() == "Belle":
141 variables.variables.addAlias('eid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, CDC, SVD), 0.5)')
142 variables.variables.addAlias('muid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, CDC, SVD), 0.5)')
143 variables.variables.addAlias('piid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, CDC, SVD), 0.5)')
144 variables.variables.addAlias('pi_vs_edEdxid', 'ifNANgiveX(pidPairProbabilityExpert(211, 11, CDC, SVD), 0.5)')
145 variables.variables.addAlias('Kid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, CDC, SVD), 0.5)')
146 else:
147 # Removed SVD PID for Belle II MC and data as it is absent in release 4.
148 variables.variables.addAlias('eid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, CDC), 0.5)')
149 variables.variables.addAlias('muid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, CDC), 0.5)')
150 variables.variables.addAlias('piid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, CDC), 0.5)')
151 variables.variables.addAlias('pi_vs_edEdxid', 'ifNANgiveX(pidPairProbabilityExpert(211, 11, CDC), 0.5)')
152 variables.variables.addAlias('Kid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, CDC), 0.5)')
153
154
155def set_GNNFlavorTagger_aliases(categories, usePIDNN):
156 """
157 This function adds aliases for the GNN-based flavor tagger.
158 """
159
160 # will be used for target variable 0:B0bar, 1:B0
161 variables.variables.addAlias('qrCombined_bit', '(qrCombined+1)/2')
162 alias_list = ['qrCombined_bit']
163
164 var_dict = {
165 # position
166 'dx': 'dx',
167 'dy': 'dy',
168 'dz': 'dz',
169 # mask
170 'E': 'E',
171 # charge,
172 'charge': 'charge',
173 # feature
174 'px_c': 'px*charge',
175 'py_c': 'py*charge',
176 'pz_c': 'pz*charge',
177 'electronID_c': 'electronIDNN*charge' if usePIDNN else 'electronID*charge',
178 'muonID_c': 'muonIDNN*charge' if usePIDNN else 'muonID*charge',
179 'pionID_c': 'pionIDNN*charge' if usePIDNN else 'pionID*charge',
180 'kaonID_c': 'kaonIDNN*charge' if usePIDNN else 'kaonID*charge',
181 'protonID_c': 'protonIDNN*charge' if usePIDNN else 'protonID*charge',
182 'deuteronID_c': 'deuteronIDNN*charge' if usePIDNN else 'deuteronID*charge',
183 'electronID_noSVD_noTOP_c': 'electronID_noSVD_noTOP*charge',
184 }
185
186 # 16 charged particles are used at most
187 for rank in range(1, 17):
188
189 for cat in categories:
190 listName = AvailableCategories[cat].particleList
191 varName = f'QpTrack({listName}, isRightCategory({cat}), isRightCategory({cat}))'
192
193 varWithRank = f'ifNANgiveX(getVariableByRank(pi+:inRoe, FT_p, {varName}, {rank}), 0)'
194 aliasWithRank = f'{cat}_rank{rank}'
195
196 variables.variables.addAlias(aliasWithRank, varWithRank)
197 alias_list.append(aliasWithRank)
198
199 for alias, var in var_dict.items():
200 varWithRank = f'ifNANgiveX(getVariableByRank(pi+:inRoe, FT_p, {var}, {rank}), 0)'
201 aliasWithRank = f'{alias}_rank{rank}'
202
203 variables.variables.addAlias(aliasWithRank, varWithRank)
204 alias_list.append(aliasWithRank)
205
206 return alias_list
207
208
209def setInputVariablesWithMask(maskName='all'):
210 """
211 Set aliases for input variables with ROE mask.
212 """
213 variables.variables.addAlias('pMissTag_withMask', 'pMissTag('+maskName+')')
214 variables.variables.addAlias('cosTPTO_withMask', 'cosTPTO('+maskName+')')
215 variables.variables.addAlias('ptTracksRoe_withMask', 'ptTracksRoe('+maskName+')')
216 variables.variables.addAlias('pt2TracksRoe_withMask', 'pt2TracksRoe('+maskName+')')
217 variables.variables.addAlias('ptTracksRoe_withMask', 'ptTracksRoe('+maskName+')')
218
219
220def getFastBDTCategories():
221 '''
222 Helper function for getting the FastBDT categories.
223 It's necessary for removing top-level ROOT imports. '''
224 fastBDTCategories = basf2_mva.FastBDTOptions()
225 fastBDTCategories.m_nTrees = 500
226 fastBDTCategories.m_nCuts = 8
227 fastBDTCategories.m_nLevels = 3
228 fastBDTCategories.m_shrinkage = 0.10
229 fastBDTCategories.m_randRatio = 0.5
230 return fastBDTCategories
231
232
233def getFastBDTCombiner():
234 '''
235 Helper function for getting the FastBDT combiner.
236 It's necessary for removing top-level ROOT imports.
237 '''
238 fastBDTCombiner = basf2_mva.FastBDTOptions()
239 fastBDTCombiner.m_nTrees = 500
240 fastBDTCombiner.m_nCuts = 8
241 fastBDTCombiner.m_nLevels = 3
242 fastBDTCombiner.m_shrinkage = 0.10
243 fastBDTCombiner.m_randRatio = 0.5
244 return fastBDTCombiner
245
246
247def getMlpFANNCombiner():
248 '''
249 Helper function for getting the MLP FANN combiner.
250 It's necessary for removing top-level ROOT imports.
251 '''
252 mlpFANNCombiner = basf2_mva.FANNOptions()
253 mlpFANNCombiner.m_max_epochs = 10000
254 mlpFANNCombiner.m_hidden_layers_architecture = "3*N"
255 mlpFANNCombiner.m_hidden_activiation_function = "FANN_SIGMOID_SYMMETRIC"
256 mlpFANNCombiner.m_output_activiation_function = "FANN_SIGMOID_SYMMETRIC"
257 mlpFANNCombiner.m_error_function = "FANN_ERRORFUNC_LINEAR"
258 mlpFANNCombiner.m_training_method = "FANN_TRAIN_RPROP"
259 mlpFANNCombiner.m_validation_fraction = 0.5
260 mlpFANNCombiner.m_random_seeds = 10
261 mlpFANNCombiner.m_test_rate = 500
262 mlpFANNCombiner.m_number_of_threads = 8
263 mlpFANNCombiner.m_scale_features = True
264 mlpFANNCombiner.m_scale_target = False
265 # mlpFANNCombiner.m_scale_target = True
266 return mlpFANNCombiner
267
268
269# SignalFraction: FBDT feature
270# For smooth output set to -1, this will break the calibration.
271# For correct calibration set to -2, leads to peaky combiner output.
272signalFraction = -2
273
274# Maximal number of events to train each method
275maxEventsNumber = 0 # 0 takes all the sampled events. The number in the past was 500000
276
277
278FTCategoryParameters = collections.namedtuple('FTCategoryParameters',
279 ['particleList', 'trackName', 'eventName', 'variableName', 'code'])
280# Definition of all available categories, 'standard category name':
281# ['ParticleList', 'trackLevel category name', 'eventLevel category name',
282# 'combinerLevel variable name', 'category code']
283AvailableCategories = {
284 'Electron':
285 FTCategoryParameters('e+:inRoe', 'Electron', 'Electron',
286 'QpOf(e+:inRoe, isRightCategory(Electron), isRightCategory(Electron))',
287 0),
288 'IntermediateElectron':
289 FTCategoryParameters('e+:inRoe', 'IntermediateElectron', 'IntermediateElectron',
290 'QpOf(e+:inRoe, isRightCategory(IntermediateElectron), isRightCategory(IntermediateElectron))',
291 1),
292 'Muon':
293 FTCategoryParameters('mu+:inRoe', 'Muon', 'Muon',
294 'QpOf(mu+:inRoe, isRightCategory(Muon), isRightCategory(Muon))',
295 2),
296 'IntermediateMuon':
297 FTCategoryParameters('mu+:inRoe', 'IntermediateMuon', 'IntermediateMuon',
298 'QpOf(mu+:inRoe, isRightCategory(IntermediateMuon), isRightCategory(IntermediateMuon))',
299 3),
300 'KinLepton':
301 FTCategoryParameters('mu+:inRoe', 'KinLepton', 'KinLepton',
302 'QpOf(mu+:inRoe, isRightCategory(KinLepton), isRightCategory(KinLepton))',
303 4),
304 'IntermediateKinLepton':
305 FTCategoryParameters('mu+:inRoe', 'IntermediateKinLepton', 'IntermediateKinLepton',
306 'QpOf(mu+:inRoe, isRightCategory(IntermediateKinLepton), isRightCategory(IntermediateKinLepton))',
307 5),
308 'Kaon':
309 FTCategoryParameters('K+:inRoe', 'Kaon', 'Kaon',
310 'weightedQpOf(K+:inRoe, isRightCategory(Kaon), isRightCategory(Kaon))',
311 6),
312 'SlowPion':
313 FTCategoryParameters('pi+:inRoe', 'SlowPion', 'SlowPion',
314 'QpOf(pi+:inRoe, isRightCategory(SlowPion), isRightCategory(SlowPion))',
315 7),
316 'FastHadron':
317 FTCategoryParameters('pi+:inRoe', 'FastHadron', 'FastHadron',
318 'QpOf(pi+:inRoe, isRightCategory(FastHadron), isRightCategory(FastHadron))',
319 8),
320 'Lambda':
321 FTCategoryParameters('Lambda0:inRoe', 'Lambda', 'Lambda',
322 'weightedQpOf(Lambda0:inRoe, isRightCategory(Lambda), isRightCategory(Lambda))',
323 9),
324 'FSC':
325 FTCategoryParameters('pi+:inRoe', 'SlowPion', 'FSC',
326 'QpOf(pi+:inRoe, isRightCategory(FSC), isRightCategory(SlowPion))',
327 10),
328 'MaximumPstar':
329 FTCategoryParameters('pi+:inRoe', 'MaximumPstar', 'MaximumPstar',
330 'QpOf(pi+:inRoe, isRightCategory(MaximumPstar), isRightCategory(MaximumPstar))',
331 11),
332 'KaonPion':
333 FTCategoryParameters('K+:inRoe', 'Kaon', 'KaonPion',
334 'QpOf(K+:inRoe, isRightCategory(KaonPion), isRightCategory(Kaon))',
335 12),
336}
337
338
339def getTrainingVariables(category=None, usePIDNN=False):
340 """
341 Helper function to get training variables.
342
343 NOTE: This function is not called the Expert mode. It is not necessary to be consistent with variables list of weight files.
344 """
345
346 KId = {'Belle': 'ifNANgiveX(atcPIDBelle(3,2), 0.5)', 'Belle2': 'kaonIDNN' if usePIDNN else 'kaonID'}
347 muId = {'Belle': 'muIDBelle', 'Belle2': 'muonIDNN' if usePIDNN else 'muonID'}
348 eId = {'Belle': 'eIDBelle', 'Belle2': 'electronIDNN' if usePIDNN else 'electronID'}
349
350 variables = []
351 if category == 'Electron' or category == 'IntermediateElectron':
352 variables = ['useCMSFrame(p)',
353 'useCMSFrame(pt)',
354 'p',
355 'pt',
356 'cosTheta',
357 eId[getBelleOrBelle2()],
358 'eid_TOP',
359 'eid_ARICH',
360 'eid_ECL',
361 'BtagToWBosonVariables(recoilMassSqrd)',
362 'BtagToWBosonVariables(pMissCMS)',
363 'BtagToWBosonVariables(cosThetaMissCMS)',
364 'BtagToWBosonVariables(EW90)',
365 'cosTPTO',
366 'chiProb',
367 ]
368 if getBelleOrBelle2() == "Belle":
369 variables.append('eid_dEdx')
370 variables.append('ImpactXY')
371 variables.append('distance')
372
373 elif category == 'Muon' or category == 'IntermediateMuon':
374 variables = ['useCMSFrame(p)',
375 'useCMSFrame(pt)',
376 'p',
377 'pt',
378 'cosTheta',
379 muId[getBelleOrBelle2()],
380 'muid_TOP',
381 'muid_ARICH',
382 'muid_KLM',
383 'BtagToWBosonVariables(recoilMassSqrd)',
384 'BtagToWBosonVariables(pMissCMS)',
385 'BtagToWBosonVariables(cosThetaMissCMS)',
386 'BtagToWBosonVariables(EW90)',
387 'cosTPTO',
388 ]
389 if getBelleOrBelle2() == "Belle":
390 variables.append('muid_dEdx')
391 variables.append('ImpactXY')
392 variables.append('distance')
393 variables.append('chiProb')
394
395 elif category == 'KinLepton' or category == 'IntermediateKinLepton':
396 variables = ['useCMSFrame(p)',
397 'useCMSFrame(pt)',
398 'p',
399 'pt',
400 'cosTheta',
401 muId[getBelleOrBelle2()],
402 'muid_TOP',
403 'muid_ARICH',
404 'muid_KLM',
405 eId[getBelleOrBelle2()],
406 'eid_TOP',
407 'eid_ARICH',
408 'eid_ECL',
409 'BtagToWBosonVariables(recoilMassSqrd)',
410 'BtagToWBosonVariables(pMissCMS)',
411 'BtagToWBosonVariables(cosThetaMissCMS)',
412 'BtagToWBosonVariables(EW90)',
413 'cosTPTO',
414 ]
415 if getBelleOrBelle2() == "Belle":
416 variables.append('eid_dEdx')
417 variables.append('muid_dEdx')
418 variables.append('ImpactXY')
419 variables.append('distance')
420 variables.append('chiProb')
421
422 elif category == 'Kaon':
423 variables = ['useCMSFrame(p)',
424 'useCMSFrame(pt)',
425 'p',
426 'pt',
427 'cosTheta',
428 KId[getBelleOrBelle2()],
429 'Kid_dEdx',
430 'Kid_TOP',
431 'Kid_ARICH',
432 'NumberOfKShortsInRoe',
433 'ptTracksRoe',
434 'BtagToWBosonVariables(recoilMassSqrd)',
435 'BtagToWBosonVariables(pMissCMS)',
436 'BtagToWBosonVariables(cosThetaMissCMS)',
437 'BtagToWBosonVariables(EW90)',
438 'cosTPTO',
439 'chiProb',
440 ]
441 if getBelleOrBelle2() == "Belle":
442 variables.append('ImpactXY')
443 variables.append('distance')
444
445 elif category == 'SlowPion':
446 variables = ['useCMSFrame(p)',
447 'useCMSFrame(pt)',
448 'cosTheta',
449 'p',
450 'pt',
451 'pionIDNN' if usePIDNN else 'pionID',
452 'piid_TOP',
453 'piid_ARICH',
454 'pi_vs_edEdxid',
455 KId[getBelleOrBelle2()],
456 'Kid_dEdx',
457 'Kid_TOP',
458 'Kid_ARICH',
459 'NumberOfKShortsInRoe',
460 'ptTracksRoe',
461 eId[getBelleOrBelle2()],
462 'BtagToWBosonVariables(recoilMassSqrd)',
463 'BtagToWBosonVariables(EW90)',
464 'BtagToWBosonVariables(cosThetaMissCMS)',
465 'BtagToWBosonVariables(pMissCMS)',
466 'cosTPTO',
467 ]
468 if getBelleOrBelle2() == "Belle":
469 variables.append('piid_dEdx')
470 variables.append('ImpactXY')
471 variables.append('distance')
472 variables.append('chiProb')
473
474 elif category == 'FastHadron':
475 variables = ['useCMSFrame(p)',
476 'useCMSFrame(pt)',
477 'cosTheta',
478 'p',
479 'pt',
480 'pionIDNN' if usePIDNN else 'pionID',
481 'piid_dEdx',
482 'piid_TOP',
483 'piid_ARICH',
484 'pi_vs_edEdxid',
485 KId[getBelleOrBelle2()],
486 'Kid_dEdx',
487 'Kid_TOP',
488 'Kid_ARICH',
489 'NumberOfKShortsInRoe',
490 'ptTracksRoe',
491 eId[getBelleOrBelle2()],
492 'BtagToWBosonVariables(recoilMassSqrd)',
493 'BtagToWBosonVariables(EW90)',
494 'BtagToWBosonVariables(cosThetaMissCMS)',
495 'cosTPTO',
496 ]
497 if getBelleOrBelle2() == "Belle":
498 variables.append('BtagToWBosonVariables(pMissCMS)')
499 variables.append('ImpactXY')
500 variables.append('distance')
501 variables.append('chiProb')
502
503 elif category == 'Lambda':
504 variables = ['lambdaFlavor',
505 'NumberOfKShortsInRoe',
506 'M',
507 'cosAngleBetweenMomentumAndVertexVector',
508 'lambdaZError',
509 'daughter(0,p)',
510 'daughter(0,useCMSFrame(p))',
511 'daughter(1,p)',
512 'daughter(1,useCMSFrame(p))',
513 'useCMSFrame(p)',
514 'p',
515 'chiProb',
516 ]
517 if getBelleOrBelle2() == "Belle2":
518 # protonID always 0 in B2BII check in future
519 variables.append('daughter(1,protonIDNN)' if usePIDNN else 'daughter(1,protonID)')
520 # not very powerful in B2BII
521 variables.append('daughter(0,pionIDNN)' if usePIDNN else 'daughter(0,pionID)')
522 else:
523 variables.append('distance')
524
525 elif category == 'MaximumPstar':
526 variables = ['useCMSFrame(p)',
527 'useCMSFrame(pt)',
528 'p',
529 'pt',
530 'cosTPTO',
531 ]
532 if getBelleOrBelle2() == "Belle2":
533 variables.append('ImpactXY')
534 variables.append('distance')
535
536 elif category == 'FSC':
537 variables = ['useCMSFrame(p)',
538 'cosTPTO',
539 KId[getBelleOrBelle2()],
540 'FSCVariables(pFastCMS)',
541 'FSCVariables(cosSlowFast)',
542 'FSCVariables(cosTPTOFast)',
543 'FSCVariables(SlowFastHaveOpositeCharges)',
544 ]
545 elif category == 'KaonPion':
546 variables = ['extraInfo(isRightCategory(Kaon))',
547 'HighestProbInCat(pi+:inRoe, isRightCategory(SlowPion))',
548 'KaonPionVariables(cosKaonPion)',
549 'KaonPionVariables(HaveOpositeCharges)',
550 KId[getBelleOrBelle2()]
551 ]
552
553 return variables
554
555
556def FillParticleLists(maskName='all', categories=None, path=None):
557 """
558 Fills the particle Lists for all categories.
559 """
560
561 from vertex import kFit
562 readyParticleLists = []
563
564 if categories is None:
565 categories = []
566
567 trackCut = 'isInRestOfEvent > 0.5 and passesROEMask(' + maskName + ') > 0.5 and p >= 0'
568
569 for category in categories:
570 particleList = AvailableCategories[category].particleList
571
572 if particleList in readyParticleLists:
573 continue
574
575 # Select particles in ROE for different categories according to mass hypothesis.
576 if particleList == 'Lambda0:inRoe':
577 if 'pi+:inRoe' not in readyParticleLists:
578 ma.fillParticleList('pi+:inRoe', trackCut, path=path)
579 readyParticleLists.append('pi+:inRoe')
580
581 ma.fillParticleList('p+:inRoe', trackCut, path=path)
582 ma.reconstructDecay(particleList + ' -> pi-:inRoe p+:inRoe', '1.00<=M<=1.23', False, path=path)
583 kFit(particleList, 0.01, path=path)
584 ma.matchMCTruth(particleList, path=path)
585 readyParticleLists.append(particleList)
586
587 else:
588 # Filling particle list for actual category
589 ma.fillParticleList(particleList, trackCut, path=path)
590 readyParticleLists.append(particleList)
591
592 # Additional particleLists for K_S0
593 if getBelleOrBelle2() == 'Belle':
594 ma.cutAndCopyList('K_S0:inRoe', 'K_S0:mdst', 'extraInfo(ksnbStandard) == 1 and isInRestOfEvent == 1', path=path)
595 else:
596 if 'pi+:inRoe' not in readyParticleLists:
597 ma.fillParticleList('pi+:inRoe', trackCut, path=path)
598 ma.reconstructDecay('K_S0:inRoe -> pi+:inRoe pi-:inRoe', '0.40<=M<=0.60', False, path=path)
599 kFit('K_S0:inRoe', 0.01, path=path)
600
601 # Apply BDT-based LID
602 if getBelleOrBelle2() == 'Belle2':
603 default_list_for_lid_BDT = ['e+:inRoe', 'mu+:inRoe']
604 list_for_lid_BDT = []
605
606 for particleList in default_list_for_lid_BDT:
607 if particleList in readyParticleLists:
608 list_for_lid_BDT.append(particleList)
609
610 if list_for_lid_BDT: # empty check
611 ma.applyChargedPidMVA(particleLists=list_for_lid_BDT, path=path,
612 trainingMode=0, # binary
613 binaryHypoPDGCodes=(11, 211)) # e vs pi
614 ma.applyChargedPidMVA(particleLists=list_for_lid_BDT, path=path,
615 trainingMode=0, # binary
616 binaryHypoPDGCodes=(13, 211)) # mu vs pi
617 ma.applyChargedPidMVA(particleLists=list_for_lid_BDT, path=path,
618 trainingMode=1) # Multiclass
619
620
621def eventLevel(mode='Expert', weightFiles='B2JpsiKs_mu', categories=None, usePIDNN=False, path=None):
622 """
623 Samples data for training or tests all categories all categories at event level.
624 """
625
626 from basf2 import create_path
627 from basf2 import register_module
628
629 B2INFO('EVENT LEVEL')
630
631 ReadyMethods = 0
632
633 # Each category has its own Path in order to be skipped if the corresponding particle list is empty
634 identifiersExtraInfosDict = dict()
635 identifiersExtraInfosKaonPion = []
636
637 if categories is None:
638 categories = []
639
640 for category in categories:
641 particleList = AvailableCategories[category].particleList
642
643 methodPrefixEventLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'EventLevel' + category + 'FBDT'
644 identifierEventLevel = methodPrefixEventLevel
645 targetVariable = 'isRightCategory(' + category + ')'
646 extraInfoName = targetVariable
647
648 if mode == 'Expert':
649
650 if downloadFlag or useOnlyLocalFlag:
651 identifierEventLevel = filesDirectory + '/' + methodPrefixEventLevel + '_1.root'
652
653 if downloadFlag:
654 if not os.path.isfile(identifierEventLevel):
655 basf2_mva.download(methodPrefixEventLevel, identifierEventLevel)
656 if not os.path.isfile(identifierEventLevel):
657 B2FATAL('Flavor Tagger: Weight file ' + identifierEventLevel +
658 ' was not downloaded from Database. Please check the buildOrRevision name. Stopped')
659
660 if useOnlyLocalFlag:
661 if not os.path.isfile(identifierEventLevel):
662 B2FATAL('Flavor Tagger: ' + particleList + ' Eventlevel was not trained. Weight file ' +
663 identifierEventLevel + ' was not found. Stopped')
664
665 B2INFO('flavorTagger: MVAExpert ' + methodPrefixEventLevel + ' ready.')
666
667 elif mode == 'Sampler':
668
669 identifierEventLevel = filesDirectory + '/' + methodPrefixEventLevel + '_1.root'
670 if os.path.isfile(identifierEventLevel):
671 B2INFO('flavorTagger: MVAExpert ' + methodPrefixEventLevel + ' ready.')
672
673 if 'KaonPion' in categories:
674 methodPrefixEventLevelKaonPion = "FlavorTagger_" + getBelleOrBelle2() + \
675 "_" + weightFiles + 'EventLevelKaonPionFBDT'
676 identifierEventLevelKaonPion = filesDirectory + '/' + methodPrefixEventLevelKaonPion + '_1.root'
677 if not os.path.isfile(identifierEventLevelKaonPion):
678 # Slow Pion and Kaon categories are used if Kaon-Pion is lacking for
679 # sampling. The others are not needed and skipped
680 if category != "SlowPion" and category != "Kaon":
681 continue
682
683 if mode == 'Expert' or (mode == 'Sampler' and os.path.isfile(identifierEventLevel)):
684
685 B2INFO('flavorTagger: Applying MVAExpert ' + methodPrefixEventLevel + '.')
686
687 if category == 'KaonPion':
688 identifiersExtraInfosKaonPion.append((extraInfoName, identifierEventLevel))
689 elif particleList not in identifiersExtraInfosDict:
690 identifiersExtraInfosDict[particleList] = [(extraInfoName, identifierEventLevel)]
691 else:
692 identifiersExtraInfosDict[particleList].append((extraInfoName, identifierEventLevel))
693
694 ReadyMethods += 1
695
696 # Each category has its own Path in order to be skipped if the corresponding particle list is empty
697 for particleList in identifiersExtraInfosDict:
698 eventLevelPath = create_path()
699 SkipEmptyParticleList = register_module("SkimFilter")
700 SkipEmptyParticleList.set_name('SkimFilter_EventLevel_' + particleList)
701 SkipEmptyParticleList.param('particleLists', particleList)
702 SkipEmptyParticleList.if_true(eventLevelPath, basf2.AfterConditionPath.CONTINUE)
703 path.add_module(SkipEmptyParticleList)
704
705 mvaMultipleExperts = register_module('MVAMultipleExperts')
706 mvaMultipleExperts.set_name('MVAMultipleExperts_EventLevel_' + particleList)
707 mvaMultipleExperts.param('listNames', [particleList])
708 mvaMultipleExperts.param('extraInfoNames', [row[0] for row in identifiersExtraInfosDict[particleList]])
709 mvaMultipleExperts.param('signalFraction', signalFraction)
710 mvaMultipleExperts.param('identifiers', [row[1] for row in identifiersExtraInfosDict[particleList]])
711 eventLevelPath.add_module(mvaMultipleExperts)
712
713 if 'KaonPion' in categories and len(identifiersExtraInfosKaonPion) != 0:
714 eventLevelKaonPionPath = create_path()
715 SkipEmptyParticleList = register_module("SkimFilter")
716 SkipEmptyParticleList.set_name('SkimFilter_' + 'K+:inRoe')
717 SkipEmptyParticleList.param('particleLists', 'K+:inRoe')
718 SkipEmptyParticleList.if_true(eventLevelKaonPionPath, basf2.AfterConditionPath.CONTINUE)
719 path.add_module(SkipEmptyParticleList)
720
721 mvaExpertKaonPion = register_module("MVAExpert")
722 mvaExpertKaonPion.set_name('MVAExpert_KaonPion_' + 'K+:inRoe')
723 mvaExpertKaonPion.param('listNames', ['K+:inRoe'])
724 mvaExpertKaonPion.param('extraInfoName', identifiersExtraInfosKaonPion[0][0])
725 mvaExpertKaonPion.param('signalFraction', signalFraction)
726 mvaExpertKaonPion.param('identifier', identifiersExtraInfosKaonPion[0][1])
727
728 eventLevelKaonPionPath.add_module(mvaExpertKaonPion)
729
730 if mode == 'Sampler':
731
732 for category in categories:
733 particleList = AvailableCategories[category].particleList
734
735 methodPrefixEventLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'EventLevel' + category + 'FBDT'
736 identifierEventLevel = filesDirectory + '/' + methodPrefixEventLevel + '_1.root'
737 targetVariable = 'isRightCategory(' + category + ')'
738
739 if not os.path.isfile(identifierEventLevel):
740
741 if category == 'KaonPion':
742 methodPrefixEventLevelSlowPion = "FlavorTagger_" + getBelleOrBelle2() + \
743 "_" + weightFiles + 'EventLevelSlowPionFBDT'
744 identifierEventLevelSlowPion = filesDirectory + '/' + methodPrefixEventLevelSlowPion + '_1.root'
745 if not os.path.isfile(identifierEventLevelSlowPion):
746 B2INFO("Flavor Tagger: event level weight file for the Slow Pion category is absent." +
747 "It is required to sample the training information for the KaonPion category." +
748 "An additional sampling step will be needed after the following training step.")
749 continue
750
751 B2INFO('flavorTagger: file ' + filesDirectory + '/' +
752 methodPrefixEventLevel + "sampled" + fileId + '.root will be saved.')
753
754 ma.applyCuts(particleList, 'isRightCategory(mcAssociated) > 0', path)
755 eventLevelpath = create_path()
756 SkipEmptyParticleList = register_module("SkimFilter")
757 SkipEmptyParticleList.set_name('SkimFilter_EventLevel' + category)
758 SkipEmptyParticleList.param('particleLists', particleList)
759 SkipEmptyParticleList.if_true(eventLevelpath, basf2.AfterConditionPath.CONTINUE)
760 path.add_module(SkipEmptyParticleList)
761
762 ntuple = register_module('VariablesToNtuple')
763 ntuple.param('fileName', filesDirectory + '/' + methodPrefixEventLevel + "sampled" + fileId + ".root")
764 ntuple.param('treeName', methodPrefixEventLevel + "_tree")
765 variablesToBeSaved = getTrainingVariables(category, usePIDNN) + [
766 targetVariable,
767 'ancestorHasWhichFlavor',
768 'isSignal',
769 'mcPDG',
770 'mcErrors',
771 'genMotherPDG',
772 'nMCMatches',
773 'B0mcErrors'
774 ]
775 if category != 'KaonPion' and category != 'FSC':
776 variablesToBeSaved = variablesToBeSaved + \
777 ['extraInfo(isRightTrack(' + category + '))',
778 'hasHighestProbInCat(' + particleList + ', isRightTrack(' + category + '))']
779 ntuple.param('variables', variablesToBeSaved)
780 ntuple.param('particleList', particleList)
781 eventLevelpath.add_module(ntuple)
782
783 if ReadyMethods != len(categories):
784 return False
785 else:
786 return True
787
788
789def eventLevelTeacher(weightFiles='B2JpsiKs_mu', categories=None, usePIDNN=False):
790 """
791 Trains all categories at event level.
792 """
793
794 B2INFO('EVENT LEVEL TEACHER')
795
796 ReadyMethods = 0
797
798 if categories is None:
799 categories = []
800
801 for category in categories:
802 methodPrefixEventLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'EventLevel' + category + 'FBDT'
803 targetVariable = 'isRightCategory(' + category + ')'
804 weightFile = filesDirectory + '/' + methodPrefixEventLevel + "_1.root"
805
806 if os.path.isfile(weightFile):
807 ReadyMethods += 1
808 continue
809
810 sampledFilesList = glob.glob(filesDirectory + '/' + methodPrefixEventLevel + 'sampled*.root')
811 if len(sampledFilesList) == 0:
812 B2INFO('flavorTagger: eventLevelTeacher did not find any ' + methodPrefixEventLevel +
813 ".root" + ' file. Please run the flavorTagger in "Sampler" mode afterwards.')
814
815 else:
816 B2INFO('flavorTagger: MVA Teacher training' + methodPrefixEventLevel + ' .')
817 trainingOptionsEventLevel = basf2_mva.GeneralOptions()
818 trainingOptionsEventLevel.m_datafiles = basf2_mva.vector(*sampledFilesList)
819 trainingOptionsEventLevel.m_treename = methodPrefixEventLevel + "_tree"
820 trainingOptionsEventLevel.m_identifier = weightFile
821 trainingOptionsEventLevel.m_variables = basf2_mva.vector(*getTrainingVariables(category, usePIDNN))
822 trainingOptionsEventLevel.m_target_variable = targetVariable
823 trainingOptionsEventLevel.m_max_events = maxEventsNumber
824
825 basf2_mva.teacher(trainingOptionsEventLevel, getFastBDTCategories())
826
827 if uploadFlag:
828 basf2_mva.upload(weightFile, methodPrefixEventLevel)
829
830 if ReadyMethods != len(categories):
831 return False
832 else:
833 return True
834
835
836def combinerLevel(mode='Expert', weightFiles='B2JpsiKs_mu', categories=None,
837 variablesCombinerLevel=None, categoriesCombinationCode=None, path=None):
838 """
839 Samples the input data or tests the combiner according to the selected categories.
840 """
841
842 B2INFO('COMBINER LEVEL')
843
844 if categories is None:
845 categories = []
846 if variablesCombinerLevel is None:
847 variablesCombinerLevel = []
848
849 B2INFO("Flavor Tagger: Required Combiner for Categories:")
850 for category in categories:
851 B2INFO(category)
852
853 B2INFO("Flavor Tagger: which corresponds to a weight file with categories combination code " + categoriesCombinationCode)
854
855 methodPrefixCombinerLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'Combiner' \
856 + categoriesCombinationCode
857
858 if mode == 'Sampler':
859
860 if os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + '_1.root') or \
861 os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + '_1.root'):
862 B2FATAL('flavorTagger: File' + methodPrefixCombinerLevel + 'FBDT' + "_1.root" + ' or ' + methodPrefixCombinerLevel +
863 'FANN' + '_1.root found. Please run the "Expert" mode or delete the file if a new sampling is desired.')
864
865 B2INFO('flavorTagger: Sampling Data on Combiner Level. File' +
866 methodPrefixCombinerLevel + ".root" + ' will be saved')
867
868 ntuple = basf2.register_module('VariablesToNtuple')
869 ntuple.param('fileName', filesDirectory + '/' + methodPrefixCombinerLevel + "sampled" + fileId + ".root")
870 ntuple.param('treeName', methodPrefixCombinerLevel + 'FBDT' + "_tree")
871 ntuple.param('variables', variablesCombinerLevel + ['qrCombined'])
872 ntuple.param('particleList', "")
873 path.add_module(ntuple)
874
875 if mode == 'Expert':
876
877 # Check if weight files are ready
878 if TMVAfbdt:
879 identifierFBDT = methodPrefixCombinerLevel + 'FBDT'
880 if downloadFlag or useOnlyLocalFlag:
881 identifierFBDT = filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + '_1.root'
882
883 if downloadFlag:
884 if not os.path.isfile(identifierFBDT):
885 basf2_mva.download(methodPrefixCombinerLevel + 'FBDT', identifierFBDT)
886 if not os.path.isfile(identifierFBDT):
887 B2FATAL('Flavor Tagger: Weight file ' + identifierFBDT +
888 ' was not downloaded from Database. Please check the buildOrRevision name. Stopped')
889
890 if useOnlyLocalFlag:
891 if not os.path.isfile(identifierFBDT):
892 B2FATAL('flavorTagger: Combinerlevel FastBDT was not trained with this combination of categories.' +
893 ' Weight file ' + identifierFBDT + ' not found. Stopped')
894
895 B2INFO('flavorTagger: Ready to be used with weightFile ' + methodPrefixCombinerLevel + 'FBDT' + '_1.root')
896
897 if FANNmlp:
898 identifierFANN = methodPrefixCombinerLevel + 'FANN'
899 if downloadFlag or useOnlyLocalFlag:
900 identifierFANN = filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + '_1.root'
901
902 if downloadFlag:
903 if not os.path.isfile(identifierFANN):
904 basf2_mva.download(methodPrefixCombinerLevel + 'FANN', identifierFANN)
905 if not os.path.isfile(identifierFANN):
906 B2FATAL('Flavor Tagger: Weight file ' + identifierFANN +
907 ' was not downloaded from Database. Please check the buildOrRevision name. Stopped')
908 if useOnlyLocalFlag:
909 if not os.path.isfile(identifierFANN):
910 B2FATAL('flavorTagger: Combinerlevel FANNMLP was not trained with this combination of categories. ' +
911 ' Weight file ' + identifierFANN + ' not found. Stopped')
912
913 B2INFO('flavorTagger: Ready to be used with weightFile ' + methodPrefixCombinerLevel + 'FANN' + '_1.root')
914
915 # At this stage, all necessary weight files should be ready.
916 # Call MVAExpert or MVAMultipleExperts module.
917 if TMVAfbdt and not FANNmlp:
918 B2INFO('flavorTagger: Apply FBDTMethod ' + methodPrefixCombinerLevel + 'FBDT')
919 path.add_module('MVAExpert', listNames=[], extraInfoName='qrCombined' + 'FBDT', signalFraction=signalFraction,
920 identifier=identifierFBDT)
921
922 if FANNmlp and not TMVAfbdt:
923 B2INFO('flavorTagger: Apply FANNMethod on combiner level')
924 path.add_module('MVAExpert', listNames=[], extraInfoName='qrCombined' + 'FANN', signalFraction=signalFraction,
925 identifier=identifierFANN)
926
927 if FANNmlp and TMVAfbdt:
928 B2INFO('flavorTagger: Apply FANNMethod and FBDTMethod on combiner level')
929 mvaMultipleExperts = basf2.register_module('MVAMultipleExperts')
930 mvaMultipleExperts.set_name('MVAMultipleExperts_Combiners')
931 mvaMultipleExperts.param('listNames', [])
932 mvaMultipleExperts.param('extraInfoNames', ['qrCombined' + 'FBDT', 'qrCombined' + 'FANN'])
933 mvaMultipleExperts.param('signalFraction', signalFraction)
934 mvaMultipleExperts.param('identifiers', [identifierFBDT, identifierFANN])
935 path.add_module(mvaMultipleExperts)
936
937
938def combinerLevelTeacher(weightFiles='B2JpsiKs_mu', variablesCombinerLevel=None,
939 categoriesCombinationCode=None):
940 """
941 Trains the combiner according to the selected categories.
942 """
943
944 B2INFO('COMBINER LEVEL TEACHER')
945
946 if variablesCombinerLevel is None:
947 variablesCombinerLevel = []
948
949 methodPrefixCombinerLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'Combiner' \
950 + categoriesCombinationCode
951
952 sampledFilesList = glob.glob(filesDirectory + '/' + methodPrefixCombinerLevel + 'sampled*.root')
953 if len(sampledFilesList) == 0:
954 B2FATAL('FlavorTagger: combinerLevelTeacher did not find any ' +
955 methodPrefixCombinerLevel + 'sampled*.root file. Please run the flavorTagger in "Sampler" mode.')
956
957 if TMVAfbdt:
958
959 if not os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + '_1.root'):
960
961 B2INFO('flavorTagger: MVA Teacher training a FastBDT on Combiner Level')
962
963 trainingOptionsCombinerLevel = basf2_mva.GeneralOptions()
964 trainingOptionsCombinerLevel.m_datafiles = basf2_mva.vector(*sampledFilesList)
965 trainingOptionsCombinerLevel.m_treename = methodPrefixCombinerLevel + 'FBDT' + "_tree"
966 trainingOptionsCombinerLevel.m_identifier = filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + "_1.root"
967 trainingOptionsCombinerLevel.m_variables = basf2_mva.vector(*variablesCombinerLevel)
968 trainingOptionsCombinerLevel.m_target_variable = 'qrCombined'
969 trainingOptionsCombinerLevel.m_max_events = maxEventsNumber
970
971 basf2_mva.teacher(trainingOptionsCombinerLevel, getFastBDTCombiner())
972
973 if uploadFlag:
974 basf2_mva.upload(filesDirectory + '/' + methodPrefixCombinerLevel +
975 'FBDT' + "_1.root", methodPrefixCombinerLevel + 'FBDT')
976
977 elif FANNmlp and not os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + '_1.root'):
978
979 B2INFO('flavorTagger: Combinerlevel FBDT was already trained with this combination of categories. Weight file ' +
980 methodPrefixCombinerLevel + 'FBDT' + '_1.root has been found.')
981
982 else:
983 B2FATAL('flavorTagger: Combinerlevel was already trained with this combination of categories. Weight files ' +
984 methodPrefixCombinerLevel + 'FBDT' + '_1.root and ' +
985 methodPrefixCombinerLevel + 'FANN' + '_1.root has been found. Please use the "Expert" mode')
986
987 if FANNmlp:
988
989 if not os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + '_1.root'):
990
991 B2INFO('flavorTagger: MVA Teacher training a FANN MLP on Combiner Level')
992
993 trainingOptionsCombinerLevel = basf2_mva.GeneralOptions()
994 trainingOptionsCombinerLevel.m_datafiles = basf2_mva.vector(*sampledFilesList)
995 trainingOptionsCombinerLevel.m_treename = methodPrefixCombinerLevel + 'FBDT' + "_tree"
996 trainingOptionsCombinerLevel.m_identifier = filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + "_1.root"
997 trainingOptionsCombinerLevel.m_variables = basf2_mva.vector(*variablesCombinerLevel)
998 trainingOptionsCombinerLevel.m_target_variable = 'qrCombined'
999 trainingOptionsCombinerLevel.m_max_events = maxEventsNumber
1000
1001 basf2_mva.teacher(trainingOptionsCombinerLevel, getMlpFANNCombiner())
1002
1003 if uploadFlag:
1004 basf2_mva.upload(filesDirectory + '/' + methodPrefixCombinerLevel +
1005 'FANN' + "_1.root", methodPrefixCombinerLevel + 'FANN')
1006
1007 elif TMVAfbdt and not os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + '_1.root'):
1008
1009 B2INFO('flavorTagger: Combinerlevel FBDT was already trained with this combination of categories. Weight file ' +
1010 methodPrefixCombinerLevel + 'FANN' + '_1.config has been found.')
1011
1012 else:
1013 B2FATAL('flavorTagger: Combinerlevel was already trained with this combination of categories. Weight files ' +
1014 methodPrefixCombinerLevel + 'FBDT' + '_1.root and ' +
1015 methodPrefixCombinerLevel + 'FANN' + '_1.root has been found. Please use the "Expert" mode')
1016
1017
1018def getEventLevelParticleLists(categories=None):
1019
1020 if categories is None:
1021 categories = []
1022
1023 eventLevelParticleLists = []
1024
1025 for category in categories:
1026 ftCategory = AvailableCategories[category]
1027 event_tuple = (ftCategory.particleList, ftCategory.eventName, ftCategory.variableName)
1028
1029 if event_tuple not in eventLevelParticleLists:
1030 eventLevelParticleLists.append(event_tuple)
1031 else:
1032 B2FATAL('Flavor Tagger: ' + category + ' has been already given')
1033
1034 return eventLevelParticleLists
1035
1036
1037def flavorTagger(
1038 particleLists=None,
1039 mode='Expert',
1040 weightFiles='B2nunubarBGx1',
1041 workingDirectory='.',
1042 combinerMethods=['TMVA-FBDT'],
1043 categories=[
1044 'Electron',
1045 'IntermediateElectron',
1046 'Muon',
1047 'IntermediateMuon',
1048 'KinLepton',
1049 'IntermediateKinLepton',
1050 'Kaon',
1051 'SlowPion',
1052 'FastHadron',
1053 'Lambda',
1054 'FSC',
1055 'MaximumPstar',
1056 'KaonPion'],
1057 maskName='FTDefaultMask',
1058 saveCategoriesInfo=True,
1059 useOnlyLocalWeightFiles=False,
1060 downloadFromDatabaseIfNotFound=False,
1061 uploadToDatabaseAfterTraining=False,
1062 samplerFileId='',
1063 prefix='MC16rd_light-2501-betelgeuse',
1064 useGNN=True,
1065 identifierGNN='GFlaT_MC16rd_light-2501-betelgeuse_tensorflow',
1066 usePIDNN=False,
1067 path=None,
1068):
1069 """
1070 Defines the whole flavor tagging process for each selected Rest of Event (ROE) built in the steering file.
1071 The flavor is predicted by Multivariate Methods trained with Variables and MetaVariables which use
1072 Tracks, ECL- and KLMClusters from the corresponding RestOfEvent dataobject.
1073 This module can be used to sample the training information, to train and/or to test the flavorTagger.
1074
1075 @param particleLists The ROEs for flavor tagging are selected from the given particle lists.
1076 @param mode The available modes are
1077 ``Expert`` (default), ``Sampler``, and ``Teacher``. In the ``Expert`` mode
1078 Flavor Tagging is applied to the analysis. In the ``Sampler`` mode you save
1079 the variables for training. In the ``Teacher`` mode the FlavorTagger is
1080 trained, for this step you do not reconstruct any particle or do any analysis,
1081 you just run the flavorTagger alone.
1082 @param weightFiles Weight files name. Default=
1083 ``B2nunubarBGx1`` (official weight files). If the user wants to train the
1084 FlavorTagger themselves, the weightfiles name should correspond to the
1085 analyzed CP channel in order to avoid confusions. The default name
1086 ``B2nunubarBGx1`` corresponds to
1087 :math:`B^0_{\\rm sig}\\to \\nu \\overline{\\nu}`.
1088 and ``B2JpsiKs_muBGx1`` to
1089 :math:`B^0_{\\rm sig}\\to J/\\psi (\\to \\mu^+ \\mu^-) K_s (\\to \\pi^+ \\pi^-)`.
1090 BGx1 stands for events simulated with background.
1091 @param workingDirectory Path to the directory containing the FlavorTagging/ folder.
1092 @param combinerMethods MVAs for the combiner: ``TMVA-FBDT` (default).
1093 ``FANN-MLP`` is available only with ``prefix=''`` (MC13 weight files).
1094 @param categories Categories used for flavor tagging. By default all are used.
1095 @param maskName Gets ROE particles from a specified ROE mask.
1096 ``FTDefaultMask`` (default): tentative mask definition that will be created
1097 automatically. The definition is as follows:
1098
1099 - Track (pion): thetaInCDCAcceptance and dr<1 and abs(dz)<3
1100 - ECL-cluster (gamma): thetaInCDCAcceptance and clusterNHits>1.5 and \
1101 [[clusterReg==1 and E>0.08] or [clusterReg==2 and E>0.03] or \
1102 [clusterReg==3 and E>0.06]] \
1103 (Same as gamma:pi0eff30_May2020 and gamma:pi0eff40_May2020)
1104
1105 ``all``: all ROE particles are used.
1106 Or one can give any mask name defined before calling this function.
1107 @param saveCategoriesInfo Sets to save information of individual categories.
1108 @param useOnlyLocalWeightFiles [Expert] Uses only locally saved weight files.
1109 @param downloadFromDatabaseIfNotFound [Expert] Weight files are downloaded from
1110 the conditions database if not available in workingDirectory.
1111 @param uploadToDatabaseAfterTraining [Expert] For librarians only: uploads weight files to localdb after training.
1112 @param samplerFileId Identifier to parallelize
1113 sampling. Only used in ``Sampler`` mode. If you are training by yourself and
1114 want to parallelize the sampling, you can run several sampling scripts in
1115 parallel. By changing this parameter you will not overwrite an older sample.
1116 @param prefix Prefix of weight files.
1117 ``MC16rd_light-2501-betelgeuse`` (default): Weight files trained for MC16rd samples.
1118 ``MC15ri_light-2207-bengal_0``: Weight files trained for MC15ri samples.
1119 ``''``: Weight files trained for MC13 samples.
1120 @param useGNN Use GNN-based Flavor Tagger in addition with FastBDT-based one.
1121 Please specify the weight file with the option ``identifierGNN``.
1122 [Expert] In the sampler mode,
1123 training files for GNN-based Flavor Tagger are produced.
1124 @param identifierGNN The name of weight file of the GNN-based Flavor Tagger.
1125 [Expert] Multiple identifiers can be given with list(str).
1126 @param usePIDNN If True, PID probabilities calculated from PID neural network are used
1127 (default is False). Prefix and identifierGNN must be set accordingly.
1128 @param path Modules are added to this path
1129
1130 """
1131
1132 if (not isinstance(particleLists, list)):
1133 particleLists = [particleLists] # in case user inputs a particle list as string
1134
1135 if mode != 'Sampler' and mode != 'Teacher' and mode != 'Expert':
1136 B2FATAL('flavorTagger: Wrong mode given: The available modes are "Sampler", "Teacher" or "Expert"')
1137
1138 if len(categories) != len(set(categories)):
1139 dup = [cat for cat in set(categories) if categories.count(cat) > 1]
1140 B2WARNING('Flavor Tagger: There are duplicate elements in the given categories list. '
1141 << 'The following duplicate elements are removed; ' << ', '.join(dup))
1142 categories = list(set(categories))
1143
1144 if len(categories) < 2:
1145 B2FATAL('Flavor Tagger: Invalid amount of categories. At least two are needed.')
1146 B2FATAL(
1147 'Flavor Tagger: Possible categories are "Electron", "IntermediateElectron", "Muon", "IntermediateMuon", '
1148 '"KinLepton", "IntermediateKinLepton", "Kaon", "SlowPion", "FastHadron",'
1149 '"Lambda", "FSC", "MaximumPstar" or "KaonPion" ')
1150
1151 for category in categories:
1152 if category not in AvailableCategories:
1153 B2FATAL('Flavor Tagger: ' + category + ' is not a valid category name given')
1154 B2FATAL('Flavor Tagger: Available categories are "Electron", "IntermediateElectron", '
1155 '"Muon", "IntermediateMuon", "KinLepton", "IntermediateKinLepton", "Kaon", "SlowPion", "FastHadron", '
1156 '"Lambda", "FSC", "MaximumPstar" or "KaonPion" ')
1157
1158 if mode == 'Expert' and useGNN and identifierGNN == '':
1159 B2FATAL('Please specify the name of the weight file with ``identifierGNN``')
1160
1161 # Directory where the weights of the trained Methods are saved
1162 # workingDirectory = os.environ['BELLE2_LOCAL_DIR'] + '/analysis/data'
1163
1164 basf2.find_file(workingDirectory)
1165
1166 global filesDirectory
1167 filesDirectory = workingDirectory + '/FlavorTagging/TrainedMethods'
1168
1169 if mode == 'Sampler' or (mode == 'Expert' and downloadFromDatabaseIfNotFound):
1170 if not basf2.find_file(workingDirectory + '/FlavorTagging', silent=True):
1171 os.mkdir(workingDirectory + '/FlavorTagging')
1172 os.mkdir(workingDirectory + '/FlavorTagging/TrainedMethods')
1173 elif not basf2.find_file(workingDirectory + '/FlavorTagging/TrainedMethods', silent=True):
1174 os.mkdir(workingDirectory + '/FlavorTagging/TrainedMethods')
1175 filesDirectory = workingDirectory + '/FlavorTagging/TrainedMethods'
1176
1177 if len(combinerMethods) < 1 or len(combinerMethods) > 2:
1178 B2FATAL('flavorTagger: Invalid list of combinerMethods. The available methods are "TMVA-FBDT" and "FANN-MLP"')
1179
1180 global FANNmlp
1181 global TMVAfbdt
1182
1183 FANNmlp = False
1184 TMVAfbdt = False
1185
1186 for method in combinerMethods:
1187 if method == 'TMVA-FBDT':
1188 TMVAfbdt = True
1189 elif method == 'FANN-MLP':
1190 FANNmlp = True
1191 else:
1192 B2FATAL('flavorTagger: Invalid list of combinerMethods. The available methods are "TMVA-FBDT" and "FANN-MLP"')
1193
1194 global fileId
1195 fileId = samplerFileId
1196
1197 global useOnlyLocalFlag
1198 useOnlyLocalFlag = useOnlyLocalWeightFiles
1199
1200 B2INFO('*** FLAVOR TAGGING ***')
1201 B2INFO(' ')
1202 B2INFO(' Working directory is: ' + filesDirectory)
1203 B2INFO(' ')
1204
1205 setInteractionWithDatabase(downloadFromDatabaseIfNotFound, uploadToDatabaseAfterTraining)
1206
1207 if prefix == '':
1208 set_FlavorTagger_pid_aliases_legacy()
1209 else:
1210 set_FlavorTagger_pid_aliases()
1211
1212 alias_list_for_GNN = []
1213 if useGNN:
1214 alias_list_for_GNN = set_GNNFlavorTagger_aliases(categories, usePIDNN)
1215
1216 setInputVariablesWithMask()
1217 if prefix != '':
1218 weightFiles = prefix + '_' + weightFiles
1219
1220 # Create configuration lists and code-name for given category's list
1221 trackLevelParticleLists = []
1222 eventLevelParticleLists = []
1223 variablesCombinerLevel = []
1224 categoriesCombination = []
1225 categoriesCombinationCode = 'CatCode'
1226 for category in categories:
1227 ftCategory = AvailableCategories[category]
1228
1229 track_tuple = (ftCategory.particleList, ftCategory.trackName)
1230 event_tuple = (ftCategory.particleList, ftCategory.eventName, ftCategory.variableName)
1231
1232 if track_tuple not in trackLevelParticleLists and category != 'MaximumPstar':
1233 trackLevelParticleLists.append(track_tuple)
1234
1235 if event_tuple not in eventLevelParticleLists:
1236 eventLevelParticleLists.append(event_tuple)
1237 variablesCombinerLevel.append(ftCategory.variableName)
1238 categoriesCombination.append(ftCategory.code)
1239 else:
1240 B2FATAL('Flavor Tagger: ' + category + ' has been already given')
1241
1242 for code in sorted(categoriesCombination):
1243 categoriesCombinationCode = categoriesCombinationCode + f'{int(code):02}'
1244
1245 # Create default ROE-mask
1246 if maskName == 'FTDefaultMask':
1247 FTDefaultMask = (
1248 'FTDefaultMask',
1249 'thetaInCDCAcceptance and dr<1 and abs(dz)<3',
1250 'thetaInCDCAcceptance and clusterNHits>1.5 and [[E>0.08 and clusterReg==1] or [E>0.03 and clusterReg==2] or \
1251 [E>0.06 and clusterReg==3]]')
1252 for name in particleLists:
1253 ma.appendROEMasks(list_name=name, mask_tuples=[FTDefaultMask], path=path)
1254
1255 # Start ROE-routine
1256 roe_path = basf2.create_path()
1257 deadEndPath = basf2.create_path()
1258
1259 if mode == 'Sampler':
1260 # Events containing ROE without B-Meson (but not empty) are discarded for training
1261 ma.signalSideParticleListsFilter(
1262 particleLists,
1263 'nROE_Charged(' + maskName + ', 0) > 0 and abs(qrCombined) == 1',
1264 roe_path,
1265 deadEndPath)
1266
1267 FillParticleLists(maskName, categories, roe_path)
1268
1269 if useGNN:
1270 if eventLevel('Expert', weightFiles, categories, usePIDNN, roe_path):
1271
1272 ma.rankByHighest('pi+:inRoe', 'p', numBest=0, allowMultiRank=False,
1273 outputVariable='FT_p_rank', overwriteRank=True, path=roe_path)
1274 ma.fillParticleListFromDummy('vpho:dummy', path=roe_path)
1275 ma.variablesToNtuple('vpho:dummy',
1276 alias_list_for_GNN,
1277 treename='tree',
1278 filename=f'{filesDirectory}/FlavorTagger_GNN_sampled{fileId}.root',
1279 signalSideParticleList=particleLists[0],
1280 path=roe_path)
1281
1282 else:
1283 if eventLevel(mode, weightFiles, categories, usePIDNN, roe_path):
1284 combinerLevel(mode, weightFiles, categories, variablesCombinerLevel, categoriesCombinationCode, roe_path)
1285
1286 path.for_each('RestOfEvent', 'RestOfEvents', roe_path)
1287
1288 elif mode == 'Expert':
1289 # If trigger returns 1 jump into empty path skipping further modules in roe_path
1290 # run filter with no cut first to get rid of ROEs that are missing the mask of the signal particle
1291 ma.signalSideParticleListsFilter(particleLists, 'nROE_Charged(' + maskName + ', 0) > 0', roe_path, deadEndPath)
1292
1293 # Initialization of flavorTaggerInfo dataObject needs to be done in the main path
1294 flavorTaggerInfoBuilder = basf2.register_module('FlavorTaggerInfoBuilder')
1295 path.add_module(flavorTaggerInfoBuilder)
1296
1297 FillParticleLists(maskName, categories, roe_path)
1298
1299 if eventLevel(mode, weightFiles, categories, usePIDNN, roe_path):
1300 combinerLevel(mode, weightFiles, categories, variablesCombinerLevel, categoriesCombinationCode, roe_path)
1301
1302 flavorTaggerInfoFiller = basf2.register_module('FlavorTaggerInfoFiller')
1303 flavorTaggerInfoFiller.param('trackLevelParticleLists', trackLevelParticleLists)
1304 flavorTaggerInfoFiller.param('eventLevelParticleLists', eventLevelParticleLists)
1305 flavorTaggerInfoFiller.param('TMVAfbdt', TMVAfbdt)
1306 flavorTaggerInfoFiller.param('FANNmlp', FANNmlp)
1307 flavorTaggerInfoFiller.param('qpCategories', saveCategoriesInfo)
1308 flavorTaggerInfoFiller.param('istrueCategories', saveCategoriesInfo)
1309 flavorTaggerInfoFiller.param('targetProb', False)
1310 flavorTaggerInfoFiller.param('trackPointers', False)
1311 roe_path.add_module(flavorTaggerInfoFiller) # Add FlavorTag Info filler to roe_path
1312 add_default_FlavorTagger_aliases()
1313
1314 if useGNN:
1315 ma.rankByHighest('pi+:inRoe', 'p', numBest=0, allowMultiRank=False,
1316 outputVariable='FT_p_rank', overwriteRank=True, path=roe_path)
1317 ma.fillParticleListFromDummy('vpho:dummy', path=roe_path)
1318
1319 if isinstance(identifierGNN, str):
1320 roe_path.add_module('MVAExpert',
1321 listNames='vpho:dummy',
1322 extraInfoName='qrGNN_raw', # the range of qrGNN_raw is [0,1]
1323 identifier=identifierGNN)
1324
1325 ma.variableToSignalSideExtraInfo('vpho:dummy', {'extraInfo(qrGNN_raw)*2-1': 'qrGNN'},
1326 path=roe_path)
1327
1328 elif isinstance(identifierGNN, list):
1329 identifierGNN = list(set(identifierGNN))
1330
1331 extraInfoNames = [f'qrGNN_{i_id}' for i_id in identifierGNN]
1332 roe_path.add_module('MVAMultipleExperts',
1333 listNames='vpho:dummy',
1334 extraInfoNames=extraInfoNames,
1335 identifiers=identifierGNN)
1336
1337 extraInfoDict = {}
1338 for extraInfoName in extraInfoNames:
1339 extraInfoDict[f'extraInfo({extraInfoName})*2-1'] = extraInfoName
1340 variables.variables.addAlias(extraInfoName, f'extraInfo({extraInfoName})')
1341
1342 ma.variableToSignalSideExtraInfo('vpho:dummy', extraInfoDict,
1343 path=roe_path)
1344
1345 path.for_each('RestOfEvent', 'RestOfEvents', roe_path)
1346
1347 elif mode == 'Teacher':
1348 if eventLevelTeacher(weightFiles, categories, usePIDNN):
1349 combinerLevelTeacher(weightFiles, variablesCombinerLevel, categoriesCombinationCode)
1350
1351
1352if __name__ == '__main__':
1353 from basf2.utils import pretty_print_module
1354 pretty_print_module(__name__, "flavorTagger")
1355
def isB2BII()
Definition: b2bii.py:14