Belle II Software light-2601-hyperion
flavorTagger.py
1#!/usr/bin/env python3
2
3
10
11import os
12from basf2 import B2FATAL
13import basf2
14from variables import variables as vm
15import modularAnalysis as ma
16from stdPhotons import stdPhotons
17from vertex import kFit
18import tflat.utils as utils
19
20
21def fill_particle_lists(config, maskName='TFLATDefaultMask', path=None):
22 """
23 Fills the particle lists.
24 """
25
26 # create particle list with pions
27 trk_cut = config['trk_cut']
28 ma.fillParticleList('pi+:tflat', trk_cut, path=path)
29
30 # create particle list with gammas
31
32 # load MVA's for all gamma
33 ma.fillParticleList(
34 "gamma:all",
35 "",
36 path=path,
37 )
38 ma.getBeamBackgroundProbability("gamma:all", config['VersionBeamBackgroundMVA'], path=path)
39 ma.getFakePhotonProbability("gamma:all", config['VersionFakePhotonMVA'], path=path)
40
41 stdPhotons(listtype='tight', path=path)
42
43 gamma_cut = config['gamma_cut']
44 ma.cutAndCopyList('gamma:tflat', 'gamma:tight', gamma_cut, path=path)
45
46 ma.reconstructDecay('K_S0:inRoe -> pi+:tflat pi-:tflat', '0.40<=M<=0.60', False, path=path)
47 kFit('K_S0:inRoe', 0.01, path=path)
48
49
50def flavorTagger(particleLists, mode='Expert', working_dir='', uniqueIdentifier='TFlaT_MC16rd_light_2601_hyperion',
51 target='qrCombined', overwrite=False,
52 sampler_id=0,
53 path=None):
54 """
55 Interfacing for the Transformer FlavorTagger (TFlat). This function can be used for preparation of
56 training datasets (``Sampler``) and inference (``Expert``).
57
58 This function requires reconstructed B meson signal particle list and where an RestOfEvent is built.
59
60 :param particleLists: string or list[string], particle list(s) of the reconstructed signal B meson
61 :param mode: string, valid modes are ``Expert`` (default), ``Sampler``
62 :param working_dir: string, working directory for the method
63 :param uniqueIdentifier: string, database identifier for the method
64 :param target: string, target variable
65 :param overwrite: bool, overwrite already (locally!) existing training
66 :param sampler_id: identifier of sampled file for parallel sampling
67 :param path: basf2 path obj
68 :return: None
69 """
70
71 if isinstance(particleLists, str):
72 particleLists = [particleLists]
73
74 if mode not in ['Expert', 'Sampler']:
75 B2FATAL(f'Invalid mode {mode}')
76
77 tree_name = 'tflat_variables'
78 rank_variable = 'p'
79
80 config = utils.load_config(uniqueIdentifier)
81
82 # create default ROE-mask
83 TFLAT_mask = config['TFLAT_Mask']
84 maskName = TFLAT_mask[0]
85 for name in particleLists:
86 ma.appendROEMasks(list_name=name, mask_tuples=[TFLAT_mask], path=path)
87
88 # create roe specific paths
89 roe_path = basf2.create_path()
90 dead_end_path = basf2.create_path()
91
92 if mode == 'Sampler':
93 trk_variable_list = config['trk_variable_list']
94 ecl_variable_list = config['ecl_variable_list']
95 roe_variable_list = config['roe_variable_list']
96 # create tagging specific variables
97 features = utils.get_variables('pi+:tflat', rank_variable, trk_variable_list,
98 particleNumber=config['parameters']['num_trk'])
99 features += utils.get_variables('gamma:tflat', rank_variable, ecl_variable_list,
100 particleNumber=config['parameters']['num_ecl'])
101 features += utils.get_variables('pi+:tflat', rank_variable, roe_variable_list,
102 particleNumber=config['parameters']['num_roe'])
103
104 output_file_name = os.path.join(working_dir, uniqueIdentifier + f'_training_data{sampler_id}.root')
105 if os.path.isfile(output_file_name) and not overwrite:
106 B2FATAL(f'Outputfile {output_file_name} already exists. Aborting writeout.')
107
108 # filter rest of events only for specific particle list
109 ma.signalSideParticleListsFilter(
110 particleLists,
111 f'nROE_Charged({maskName}, 0) > 0 and abs(qrCombined) == 1',
112 roe_path,
113 dead_end_path)
114
115 fill_particle_lists(config, maskName, roe_path)
116
117 ma.rankByHighest('pi+:tflat', rank_variable, path=roe_path)
118 ma.rankByHighest('gamma:tflat', rank_variable, path=roe_path)
119
120 vm.addAlias('refdx', 'getVariableByRank(pi+:tflat, p, dx, 1)')
121 vm.addAlias('dxdiff', 'formula(dx-refdx)')
122 vm.addAlias('refdy', 'getVariableByRank(pi+:tflat, p, dy, 1)')
123 vm.addAlias('dydiff', 'formula(dy-refdy)')
124 vm.addAlias('refdz', 'getVariableByRank(pi+:tflat, p, dz, 1)')
125 vm.addAlias('dzdiff', 'formula(dz-refdz)')
126
127 # and add target
128 all_variables = features + [target]
129
130 # write to ntuples
131 ma.variablesToNtuple('', all_variables, tree_name, output_file_name, roe_path)
132
133 path.for_each('RestOfEvent', 'RestOfEvents', roe_path)
134
135 elif mode == 'Expert':
136
137 # filter rest of events only for specific particle list
138 ma.signalSideParticleListsFilter(
139 particleLists,
140 f'nROE_Charged({maskName}, 0) > 0',
141 roe_path,
142 dead_end_path)
143
144 path.add_module('FlavorTaggerInfoBuilder')
145
146 fill_particle_lists(config, maskName, roe_path)
147
148 ma.rankByHighest('pi+:tflat', rank_variable, path=roe_path)
149 ma.rankByHighest('gamma:tflat', rank_variable, path=roe_path)
150
151 vm.addAlias('refdx', 'getVariableByRank(pi+:tflat, p, dx, 1)')
152 vm.addAlias('dxdiff', 'formula(dx-refdx)')
153 vm.addAlias('refdy', 'getVariableByRank(pi+:tflat, p, dy, 1)')
154 vm.addAlias('dydiff', 'formula(dy-refdy)')
155 vm.addAlias('refdz', 'getVariableByRank(pi+:tflat, p, dz, 1)')
156 vm.addAlias('dzdiff', 'formula(dz-refdz)')
157
158 expert_module = basf2.register_module('MVAExpert')
159 expert_module.param('listNames', particleLists)
160 expert_module.param('identifier', uniqueIdentifier)
161 expert_module.param('extraInfoName', 'tflat_output')
162
163 roe_path.add_module(expert_module)
164
165 flavorTaggerInfoFiller = basf2.register_module('FlavorTaggerInfoFiller')
166 flavorTaggerInfoFiller.param('TFLATnn', True)
167 roe_path.add_module(flavorTaggerInfoFiller)
168
169 # Create standard alias for the output of the flavor tagger
170 vm.addAlias('qrTFLAT', 'qrOutput(TFLAT)')
171
172 path.for_each('RestOfEvent', 'RestOfEvents', roe_path)