Belle II Software prerelease-10-00-00a
flavorTagger.py
1#!/usr/bin/env python3
2
3
10
11# ************* Flavor Tagging ************
12# * This script is needed to train *
13# * and to test the flavor tagger. *
14# ********************************************
15
16from basf2 import B2INFO, B2FATAL, B2WARNING
17import basf2
18import basf2_mva
19import modularAnalysis as ma
20import variables
21from variables import utils
22import os
23import glob
24import b2bii
25import collections
26
27
28def getBelleOrBelle2():
29 """
30 Gets the global ModeCode.
31 """
32 if b2bii.isB2BII():
33 return 'Belle'
34 else:
35 return 'Belle2'
36
37
38def setInteractionWithDatabase(downloadFromDatabaseIfNotFound=False, uploadToDatabaseAfterTraining=False):
39 """
40 Sets the interaction with the database: download trained weight files or upload weight files after training.
41 """
42
43 global downloadFlag
44 global uploadFlag
45
46 downloadFlag = downloadFromDatabaseIfNotFound
47 uploadFlag = uploadToDatabaseAfterTraining
48
49
50# Default list of aliases that should be used to save the flavor tagging information using VariablesToNtuple
51flavor_tagging = ['FBDT_qrCombined', 'FANN_qrCombined', 'qrMC', 'mcFlavorOfOtherB', 'qrGNN',
52 'qpElectron', 'hasTrueTargetElectron', 'isRightCategoryElectron',
53 'qpIntermediateElectron', 'hasTrueTargetIntermediateElectron', 'isRightCategoryIntermediateElectron',
54 'qpMuon', 'hasTrueTargetMuon', 'isRightCategoryMuon',
55 'qpIntermediateMuon', 'hasTrueTargetIntermediateMuon', 'isRightCategoryIntermediateMuon',
56 'qpKinLepton', 'hasTrueTargetKinLepton', 'isRightCategoryKinLepton',
57 'qpIntermediateKinLepton', 'hasTrueTargetIntermediateKinLepton', 'isRightCategoryIntermediateKinLepton',
58 'qpKaon', 'hasTrueTargetKaon', 'isRightCategoryKaon',
59 'qpSlowPion', 'hasTrueTargetSlowPion', 'isRightCategorySlowPion',
60 'qpFastHadron', 'hasTrueTargetFastHadron', 'isRightCategoryFastHadron',
61 'qpLambda', 'hasTrueTargetLambda', 'isRightCategoryLambda',
62 'qpFSC', 'hasTrueTargetFSC', 'isRightCategoryFSC',
63 'qpMaximumPstar', 'hasTrueTargetMaximumPstar', 'isRightCategoryMaximumPstar',
64 'qpKaonPion', 'hasTrueTargetKaonPion', 'isRightCategoryKaonPion']
65
66
67def add_default_FlavorTagger_aliases():
68 """
69 This function adds the default aliases for flavor tagging variables
70 and defines the collection of flavor tagging variables.
71 """
72
73 variables.variables.addAlias('FBDT_qrCombined', 'qrOutput(FBDT)')
74 variables.variables.addAlias('FANN_qrCombined', 'qrOutput(FANN)')
75 variables.variables.addAlias('qrMC', 'isRelatedRestOfEventB0Flavor')
76
77 variables.variables.addAlias('qrGNN', 'extraInfo(qrGNN)')
78
79 for iCategory in AvailableCategories:
80 aliasForQp = 'qp' + iCategory
81 aliasForTrueTarget = 'hasTrueTarget' + iCategory
82 aliasForIsRightCategory = 'isRightCategory' + iCategory
83 variables.variables.addAlias(aliasForQp, 'qpCategory(' + iCategory + ')')
84 variables.variables.addAlias(aliasForTrueTarget, 'hasTrueTargets(' + iCategory + ')')
85 variables.variables.addAlias(aliasForIsRightCategory, 'isTrueFTCategory(' + iCategory + ')')
86
87 utils.add_collection(flavor_tagging, 'flavor_tagging')
88
89
90def set_FlavorTagger_pid_aliases():
91 """
92 This function adds the pid aliases needed by the flavor tagger.
93 """
94 variables.variables.addAlias('eid_TOP', 'pidPairProbabilityExpert(11, 211, TOP)')
95 variables.variables.addAlias('eid_ARICH', 'pidPairProbabilityExpert(11, 211, ARICH)')
96 variables.variables.addAlias('eid_ECL', 'pidPairProbabilityExpert(11, 211, ECL)')
97
98 variables.variables.addAlias('muid_TOP', 'pidPairProbabilityExpert(13, 211, TOP)')
99 variables.variables.addAlias('muid_ARICH', 'pidPairProbabilityExpert(13, 211, ARICH)')
100 variables.variables.addAlias('muid_KLM', 'pidPairProbabilityExpert(13, 211, KLM)')
101
102 variables.variables.addAlias('piid_TOP', 'pidPairProbabilityExpert(211, 321, TOP)')
103 variables.variables.addAlias('piid_ARICH', 'pidPairProbabilityExpert(211, 321, ARICH)')
104
105 variables.variables.addAlias('Kid_TOP', 'pidPairProbabilityExpert(321, 211, TOP)')
106 variables.variables.addAlias('Kid_ARICH', 'pidPairProbabilityExpert(321, 211, ARICH)')
107
108 if getBelleOrBelle2() == "Belle":
109 variables.variables.addAlias('eid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, CDC, SVD), 0.5)')
110 variables.variables.addAlias('muid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, CDC, SVD), 0.5)')
111 variables.variables.addAlias('piid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, CDC, SVD), 0.5)')
112 variables.variables.addAlias('pi_vs_edEdxid', 'ifNANgiveX(pidPairProbabilityExpert(211, 11, CDC, SVD), 0.5)')
113 variables.variables.addAlias('Kid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, CDC, SVD), 0.5)')
114 else:
115 variables.variables.addAlias('eid_dEdx', 'pidPairProbabilityExpert(11, 211, CDC)')
116 variables.variables.addAlias('muid_dEdx', 'pidPairProbabilityExpert(13, 211, CDC)')
117 variables.variables.addAlias('piid_dEdx', 'pidPairProbabilityExpert(211, 321, CDC)')
118 variables.variables.addAlias('pi_vs_edEdxid', 'pidPairProbabilityExpert(211, 11, CDC)')
119 variables.variables.addAlias('Kid_dEdx', 'pidPairProbabilityExpert(321, 211, CDC)')
120
121
122def set_FlavorTagger_pid_aliases_legacy():
123 """
124 This function adds the pid aliases needed by the flavor tagger trained for MC13.
125 """
126 variables.variables.addAlias('eid_TOP', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, TOP), 0.5)')
127 variables.variables.addAlias('eid_ARICH', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, ARICH), 0.5)')
128 variables.variables.addAlias('eid_ECL', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, ECL), 0.5)')
129
130 variables.variables.addAlias('muid_TOP', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, TOP), 0.5)')
131 variables.variables.addAlias('muid_ARICH', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, ARICH), 0.5)')
132 variables.variables.addAlias('muid_KLM', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, KLM), 0.5)')
133
134 variables.variables.addAlias('piid_TOP', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, TOP), 0.5)')
135 variables.variables.addAlias('piid_ARICH', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, ARICH), 0.5)')
136
137 variables.variables.addAlias('Kid_TOP', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, TOP), 0.5)')
138 variables.variables.addAlias('Kid_ARICH', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, ARICH), 0.5)')
139
140 if getBelleOrBelle2() == "Belle":
141 variables.variables.addAlias('eid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, CDC, SVD), 0.5)')
142 variables.variables.addAlias('muid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, CDC, SVD), 0.5)')
143 variables.variables.addAlias('piid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, CDC, SVD), 0.5)')
144 variables.variables.addAlias('pi_vs_edEdxid', 'ifNANgiveX(pidPairProbabilityExpert(211, 11, CDC, SVD), 0.5)')
145 variables.variables.addAlias('Kid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, CDC, SVD), 0.5)')
146 else:
147 # Removed SVD PID for Belle II MC and data as it is absent in release 4.
148 variables.variables.addAlias('eid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, CDC), 0.5)')
149 variables.variables.addAlias('muid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, CDC), 0.5)')
150 variables.variables.addAlias('piid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, CDC), 0.5)')
151 variables.variables.addAlias('pi_vs_edEdxid', 'ifNANgiveX(pidPairProbabilityExpert(211, 11, CDC), 0.5)')
152 variables.variables.addAlias('Kid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, CDC), 0.5)')
153
154
155def set_GNNFlavorTagger_aliases(categories, usePIDNN):
156 """
157 This function adds aliases for the GNN-based flavor tagger.
158 """
159
160 # will be used for target variable 0:B0bar, 1:B0
161 variables.variables.addAlias('qrCombined_bit', '(qrCombined+1)/2')
162 alias_list = ['qrCombined_bit']
163
164 var_dict = {
165 # position
166 'dx': 'dx',
167 'dy': 'dy',
168 'dz': 'dz',
169 # mask
170 'E': 'E',
171 # charge,
172 'charge': 'charge',
173 # feature
174 'px_c': 'px*charge',
175 'py_c': 'py*charge',
176 'pz_c': 'pz*charge',
177 'electronID_c': 'electronIDNN*charge' if usePIDNN else 'electronID*charge',
178 'muonID_c': 'muonIDNN*charge' if usePIDNN else 'muonID*charge',
179 'pionID_c': 'pionIDNN*charge' if usePIDNN else 'pionID*charge',
180 'kaonID_c': 'kaonIDNN*charge' if usePIDNN else 'kaonID*charge',
181 'protonID_c': 'protonIDNN*charge' if usePIDNN else 'protonID*charge',
182 'deuteronID_c': 'deuteronIDNN*charge' if usePIDNN else 'deuteronID*charge',
183 'electronID_noSVD_noTOP_c': 'electronID_noSVD_noTOP*charge',
184 }
185
186 # 16 charged particles are used at most
187 for rank in range(1, 17):
188
189 for cat in categories:
190 listName = AvailableCategories[cat].particleList
191 varName = f'QpTrack({listName}, isRightCategory({cat}), isRightCategory({cat}))'
192
193 varWithRank = f'ifNANgiveX(getVariableByRank(pi+:inRoe, FT_p, {varName}, {rank}), 0)'
194 aliasWithRank = f'{cat}_rank{rank}'
195
196 variables.variables.addAlias(aliasWithRank, varWithRank)
197 alias_list.append(aliasWithRank)
198
199 for alias, var in var_dict.items():
200 varWithRank = f'ifNANgiveX(getVariableByRank(pi+:inRoe, FT_p, {var}, {rank}), 0)'
201 aliasWithRank = f'{alias}_rank{rank}'
202
203 variables.variables.addAlias(aliasWithRank, varWithRank)
204 alias_list.append(aliasWithRank)
205
206 return alias_list
207
208
209def setInputVariablesWithMask(maskName='all'):
210 """
211 Set aliases for input variables with ROE mask.
212 """
213 variables.variables.addAlias('pMissTag_withMask', 'pMissTag('+maskName+')')
214 variables.variables.addAlias('cosTPTO_withMask', 'cosTPTO('+maskName+')')
215 variables.variables.addAlias('ptTracksRoe_withMask', 'ptTracksRoe('+maskName+')')
216 variables.variables.addAlias('pt2TracksRoe_withMask', 'pt2TracksRoe('+maskName+')')
217 variables.variables.addAlias('ptTracksRoe_withMask', 'ptTracksRoe('+maskName+')')
218
219
220def getFastBDTCategories():
221 '''
222 Helper function for getting the FastBDT categories.
223 It's necessary for removing top-level ROOT imports.
224 '''
225 fastBDTCategories = basf2_mva.FastBDTOptions()
226 fastBDTCategories.m_nTrees = 500
227 fastBDTCategories.m_nCuts = 8
228 fastBDTCategories.m_nLevels = 3
229 fastBDTCategories.m_shrinkage = 0.10
230 fastBDTCategories.m_randRatio = 0.5
231 return fastBDTCategories
232
233
234def getFastBDTCombiner():
235 '''
236 Helper function for getting the FastBDT combiner.
237 It's necessary for removing top-level ROOT imports.
238 '''
239 fastBDTCombiner = basf2_mva.FastBDTOptions()
240 fastBDTCombiner.m_nTrees = 500
241 fastBDTCombiner.m_nCuts = 8
242 fastBDTCombiner.m_nLevels = 3
243 fastBDTCombiner.m_shrinkage = 0.10
244 fastBDTCombiner.m_randRatio = 0.5
245 return fastBDTCombiner
246
247
248def getMlpFANNCombiner():
249 '''
250 Helper function for getting the MLP FANN combiner.
251 It's necessary for removing top-level ROOT imports.
252 '''
253 mlpFANNCombiner = basf2_mva.FANNOptions()
254 mlpFANNCombiner.m_max_epochs = 10000
255 mlpFANNCombiner.m_hidden_layers_architecture = "3*N"
256 mlpFANNCombiner.m_hidden_activiation_function = "FANN_SIGMOID_SYMMETRIC"
257 mlpFANNCombiner.m_output_activiation_function = "FANN_SIGMOID_SYMMETRIC"
258 mlpFANNCombiner.m_error_function = "FANN_ERRORFUNC_LINEAR"
259 mlpFANNCombiner.m_training_method = "FANN_TRAIN_RPROP"
260 mlpFANNCombiner.m_validation_fraction = 0.5
261 mlpFANNCombiner.m_random_seeds = 10
262 mlpFANNCombiner.m_test_rate = 500
263 mlpFANNCombiner.m_number_of_threads = 8
264 mlpFANNCombiner.m_scale_features = True
265 mlpFANNCombiner.m_scale_target = False
266 # mlpFANNCombiner.m_scale_target = True
267 return mlpFANNCombiner
268
269
270# SignalFraction: FBDT feature
271# For smooth output set to -1, this will break the calibration.
272# For correct calibration set to -2, leads to peaky combiner output.
273signalFraction = -2
274
275# Maximal number of events to train each method
276maxEventsNumber = 0 # 0 takes all the sampled events. The number in the past was 500000
277
278
279FTCategoryParameters = collections.namedtuple('FTCategoryParameters',
280 ['particleList', 'trackName', 'eventName', 'variableName', 'code'])
281# Definition of all available categories, 'standard category name':
282# ['ParticleList', 'trackLevel category name', 'eventLevel category name',
283# 'combinerLevel variable name', 'category code']
284AvailableCategories = {
285 'Electron':
286 FTCategoryParameters('e+:inRoe', 'Electron', 'Electron',
287 'QpOf(e+:inRoe, isRightCategory(Electron), isRightCategory(Electron))',
288 0),
289 'IntermediateElectron':
290 FTCategoryParameters('e+:inRoe', 'IntermediateElectron', 'IntermediateElectron',
291 'QpOf(e+:inRoe, isRightCategory(IntermediateElectron), isRightCategory(IntermediateElectron))',
292 1),
293 'Muon':
294 FTCategoryParameters('mu+:inRoe', 'Muon', 'Muon',
295 'QpOf(mu+:inRoe, isRightCategory(Muon), isRightCategory(Muon))',
296 2),
297 'IntermediateMuon':
298 FTCategoryParameters('mu+:inRoe', 'IntermediateMuon', 'IntermediateMuon',
299 'QpOf(mu+:inRoe, isRightCategory(IntermediateMuon), isRightCategory(IntermediateMuon))',
300 3),
301 'KinLepton':
302 FTCategoryParameters('mu+:inRoe', 'KinLepton', 'KinLepton',
303 'QpOf(mu+:inRoe, isRightCategory(KinLepton), isRightCategory(KinLepton))',
304 4),
305 'IntermediateKinLepton':
306 FTCategoryParameters('mu+:inRoe', 'IntermediateKinLepton', 'IntermediateKinLepton',
307 'QpOf(mu+:inRoe, isRightCategory(IntermediateKinLepton), isRightCategory(IntermediateKinLepton))',
308 5),
309 'Kaon':
310 FTCategoryParameters('K+:inRoe', 'Kaon', 'Kaon',
311 'weightedQpOf(K+:inRoe, isRightCategory(Kaon), isRightCategory(Kaon))',
312 6),
313 'SlowPion':
314 FTCategoryParameters('pi+:inRoe', 'SlowPion', 'SlowPion',
315 'QpOf(pi+:inRoe, isRightCategory(SlowPion), isRightCategory(SlowPion))',
316 7),
317 'FastHadron':
318 FTCategoryParameters('pi+:inRoe', 'FastHadron', 'FastHadron',
319 'QpOf(pi+:inRoe, isRightCategory(FastHadron), isRightCategory(FastHadron))',
320 8),
321 'Lambda':
322 FTCategoryParameters('Lambda0:inRoe', 'Lambda', 'Lambda',
323 'weightedQpOf(Lambda0:inRoe, isRightCategory(Lambda), isRightCategory(Lambda))',
324 9),
325 'FSC':
326 FTCategoryParameters('pi+:inRoe', 'SlowPion', 'FSC',
327 'QpOf(pi+:inRoe, isRightCategory(FSC), isRightCategory(SlowPion))',
328 10),
329 'MaximumPstar':
330 FTCategoryParameters('pi+:inRoe', 'MaximumPstar', 'MaximumPstar',
331 'QpOf(pi+:inRoe, isRightCategory(MaximumPstar), isRightCategory(MaximumPstar))',
332 11),
333 'KaonPion':
334 FTCategoryParameters('K+:inRoe', 'Kaon', 'KaonPion',
335 'QpOf(K+:inRoe, isRightCategory(KaonPion), isRightCategory(Kaon))',
336 12),
337}
338
339
340def getTrainingVariables(category=None, usePIDNN=False):
341 """
342 Helper function to get training variables.
343
344 NOTE: This function is not called the Expert mode. It is not necessary to be consistent with variables list of weight files.
345 """
346
347 KId = {'Belle': 'ifNANgiveX(atcPIDBelle(3,2), 0.5)', 'Belle2': 'kaonIDNN' if usePIDNN else 'kaonID'}
348 muId = {'Belle': 'muIDBelle', 'Belle2': 'muonIDNN' if usePIDNN else 'muonID'}
349 eId = {'Belle': 'eIDBelle', 'Belle2': 'electronIDNN' if usePIDNN else 'electronID'}
350
351 variables = []
352 if category == 'Electron' or category == 'IntermediateElectron':
353 variables = ['useCMSFrame(p)',
354 'useCMSFrame(pt)',
355 'p',
356 'pt',
357 'cosTheta',
358 eId[getBelleOrBelle2()],
359 'eid_TOP',
360 'eid_ARICH',
361 'eid_ECL',
362 'BtagToWBosonVariables(recoilMassSqrd)',
363 'BtagToWBosonVariables(pMissCMS)',
364 'BtagToWBosonVariables(cosThetaMissCMS)',
365 'BtagToWBosonVariables(EW90)',
366 'cosTPTO',
367 'chiProb',
368 ]
369 if getBelleOrBelle2() == "Belle":
370 variables.append('eid_dEdx')
371 variables.append('ImpactXY')
372 variables.append('distance')
373
374 elif category == 'Muon' or category == 'IntermediateMuon':
375 variables = ['useCMSFrame(p)',
376 'useCMSFrame(pt)',
377 'p',
378 'pt',
379 'cosTheta',
380 muId[getBelleOrBelle2()],
381 'muid_TOP',
382 'muid_ARICH',
383 'muid_KLM',
384 'BtagToWBosonVariables(recoilMassSqrd)',
385 'BtagToWBosonVariables(pMissCMS)',
386 'BtagToWBosonVariables(cosThetaMissCMS)',
387 'BtagToWBosonVariables(EW90)',
388 'cosTPTO',
389 ]
390 if getBelleOrBelle2() == "Belle":
391 variables.append('muid_dEdx')
392 variables.append('ImpactXY')
393 variables.append('distance')
394 variables.append('chiProb')
395
396 elif category == 'KinLepton' or category == 'IntermediateKinLepton':
397 variables = ['useCMSFrame(p)',
398 'useCMSFrame(pt)',
399 'p',
400 'pt',
401 'cosTheta',
402 muId[getBelleOrBelle2()],
403 'muid_TOP',
404 'muid_ARICH',
405 'muid_KLM',
406 eId[getBelleOrBelle2()],
407 'eid_TOP',
408 'eid_ARICH',
409 'eid_ECL',
410 'BtagToWBosonVariables(recoilMassSqrd)',
411 'BtagToWBosonVariables(pMissCMS)',
412 'BtagToWBosonVariables(cosThetaMissCMS)',
413 'BtagToWBosonVariables(EW90)',
414 'cosTPTO',
415 ]
416 if getBelleOrBelle2() == "Belle":
417 variables.append('eid_dEdx')
418 variables.append('muid_dEdx')
419 variables.append('ImpactXY')
420 variables.append('distance')
421 variables.append('chiProb')
422
423 elif category == 'Kaon':
424 variables = ['useCMSFrame(p)',
425 'useCMSFrame(pt)',
426 'p',
427 'pt',
428 'cosTheta',
429 KId[getBelleOrBelle2()],
430 'Kid_dEdx',
431 'Kid_TOP',
432 'Kid_ARICH',
433 'NumberOfKShortsInRoe',
434 'ptTracksRoe',
435 'BtagToWBosonVariables(recoilMassSqrd)',
436 'BtagToWBosonVariables(pMissCMS)',
437 'BtagToWBosonVariables(cosThetaMissCMS)',
438 'BtagToWBosonVariables(EW90)',
439 'cosTPTO',
440 'chiProb',
441 ]
442 if getBelleOrBelle2() == "Belle":
443 variables.append('ImpactXY')
444 variables.append('distance')
445
446 elif category == 'SlowPion':
447 variables = ['useCMSFrame(p)',
448 'useCMSFrame(pt)',
449 'cosTheta',
450 'p',
451 'pt',
452 'pionIDNN' if usePIDNN else 'pionID',
453 'piid_TOP',
454 'piid_ARICH',
455 'pi_vs_edEdxid',
456 KId[getBelleOrBelle2()],
457 'Kid_dEdx',
458 'Kid_TOP',
459 'Kid_ARICH',
460 'NumberOfKShortsInRoe',
461 'ptTracksRoe',
462 eId[getBelleOrBelle2()],
463 'BtagToWBosonVariables(recoilMassSqrd)',
464 'BtagToWBosonVariables(EW90)',
465 'BtagToWBosonVariables(cosThetaMissCMS)',
466 'BtagToWBosonVariables(pMissCMS)',
467 'cosTPTO',
468 ]
469 if getBelleOrBelle2() == "Belle":
470 variables.append('piid_dEdx')
471 variables.append('ImpactXY')
472 variables.append('distance')
473 variables.append('chiProb')
474
475 elif category == 'FastHadron':
476 variables = ['useCMSFrame(p)',
477 'useCMSFrame(pt)',
478 'cosTheta',
479 'p',
480 'pt',
481 'pionIDNN' if usePIDNN else 'pionID',
482 'piid_dEdx',
483 'piid_TOP',
484 'piid_ARICH',
485 'pi_vs_edEdxid',
486 KId[getBelleOrBelle2()],
487 'Kid_dEdx',
488 'Kid_TOP',
489 'Kid_ARICH',
490 'NumberOfKShortsInRoe',
491 'ptTracksRoe',
492 eId[getBelleOrBelle2()],
493 'BtagToWBosonVariables(recoilMassSqrd)',
494 'BtagToWBosonVariables(EW90)',
495 'BtagToWBosonVariables(cosThetaMissCMS)',
496 'cosTPTO',
497 ]
498 if getBelleOrBelle2() == "Belle":
499 variables.append('BtagToWBosonVariables(pMissCMS)')
500 variables.append('ImpactXY')
501 variables.append('distance')
502 variables.append('chiProb')
503
504 elif category == 'Lambda':
505 variables = ['lambdaFlavor',
506 'NumberOfKShortsInRoe',
507 'M',
508 'cosAngleBetweenMomentumAndVertexVector',
509 'lambdaZError',
510 'daughter(0,p)',
511 'daughter(0,useCMSFrame(p))',
512 'daughter(1,p)',
513 'daughter(1,useCMSFrame(p))',
514 'useCMSFrame(p)',
515 'p',
516 'chiProb',
517 ]
518 if getBelleOrBelle2() == "Belle2":
519 # protonID always 0 in B2BII check in future
520 variables.append('daughter(1,protonIDNN)' if usePIDNN else 'daughter(1,protonID)')
521 # not very powerful in B2BII
522 variables.append('daughter(0,pionIDNN)' if usePIDNN else 'daughter(0,pionID)')
523 else:
524 variables.append('distance')
525
526 elif category == 'MaximumPstar':
527 variables = ['useCMSFrame(p)',
528 'useCMSFrame(pt)',
529 'p',
530 'pt',
531 'cosTPTO',
532 ]
533 if getBelleOrBelle2() == "Belle2":
534 variables.append('ImpactXY')
535 variables.append('distance')
536
537 elif category == 'FSC':
538 variables = ['useCMSFrame(p)',
539 'cosTPTO',
540 KId[getBelleOrBelle2()],
541 'FSCVariables(pFastCMS)',
542 'FSCVariables(cosSlowFast)',
543 'FSCVariables(cosTPTOFast)',
544 'FSCVariables(SlowFastHaveOpositeCharges)',
545 ]
546 elif category == 'KaonPion':
547 variables = ['extraInfo(isRightCategory(Kaon))',
548 'HighestProbInCat(pi+:inRoe, isRightCategory(SlowPion))',
549 'KaonPionVariables(cosKaonPion)',
550 'KaonPionVariables(HaveOpositeCharges)',
551 KId[getBelleOrBelle2()]
552 ]
553
554 return variables
555
556
557def FillParticleLists(maskName='all', categories=None, path=None):
558 """
559 Fills the particle Lists for all categories.
560 """
561
562 from vertex import kFit
563 readyParticleLists = []
564
565 if categories is None:
566 categories = []
567
568 trackCut = 'isInRestOfEvent > 0.5 and passesROEMask(' + maskName + ') > 0.5 and p >= 0'
569
570 for category in categories:
571 particleList = AvailableCategories[category].particleList
572
573 if particleList in readyParticleLists:
574 continue
575
576 # Select particles in ROE for different categories according to mass hypothesis.
577 if particleList == 'Lambda0:inRoe':
578 if 'pi+:inRoe' not in readyParticleLists:
579 ma.fillParticleList('pi+:inRoe', trackCut, path=path)
580 readyParticleLists.append('pi+:inRoe')
581
582 ma.fillParticleList('p+:inRoe', trackCut, path=path)
583 ma.reconstructDecay(particleList + ' -> pi-:inRoe p+:inRoe', '1.00<=M<=1.23', False, path=path)
584 kFit(particleList, 0.01, path=path)
585 ma.matchMCTruth(particleList, path=path)
586 readyParticleLists.append(particleList)
587
588 else:
589 # Filling particle list for actual category
590 ma.fillParticleList(particleList, trackCut, path=path)
591 readyParticleLists.append(particleList)
592
593 # Additional particleLists for K_S0
594 if getBelleOrBelle2() == 'Belle':
595 ma.cutAndCopyList('K_S0:inRoe', 'K_S0:mdst', 'extraInfo(ksnbStandard) == 1 and isInRestOfEvent == 1', path=path)
596 else:
597 if 'pi+:inRoe' not in readyParticleLists:
598 ma.fillParticleList('pi+:inRoe', trackCut, path=path)
599 ma.reconstructDecay('K_S0:inRoe -> pi+:inRoe pi-:inRoe', '0.40<=M<=0.60', False, path=path)
600 kFit('K_S0:inRoe', 0.01, path=path)
601
602 # Apply BDT-based LID
603 if getBelleOrBelle2() == 'Belle2':
604 default_list_for_lid_BDT = ['e+:inRoe', 'mu+:inRoe']
605 list_for_lid_BDT = []
606
607 for particleList in default_list_for_lid_BDT:
608 if particleList in readyParticleLists:
609 list_for_lid_BDT.append(particleList)
610
611 if list_for_lid_BDT: # empty check
612 ma.applyChargedPidMVA(particleLists=list_for_lid_BDT, path=path,
613 trainingMode=0, # binary
614 binaryHypoPDGCodes=(11, 211)) # e vs pi
615 ma.applyChargedPidMVA(particleLists=list_for_lid_BDT, path=path,
616 trainingMode=0, # binary
617 binaryHypoPDGCodes=(13, 211)) # mu vs pi
618 ma.applyChargedPidMVA(particleLists=list_for_lid_BDT, path=path,
619 trainingMode=1) # Multiclass
620
621
622def eventLevel(mode='Expert', weightFiles='B2JpsiKs_mu', categories=None, usePIDNN=False, path=None):
623 """
624 Samples data for training or tests all categories all categories at event level.
625 """
626
627 from basf2 import create_path
628 from basf2 import register_module
629
630 B2INFO('EVENT LEVEL')
631
632 ReadyMethods = 0
633
634 # Each category has its own Path in order to be skipped if the corresponding particle list is empty
635 identifiersExtraInfosDict = dict()
636 identifiersExtraInfosKaonPion = []
637
638 if categories is None:
639 categories = []
640
641 for category in categories:
642 particleList = AvailableCategories[category].particleList
643
644 methodPrefixEventLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'EventLevel' + category + 'FBDT'
645 identifierEventLevel = methodPrefixEventLevel
646 targetVariable = 'isRightCategory(' + category + ')'
647 extraInfoName = targetVariable
648
649 if mode == 'Expert':
650
651 if downloadFlag or useOnlyLocalFlag:
652 identifierEventLevel = filesDirectory + '/' + methodPrefixEventLevel + '_1.root'
653
654 if downloadFlag:
655 if not os.path.isfile(identifierEventLevel):
656 basf2_mva.download(methodPrefixEventLevel, identifierEventLevel)
657 if not os.path.isfile(identifierEventLevel):
658 B2FATAL('Flavor Tagger: Weight file ' + identifierEventLevel +
659 ' was not downloaded from Database. Please check the buildOrRevision name. Stopped')
660
661 if useOnlyLocalFlag:
662 if not os.path.isfile(identifierEventLevel):
663 B2FATAL('Flavor Tagger: ' + particleList + ' Eventlevel was not trained. Weight file ' +
664 identifierEventLevel + ' was not found. Stopped')
665
666 B2INFO('flavorTagger: MVAExpert ' + methodPrefixEventLevel + ' ready.')
667
668 elif mode == 'Sampler':
669
670 identifierEventLevel = filesDirectory + '/' + methodPrefixEventLevel + '_1.root'
671 if os.path.isfile(identifierEventLevel):
672 B2INFO('flavorTagger: MVAExpert ' + methodPrefixEventLevel + ' ready.')
673
674 if 'KaonPion' in categories:
675 methodPrefixEventLevelKaonPion = "FlavorTagger_" + getBelleOrBelle2() + \
676 "_" + weightFiles + 'EventLevelKaonPionFBDT'
677 identifierEventLevelKaonPion = filesDirectory + '/' + methodPrefixEventLevelKaonPion + '_1.root'
678 if not os.path.isfile(identifierEventLevelKaonPion):
679 # Slow Pion and Kaon categories are used if Kaon-Pion is lacking for
680 # sampling. The others are not needed and skipped
681 if category != "SlowPion" and category != "Kaon":
682 continue
683
684 if mode == 'Expert' or (mode == 'Sampler' and os.path.isfile(identifierEventLevel)):
685
686 B2INFO('flavorTagger: Applying MVAExpert ' + methodPrefixEventLevel + '.')
687
688 if category == 'KaonPion':
689 identifiersExtraInfosKaonPion.append((extraInfoName, identifierEventLevel))
690 elif particleList not in identifiersExtraInfosDict:
691 identifiersExtraInfosDict[particleList] = [(extraInfoName, identifierEventLevel)]
692 else:
693 identifiersExtraInfosDict[particleList].append((extraInfoName, identifierEventLevel))
694
695 ReadyMethods += 1
696
697 # Each category has its own Path in order to be skipped if the corresponding particle list is empty
698 for particleList in identifiersExtraInfosDict:
699 eventLevelPath = create_path()
700 SkipEmptyParticleList = register_module("SkimFilter")
701 SkipEmptyParticleList.set_name('SkimFilter_EventLevel_' + particleList)
702 SkipEmptyParticleList.param('particleLists', particleList)
703 SkipEmptyParticleList.if_true(eventLevelPath, basf2.AfterConditionPath.CONTINUE)
704 path.add_module(SkipEmptyParticleList)
705
706 mvaMultipleExperts = register_module('MVAMultipleExperts')
707 mvaMultipleExperts.set_name('MVAMultipleExperts_EventLevel_' + particleList)
708 mvaMultipleExperts.param('listNames', [particleList])
709 mvaMultipleExperts.param('extraInfoNames', [row[0] for row in identifiersExtraInfosDict[particleList]])
710 mvaMultipleExperts.param('signalFraction', signalFraction)
711 mvaMultipleExperts.param('identifiers', [row[1] for row in identifiersExtraInfosDict[particleList]])
712 eventLevelPath.add_module(mvaMultipleExperts)
713
714 if 'KaonPion' in categories and len(identifiersExtraInfosKaonPion) != 0:
715 eventLevelKaonPionPath = create_path()
716 SkipEmptyParticleList = register_module("SkimFilter")
717 SkipEmptyParticleList.set_name('SkimFilter_' + 'K+:inRoe')
718 SkipEmptyParticleList.param('particleLists', 'K+:inRoe')
719 SkipEmptyParticleList.if_true(eventLevelKaonPionPath, basf2.AfterConditionPath.CONTINUE)
720 path.add_module(SkipEmptyParticleList)
721
722 mvaExpertKaonPion = register_module("MVAExpert")
723 mvaExpertKaonPion.set_name('MVAExpert_KaonPion_' + 'K+:inRoe')
724 mvaExpertKaonPion.param('listNames', ['K+:inRoe'])
725 mvaExpertKaonPion.param('extraInfoName', identifiersExtraInfosKaonPion[0][0])
726 mvaExpertKaonPion.param('signalFraction', signalFraction)
727 mvaExpertKaonPion.param('identifier', identifiersExtraInfosKaonPion[0][1])
728
729 eventLevelKaonPionPath.add_module(mvaExpertKaonPion)
730
731 if mode == 'Sampler':
732
733 for category in categories:
734 particleList = AvailableCategories[category].particleList
735
736 methodPrefixEventLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'EventLevel' + category + 'FBDT'
737 identifierEventLevel = filesDirectory + '/' + methodPrefixEventLevel + '_1.root'
738 targetVariable = 'isRightCategory(' + category + ')'
739
740 if not os.path.isfile(identifierEventLevel):
741
742 if category == 'KaonPion':
743 methodPrefixEventLevelSlowPion = "FlavorTagger_" + getBelleOrBelle2() + \
744 "_" + weightFiles + 'EventLevelSlowPionFBDT'
745 identifierEventLevelSlowPion = filesDirectory + '/' + methodPrefixEventLevelSlowPion + '_1.root'
746 if not os.path.isfile(identifierEventLevelSlowPion):
747 B2INFO("Flavor Tagger: event level weight file for the Slow Pion category is absent." +
748 "It is required to sample the training information for the KaonPion category." +
749 "An additional sampling step will be needed after the following training step.")
750 continue
751
752 B2INFO('flavorTagger: file ' + filesDirectory + '/' +
753 methodPrefixEventLevel + "sampled" + fileId + '.root will be saved.')
754
755 ma.applyCuts(particleList, 'isRightCategory(mcAssociated) > 0', path)
756 eventLevelpath = create_path()
757 SkipEmptyParticleList = register_module("SkimFilter")
758 SkipEmptyParticleList.set_name('SkimFilter_EventLevel' + category)
759 SkipEmptyParticleList.param('particleLists', particleList)
760 SkipEmptyParticleList.if_true(eventLevelpath, basf2.AfterConditionPath.CONTINUE)
761 path.add_module(SkipEmptyParticleList)
762
763 ntuple = register_module('VariablesToNtuple')
764 ntuple.param('fileName', filesDirectory + '/' + methodPrefixEventLevel + "sampled" + fileId + ".root")
765 ntuple.param('treeName', methodPrefixEventLevel + "_tree")
766 variablesToBeSaved = getTrainingVariables(category, usePIDNN) + [
767 targetVariable,
768 'ancestorHasWhichFlavor',
769 'isSignal',
770 'mcPDG',
771 'mcErrors',
772 'genMotherPDG',
773 'nMCMatches',
774 'B0mcErrors'
775 ]
776 if category != 'KaonPion' and category != 'FSC':
777 variablesToBeSaved = variablesToBeSaved + \
778 ['extraInfo(isRightTrack(' + category + '))',
779 'hasHighestProbInCat(' + particleList + ', isRightTrack(' + category + '))']
780 ntuple.param('variables', variablesToBeSaved)
781 ntuple.param('particleList', particleList)
782 eventLevelpath.add_module(ntuple)
783
784 if ReadyMethods != len(categories):
785 return False
786 else:
787 return True
788
789
790def eventLevelTeacher(weightFiles='B2JpsiKs_mu', categories=None, usePIDNN=False):
791 """
792 Trains all categories at event level.
793 """
794
795 B2INFO('EVENT LEVEL TEACHER')
796
797 ReadyMethods = 0
798
799 if categories is None:
800 categories = []
801
802 for category in categories:
803 methodPrefixEventLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'EventLevel' + category + 'FBDT'
804 targetVariable = 'isRightCategory(' + category + ')'
805 weightFile = filesDirectory + '/' + methodPrefixEventLevel + "_1.root"
806
807 if os.path.isfile(weightFile):
808 ReadyMethods += 1
809 continue
810
811 sampledFilesList = glob.glob(filesDirectory + '/' + methodPrefixEventLevel + 'sampled*.root')
812 if len(sampledFilesList) == 0:
813 B2INFO('flavorTagger: eventLevelTeacher did not find any ' + methodPrefixEventLevel +
814 ".root" + ' file. Please run the flavorTagger in "Sampler" mode afterwards.')
815
816 else:
817 B2INFO('flavorTagger: MVA Teacher training' + methodPrefixEventLevel + ' .')
818 trainingOptionsEventLevel = basf2_mva.GeneralOptions()
819 trainingOptionsEventLevel.m_datafiles = basf2_mva.vector(*sampledFilesList)
820 trainingOptionsEventLevel.m_treename = methodPrefixEventLevel + "_tree"
821 trainingOptionsEventLevel.m_identifier = weightFile
822 trainingOptionsEventLevel.m_variables = basf2_mva.vector(*getTrainingVariables(category, usePIDNN))
823 trainingOptionsEventLevel.m_target_variable = targetVariable
824 trainingOptionsEventLevel.m_max_events = maxEventsNumber
825
826 basf2_mva.teacher(trainingOptionsEventLevel, getFastBDTCategories())
827
828 if uploadFlag:
829 basf2_mva.upload(weightFile, methodPrefixEventLevel)
830
831 if ReadyMethods != len(categories):
832 return False
833 else:
834 return True
835
836
837def combinerLevel(mode='Expert', weightFiles='B2JpsiKs_mu', categories=None,
838 variablesCombinerLevel=None, categoriesCombinationCode=None, path=None):
839 """
840 Samples the input data or tests the combiner according to the selected categories.
841 """
842
843 B2INFO('COMBINER LEVEL')
844
845 if categories is None:
846 categories = []
847 if variablesCombinerLevel is None:
848 variablesCombinerLevel = []
849
850 B2INFO("Flavor Tagger: Required Combiner for Categories:")
851 for category in categories:
852 B2INFO(category)
853
854 B2INFO("Flavor Tagger: which corresponds to a weight file with categories combination code " + categoriesCombinationCode)
855
856 methodPrefixCombinerLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'Combiner' \
857 + categoriesCombinationCode
858
859 if mode == 'Sampler':
860
861 if os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + '_1.root') or \
862 os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + '_1.root'):
863 B2FATAL('flavorTagger: File' + methodPrefixCombinerLevel + 'FBDT' + "_1.root" + ' or ' + methodPrefixCombinerLevel +
864 'FANN' + '_1.root found. Please run the "Expert" mode or delete the file if a new sampling is desired.')
865
866 B2INFO('flavorTagger: Sampling Data on Combiner Level. File' +
867 methodPrefixCombinerLevel + ".root" + ' will be saved')
868
869 ntuple = basf2.register_module('VariablesToNtuple')
870 ntuple.param('fileName', filesDirectory + '/' + methodPrefixCombinerLevel + "sampled" + fileId + ".root")
871 ntuple.param('treeName', methodPrefixCombinerLevel + 'FBDT' + "_tree")
872 ntuple.param('variables', variablesCombinerLevel + ['qrCombined'])
873 ntuple.param('particleList', "")
874 path.add_module(ntuple)
875
876 if mode == 'Expert':
877
878 # Check if weight files are ready
879 if TMVAfbdt:
880 identifierFBDT = methodPrefixCombinerLevel + 'FBDT'
881 if downloadFlag or useOnlyLocalFlag:
882 identifierFBDT = filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + '_1.root'
883
884 if downloadFlag:
885 if not os.path.isfile(identifierFBDT):
886 basf2_mva.download(methodPrefixCombinerLevel + 'FBDT', identifierFBDT)
887 if not os.path.isfile(identifierFBDT):
888 B2FATAL('Flavor Tagger: Weight file ' + identifierFBDT +
889 ' was not downloaded from Database. Please check the buildOrRevision name. Stopped')
890
891 if useOnlyLocalFlag:
892 if not os.path.isfile(identifierFBDT):
893 B2FATAL('flavorTagger: Combinerlevel FastBDT was not trained with this combination of categories.' +
894 ' Weight file ' + identifierFBDT + ' not found. Stopped')
895
896 B2INFO('flavorTagger: Ready to be used with weightFile ' + methodPrefixCombinerLevel + 'FBDT' + '_1.root')
897
898 if FANNmlp:
899 identifierFANN = methodPrefixCombinerLevel + 'FANN'
900 if downloadFlag or useOnlyLocalFlag:
901 identifierFANN = filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + '_1.root'
902
903 if downloadFlag:
904 if not os.path.isfile(identifierFANN):
905 basf2_mva.download(methodPrefixCombinerLevel + 'FANN', identifierFANN)
906 if not os.path.isfile(identifierFANN):
907 B2FATAL('Flavor Tagger: Weight file ' + identifierFANN +
908 ' was not downloaded from Database. Please check the buildOrRevision name. Stopped')
909 if useOnlyLocalFlag:
910 if not os.path.isfile(identifierFANN):
911 B2FATAL('flavorTagger: Combinerlevel FANNMLP was not trained with this combination of categories. ' +
912 ' Weight file ' + identifierFANN + ' not found. Stopped')
913
914 B2INFO('flavorTagger: Ready to be used with weightFile ' + methodPrefixCombinerLevel + 'FANN' + '_1.root')
915
916 # At this stage, all necessary weight files should be ready.
917 # Call MVAExpert or MVAMultipleExperts module.
918 if TMVAfbdt and not FANNmlp:
919 B2INFO('flavorTagger: Apply FBDTMethod ' + methodPrefixCombinerLevel + 'FBDT')
920 path.add_module('MVAExpert', listNames=[], extraInfoName='qrCombined' + 'FBDT', signalFraction=signalFraction,
921 identifier=identifierFBDT)
922
923 if FANNmlp and not TMVAfbdt:
924 B2INFO('flavorTagger: Apply FANNMethod on combiner level')
925 path.add_module('MVAExpert', listNames=[], extraInfoName='qrCombined' + 'FANN', signalFraction=signalFraction,
926 identifier=identifierFANN)
927
928 if FANNmlp and TMVAfbdt:
929 B2INFO('flavorTagger: Apply FANNMethod and FBDTMethod on combiner level')
930 mvaMultipleExperts = basf2.register_module('MVAMultipleExperts')
931 mvaMultipleExperts.set_name('MVAMultipleExperts_Combiners')
932 mvaMultipleExperts.param('listNames', [])
933 mvaMultipleExperts.param('extraInfoNames', ['qrCombined' + 'FBDT', 'qrCombined' + 'FANN'])
934 mvaMultipleExperts.param('signalFraction', signalFraction)
935 mvaMultipleExperts.param('identifiers', [identifierFBDT, identifierFANN])
936 path.add_module(mvaMultipleExperts)
937
938
939def combinerLevelTeacher(weightFiles='B2JpsiKs_mu', variablesCombinerLevel=None,
940 categoriesCombinationCode=None):
941 """
942 Trains the combiner according to the selected categories.
943 """
944
945 B2INFO('COMBINER LEVEL TEACHER')
946
947 if variablesCombinerLevel is None:
948 variablesCombinerLevel = []
949
950 methodPrefixCombinerLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'Combiner' \
951 + categoriesCombinationCode
952
953 sampledFilesList = glob.glob(filesDirectory + '/' + methodPrefixCombinerLevel + 'sampled*.root')
954 if len(sampledFilesList) == 0:
955 B2FATAL('FlavorTagger: combinerLevelTeacher did not find any ' +
956 methodPrefixCombinerLevel + 'sampled*.root file. Please run the flavorTagger in "Sampler" mode.')
957
958 if TMVAfbdt:
959
960 if not os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + '_1.root'):
961
962 B2INFO('flavorTagger: MVA Teacher training a FastBDT on Combiner Level')
963
964 trainingOptionsCombinerLevel = basf2_mva.GeneralOptions()
965 trainingOptionsCombinerLevel.m_datafiles = basf2_mva.vector(*sampledFilesList)
966 trainingOptionsCombinerLevel.m_treename = methodPrefixCombinerLevel + 'FBDT' + "_tree"
967 trainingOptionsCombinerLevel.m_identifier = filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + "_1.root"
968 trainingOptionsCombinerLevel.m_variables = basf2_mva.vector(*variablesCombinerLevel)
969 trainingOptionsCombinerLevel.m_target_variable = 'qrCombined'
970 trainingOptionsCombinerLevel.m_max_events = maxEventsNumber
971
972 basf2_mva.teacher(trainingOptionsCombinerLevel, getFastBDTCombiner())
973
974 if uploadFlag:
975 basf2_mva.upload(filesDirectory + '/' + methodPrefixCombinerLevel +
976 'FBDT' + "_1.root", methodPrefixCombinerLevel + 'FBDT')
977
978 elif FANNmlp and not os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + '_1.root'):
979
980 B2INFO('flavorTagger: Combinerlevel FBDT was already trained with this combination of categories. Weight file ' +
981 methodPrefixCombinerLevel + 'FBDT' + '_1.root has been found.')
982
983 else:
984 B2FATAL('flavorTagger: Combinerlevel was already trained with this combination of categories. Weight files ' +
985 methodPrefixCombinerLevel + 'FBDT' + '_1.root and ' +
986 methodPrefixCombinerLevel + 'FANN' + '_1.root has been found. Please use the "Expert" mode')
987
988 if FANNmlp:
989
990 if not os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + '_1.root'):
991
992 B2INFO('flavorTagger: MVA Teacher training a FANN MLP on Combiner Level')
993
994 trainingOptionsCombinerLevel = basf2_mva.GeneralOptions()
995 trainingOptionsCombinerLevel.m_datafiles = basf2_mva.vector(*sampledFilesList)
996 trainingOptionsCombinerLevel.m_treename = methodPrefixCombinerLevel + 'FBDT' + "_tree"
997 trainingOptionsCombinerLevel.m_identifier = filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + "_1.root"
998 trainingOptionsCombinerLevel.m_variables = basf2_mva.vector(*variablesCombinerLevel)
999 trainingOptionsCombinerLevel.m_target_variable = 'qrCombined'
1000 trainingOptionsCombinerLevel.m_max_events = maxEventsNumber
1001
1002 basf2_mva.teacher(trainingOptionsCombinerLevel, getMlpFANNCombiner())
1003
1004 if uploadFlag:
1005 basf2_mva.upload(filesDirectory + '/' + methodPrefixCombinerLevel +
1006 'FANN' + "_1.root", methodPrefixCombinerLevel + 'FANN')
1007
1008 elif TMVAfbdt and not os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + '_1.root'):
1009
1010 B2INFO('flavorTagger: Combinerlevel FBDT was already trained with this combination of categories. Weight file ' +
1011 methodPrefixCombinerLevel + 'FANN' + '_1.config has been found.')
1012
1013 else:
1014 B2FATAL('flavorTagger: Combinerlevel was already trained with this combination of categories. Weight files ' +
1015 methodPrefixCombinerLevel + 'FBDT' + '_1.root and ' +
1016 methodPrefixCombinerLevel + 'FANN' + '_1.root has been found. Please use the "Expert" mode')
1017
1018
1019def getEventLevelParticleLists(categories=None):
1020
1021 if categories is None:
1022 categories = []
1023
1024 eventLevelParticleLists = []
1025
1026 for category in categories:
1027 ftCategory = AvailableCategories[category]
1028 event_tuple = (ftCategory.particleList, ftCategory.eventName, ftCategory.variableName)
1029
1030 if event_tuple not in eventLevelParticleLists:
1031 eventLevelParticleLists.append(event_tuple)
1032 else:
1033 B2FATAL('Flavor Tagger: ' + category + ' has been already given')
1034
1035 return eventLevelParticleLists
1036
1037
1038def flavorTagger(
1039 particleLists=None,
1040 mode='Expert',
1041 weightFiles='B2nunubarBGx1',
1042 workingDirectory='.',
1043 combinerMethods=['TMVA-FBDT'],
1044 categories=[
1045 'Electron',
1046 'IntermediateElectron',
1047 'Muon',
1048 'IntermediateMuon',
1049 'KinLepton',
1050 'IntermediateKinLepton',
1051 'Kaon',
1052 'SlowPion',
1053 'FastHadron',
1054 'Lambda',
1055 'FSC',
1056 'MaximumPstar',
1057 'KaonPion'],
1058 maskName='FTDefaultMask',
1059 saveCategoriesInfo=True,
1060 useOnlyLocalWeightFiles=False,
1061 downloadFromDatabaseIfNotFound=False,
1062 uploadToDatabaseAfterTraining=False,
1063 samplerFileId='',
1064 prefix='MC16rd_light-2501-betelgeuse',
1065 useGNN=True,
1066 identifierGNN='GFlaT_MC16rd_light-2501-betelgeuse_tensorflow',
1067 usePIDNN=False,
1068 path=None,
1069):
1070 """
1071 Defines the whole flavor tagging process for each selected Rest of Event (ROE) built in the steering file.
1072 The flavor is predicted by Multivariate Methods trained with Variables and MetaVariables which use
1073 Tracks, ECL- and KLMClusters from the corresponding RestOfEvent dataobject.
1074 This module can be used to sample the training information, to train and/or to test the flavorTagger.
1075
1076 @param particleLists The ROEs for flavor tagging are selected from the given particle lists.
1077 @param mode The available modes are
1078 ``Expert`` (default), ``Sampler``, and ``Teacher``. In the ``Expert`` mode
1079 Flavor Tagging is applied to the analysis. In the ``Sampler`` mode you save
1080 the variables for training. In the ``Teacher`` mode the FlavorTagger is
1081 trained, for this step you do not reconstruct any particle or do any analysis,
1082 you just run the flavorTagger alone.
1083 @param weightFiles Weight files name. Default=
1084 ``B2nunubarBGx1`` (official weight files). If the user wants to train the
1085 FlavorTagger themselves, the weightfiles name should correspond to the
1086 analyzed CP channel in order to avoid confusions. The default name
1087 ``B2nunubarBGx1`` corresponds to
1088 :math:`B^0_{\\rm sig}\\to \\nu \\overline{\\nu}`.
1089 and ``B2JpsiKs_muBGx1`` to
1090 :math:`B^0_{\\rm sig}\\to J/\\psi (\\to \\mu^+ \\mu^-) K_s (\\to \\pi^+ \\pi^-)`.
1091 BGx1 stands for events simulated with background.
1092 @param workingDirectory Path to the directory containing the FlavorTagging/ folder.
1093 @param combinerMethods MVAs for the combiner: ``TMVA-FBDT` (default).
1094 ``FANN-MLP`` is available only with ``prefix=''`` (MC13 weight files).
1095 @param categories Categories used for flavor tagging. By default all are used.
1096 @param maskName Gets ROE particles from a specified ROE mask.
1097 ``FTDefaultMask`` (default): tentative mask definition that will be created
1098 automatically. The definition is as follows:
1099
1100 - Track (pion): thetaInCDCAcceptance and dr<1 and abs(dz)<3
1101 - ECL-cluster (gamma): thetaInCDCAcceptance and clusterNHits>1.5 and \
1102 [[clusterReg==1 and E>0.08] or [clusterReg==2 and E>0.03] or \
1103 [clusterReg==3 and E>0.06]] \
1104 (Same as gamma:pi0eff30_May2020 and gamma:pi0eff40_May2020)
1105
1106 ``all``: all ROE particles are used.
1107 Or one can give any mask name defined before calling this function.
1108 @param saveCategoriesInfo Sets to save information of individual categories.
1109 @param useOnlyLocalWeightFiles [Expert] Uses only locally saved weight files.
1110 @param downloadFromDatabaseIfNotFound [Expert] Weight files are downloaded from
1111 the conditions database if not available in workingDirectory.
1112 @param uploadToDatabaseAfterTraining [Expert] For librarians only: uploads weight files to localdb after training.
1113 @param samplerFileId Identifier to parallelize
1114 sampling. Only used in ``Sampler`` mode. If you are training by yourself and
1115 want to parallelize the sampling, you can run several sampling scripts in
1116 parallel. By changing this parameter you will not overwrite an older sample.
1117 @param prefix Prefix of weight files.
1118 ``MC16rd_light-2501-betelgeuse`` (default): Weight files trained for MC16rd samples.
1119 ``MC15ri_light-2207-bengal_0``: Weight files trained for MC15ri samples.
1120 ``''``: Weight files trained for MC13 samples.
1121 @param useGNN Use GNN-based Flavor Tagger in addition with FastBDT-based one.
1122 Please specify the weight file with the option ``identifierGNN``.
1123 [Expert] In the sampler mode,
1124 training files for GNN-based Flavor Tagger are produced.
1125 @param identifierGNN The name of weight file of the GNN-based Flavor Tagger.
1126 [Expert] Multiple identifiers can be given with list(str).
1127 @param usePIDNN If True, PID probabilities calculated from PID neural network are used
1128 (default is False). Prefix and identifierGNN must be set accordingly.
1129 @param path Modules are added to this path
1130
1131 """
1132
1133 if (not isinstance(particleLists, list)):
1134 particleLists = [particleLists] # in case user inputs a particle list as string
1135
1136 if mode != 'Sampler' and mode != 'Teacher' and mode != 'Expert':
1137 B2FATAL('flavorTagger: Wrong mode given: The available modes are "Sampler", "Teacher" or "Expert"')
1138
1139 if len(categories) != len(set(categories)):
1140 dup = [cat for cat in set(categories) if categories.count(cat) > 1]
1141 B2WARNING('Flavor Tagger: There are duplicate elements in the given categories list. '
1142 << 'The following duplicate elements are removed; ' << ', '.join(dup))
1143 categories = list(set(categories))
1144
1145 if len(categories) < 2:
1146 B2FATAL('Flavor Tagger: Invalid amount of categories. At least two are needed.')
1147 B2FATAL(
1148 'Flavor Tagger: Possible categories are "Electron", "IntermediateElectron", "Muon", "IntermediateMuon", '
1149 '"KinLepton", "IntermediateKinLepton", "Kaon", "SlowPion", "FastHadron",'
1150 '"Lambda", "FSC", "MaximumPstar" or "KaonPion" ')
1151
1152 for category in categories:
1153 if category not in AvailableCategories:
1154 B2FATAL('Flavor Tagger: ' + category + ' is not a valid category name given')
1155 B2FATAL('Flavor Tagger: Available categories are "Electron", "IntermediateElectron", '
1156 '"Muon", "IntermediateMuon", "KinLepton", "IntermediateKinLepton", "Kaon", "SlowPion", "FastHadron", '
1157 '"Lambda", "FSC", "MaximumPstar" or "KaonPion" ')
1158
1159 if mode == 'Expert' and useGNN and identifierGNN == '':
1160 B2FATAL('Please specify the name of the weight file with ``identifierGNN``')
1161
1162 # Directory where the weights of the trained Methods are saved
1163 # workingDirectory = os.environ['BELLE2_LOCAL_DIR'] + '/analysis/data'
1164
1165 basf2.find_file(workingDirectory)
1166
1167 global filesDirectory
1168 filesDirectory = workingDirectory + '/FlavorTagging/TrainedMethods'
1169
1170 if mode == 'Sampler' or (mode == 'Expert' and downloadFromDatabaseIfNotFound):
1171 if not basf2.find_file(workingDirectory + '/FlavorTagging', silent=True):
1172 os.mkdir(workingDirectory + '/FlavorTagging')
1173 os.mkdir(workingDirectory + '/FlavorTagging/TrainedMethods')
1174 elif not basf2.find_file(workingDirectory + '/FlavorTagging/TrainedMethods', silent=True):
1175 os.mkdir(workingDirectory + '/FlavorTagging/TrainedMethods')
1176 filesDirectory = workingDirectory + '/FlavorTagging/TrainedMethods'
1177
1178 if len(combinerMethods) < 1 or len(combinerMethods) > 2:
1179 B2FATAL('flavorTagger: Invalid list of combinerMethods. The available methods are "TMVA-FBDT" and "FANN-MLP"')
1180
1181 global FANNmlp
1182 global TMVAfbdt
1183
1184 FANNmlp = False
1185 TMVAfbdt = False
1186
1187 for method in combinerMethods:
1188 if method == 'TMVA-FBDT':
1189 TMVAfbdt = True
1190 elif method == 'FANN-MLP':
1191 FANNmlp = True
1192 else:
1193 B2FATAL('flavorTagger: Invalid list of combinerMethods. The available methods are "TMVA-FBDT" and "FANN-MLP"')
1194
1195 global fileId
1196 fileId = samplerFileId
1197
1198 global useOnlyLocalFlag
1199 useOnlyLocalFlag = useOnlyLocalWeightFiles
1200
1201 B2INFO('*** FLAVOR TAGGING ***')
1202 B2INFO(' ')
1203 B2INFO(' Working directory is: ' + filesDirectory)
1204 B2INFO(' ')
1205
1206 setInteractionWithDatabase(downloadFromDatabaseIfNotFound, uploadToDatabaseAfterTraining)
1207
1208 if prefix == '':
1209 set_FlavorTagger_pid_aliases_legacy()
1210 else:
1211 set_FlavorTagger_pid_aliases()
1212
1213 alias_list_for_GNN = []
1214 if useGNN:
1215 alias_list_for_GNN = set_GNNFlavorTagger_aliases(categories, usePIDNN)
1216
1217 setInputVariablesWithMask()
1218 if prefix != '':
1219 weightFiles = prefix + '_' + weightFiles
1220
1221 # Create configuration lists and code-name for given category's list
1222 trackLevelParticleLists = []
1223 eventLevelParticleLists = []
1224 variablesCombinerLevel = []
1225 categoriesCombination = []
1226 categoriesCombinationCode = 'CatCode'
1227 for category in categories:
1228 ftCategory = AvailableCategories[category]
1229
1230 track_tuple = (ftCategory.particleList, ftCategory.trackName)
1231 event_tuple = (ftCategory.particleList, ftCategory.eventName, ftCategory.variableName)
1232
1233 if track_tuple not in trackLevelParticleLists and category != 'MaximumPstar':
1234 trackLevelParticleLists.append(track_tuple)
1235
1236 if event_tuple not in eventLevelParticleLists:
1237 eventLevelParticleLists.append(event_tuple)
1238 variablesCombinerLevel.append(ftCategory.variableName)
1239 categoriesCombination.append(ftCategory.code)
1240 else:
1241 B2FATAL('Flavor Tagger: ' + category + ' has been already given')
1242
1243 for code in sorted(categoriesCombination):
1244 categoriesCombinationCode = categoriesCombinationCode + f'{int(code):02}'
1245
1246 # Create default ROE-mask
1247 if maskName == 'FTDefaultMask':
1248 FTDefaultMask = (
1249 'FTDefaultMask',
1250 'thetaInCDCAcceptance and dr<1 and abs(dz)<3',
1251 'thetaInCDCAcceptance and clusterNHits>1.5 and [[E>0.08 and clusterReg==1] or [E>0.03 and clusterReg==2] or \
1252 [E>0.06 and clusterReg==3]]')
1253 for name in particleLists:
1254 ma.appendROEMasks(list_name=name, mask_tuples=[FTDefaultMask], path=path)
1255
1256 # Start ROE-routine
1257 roe_path = basf2.create_path()
1258 deadEndPath = basf2.create_path()
1259
1260 if mode == 'Sampler':
1261 # Events containing ROE without B-Meson (but not empty) are discarded for training
1262 ma.signalSideParticleListsFilter(
1263 particleLists,
1264 'nROE_Charged(' + maskName + ', 0) > 0 and abs(qrCombined) == 1',
1265 roe_path,
1266 deadEndPath)
1267
1268 FillParticleLists(maskName, categories, roe_path)
1269
1270 if useGNN:
1271 if eventLevel('Expert', weightFiles, categories, usePIDNN, roe_path):
1272
1273 ma.rankByHighest('pi+:inRoe', 'p', numBest=0, allowMultiRank=False,
1274 outputVariable='FT_p_rank', overwriteRank=True, path=roe_path)
1275 ma.fillParticleListFromDummy('vpho:dummy', path=roe_path)
1276 ma.variablesToNtuple('vpho:dummy',
1277 alias_list_for_GNN,
1278 treename='tree',
1279 filename=f'{filesDirectory}/FlavorTagger_GNN_sampled{fileId}.root',
1280 signalSideParticleList=particleLists[0],
1281 path=roe_path)
1282
1283 else:
1284 if eventLevel(mode, weightFiles, categories, usePIDNN, roe_path):
1285 combinerLevel(mode, weightFiles, categories, variablesCombinerLevel, categoriesCombinationCode, roe_path)
1286
1287 path.for_each('RestOfEvent', 'RestOfEvents', roe_path)
1288
1289 elif mode == 'Expert':
1290 # If trigger returns 1 jump into empty path skipping further modules in roe_path
1291 # run filter with no cut first to get rid of ROEs that are missing the mask of the signal particle
1292 ma.signalSideParticleListsFilter(particleLists, 'nROE_Charged(' + maskName + ', 0) > 0', roe_path, deadEndPath)
1293
1294 # Initialization of flavorTaggerInfo dataObject needs to be done in the main path
1295 flavorTaggerInfoBuilder = basf2.register_module('FlavorTaggerInfoBuilder')
1296 path.add_module(flavorTaggerInfoBuilder)
1297
1298 FillParticleLists(maskName, categories, roe_path)
1299
1300 if eventLevel(mode, weightFiles, categories, usePIDNN, roe_path):
1301 combinerLevel(mode, weightFiles, categories, variablesCombinerLevel, categoriesCombinationCode, roe_path)
1302
1303 flavorTaggerInfoFiller = basf2.register_module('FlavorTaggerInfoFiller')
1304 flavorTaggerInfoFiller.param('trackLevelParticleLists', trackLevelParticleLists)
1305 flavorTaggerInfoFiller.param('eventLevelParticleLists', eventLevelParticleLists)
1306 flavorTaggerInfoFiller.param('TMVAfbdt', TMVAfbdt)
1307 flavorTaggerInfoFiller.param('FANNmlp', FANNmlp)
1308 flavorTaggerInfoFiller.param('qpCategories', saveCategoriesInfo)
1309 flavorTaggerInfoFiller.param('istrueCategories', saveCategoriesInfo)
1310 flavorTaggerInfoFiller.param('targetProb', False)
1311 flavorTaggerInfoFiller.param('trackPointers', False)
1312 roe_path.add_module(flavorTaggerInfoFiller) # Add FlavorTag Info filler to roe_path
1313 add_default_FlavorTagger_aliases()
1314
1315 if useGNN:
1316 ma.rankByHighest('pi+:inRoe', 'p', numBest=0, allowMultiRank=False,
1317 outputVariable='FT_p_rank', overwriteRank=True, path=roe_path)
1318 ma.fillParticleListFromDummy('vpho:dummy', path=roe_path)
1319
1320 if isinstance(identifierGNN, str):
1321 roe_path.add_module('MVAExpert',
1322 listNames='vpho:dummy',
1323 extraInfoName='qrGNN_raw', # the range of qrGNN_raw is [0,1]
1324 identifier=identifierGNN)
1325
1326 ma.variableToSignalSideExtraInfo('vpho:dummy', {'extraInfo(qrGNN_raw)*2-1': 'qrGNN'},
1327 path=roe_path)
1328
1329 elif isinstance(identifierGNN, list):
1330 identifierGNN = list(set(identifierGNN))
1331
1332 extraInfoNames = [f'qrGNN_{i_id}' for i_id in identifierGNN]
1333 roe_path.add_module('MVAMultipleExperts',
1334 listNames='vpho:dummy',
1335 extraInfoNames=extraInfoNames,
1336 identifiers=identifierGNN)
1337
1338 extraInfoDict = {}
1339 for extraInfoName in extraInfoNames:
1340 extraInfoDict[f'extraInfo({extraInfoName})*2-1'] = extraInfoName
1341 variables.variables.addAlias(extraInfoName, f'extraInfo({extraInfoName})')
1342
1343 ma.variableToSignalSideExtraInfo('vpho:dummy', extraInfoDict,
1344 path=roe_path)
1345
1346 path.for_each('RestOfEvent', 'RestOfEvents', roe_path)
1347
1348 elif mode == 'Teacher':
1349 if eventLevelTeacher(weightFiles, categories, usePIDNN):
1350 combinerLevelTeacher(weightFiles, variablesCombinerLevel, categoriesCombinationCode)
1351
1352
1353if __name__ == '__main__':
1354 from basf2.utils import pretty_print_module
1355 pretty_print_module(__name__, "flavorTagger")
isB2BII()
Definition b2bii.py:14