Belle II Software development
flavorTagger.py
1#!/usr/bin/env python3
2
3
10
11import os
12from basf2 import B2FATAL
13import basf2
14from variables import variables as vm
15import modularAnalysis as ma
16from stdPhotons import stdPhotons
17from vertex import kFit
18import tflat.utils as utils
19import b2bii
20
21
22def fill_particle_lists(config, maskName='TFLATDefaultMask', path=None):
23 """
24 Fills the particle lists.
25 """
26
27 # create particle list with pions
28 trk_cut = config['trk_cut']
29 ma.fillParticleList('pi+:tflat', trk_cut, path=path)
30
31 if b2bii.isB2BII():
32 # build a list of K_S0 candidates in the rest of event
33 ma.cutAndCopyList(
34 'K_S0:inRoe',
35 'K_S0:mdst',
36 'daughter(0, isInRestOfEvent) > 0.5 and daughter(1, isInRestOfEvent) > 0.5',
37 path=path)
38
39 # load MVA's for gammas
40 ma.getBeamBackgroundProbability(particleList=['gamma:mdst'], weight=config['VersionBeamBackgroundMVA'], path=path)
41 ma.getFakePhotonProbability(particleList=['gamma:mdst'], weight=config['VersionFakePhotonMVA'], path=path)
42
43 # create particle list with gammas
44 gamma_cut = config['gamma_cut']
45 ma.cutAndCopyList('gamma:tflat', 'gamma:mdst', gamma_cut, path=path)
46 else:
47 # create particle list of K_S0
48 ma.reconstructDecay('K_S0:inRoe -> pi+:tflat pi-:tflat', '0.40<=M<=0.60', False, path=path)
49 kFit('K_S0:inRoe', 0.01, path=path)
50
51 # load MVA's for gammas
52 ma.fillParticleList("gamma:all", "", path=path)
53 ma.getBeamBackgroundProbability("gamma:all", config['VersionBeamBackgroundMVA'], path=path)
54 ma.getFakePhotonProbability("gamma:all", config['VersionFakePhotonMVA'], path=path)
55
56 # create particle list with gammas
57 stdPhotons(listtype='tight', path=path)
58 gamma_cut = config['gamma_cut']
59 ma.cutAndCopyList('gamma:tflat', 'gamma:tight', gamma_cut, path=path)
60
61
62def flavorTagger(particleLists, mode='Expert', working_dir='', uniqueIdentifier='TFlaT_MC16rd_light_2601_hyperion',
63 target='qrCombined', overwrite=False,
64 sampler_id=0,
65 path=None):
66 """
67 Interfacing for the Transformer FlavorTagger (TFlat). This function can be used for preparation of
68 training datasets (``Sampler``) and inference (``Expert``).
69
70 This function requires reconstructed B meson signal particle list and where an RestOfEvent is built.
71
72 :param particleLists: string or list[string], particle list(s) of the reconstructed signal B meson
73 :param mode: string, valid modes are ``Expert`` (default), ``Sampler``
74 :param working_dir: string, working directory for the method
75 :param uniqueIdentifier: string, database identifier for the method
76 :param target: string, target variable
77 :param overwrite: bool, overwrite already (locally!) existing training
78 :param sampler_id: identifier of sampled file for parallel sampling
79 :param path: basf2 path obj
80 :return: None
81 """
82
83 if isinstance(particleLists, str):
84 particleLists = [particleLists]
85
86 if mode not in ['Expert', 'Sampler']:
87 B2FATAL(f'Invalid mode {mode}')
88
89 tree_name = 'tflat_variables'
90 rank_variable = 'p'
91
92 config = utils.load_config(uniqueIdentifier)
93
94 # create default ROE-mask
95 TFLAT_mask = config['TFLAT_Mask']
96 maskName = TFLAT_mask[0]
97 for name in particleLists:
98 ma.appendROEMasks(list_name=name, mask_tuples=[TFLAT_mask], path=path)
99
100 # create roe specific paths
101 roe_path = basf2.create_path()
102 dead_end_path = basf2.create_path()
103
104 if mode == 'Sampler':
105 trk_variable_list = config['trk_variable_list']
106 ecl_variable_list = config['ecl_variable_list']
107 roe_variable_list = config['roe_variable_list']
108 # create tagging specific variables
109 features = utils.get_variables('pi+:tflat', rank_variable, trk_variable_list,
110 particleNumber=config['parameters']['num_trk'])
111 features += utils.get_variables('gamma:tflat', rank_variable, ecl_variable_list,
112 particleNumber=config['parameters']['num_ecl'])
113 features += utils.get_variables('pi+:tflat', rank_variable, roe_variable_list,
114 particleNumber=config['parameters']['num_roe'])
115
116 output_file_name = os.path.join(working_dir, uniqueIdentifier + f'_training_data{sampler_id}.root')
117 if os.path.isfile(output_file_name) and not overwrite:
118 B2FATAL(f'Outputfile {output_file_name} already exists. Aborting writeout.')
119
120 # filter rest of events only for specific particle list
121 ma.signalSideParticleListsFilter(
122 particleLists,
123 f'nROE_Charged({maskName}, 0) > 0 and abs(qrCombined) == 1',
124 roe_path,
125 dead_end_path)
126
127 fill_particle_lists(config, maskName, roe_path)
128
129 ma.rankByHighest('pi+:tflat', rank_variable, path=roe_path)
130 ma.rankByHighest('gamma:tflat', rank_variable, path=roe_path)
131
132 vm.addAlias('refdx', 'getVariableByRank(pi+:tflat, p, dx, 1)')
133 vm.addAlias('dxdiff', 'formula(dx-refdx)')
134 vm.addAlias('refdy', 'getVariableByRank(pi+:tflat, p, dy, 1)')
135 vm.addAlias('dydiff', 'formula(dy-refdy)')
136 vm.addAlias('refdz', 'getVariableByRank(pi+:tflat, p, dz, 1)')
137 vm.addAlias('dzdiff', 'formula(dz-refdz)')
138
139 # and add target
140 all_variables = features + [target]
141
142 # write to ntuples
143 ma.variablesToNtuple('', all_variables, tree_name, output_file_name, roe_path)
144
145 path.for_each('RestOfEvent', 'RestOfEvents', roe_path)
146
147 elif mode == 'Expert':
148
149 # filter rest of events only for specific particle list
150 ma.signalSideParticleListsFilter(
151 particleLists,
152 f'nROE_Charged({maskName}, 0) > 0',
153 roe_path,
154 dead_end_path)
155
156 if 'FlavorTaggerInfoBuilder' not in path:
157 path.add_module('FlavorTaggerInfoBuilder')
158
159 fill_particle_lists(config, maskName, roe_path)
160
161 ma.rankByHighest('pi+:tflat', rank_variable, path=roe_path)
162 ma.rankByHighest('gamma:tflat', rank_variable, path=roe_path)
163
164 vm.addAlias('refdx', 'getVariableByRank(pi+:tflat, p, dx, 1)')
165 vm.addAlias('dxdiff', 'formula(dx-refdx)')
166 vm.addAlias('refdy', 'getVariableByRank(pi+:tflat, p, dy, 1)')
167 vm.addAlias('dydiff', 'formula(dy-refdy)')
168 vm.addAlias('refdz', 'getVariableByRank(pi+:tflat, p, dz, 1)')
169 vm.addAlias('dzdiff', 'formula(dz-refdz)')
170
171 expert_module = basf2.register_module('MVAExpert')
172 expert_module.param('listNames', particleLists)
173 expert_module.param('identifier', uniqueIdentifier)
174 expert_module.param('extraInfoName', 'tflat_output')
175
176 roe_path.add_module(expert_module)
177
178 flavorTaggerInfoFiller = basf2.register_module('FlavorTaggerInfoFiller')
179 flavorTaggerInfoFiller.param('TFLATnn', True)
180 roe_path.add_module(flavorTaggerInfoFiller)
181
182 # Create standard alias for the output of the flavor tagger
183 vm.addAlias('qrTFLAT', 'qrOutput(TFLAT)')
184
185 path.for_each('RestOfEvent', 'RestOfEvents', roe_path)
isB2BII()
Definition b2bii.py:14