Belle II Software  release-08-01-10
flavorTagger.py
1 #!/usr/bin/env python3
2 
3 
10 
11 # ************* Flavor Tagging ************
12 # * This script is needed to train *
13 # * and to test the flavor tagger. *
14 # ********************************************
15 
16 from basf2 import B2INFO, B2FATAL, B2WARNING
17 import basf2
18 import basf2_mva
19 import inspect
20 import modularAnalysis as ma
21 import variables
22 from variables import utils
23 import os
24 import glob
25 import b2bii
26 import collections
27 
28 
29 def getBelleOrBelle2():
30  """
31  Gets the global ModeCode.
32  """
33  if b2bii.isB2BII():
34  return 'Belle'
35  else:
36  return 'Belle2'
37 
38 
39 def setInteractionWithDatabase(downloadFromDatabaseIfNotFound=False, uploadToDatabaseAfterTraining=False):
40  """
41  Sets the interaction with the database: download trained weight files or upload weight files after training.
42  """
43 
44  global downloadFlag
45  global uploadFlag
46 
47  downloadFlag = downloadFromDatabaseIfNotFound
48  uploadFlag = uploadToDatabaseAfterTraining
49 
50 
51 # Default list of aliases that should be used to save the flavor tagging information using VariablesToNtuple
52 flavor_tagging = ['FBDT_qrCombined', 'FANN_qrCombined', 'qrMC', 'mcFlavorOfOtherB', 'qrGNN',
53  'qpElectron', 'hasTrueTargetElectron', 'isRightCategoryElectron',
54  'qpIntermediateElectron', 'hasTrueTargetIntermediateElectron', 'isRightCategoryIntermediateElectron',
55  'qpMuon', 'hasTrueTargetMuon', 'isRightCategoryMuon',
56  'qpIntermediateMuon', 'hasTrueTargetIntermediateMuon', 'isRightCategoryIntermediateMuon',
57  'qpKinLepton', 'hasTrueTargetKinLepton', 'isRightCategoryKinLepton',
58  'qpIntermediateKinLepton', 'hasTrueTargetIntermediateKinLepton', 'isRightCategoryIntermediateKinLepton',
59  'qpKaon', 'hasTrueTargetKaon', 'isRightCategoryKaon',
60  'qpSlowPion', 'hasTrueTargetSlowPion', 'isRightCategorySlowPion',
61  'qpFastHadron', 'hasTrueTargetFastHadron', 'isRightCategoryFastHadron',
62  'qpLambda', 'hasTrueTargetLambda', 'isRightCategoryLambda',
63  'qpFSC', 'hasTrueTargetFSC', 'isRightCategoryFSC',
64  'qpMaximumPstar', 'hasTrueTargetMaximumPstar', 'isRightCategoryMaximumPstar',
65  'qpKaonPion', 'hasTrueTargetKaonPion', 'isRightCategoryKaonPion']
66 
67 
68 def add_default_FlavorTagger_aliases():
69  """
70  This function adds the default aliases for flavor tagging variables
71  and defines the collection of flavor tagging variables.
72  """
73 
74  variables.variables.addAlias('FBDT_qrCombined', 'qrOutput(FBDT)')
75  variables.variables.addAlias('FANN_qrCombined', 'qrOutput(FANN)')
76  variables.variables.addAlias('qrMC', 'isRelatedRestOfEventB0Flavor')
77 
78  variables.variables.addAlias('qrGNN', 'extraInfo(qrGNN)')
79 
80  for iCategory in AvailableCategories:
81  aliasForQp = 'qp' + iCategory
82  aliasForTrueTarget = 'hasTrueTarget' + iCategory
83  aliasForIsRightCategory = 'isRightCategory' + iCategory
84  variables.variables.addAlias(aliasForQp, 'qpCategory(' + iCategory + ')')
85  variables.variables.addAlias(aliasForTrueTarget, 'hasTrueTargets(' + iCategory + ')')
86  variables.variables.addAlias(aliasForIsRightCategory, 'isTrueFTCategory(' + iCategory + ')')
87 
88  utils.add_collection(flavor_tagging, 'flavor_tagging')
89 
90 
91 def set_FlavorTagger_pid_aliases():
92  """
93  This function adds the pid aliases needed by the flavor tagger.
94  """
95  variables.variables.addAlias('eid_TOP', 'pidPairProbabilityExpert(11, 211, TOP)')
96  variables.variables.addAlias('eid_ARICH', 'pidPairProbabilityExpert(11, 211, ARICH)')
97  variables.variables.addAlias('eid_ECL', 'pidPairProbabilityExpert(11, 211, ECL)')
98 
99  variables.variables.addAlias('muid_TOP', 'pidPairProbabilityExpert(13, 211, TOP)')
100  variables.variables.addAlias('muid_ARICH', 'pidPairProbabilityExpert(13, 211, ARICH)')
101  variables.variables.addAlias('muid_KLM', 'pidPairProbabilityExpert(13, 211, KLM)')
102 
103  variables.variables.addAlias('piid_TOP', 'pidPairProbabilityExpert(211, 321, TOP)')
104  variables.variables.addAlias('piid_ARICH', 'pidPairProbabilityExpert(211, 321, ARICH)')
105 
106  variables.variables.addAlias('Kid_TOP', 'pidPairProbabilityExpert(321, 211, TOP)')
107  variables.variables.addAlias('Kid_ARICH', 'pidPairProbabilityExpert(321, 211, ARICH)')
108 
109  if getBelleOrBelle2() == "Belle":
110  variables.variables.addAlias('eid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, CDC, SVD), 0.5)')
111  variables.variables.addAlias('muid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, CDC, SVD), 0.5)')
112  variables.variables.addAlias('piid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, CDC, SVD), 0.5)')
113  variables.variables.addAlias('pi_vs_edEdxid', 'ifNANgiveX(pidPairProbabilityExpert(211, 11, CDC, SVD), 0.5)')
114  variables.variables.addAlias('Kid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, CDC, SVD), 0.5)')
115  else:
116  variables.variables.addAlias('eid_dEdx', 'pidPairProbabilityExpert(11, 211, CDC)')
117  variables.variables.addAlias('muid_dEdx', 'pidPairProbabilityExpert(13, 211, CDC)')
118  variables.variables.addAlias('piid_dEdx', 'pidPairProbabilityExpert(211, 321, CDC)')
119  variables.variables.addAlias('pi_vs_edEdxid', 'pidPairProbabilityExpert(211, 11, CDC)')
120  variables.variables.addAlias('Kid_dEdx', 'pidPairProbabilityExpert(321, 211, CDC)')
121 
122 
123 def set_FlavorTagger_pid_aliases_legacy():
124  """
125  This function adds the pid aliases needed by the flavor tagger trained for MC13.
126  """
127  variables.variables.addAlias('eid_TOP', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, TOP), 0.5)')
128  variables.variables.addAlias('eid_ARICH', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, ARICH), 0.5)')
129  variables.variables.addAlias('eid_ECL', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, ECL), 0.5)')
130 
131  variables.variables.addAlias('muid_TOP', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, TOP), 0.5)')
132  variables.variables.addAlias('muid_ARICH', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, ARICH), 0.5)')
133  variables.variables.addAlias('muid_KLM', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, KLM), 0.5)')
134 
135  variables.variables.addAlias('piid_TOP', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, TOP), 0.5)')
136  variables.variables.addAlias('piid_ARICH', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, ARICH), 0.5)')
137 
138  variables.variables.addAlias('Kid_TOP', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, TOP), 0.5)')
139  variables.variables.addAlias('Kid_ARICH', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, ARICH), 0.5)')
140 
141  if getBelleOrBelle2() == "Belle":
142  variables.variables.addAlias('eid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, CDC, SVD), 0.5)')
143  variables.variables.addAlias('muid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, CDC, SVD), 0.5)')
144  variables.variables.addAlias('piid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, CDC, SVD), 0.5)')
145  variables.variables.addAlias('pi_vs_edEdxid', 'ifNANgiveX(pidPairProbabilityExpert(211, 11, CDC, SVD), 0.5)')
146  variables.variables.addAlias('Kid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, CDC, SVD), 0.5)')
147  else:
148  # Removed SVD PID for Belle II MC and data as it is absent in release 4.
149  variables.variables.addAlias('eid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(11, 211, CDC), 0.5)')
150  variables.variables.addAlias('muid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(13, 211, CDC), 0.5)')
151  variables.variables.addAlias('piid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(211, 321, CDC), 0.5)')
152  variables.variables.addAlias('pi_vs_edEdxid', 'ifNANgiveX(pidPairProbabilityExpert(211, 11, CDC), 0.5)')
153  variables.variables.addAlias('Kid_dEdx', 'ifNANgiveX(pidPairProbabilityExpert(321, 211, CDC), 0.5)')
154 
155 
156 def set_GNNFlavorTagger_aliases(categories):
157  """
158  This function adds aliases for the GNN-based flavor tagger.
159  """
160 
161  # will be used for target variable 0:B0bar, 1:B0
162  variables.variables.addAlias('qrCombined_bit', '(qrCombined+1)/2')
163  alias_list = ['qrCombined_bit']
164 
165  var_dict = {
166  # position
167  'dx': 'dx',
168  'dy': 'dy',
169  'dz': 'dz',
170  # mask
171  'E': 'E',
172  # charge,
173  'charge': 'charge',
174  # feature
175  'px_c': 'px*charge',
176  'py_c': 'py*charge',
177  'pz_c': 'pz*charge',
178  'electronID_c': 'electronID*charge',
179  'muonID_c': 'muonID*charge',
180  'pionID_c': 'pionID*charge',
181  'kaonID_c': 'kaonID*charge',
182  'protonID_c': 'protonID*charge',
183  'deuteronID_c': 'deuteronID*charge',
184  'electronID_noSVD_noTOP_c': 'electronID_noSVD_noTOP*charge',
185  }
186 
187  # 16 charged particles are used at most
188  for rank in range(1, 17):
189 
190  for cat in categories:
191  listName = AvailableCategories[cat].particleList
192  varName = f'QpTrack({listName}, isRightCategory({cat}), isRightCategory({cat}))'
193 
194  varWithRank = f'ifNANgiveX(getVariableByRank(pi+:inRoe, FT_p, {varName}, {rank}), 0)'
195  aliasWithRank = f'{cat}_rank{rank}'
196 
197  variables.variables.addAlias(aliasWithRank, varWithRank)
198  alias_list.append(aliasWithRank)
199 
200  for alias, var in var_dict.items():
201  varWithRank = f'ifNANgiveX(getVariableByRank(pi+:inRoe, FT_p, {var}, {rank}), 0)'
202  aliasWithRank = f'{alias}_rank{rank}'
203 
204  variables.variables.addAlias(aliasWithRank, varWithRank)
205  alias_list.append(aliasWithRank)
206 
207  return alias_list
208 
209 
210 def setInputVariablesWithMask(maskName='all'):
211  """
212  Set aliases for input variables with ROE mask.
213  """
214  variables.variables.addAlias('pMissTag_withMask', 'pMissTag('+maskName+')')
215  variables.variables.addAlias('cosTPTO_withMask', 'cosTPTO('+maskName+')')
216  variables.variables.addAlias('ptTracksRoe_withMask', 'ptTracksRoe('+maskName+')')
217  variables.variables.addAlias('pt2TracksRoe_withMask', 'pt2TracksRoe('+maskName+')')
218  variables.variables.addAlias('ptTracksRoe_withMask', 'ptTracksRoe('+maskName+')')
219 
220 
221 def getFastBDTCategories():
222  '''
223  Helper function for getting the FastBDT categories.
224  It's necessary for removing top-level ROOT imports.
225  '''
226  fastBDTCategories = basf2_mva.FastBDTOptions()
227  fastBDTCategories.m_nTrees = 500
228  fastBDTCategories.m_nCuts = 8
229  fastBDTCategories.m_nLevels = 3
230  fastBDTCategories.m_shrinkage = 0.10
231  fastBDTCategories.m_randRatio = 0.5
232  return fastBDTCategories
233 
234 
235 def getFastBDTCombiner():
236  '''
237  Helper function for getting the FastBDT combiner.
238  It's necessary for removing top-level ROOT imports.
239  '''
240  fastBDTCombiner = basf2_mva.FastBDTOptions()
241  fastBDTCombiner.m_nTrees = 500
242  fastBDTCombiner.m_nCuts = 8
243  fastBDTCombiner.m_nLevels = 3
244  fastBDTCombiner.m_shrinkage = 0.10
245  fastBDTCombiner.m_randRatio = 0.5
246  return fastBDTCombiner
247 
248 
249 def getMlpFANNCombiner():
250  '''
251  Helper function for getting the MLP FANN combiner.
252  It's necessary for removing top-level ROOT imports.
253  '''
254  mlpFANNCombiner = basf2_mva.FANNOptions()
255  mlpFANNCombiner.m_max_epochs = 10000
256  mlpFANNCombiner.m_hidden_layers_architecture = "3*N"
257  mlpFANNCombiner.m_hidden_activiation_function = "FANN_SIGMOID_SYMMETRIC"
258  mlpFANNCombiner.m_output_activiation_function = "FANN_SIGMOID_SYMMETRIC"
259  mlpFANNCombiner.m_error_function = "FANN_ERRORFUNC_LINEAR"
260  mlpFANNCombiner.m_training_method = "FANN_TRAIN_RPROP"
261  mlpFANNCombiner.m_validation_fraction = 0.5
262  mlpFANNCombiner.m_random_seeds = 10
263  mlpFANNCombiner.m_test_rate = 500
264  mlpFANNCombiner.m_number_of_threads = 8
265  mlpFANNCombiner.m_scale_features = True
266  mlpFANNCombiner.m_scale_target = False
267  # mlpFANNCombiner.m_scale_target = True
268  return mlpFANNCombiner
269 
270 
271 # SignalFraction: FBDT feature
272 # For smooth output set to -1, this will break the calibration.
273 # For correct calibration set to -2, leads to peaky combiner output.
274 signalFraction = -2
275 
276 # Maximal number of events to train each method
277 maxEventsNumber = 0 # 0 takes all the sampled events. The number in the past was 500000
278 
279 
280 FTCategoryParameters = collections.namedtuple('FTCategoryParameters',
281  ['particleList', 'trackName', 'eventName', 'variableName', 'code'])
282 # Definition of all available categories, 'standard category name':
283 # ['ParticleList', 'trackLevel category name', 'eventLevel category name',
284 # 'combinerLevel variable name', 'category code']
285 AvailableCategories = {
286  'Electron':
287  FTCategoryParameters('e+:inRoe', 'Electron', 'Electron',
288  'QpOf(e+:inRoe, isRightCategory(Electron), isRightCategory(Electron))',
289  0),
290  'IntermediateElectron':
291  FTCategoryParameters('e+:inRoe', 'IntermediateElectron', 'IntermediateElectron',
292  'QpOf(e+:inRoe, isRightCategory(IntermediateElectron), isRightCategory(IntermediateElectron))',
293  1),
294  'Muon':
295  FTCategoryParameters('mu+:inRoe', 'Muon', 'Muon',
296  'QpOf(mu+:inRoe, isRightCategory(Muon), isRightCategory(Muon))',
297  2),
298  'IntermediateMuon':
299  FTCategoryParameters('mu+:inRoe', 'IntermediateMuon', 'IntermediateMuon',
300  'QpOf(mu+:inRoe, isRightCategory(IntermediateMuon), isRightCategory(IntermediateMuon))',
301  3),
302  'KinLepton':
303  FTCategoryParameters('mu+:inRoe', 'KinLepton', 'KinLepton',
304  'QpOf(mu+:inRoe, isRightCategory(KinLepton), isRightCategory(KinLepton))',
305  4),
306  'IntermediateKinLepton':
307  FTCategoryParameters('mu+:inRoe', 'IntermediateKinLepton', 'IntermediateKinLepton',
308  'QpOf(mu+:inRoe, isRightCategory(IntermediateKinLepton), isRightCategory(IntermediateKinLepton))',
309  5),
310  'Kaon':
311  FTCategoryParameters('K+:inRoe', 'Kaon', 'Kaon',
312  'weightedQpOf(K+:inRoe, isRightCategory(Kaon), isRightCategory(Kaon))',
313  6),
314  'SlowPion':
315  FTCategoryParameters('pi+:inRoe', 'SlowPion', 'SlowPion',
316  'QpOf(pi+:inRoe, isRightCategory(SlowPion), isRightCategory(SlowPion))',
317  7),
318  'FastHadron':
319  FTCategoryParameters('pi+:inRoe', 'FastHadron', 'FastHadron',
320  'QpOf(pi+:inRoe, isRightCategory(FastHadron), isRightCategory(FastHadron))',
321  8),
322  'Lambda':
323  FTCategoryParameters('Lambda0:inRoe', 'Lambda', 'Lambda',
324  'weightedQpOf(Lambda0:inRoe, isRightCategory(Lambda), isRightCategory(Lambda))',
325  9),
326  'FSC':
327  FTCategoryParameters('pi+:inRoe', 'SlowPion', 'FSC',
328  'QpOf(pi+:inRoe, isRightCategory(FSC), isRightCategory(SlowPion))',
329  10),
330  'MaximumPstar':
331  FTCategoryParameters('pi+:inRoe', 'MaximumPstar', 'MaximumPstar',
332  'QpOf(pi+:inRoe, isRightCategory(MaximumPstar), isRightCategory(MaximumPstar))',
333  11),
334  'KaonPion':
335  FTCategoryParameters('K+:inRoe', 'Kaon', 'KaonPion',
336  'QpOf(K+:inRoe, isRightCategory(KaonPion), isRightCategory(Kaon))',
337  12),
338 }
339 
340 
341 def getTrainingVariables(category=None):
342  """
343  Helper function to get training variables.
344 
345  NOTE: This function is not called the Expert mode. It is not necessary to be consistent with variables list of weight files.
346  """
347 
348  KId = {'Belle': 'ifNANgiveX(atcPIDBelle(3,2), 0.5)', 'Belle2': 'kaonID'}
349  muId = {'Belle': 'muIDBelle', 'Belle2': 'muonID'}
350  eId = {'Belle': 'eIDBelle', 'Belle2': 'electronID'}
351 
352  variables = []
353  if category == 'Electron' or category == 'IntermediateElectron':
354  variables = ['useCMSFrame(p)',
355  'useCMSFrame(pt)',
356  'p',
357  'pt',
358  'cosTheta',
359  eId[getBelleOrBelle2()],
360  'eid_TOP',
361  'eid_ARICH',
362  'eid_ECL',
363  'BtagToWBosonVariables(recoilMassSqrd)',
364  'BtagToWBosonVariables(pMissCMS)',
365  'BtagToWBosonVariables(cosThetaMissCMS)',
366  'BtagToWBosonVariables(EW90)',
367  'cosTPTO',
368  'chiProb',
369  ]
370  if getBelleOrBelle2() == "Belle":
371  variables.append('eid_dEdx')
372  variables.append('ImpactXY')
373  variables.append('distance')
374 
375  elif category == 'Muon' or category == 'IntermediateMuon':
376  variables = ['useCMSFrame(p)',
377  'useCMSFrame(pt)',
378  'p',
379  'pt',
380  'cosTheta',
381  muId[getBelleOrBelle2()],
382  'muid_TOP',
383  'muid_ARICH',
384  'muid_KLM',
385  'BtagToWBosonVariables(recoilMassSqrd)',
386  'BtagToWBosonVariables(pMissCMS)',
387  'BtagToWBosonVariables(cosThetaMissCMS)',
388  'BtagToWBosonVariables(EW90)',
389  'cosTPTO',
390  ]
391  if getBelleOrBelle2() == "Belle":
392  variables.append('muid_dEdx')
393  variables.append('ImpactXY')
394  variables.append('distance')
395  variables.append('chiProb')
396 
397  elif category == 'KinLepton' or category == 'IntermediateKinLepton':
398  variables = ['useCMSFrame(p)',
399  'useCMSFrame(pt)',
400  'p',
401  'pt',
402  'cosTheta',
403  muId[getBelleOrBelle2()],
404  'muid_TOP',
405  'muid_ARICH',
406  'muid_KLM',
407  eId[getBelleOrBelle2()],
408  'eid_TOP',
409  'eid_ARICH',
410  'eid_ECL',
411  'BtagToWBosonVariables(recoilMassSqrd)',
412  'BtagToWBosonVariables(pMissCMS)',
413  'BtagToWBosonVariables(cosThetaMissCMS)',
414  'BtagToWBosonVariables(EW90)',
415  'cosTPTO',
416  ]
417  if getBelleOrBelle2() == "Belle":
418  variables.append('eid_dEdx')
419  variables.append('muid_dEdx')
420  variables.append('ImpactXY')
421  variables.append('distance')
422  variables.append('chiProb')
423 
424  elif category == 'Kaon':
425  variables = ['useCMSFrame(p)',
426  'useCMSFrame(pt)',
427  'p',
428  'pt',
429  'cosTheta',
430  KId[getBelleOrBelle2()],
431  'Kid_dEdx',
432  'Kid_TOP',
433  'Kid_ARICH',
434  'NumberOfKShortsInRoe',
435  'ptTracksRoe',
436  'BtagToWBosonVariables(recoilMassSqrd)',
437  'BtagToWBosonVariables(pMissCMS)',
438  'BtagToWBosonVariables(cosThetaMissCMS)',
439  'BtagToWBosonVariables(EW90)',
440  'cosTPTO',
441  'chiProb',
442  ]
443  if getBelleOrBelle2() == "Belle":
444  variables.append('ImpactXY')
445  variables.append('distance')
446 
447  elif category == 'SlowPion':
448  variables = ['useCMSFrame(p)',
449  'useCMSFrame(pt)',
450  'cosTheta',
451  'p',
452  'pt',
453  'pionID',
454  'piid_TOP',
455  'piid_ARICH',
456  'pi_vs_edEdxid',
457  KId[getBelleOrBelle2()],
458  'Kid_dEdx',
459  'Kid_TOP',
460  'Kid_ARICH',
461  'NumberOfKShortsInRoe',
462  'ptTracksRoe',
463  eId[getBelleOrBelle2()],
464  'BtagToWBosonVariables(recoilMassSqrd)',
465  'BtagToWBosonVariables(EW90)',
466  'BtagToWBosonVariables(cosThetaMissCMS)',
467  'BtagToWBosonVariables(pMissCMS)',
468  'cosTPTO',
469  ]
470  if getBelleOrBelle2() == "Belle":
471  variables.append('piid_dEdx')
472  variables.append('ImpactXY')
473  variables.append('distance')
474  variables.append('chiProb')
475 
476  elif category == 'FastHadron':
477  variables = ['useCMSFrame(p)',
478  'useCMSFrame(pt)',
479  'cosTheta',
480  'p',
481  'pt',
482  'pionID',
483  'piid_dEdx',
484  'piid_TOP',
485  'piid_ARICH',
486  'pi_vs_edEdxid',
487  KId[getBelleOrBelle2()],
488  'Kid_dEdx',
489  'Kid_TOP',
490  'Kid_ARICH',
491  'NumberOfKShortsInRoe',
492  'ptTracksRoe',
493  eId[getBelleOrBelle2()],
494  'BtagToWBosonVariables(recoilMassSqrd)',
495  'BtagToWBosonVariables(EW90)',
496  'BtagToWBosonVariables(cosThetaMissCMS)',
497  'cosTPTO',
498  ]
499  if getBelleOrBelle2() == "Belle":
500  variables.append('BtagToWBosonVariables(pMissCMS)')
501  variables.append('ImpactXY')
502  variables.append('distance')
503  variables.append('chiProb')
504 
505  elif category == 'Lambda':
506  variables = ['lambdaFlavor',
507  'NumberOfKShortsInRoe',
508  'M',
509  'cosAngleBetweenMomentumAndVertexVector',
510  'lambdaZError',
511  'daughter(0,p)',
512  'daughter(0,useCMSFrame(p))',
513  'daughter(1,p)',
514  'daughter(1,useCMSFrame(p))',
515  'useCMSFrame(p)',
516  'p',
517  'chiProb',
518  ]
519  if getBelleOrBelle2() == "Belle2":
520  variables.append('daughter(1,protonID)') # protonID always 0 in B2BII check in future
521  variables.append('daughter(0,pionID)') # not very powerful in B2BII
522  else:
523  variables.append('distance')
524 
525  elif category == 'MaximumPstar':
526  variables = ['useCMSFrame(p)',
527  'useCMSFrame(pt)',
528  'p',
529  'pt',
530  'cosTPTO',
531  ]
532  if getBelleOrBelle2() == "Belle2":
533  variables.append('ImpactXY')
534  variables.append('distance')
535 
536  elif category == 'FSC':
537  variables = ['useCMSFrame(p)',
538  'cosTPTO',
539  KId[getBelleOrBelle2()],
540  'FSCVariables(pFastCMS)',
541  'FSCVariables(cosSlowFast)',
542  'FSCVariables(cosTPTOFast)',
543  'FSCVariables(SlowFastHaveOpositeCharges)',
544  ]
545  elif category == 'KaonPion':
546  variables = ['extraInfo(isRightCategory(Kaon))',
547  'HighestProbInCat(pi+:inRoe, isRightCategory(SlowPion))',
548  'KaonPionVariables(cosKaonPion)',
549  'KaonPionVariables(HaveOpositeCharges)',
550  KId[getBelleOrBelle2()]
551  ]
552 
553  return variables
554 
555 
556 def FillParticleLists(maskName='all', categories=None, path=None):
557  """
558  Fills the particle Lists for all categories.
559  """
560 
561  from vertex import kFit
562  readyParticleLists = []
563 
564  if categories is None:
565  categories = []
566 
567  trackCut = 'isInRestOfEvent > 0.5 and passesROEMask(' + maskName + ') > 0.5 and p >= 0'
568 
569  for category in categories:
570  particleList = AvailableCategories[category].particleList
571 
572  if particleList in readyParticleLists:
573  continue
574 
575  # Select particles in ROE for different categories according to mass hypothesis.
576  if particleList == 'Lambda0:inRoe':
577  if 'pi+:inRoe' not in readyParticleLists:
578  ma.fillParticleList('pi+:inRoe', trackCut, path=path)
579  readyParticleLists.append('pi+:inRoe')
580 
581  ma.fillParticleList('p+:inRoe', trackCut, path=path)
582  ma.reconstructDecay(particleList + ' -> pi-:inRoe p+:inRoe', '1.00<=M<=1.23', False, path=path)
583  kFit(particleList, 0.01, path=path)
584  ma.matchMCTruth(particleList, path=path)
585  readyParticleLists.append(particleList)
586 
587  else:
588  # Filling particle list for actual category
589  ma.fillParticleList(particleList, trackCut, path=path)
590  readyParticleLists.append(particleList)
591 
592  # Additional particleLists for K_S0
593  if getBelleOrBelle2() == 'Belle':
594  ma.cutAndCopyList('K_S0:inRoe', 'K_S0:mdst', 'extraInfo(ksnbStandard) == 1 and isInRestOfEvent == 1', path=path)
595  else:
596  if 'pi+:inRoe' not in readyParticleLists:
597  ma.fillParticleList('pi+:inRoe', trackCut, path=path)
598  ma.reconstructDecay('K_S0:inRoe -> pi+:inRoe pi-:inRoe', '0.40<=M<=0.60', False, path=path)
599  kFit('K_S0:inRoe', 0.01, path=path)
600 
601  # Apply BDT-based LID
602  if getBelleOrBelle2() == 'Belle2':
603  default_list_for_lid_BDT = ['e+:inRoe', 'mu+:inRoe']
604  list_for_lid_BDT = []
605 
606  for particleList in default_list_for_lid_BDT:
607  if particleList in readyParticleLists:
608  list_for_lid_BDT.append(particleList)
609 
610  if list_for_lid_BDT: # empty check
611  ma.applyChargedPidMVA(particleLists=list_for_lid_BDT, path=path,
612  trainingMode=0, # binary
613  binaryHypoPDGCodes=(11, 211)) # e vs pi
614  ma.applyChargedPidMVA(particleLists=list_for_lid_BDT, path=path,
615  trainingMode=0, # binary
616  binaryHypoPDGCodes=(13, 211)) # mu vs pi
617  ma.applyChargedPidMVA(particleLists=list_for_lid_BDT, path=path,
618  trainingMode=1) # Multiclass
619 
620 
621 def eventLevel(mode='Expert', weightFiles='B2JpsiKs_mu', categories=None, path=None):
622  """
623  Samples data for training or tests all categories all categories at event level.
624  """
625 
626  from basf2 import create_path
627  from basf2 import register_module
628 
629  B2INFO('EVENT LEVEL')
630 
631  ReadyMethods = 0
632 
633  # Each category has its own Path in order to be skipped if the corresponding particle list is empty
634  identifiersExtraInfosDict = dict()
635  identifiersExtraInfosKaonPion = []
636 
637  if categories is None:
638  categories = []
639 
640  for category in categories:
641  particleList = AvailableCategories[category].particleList
642 
643  methodPrefixEventLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'EventLevel' + category + 'FBDT'
644  identifierEventLevel = methodPrefixEventLevel
645  targetVariable = 'isRightCategory(' + category + ')'
646  extraInfoName = targetVariable
647 
648  if mode == 'Expert':
649 
650  if downloadFlag or useOnlyLocalFlag:
651  identifierEventLevel = filesDirectory + '/' + methodPrefixEventLevel + '_1.root'
652 
653  if downloadFlag:
654  if not os.path.isfile(identifierEventLevel):
655  basf2_mva.download(methodPrefixEventLevel, identifierEventLevel)
656  if not os.path.isfile(identifierEventLevel):
657  B2FATAL('Flavor Tagger: Weight file ' + identifierEventLevel +
658  ' was not downloaded from Database. Please check the buildOrRevision name. Stopped')
659 
660  if useOnlyLocalFlag:
661  if not os.path.isfile(identifierEventLevel):
662  B2FATAL('Flavor Tagger: ' + particleList + ' Eventlevel was not trained. Weight file ' +
663  identifierEventLevel + ' was not found. Stopped')
664 
665  B2INFO('flavorTagger: MVAExpert ' + methodPrefixEventLevel + ' ready.')
666 
667  elif mode == 'Sampler':
668 
669  identifierEventLevel = filesDirectory + '/' + methodPrefixEventLevel + '_1.root'
670  if os.path.isfile(identifierEventLevel):
671  B2INFO('flavorTagger: MVAExpert ' + methodPrefixEventLevel + ' ready.')
672 
673  if 'KaonPion' in categories:
674  methodPrefixEventLevelKaonPion = "FlavorTagger_" + getBelleOrBelle2() + \
675  "_" + weightFiles + 'EventLevelKaonPionFBDT'
676  identifierEventLevelKaonPion = filesDirectory + '/' + methodPrefixEventLevelKaonPion + '_1.root'
677  if not os.path.isfile(identifierEventLevelKaonPion):
678  # Slow Pion and Kaon categories are used if Kaon-Pion is lacking for
679  # sampling. The others are not needed and skipped
680  if category != "SlowPion" and category != "Kaon":
681  continue
682 
683  if mode == 'Expert' or (mode == 'Sampler' and os.path.isfile(identifierEventLevel)):
684 
685  B2INFO('flavorTagger: Applying MVAExpert ' + methodPrefixEventLevel + '.')
686 
687  if category == 'KaonPion':
688  identifiersExtraInfosKaonPion.append((extraInfoName, identifierEventLevel))
689  elif particleList not in identifiersExtraInfosDict:
690  identifiersExtraInfosDict[particleList] = [(extraInfoName, identifierEventLevel)]
691  else:
692  identifiersExtraInfosDict[particleList].append((extraInfoName, identifierEventLevel))
693 
694  ReadyMethods += 1
695 
696  # Each category has its own Path in order to be skipped if the corresponding particle list is empty
697  for particleList in identifiersExtraInfosDict:
698  eventLevelPath = create_path()
699  SkipEmptyParticleList = register_module("SkimFilter")
700  SkipEmptyParticleList.set_name('SkimFilter_EventLevel_' + particleList)
701  SkipEmptyParticleList.param('particleLists', particleList)
702  SkipEmptyParticleList.if_true(eventLevelPath, basf2.AfterConditionPath.CONTINUE)
703  path.add_module(SkipEmptyParticleList)
704 
705  mvaMultipleExperts = register_module('MVAMultipleExperts')
706  mvaMultipleExperts.set_name('MVAMultipleExperts_EventLevel_' + particleList)
707  mvaMultipleExperts.param('listNames', [particleList])
708  mvaMultipleExperts.param('extraInfoNames', [row[0] for row in identifiersExtraInfosDict[particleList]])
709  mvaMultipleExperts.param('signalFraction', signalFraction)
710  mvaMultipleExperts.param('identifiers', [row[1] for row in identifiersExtraInfosDict[particleList]])
711  eventLevelPath.add_module(mvaMultipleExperts)
712 
713  if 'KaonPion' in categories and len(identifiersExtraInfosKaonPion) != 0:
714  eventLevelKaonPionPath = create_path()
715  SkipEmptyParticleList = register_module("SkimFilter")
716  SkipEmptyParticleList.set_name('SkimFilter_' + 'K+:inRoe')
717  SkipEmptyParticleList.param('particleLists', 'K+:inRoe')
718  SkipEmptyParticleList.if_true(eventLevelKaonPionPath, basf2.AfterConditionPath.CONTINUE)
719  path.add_module(SkipEmptyParticleList)
720 
721  mvaExpertKaonPion = register_module("MVAExpert")
722  mvaExpertKaonPion.set_name('MVAExpert_KaonPion_' + 'K+:inRoe')
723  mvaExpertKaonPion.param('listNames', ['K+:inRoe'])
724  mvaExpertKaonPion.param('extraInfoName', identifiersExtraInfosKaonPion[0][0])
725  mvaExpertKaonPion.param('signalFraction', signalFraction)
726  mvaExpertKaonPion.param('identifier', identifiersExtraInfosKaonPion[0][1])
727 
728  eventLevelKaonPionPath.add_module(mvaExpertKaonPion)
729 
730  if mode == 'Sampler':
731 
732  for category in categories:
733  particleList = AvailableCategories[category].particleList
734 
735  methodPrefixEventLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'EventLevel' + category + 'FBDT'
736  identifierEventLevel = filesDirectory + '/' + methodPrefixEventLevel + '_1.root'
737  targetVariable = 'isRightCategory(' + category + ')'
738 
739  if not os.path.isfile(identifierEventLevel):
740 
741  if category == 'KaonPion':
742  methodPrefixEventLevelSlowPion = "FlavorTagger_" + getBelleOrBelle2() + \
743  "_" + weightFiles + 'EventLevelSlowPionFBDT'
744  identifierEventLevelSlowPion = filesDirectory + '/' + methodPrefixEventLevelSlowPion + '_1.root'
745  if not os.path.isfile(identifierEventLevelSlowPion):
746  B2INFO("Flavor Tagger: event level weight file for the Slow Pion category is absent." +
747  "It is required to sample the training information for the KaonPion category." +
748  "An additional sampling step will be needed after the following training step.")
749  continue
750 
751  B2INFO('flavorTagger: file ' + filesDirectory + '/' +
752  methodPrefixEventLevel + "sampled" + fileId + '.root will be saved.')
753 
754  ma.applyCuts(particleList, 'isRightCategory(mcAssociated) > 0', path)
755  eventLevelpath = create_path()
756  SkipEmptyParticleList = register_module("SkimFilter")
757  SkipEmptyParticleList.set_name('SkimFilter_EventLevel' + category)
758  SkipEmptyParticleList.param('particleLists', particleList)
759  SkipEmptyParticleList.if_true(eventLevelpath, basf2.AfterConditionPath.CONTINUE)
760  path.add_module(SkipEmptyParticleList)
761 
762  ntuple = register_module('VariablesToNtuple')
763  ntuple.param('fileName', filesDirectory + '/' + methodPrefixEventLevel + "sampled" + fileId + ".root")
764  ntuple.param('treeName', methodPrefixEventLevel + "_tree")
765  variablesToBeSaved = getTrainingVariables(category) + [targetVariable, 'ancestorHasWhichFlavor',
766  'isSignal', 'mcPDG', 'mcErrors', 'genMotherPDG',
767  'nMCMatches', 'B0mcErrors']
768  if category != 'KaonPion' and category != 'FSC':
769  variablesToBeSaved = variablesToBeSaved + \
770  ['extraInfo(isRightTrack(' + category + '))',
771  'hasHighestProbInCat(' + particleList + ', isRightTrack(' + category + '))']
772  ntuple.param('variables', variablesToBeSaved)
773  ntuple.param('particleList', particleList)
774  eventLevelpath.add_module(ntuple)
775 
776  if ReadyMethods != len(categories):
777  return False
778  else:
779  return True
780 
781 
782 def eventLevelTeacher(weightFiles='B2JpsiKs_mu', categories=None):
783  """
784  Trains all categories at event level.
785  """
786 
787  B2INFO('EVENT LEVEL TEACHER')
788 
789  ReadyMethods = 0
790 
791  if categories is None:
792  categories = []
793 
794  for category in categories:
795  methodPrefixEventLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'EventLevel' + category + 'FBDT'
796  targetVariable = 'isRightCategory(' + category + ')'
797  weightFile = filesDirectory + '/' + methodPrefixEventLevel + "_1.root"
798 
799  if os.path.isfile(weightFile):
800  ReadyMethods += 1
801  continue
802 
803  sampledFilesList = glob.glob(filesDirectory + '/' + methodPrefixEventLevel + 'sampled*.root')
804  if len(sampledFilesList) == 0:
805  B2INFO('flavorTagger: eventLevelTeacher did not find any ' + methodPrefixEventLevel +
806  ".root" + ' file. Please run the flavorTagger in "Sampler" mode afterwards.')
807 
808  else:
809  B2INFO('flavorTagger: MVA Teacher training' + methodPrefixEventLevel + ' .')
810  trainingOptionsEventLevel = basf2_mva.GeneralOptions()
811  trainingOptionsEventLevel.m_datafiles = basf2_mva.vector(*sampledFilesList)
812  trainingOptionsEventLevel.m_treename = methodPrefixEventLevel + "_tree"
813  trainingOptionsEventLevel.m_identifier = weightFile
814  trainingOptionsEventLevel.m_variables = basf2_mva.vector(*getTrainingVariables(category))
815  trainingOptionsEventLevel.m_target_variable = targetVariable
816  trainingOptionsEventLevel.m_max_events = maxEventsNumber
817 
818  basf2_mva.teacher(trainingOptionsEventLevel, getFastBDTCategories())
819 
820  if uploadFlag:
821  basf2_mva.upload(weightFile, methodPrefixEventLevel)
822 
823  if ReadyMethods != len(categories):
824  return False
825  else:
826  return True
827 
828 
829 def combinerLevel(mode='Expert', weightFiles='B2JpsiKs_mu', categories=None,
830  variablesCombinerLevel=None, categoriesCombinationCode=None, path=None):
831  """
832  Samples the input data or tests the combiner according to the selected categories.
833  """
834 
835  B2INFO('COMBINER LEVEL')
836 
837  if categories is None:
838  categories = []
839  if variablesCombinerLevel is None:
840  variablesCombinerLevel = []
841 
842  B2INFO("Flavor Tagger: Required Combiner for Categories:")
843  for category in categories:
844  B2INFO(category)
845 
846  B2INFO("Flavor Tagger: which corresponds to a weight file with categories combination code " + categoriesCombinationCode)
847 
848  methodPrefixCombinerLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'Combiner' \
849  + categoriesCombinationCode
850 
851  if mode == 'Sampler':
852 
853  if os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + '_1.root') or \
854  os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + '_1.root'):
855  B2FATAL('flavorTagger: File' + methodPrefixCombinerLevel + 'FBDT' + "_1.root" + ' or ' + methodPrefixCombinerLevel +
856  'FANN' + '_1.root found. Please run the "Expert" mode or delete the file if a new sampling is desired.')
857 
858  B2INFO('flavorTagger: Sampling Data on Combiner Level. File' +
859  methodPrefixCombinerLevel + ".root" + ' will be saved')
860 
861  ntuple = basf2.register_module('VariablesToNtuple')
862  ntuple.param('fileName', filesDirectory + '/' + methodPrefixCombinerLevel + "sampled" + fileId + ".root")
863  ntuple.param('treeName', methodPrefixCombinerLevel + 'FBDT' + "_tree")
864  ntuple.param('variables', variablesCombinerLevel + ['qrCombined'])
865  ntuple.param('particleList', "")
866  path.add_module(ntuple)
867 
868  if mode == 'Expert':
869 
870  # Check if weight files are ready
871  if TMVAfbdt:
872  identifierFBDT = methodPrefixCombinerLevel + 'FBDT'
873  if downloadFlag or useOnlyLocalFlag:
874  identifierFBDT = filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + '_1.root'
875 
876  if downloadFlag:
877  if not os.path.isfile(identifierFBDT):
878  basf2_mva.download(methodPrefixCombinerLevel + 'FBDT', identifierFBDT)
879  if not os.path.isfile(identifierFBDT):
880  B2FATAL('Flavor Tagger: Weight file ' + identifierFBDT +
881  ' was not downloaded from Database. Please check the buildOrRevision name. Stopped')
882 
883  if useOnlyLocalFlag:
884  if not os.path.isfile(identifierFBDT):
885  B2FATAL('flavorTagger: Combinerlevel FastBDT was not trained with this combination of categories.' +
886  ' Weight file ' + identifierFBDT + ' not found. Stopped')
887 
888  B2INFO('flavorTagger: Ready to be used with weightFile ' + methodPrefixCombinerLevel + 'FBDT' + '_1.root')
889 
890  if FANNmlp:
891  identifierFANN = methodPrefixCombinerLevel + 'FANN'
892  if downloadFlag or useOnlyLocalFlag:
893  identifierFANN = filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + '_1.root'
894 
895  if downloadFlag:
896  if not os.path.isfile(identifierFANN):
897  basf2_mva.download(methodPrefixCombinerLevel + 'FANN', identifierFANN)
898  if not os.path.isfile(identifierFANN):
899  B2FATAL('Flavor Tagger: Weight file ' + identifierFANN +
900  ' was not downloaded from Database. Please check the buildOrRevision name. Stopped')
901  if useOnlyLocalFlag:
902  if not os.path.isfile(identifierFANN):
903  B2FATAL('flavorTagger: Combinerlevel FANNMLP was not trained with this combination of categories. ' +
904  ' Weight file ' + identifierFANN + ' not found. Stopped')
905 
906  B2INFO('flavorTagger: Ready to be used with weightFile ' + methodPrefixCombinerLevel + 'FANN' + '_1.root')
907 
908  # At this stage, all necessary weight files should be ready.
909  # Call MVAExpert or MVAMultipleExperts module.
910  if TMVAfbdt and not FANNmlp:
911  B2INFO('flavorTagger: Apply FBDTMethod ' + methodPrefixCombinerLevel + 'FBDT')
912  path.add_module('MVAExpert', listNames=[], extraInfoName='qrCombined' + 'FBDT', signalFraction=signalFraction,
913  identifier=identifierFBDT)
914 
915  if FANNmlp and not TMVAfbdt:
916  B2INFO('flavorTagger: Apply FANNMethod on combiner level')
917  path.add_module('MVAExpert', listNames=[], extraInfoName='qrCombined' + 'FANN', signalFraction=signalFraction,
918  identifier=identifierFANN)
919 
920  if FANNmlp and TMVAfbdt:
921  B2INFO('flavorTagger: Apply FANNMethod and FBDTMethod on combiner level')
922  mvaMultipleExperts = basf2.register_module('MVAMultipleExperts')
923  mvaMultipleExperts.set_name('MVAMultipleExperts_Combiners')
924  mvaMultipleExperts.param('listNames', [])
925  mvaMultipleExperts.param('extraInfoNames', ['qrCombined' + 'FBDT', 'qrCombined' + 'FANN'])
926  mvaMultipleExperts.param('signalFraction', signalFraction)
927  mvaMultipleExperts.param('identifiers', [identifierFBDT, identifierFANN])
928  path.add_module(mvaMultipleExperts)
929 
930 
931 def combinerLevelTeacher(weightFiles='B2JpsiKs_mu', variablesCombinerLevel=None,
932  categoriesCombinationCode=None):
933  """
934  Trains the combiner according to the selected categories.
935  """
936 
937  B2INFO('COMBINER LEVEL TEACHER')
938 
939  if variablesCombinerLevel is None:
940  variablesCombinerLevel = []
941 
942  methodPrefixCombinerLevel = "FlavorTagger_" + getBelleOrBelle2() + "_" + weightFiles + 'Combiner' \
943  + categoriesCombinationCode
944 
945  sampledFilesList = glob.glob(filesDirectory + '/' + methodPrefixCombinerLevel + 'sampled*.root')
946  if len(sampledFilesList) == 0:
947  B2FATAL('FlavorTagger: combinerLevelTeacher did not find any ' +
948  methodPrefixCombinerLevel + 'sampled*.root file. Please run the flavorTagger in "Sampler" mode.')
949 
950  if TMVAfbdt:
951 
952  if not os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + '_1.root'):
953 
954  B2INFO('flavorTagger: MVA Teacher training a FastBDT on Combiner Level')
955 
956  trainingOptionsCombinerLevel = basf2_mva.GeneralOptions()
957  trainingOptionsCombinerLevel.m_datafiles = basf2_mva.vector(*sampledFilesList)
958  trainingOptionsCombinerLevel.m_treename = methodPrefixCombinerLevel + 'FBDT' + "_tree"
959  trainingOptionsCombinerLevel.m_identifier = filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + "_1.root"
960  trainingOptionsCombinerLevel.m_variables = basf2_mva.vector(*variablesCombinerLevel)
961  trainingOptionsCombinerLevel.m_target_variable = 'qrCombined'
962  trainingOptionsCombinerLevel.m_max_events = maxEventsNumber
963 
964  basf2_mva.teacher(trainingOptionsCombinerLevel, getFastBDTCombiner())
965 
966  if uploadFlag:
967  basf2_mva.upload(filesDirectory + '/' + methodPrefixCombinerLevel +
968  'FBDT' + "_1.root", methodPrefixCombinerLevel + 'FBDT')
969 
970  elif FANNmlp and not os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + '_1.root'):
971 
972  B2INFO('flavorTagger: Combinerlevel FBDT was already trained with this combination of categories. Weight file ' +
973  methodPrefixCombinerLevel + 'FBDT' + '_1.root has been found.')
974 
975  else:
976  B2FATAL('flavorTagger: Combinerlevel was already trained with this combination of categories. Weight files ' +
977  methodPrefixCombinerLevel + 'FBDT' + '_1.root and ' +
978  methodPrefixCombinerLevel + 'FANN' + '_1.root has been found. Please use the "Expert" mode')
979 
980  if FANNmlp:
981 
982  if not os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + '_1.root'):
983 
984  B2INFO('flavorTagger: MVA Teacher training a FANN MLP on Combiner Level')
985 
986  trainingOptionsCombinerLevel = basf2_mva.GeneralOptions()
987  trainingOptionsCombinerLevel.m_datafiles = basf2_mva.vector(*sampledFilesList)
988  trainingOptionsCombinerLevel.m_treename = methodPrefixCombinerLevel + 'FBDT' + "_tree"
989  trainingOptionsCombinerLevel.m_identifier = filesDirectory + '/' + methodPrefixCombinerLevel + 'FANN' + "_1.root"
990  trainingOptionsCombinerLevel.m_variables = basf2_mva.vector(*variablesCombinerLevel)
991  trainingOptionsCombinerLevel.m_target_variable = 'qrCombined'
992  trainingOptionsCombinerLevel.m_max_events = maxEventsNumber
993 
994  basf2_mva.teacher(trainingOptionsCombinerLevel, getMlpFANNCombiner())
995 
996  if uploadFlag:
997  basf2_mva.upload(filesDirectory + '/' + methodPrefixCombinerLevel +
998  'FANN' + "_1.root", methodPrefixCombinerLevel + 'FANN')
999 
1000  elif TMVAfbdt and not os.path.isfile(filesDirectory + '/' + methodPrefixCombinerLevel + 'FBDT' + '_1.root'):
1001 
1002  B2INFO('flavorTagger: Combinerlevel FBDT was already trained with this combination of categories. Weight file ' +
1003  methodPrefixCombinerLevel + 'FANN' + '_1.config has been found.')
1004 
1005  else:
1006  B2FATAL('flavorTagger: Combinerlevel was already trained with this combination of categories. Weight files ' +
1007  methodPrefixCombinerLevel + 'FBDT' + '_1.root and ' +
1008  methodPrefixCombinerLevel + 'FANN' + '_1.root has been found. Please use the "Expert" mode')
1009 
1010 
1011 def getEventLevelParticleLists(categories=None):
1012 
1013  if categories is None:
1014  categories = []
1015 
1016  eventLevelParticleLists = []
1017 
1018  for category in categories:
1019  ftCategory = AvailableCategories[category]
1020  event_tuple = (ftCategory.particleList, ftCategory.eventName, ftCategory.variableName)
1021 
1022  if event_tuple not in eventLevelParticleLists:
1023  eventLevelParticleLists.append(event_tuple)
1024  else:
1025  B2FATAL('Flavor Tagger: ' + category + ' has been already given')
1026 
1027  return eventLevelParticleLists
1028 
1029 
1030 def flavorTagger(
1031  particleLists=None,
1032  mode='Expert',
1033  weightFiles='B2nunubarBGx1',
1034  workingDirectory='.',
1035  combinerMethods=['TMVA-FBDT'],
1036  categories=[
1037  'Electron',
1038  'IntermediateElectron',
1039  'Muon',
1040  'IntermediateMuon',
1041  'KinLepton',
1042  'IntermediateKinLepton',
1043  'Kaon',
1044  'SlowPion',
1045  'FastHadron',
1046  'Lambda',
1047  'FSC',
1048  'MaximumPstar',
1049  'KaonPion'],
1050  maskName='FTDefaultMask',
1051  saveCategoriesInfo=True,
1052  useOnlyLocalWeightFiles=False,
1053  downloadFromDatabaseIfNotFound=False,
1054  uploadToDatabaseAfterTraining=False,
1055  samplerFileId='',
1056  prefix='MC15ri_light-2207-bengal_0',
1057  useGNN=False,
1058  identifierGNN='GFlaT_MC15ri_light_2303_iriomote_0',
1059  path=None,
1060 ):
1061  """
1062  Defines the whole flavor tagging process for each selected Rest of Event (ROE) built in the steering file.
1063  The flavor is predicted by Multivariate Methods trained with Variables and MetaVariables which use
1064  Tracks, ECL- and KLMClusters from the corresponding RestOfEvent dataobject.
1065  This module can be used to sample the training information, to train and/or to test the flavorTagger.
1066 
1067  @param particleLists The ROEs for flavor tagging are selected from the given particle lists.
1068  @param mode The available modes are
1069  ``Expert`` (default), ``Sampler``, and ``Teacher``. In the ``Expert`` mode
1070  Flavor Tagging is applied to the analysis,. In the ``Sampler`` mode you save
1071  save the variables for training. In the ``Teacher`` mode the FlavorTagger is
1072  trained, for this step you do not reconstruct any particle or do any analysis,
1073  you just run the flavorTagger alone.
1074  @param weightFiles Weight files name. Default=
1075  ``B2nunubarBGx1`` (official weight files). If the user self
1076  wants to train the FlavorTagger, the weightfiles name should correspond to the
1077  analysed CP channel in order to avoid confusions. The default name
1078  ``B2nunubarBGx1`` corresponds to
1079  :math:`B^0_{\\rm sig}\\to \\nu \\overline{\\nu}`.
1080  and ``B2JpsiKs_muBGx1`` to
1081  :math:`B^0_{\\rm sig}\\to J/\\psi (\\to \\mu^+ \\mu^-) K_s (\\to \\pi^+ \\pi^-)`.
1082  BGx1 stays for events simulated with background.
1083  @param workingDirectory Path to the directory containing the FlavorTagging/ folder.
1084  @param combinerMethods MVAs for the combiner: ``TMVA-FBDT` (default).
1085  ``FANN-MLP`` is available only with ``prefix=''`` (MC13 weight files).
1086  @param categories Categories used for flavor tagging. By default all are used.
1087  @param maskName Gets ROE particles from a specified ROE mask.
1088  ``FTDefaultMask`` (default): tentative mask definition that will be created
1089  automatically. The definition is as follows:
1090 
1091  - Track (pion): thetaInCDCAcceptance and dr<1 and abs(dz)<3
1092  - ECL-cluster (gamma): thetaInCDCAcceptance and clusterNHits>1.5 and \
1093  [[clusterReg==1 and E>0.08] or [clusterReg==2 and E>0.03] or \
1094  [clusterReg==3 and E>0.06]] \
1095  (Same as gamma:pi0eff30_May2020 and gamma:pi0eff40_May2020)
1096 
1097  ``all``: all ROE particles are used.
1098  Or one can give any mask name defined before calling this function.
1099  @param saveCategoriesInfo Sets to save information of individual categories.
1100  @param useOnlyLocalWeightFiles [Expert] Uses only locally saved weight files.
1101  @param downloadFromDatabaseIfNotFound [Expert] Weight files are downloaded from
1102  the conditions database if not available in workingDirectory.
1103  @param uploadToDatabaseAfterTraining [Expert] For librarians only: uploads weight files to localdb after training.
1104  @param samplerFileId Identifier to paralellize
1105  sampling. Only used in ``Sampler`` mode. If you are training by yourself and
1106  want to parallelize the sampling, you can run several sampling scripts in
1107  parallel. By changing this parameter you will not overwrite an older sample.
1108  @param prefix Prefix of weight files.
1109  ``MC15ri_light-2207-bengal_0`` (default): Weight files trained for MC15ri samples.
1110  ``''``: Weight files trained for MC13 samples.
1111  @param useGNN Use GNN-based Flavor Tagger in addition with FastBDT-based one.
1112  Please specify the weight file with the option ``identifierGNN``.
1113  [Expert] In the sampler mode, training files for GNN-based Flavor Tagger is produced.
1114  @param identifierGNN The name of weight file of the GNN-based Flavor Tagger.
1115  [Expert] Multiple identifiers can be given with list(str).
1116  @param path Modules are added to this path
1117 
1118  """
1119 
1120  if (not isinstance(particleLists, list)):
1121  particleLists = [particleLists] # in case user inputs a particle list as string
1122 
1123  if mode != 'Sampler' and mode != 'Teacher' and mode != 'Expert':
1124  B2FATAL('flavorTagger: Wrong mode given: The available modes are "Sampler", "Teacher" or "Expert"')
1125 
1126  if len(categories) != len(set(categories)):
1127  dup = [cat for cat in set(categories) if categories.count(cat) > 1]
1128  B2WARNING('Flavor Tagger: There are duplicate elements in the given categories list. '
1129  << 'The following duplicate elements are removed; ' << ', '.join(dup))
1130  categories = list(set(categories))
1131 
1132  if len(categories) < 2:
1133  B2FATAL('Flavor Tagger: Invalid amount of categories. At least two are needed.')
1134  B2FATAL(
1135  'Flavor Tagger: Possible categories are "Electron", "IntermediateElectron", "Muon", "IntermediateMuon", '
1136  '"KinLepton", "IntermediateKinLepton", "Kaon", "SlowPion", "FastHadron",'
1137  '"Lambda", "FSC", "MaximumPstar" or "KaonPion" ')
1138 
1139  for category in categories:
1140  if category not in AvailableCategories:
1141  B2FATAL('Flavor Tagger: ' + category + ' is not a valid category name given')
1142  B2FATAL('Flavor Tagger: Available categories are "Electron", "IntermediateElectron", '
1143  '"Muon", "IntermediateMuon", "KinLepton", "IntermediateKinLepton", "Kaon", "SlowPion", "FastHadron", '
1144  '"Lambda", "FSC", "MaximumPstar" or "KaonPion" ')
1145 
1146  if mode == 'Expert' and useGNN and identifierGNN == '':
1147  B2FATAL('Please specify the name of the weight file with ``identifierGNN``')
1148 
1149  # Directory where the weights of the trained Methods are saved
1150  # workingDirectory = os.environ['BELLE2_LOCAL_DIR'] + '/analysis/data'
1151 
1152  basf2.find_file(workingDirectory)
1153 
1154  global filesDirectory
1155  filesDirectory = workingDirectory + '/FlavorTagging/TrainedMethods'
1156 
1157  if mode == 'Sampler' or (mode == 'Expert' and downloadFromDatabaseIfNotFound):
1158  if not basf2.find_file(workingDirectory + '/FlavorTagging', silent=True):
1159  os.mkdir(workingDirectory + '/FlavorTagging')
1160  os.mkdir(workingDirectory + '/FlavorTagging/TrainedMethods')
1161  elif not basf2.find_file(workingDirectory + '/FlavorTagging/TrainedMethods', silent=True):
1162  os.mkdir(workingDirectory + '/FlavorTagging/TrainedMethods')
1163  filesDirectory = workingDirectory + '/FlavorTagging/TrainedMethods'
1164 
1165  if len(combinerMethods) < 1 or len(combinerMethods) > 2:
1166  B2FATAL('flavorTagger: Invalid list of combinerMethods. The available methods are "TMVA-FBDT" and "FANN-MLP"')
1167 
1168  global FANNmlp
1169  global TMVAfbdt
1170 
1171  FANNmlp = False
1172  TMVAfbdt = False
1173 
1174  for method in combinerMethods:
1175  if method == 'TMVA-FBDT':
1176  TMVAfbdt = True
1177  elif method == 'FANN-MLP':
1178  FANNmlp = True
1179  else:
1180  B2FATAL('flavorTagger: Invalid list of combinerMethods. The available methods are "TMVA-FBDT" and "FANN-MLP"')
1181 
1182  global fileId
1183  fileId = samplerFileId
1184 
1185  global useOnlyLocalFlag
1186  useOnlyLocalFlag = useOnlyLocalWeightFiles
1187 
1188  B2INFO('*** FLAVOR TAGGING ***')
1189  B2INFO(' ')
1190  B2INFO(' Working directory is: ' + filesDirectory)
1191  B2INFO(' ')
1192 
1193  setInteractionWithDatabase(downloadFromDatabaseIfNotFound, uploadToDatabaseAfterTraining)
1194 
1195  if prefix == '':
1196  set_FlavorTagger_pid_aliases_legacy()
1197  else:
1198  set_FlavorTagger_pid_aliases()
1199 
1200  alias_list_for_GNN = []
1201  if useGNN:
1202  alias_list_for_GNN = set_GNNFlavorTagger_aliases(categories)
1203 
1204  setInputVariablesWithMask()
1205  if prefix != '':
1206  weightFiles = prefix + '_' + weightFiles
1207 
1208  # Create configuration lists and code-name for given category's list
1209  trackLevelParticleLists = []
1210  eventLevelParticleLists = []
1211  variablesCombinerLevel = []
1212  categoriesCombination = []
1213  categoriesCombinationCode = 'CatCode'
1214  for category in categories:
1215  ftCategory = AvailableCategories[category]
1216 
1217  track_tuple = (ftCategory.particleList, ftCategory.trackName)
1218  event_tuple = (ftCategory.particleList, ftCategory.eventName, ftCategory.variableName)
1219 
1220  if track_tuple not in trackLevelParticleLists and category != 'MaximumPstar':
1221  trackLevelParticleLists.append(track_tuple)
1222 
1223  if event_tuple not in eventLevelParticleLists:
1224  eventLevelParticleLists.append(event_tuple)
1225  variablesCombinerLevel.append(ftCategory.variableName)
1226  categoriesCombination.append(ftCategory.code)
1227  else:
1228  B2FATAL('Flavor Tagger: ' + category + ' has been already given')
1229 
1230  for code in sorted(categoriesCombination):
1231  categoriesCombinationCode = categoriesCombinationCode + '%02d' % code
1232 
1233  # Create default ROE-mask
1234  if maskName == 'FTDefaultMask':
1235  FTDefaultMask = (
1236  'FTDefaultMask',
1237  'thetaInCDCAcceptance and dr<1 and abs(dz)<3',
1238  'thetaInCDCAcceptance and clusterNHits>1.5 and [[E>0.08 and clusterReg==1] or [E>0.03 and clusterReg==2] or \
1239  [E>0.06 and clusterReg==3]]')
1240  for name in particleLists:
1241  ma.appendROEMasks(list_name=name, mask_tuples=[FTDefaultMask], path=path)
1242 
1243  # Start ROE-routine
1244  roe_path = basf2.create_path()
1245  deadEndPath = basf2.create_path()
1246 
1247  if mode == 'Sampler':
1248  # Events containing ROE without B-Meson (but not empty) are discarded for training
1249  ma.signalSideParticleListsFilter(
1250  particleLists,
1251  'nROE_Charged(' + maskName + ', 0) > 0 and abs(qrCombined) == 1',
1252  roe_path,
1253  deadEndPath)
1254 
1255  FillParticleLists(maskName, categories, roe_path)
1256 
1257  if useGNN:
1258  if eventLevel('Expert', weightFiles, categories, roe_path):
1259 
1260  ma.rankByHighest('pi+:inRoe', 'p', numBest=0, allowMultiRank=False,
1261  outputVariable='FT_p_rank', overwriteRank=True, path=roe_path)
1262  ma.fillParticleListFromDummy('vpho:dummy', path=roe_path)
1263  ma.variablesToNtuple('vpho:dummy',
1264  alias_list_for_GNN,
1265  treename='tree',
1266  filename=f'{filesDirectory}/FlavorTagger_GNN_sampled{fileId}.root',
1267  signalSideParticleList=particleLists[0],
1268  path=roe_path)
1269 
1270  else:
1271  if eventLevel(mode, weightFiles, categories, roe_path):
1272  combinerLevel(mode, weightFiles, categories, variablesCombinerLevel, categoriesCombinationCode, roe_path)
1273 
1274  path.for_each('RestOfEvent', 'RestOfEvents', roe_path)
1275 
1276  elif mode == 'Expert':
1277  # If trigger returns 1 jump into empty path skipping further modules in roe_path
1278  # run filter with no cut first to get rid of ROEs that are missing the mask of the signal particle
1279  ma.signalSideParticleListsFilter(particleLists, 'nROE_Charged(' + maskName + ', 0) > 0', roe_path, deadEndPath)
1280 
1281  # Initialization of flavorTaggerInfo dataObject needs to be done in the main path
1282  flavorTaggerInfoBuilder = basf2.register_module('FlavorTaggerInfoBuilder')
1283  path.add_module(flavorTaggerInfoBuilder)
1284 
1285  FillParticleLists(maskName, categories, roe_path)
1286 
1287  if eventLevel(mode, weightFiles, categories, roe_path):
1288  combinerLevel(mode, weightFiles, categories, variablesCombinerLevel, categoriesCombinationCode, roe_path)
1289 
1290  flavorTaggerInfoFiller = basf2.register_module('FlavorTaggerInfoFiller')
1291  flavorTaggerInfoFiller.param('trackLevelParticleLists', trackLevelParticleLists)
1292  flavorTaggerInfoFiller.param('eventLevelParticleLists', eventLevelParticleLists)
1293  flavorTaggerInfoFiller.param('TMVAfbdt', TMVAfbdt)
1294  flavorTaggerInfoFiller.param('FANNmlp', FANNmlp)
1295  flavorTaggerInfoFiller.param('qpCategories', saveCategoriesInfo)
1296  flavorTaggerInfoFiller.param('istrueCategories', saveCategoriesInfo)
1297  flavorTaggerInfoFiller.param('targetProb', False)
1298  flavorTaggerInfoFiller.param('trackPointers', False)
1299  roe_path.add_module(flavorTaggerInfoFiller) # Add FlavorTag Info filler to roe_path
1300  add_default_FlavorTagger_aliases()
1301 
1302  if useGNN:
1303  ma.rankByHighest('pi+:inRoe', 'p', numBest=0, allowMultiRank=False,
1304  outputVariable='FT_p_rank', overwriteRank=True, path=roe_path)
1305  ma.fillParticleListFromDummy('vpho:dummy', path=roe_path)
1306 
1307  if isinstance(identifierGNN, str):
1308  roe_path.add_module('MVAExpert',
1309  listNames='vpho:dummy',
1310  extraInfoName='qrGNN_raw', # the range of qrGNN_raw is [0,1]
1311  identifier=identifierGNN)
1312 
1313  ma.variableToSignalSideExtraInfo('vpho:dummy', {'extraInfo(qrGNN_raw)*2-1': 'qrGNN'},
1314  path=roe_path)
1315 
1316  elif isinstance(identifierGNN, list):
1317  identifierGNN = list(set(identifierGNN))
1318 
1319  extraInfoNames = [f'qrGNN_{i_id}' for i_id in identifierGNN]
1320  roe_path.add_module('MVAMultipleExperts',
1321  listNames='vpho:dummy',
1322  extraInfoNames=extraInfoNames,
1323  identifiers=identifierGNN)
1324 
1325  extraInfoDict = {}
1326  for extraInfoName in extraInfoNames:
1327  extraInfoDict[f'extraInfo({extraInfoName})*2-1'] = extraInfoName
1328  variables.variables.addAlias(extraInfoName, f'extraInfo({extraInfoName})')
1329 
1330  ma.variableToSignalSideExtraInfo('vpho:dummy', extraInfoDict,
1331  path=roe_path)
1332 
1333  path.for_each('RestOfEvent', 'RestOfEvents', roe_path)
1334 
1335  elif mode == 'Teacher':
1336  if eventLevelTeacher(weightFiles, categories):
1337  combinerLevelTeacher(weightFiles, variablesCombinerLevel, categoriesCombinationCode)
1338 
1339 
1340 if __name__ == '__main__':
1341 
1342  desc_list = []
1343 
1344  function = globals()["flavorTagger"]
1345  signature = inspect.formatargspec(*inspect.getfullargspec(function))
1346  desc_list.append((function.__name__, signature + '\n' + function.__doc__))
1347 
1348  from terminal_utils import Pager
1349  from basf2.utils import pretty_print_description_list
1350  with Pager('Flavor Tagger function accepts the following arguments:'):
1351  pretty_print_description_list(desc_list)
def isB2BII()
Definition: b2bii.py:14