12from basf2
import B2FATAL
14from variables
import variables
as vm
15import modularAnalysis
as ma
16from stdPhotons
import stdPhotons
17from vertex
import kFit
22def fill_particle_lists(config, maskName='TFLATDefaultMask', path=None):
24 Fills the particle lists.
28 trk_cut = config[
'trk_cut']
29 ma.fillParticleList(
'pi+:tflat', trk_cut, path=path)
36 'daughter(0, isInRestOfEvent) > 0.5 and daughter(1, isInRestOfEvent) > 0.5',
40 ma.getBeamBackgroundProbability(particleList=[
'gamma:mdst'], weight=config[
'VersionBeamBackgroundMVA'], path=path)
41 ma.getFakePhotonProbability(particleList=[
'gamma:mdst'], weight=config[
'VersionFakePhotonMVA'], path=path)
44 gamma_cut = config[
'gamma_cut']
45 ma.cutAndCopyList(
'gamma:tflat',
'gamma:mdst', gamma_cut, path=path)
48 ma.reconstructDecay(
'K_S0:inRoe -> pi+:tflat pi-:tflat',
'0.40<=M<=0.60',
False, path=path)
49 kFit(
'K_S0:inRoe', 0.01, path=path)
52 ma.fillParticleList(
"gamma:all",
"", path=path)
53 ma.getBeamBackgroundProbability(
"gamma:all", config[
'VersionBeamBackgroundMVA'], path=path)
54 ma.getFakePhotonProbability(
"gamma:all", config[
'VersionFakePhotonMVA'], path=path)
58 gamma_cut = config[
'gamma_cut']
59 ma.cutAndCopyList(
'gamma:tflat',
'gamma:tight', gamma_cut, path=path)
62def flavorTagger(particleLists, mode='Expert', working_dir='', uniqueIdentifier='TFlaT_MC16rd_light_2601_hyperion',
63 target='qrCombined', overwrite=False,
67 Interfacing for the Transformer FlavorTagger (TFlat). This function can be used for preparation of
68 training datasets (``Sampler``) and inference (``Expert``).
70 This function requires reconstructed B meson signal particle list and where an RestOfEvent is built.
72 :param particleLists: string or list[string], particle list(s) of the reconstructed signal B meson
73 :param mode: string, valid modes are ``Expert`` (default), ``Sampler``
74 :param working_dir: string, working directory for the method
75 :param uniqueIdentifier: string, database identifier for the method
76 :param target: string, target variable
77 :param overwrite: bool, overwrite already (locally!) existing training
78 :param sampler_id: identifier of sampled file for parallel sampling
79 :param path: basf2 path obj
83 if isinstance(particleLists, str):
84 particleLists = [particleLists]
86 if mode
not in [
'Expert',
'Sampler']:
87 B2FATAL(f
'Invalid mode {mode}')
89 tree_name =
'tflat_variables'
92 config = utils.load_config(uniqueIdentifier)
95 TFLAT_mask = config[
'TFLAT_Mask']
96 maskName = TFLAT_mask[0]
97 for name
in particleLists:
98 ma.appendROEMasks(list_name=name, mask_tuples=[TFLAT_mask], path=path)
101 roe_path = basf2.create_path()
102 dead_end_path = basf2.create_path()
104 if mode ==
'Sampler':
105 trk_variable_list = config[
'trk_variable_list']
106 ecl_variable_list = config[
'ecl_variable_list']
107 roe_variable_list = config[
'roe_variable_list']
109 features = utils.get_variables(
'pi+:tflat', rank_variable, trk_variable_list,
110 particleNumber=config[
'parameters'][
'num_trk'])
111 features += utils.get_variables(
'gamma:tflat', rank_variable, ecl_variable_list,
112 particleNumber=config[
'parameters'][
'num_ecl'])
113 features += utils.get_variables(
'pi+:tflat', rank_variable, roe_variable_list,
114 particleNumber=config[
'parameters'][
'num_roe'])
116 output_file_name = os.path.join(working_dir, uniqueIdentifier + f
'_training_data{sampler_id}.root')
117 if os.path.isfile(output_file_name)
and not overwrite:
118 B2FATAL(f
'Outputfile {output_file_name} already exists. Aborting writeout.')
121 ma.signalSideParticleListsFilter(
123 f
'nROE_Charged({maskName}, 0) > 0 and abs(qrCombined) == 1',
127 fill_particle_lists(config, maskName, roe_path)
129 ma.rankByHighest(
'pi+:tflat', rank_variable, path=roe_path)
130 ma.rankByHighest(
'gamma:tflat', rank_variable, path=roe_path)
132 vm.addAlias(
'refdx',
'getVariableByRank(pi+:tflat, p, dx, 1)')
133 vm.addAlias(
'dxdiff',
'formula(dx-refdx)')
134 vm.addAlias(
'refdy',
'getVariableByRank(pi+:tflat, p, dy, 1)')
135 vm.addAlias(
'dydiff',
'formula(dy-refdy)')
136 vm.addAlias(
'refdz',
'getVariableByRank(pi+:tflat, p, dz, 1)')
137 vm.addAlias(
'dzdiff',
'formula(dz-refdz)')
140 all_variables = features + [target]
143 ma.variablesToNtuple(
'', all_variables, tree_name, output_file_name, roe_path)
145 path.for_each(
'RestOfEvent',
'RestOfEvents', roe_path)
147 elif mode ==
'Expert':
150 ma.signalSideParticleListsFilter(
152 f
'nROE_Charged({maskName}, 0) > 0',
156 if 'FlavorTaggerInfoBuilder' not in path:
157 path.add_module(
'FlavorTaggerInfoBuilder')
159 fill_particle_lists(config, maskName, roe_path)
161 ma.rankByHighest(
'pi+:tflat', rank_variable, path=roe_path)
162 ma.rankByHighest(
'gamma:tflat', rank_variable, path=roe_path)
164 vm.addAlias(
'refdx',
'getVariableByRank(pi+:tflat, p, dx, 1)')
165 vm.addAlias(
'dxdiff',
'formula(dx-refdx)')
166 vm.addAlias(
'refdy',
'getVariableByRank(pi+:tflat, p, dy, 1)')
167 vm.addAlias(
'dydiff',
'formula(dy-refdy)')
168 vm.addAlias(
'refdz',
'getVariableByRank(pi+:tflat, p, dz, 1)')
169 vm.addAlias(
'dzdiff',
'formula(dz-refdz)')
171 expert_module = basf2.register_module(
'MVAExpert')
172 expert_module.param(
'listNames', particleLists)
173 expert_module.param(
'identifier', uniqueIdentifier)
174 expert_module.param(
'extraInfoName',
'tflat_output')
176 roe_path.add_module(expert_module)
178 flavorTaggerInfoFiller = basf2.register_module(
'FlavorTaggerInfoFiller')
179 flavorTaggerInfoFiller.param(
'TFLATnn',
True)
180 roe_path.add_module(flavorTaggerInfoFiller)
183 vm.addAlias(
'qrTFLAT',
'qrOutput(TFLAT)')
185 path.for_each(
'RestOfEvent',
'RestOfEvents', roe_path)