12from basf2
import B2FATAL
14from variables
import variables
as vm
15import modularAnalysis
as ma
16from stdPhotons
import stdPhotons
17from vertex
import kFit
21def fill_particle_lists(config, maskName='TFLATDefaultMask', path=None):
23 Fills the particle lists.
27 trk_cut = config[
'trk_cut']
28 ma.fillParticleList(
'pi+:tflat', trk_cut, path=path)
38 ma.getBeamBackgroundProbability(
"gamma:all", config[
'VersionBeamBackgroundMVA'], path=path)
39 ma.getFakePhotonProbability(
"gamma:all", config[
'VersionFakePhotonMVA'], path=path)
43 gamma_cut = config[
'gamma_cut']
44 ma.cutAndCopyList(
'gamma:tflat',
'gamma:tight', gamma_cut, path=path)
46 ma.reconstructDecay(
'K_S0:inRoe -> pi+:tflat pi-:tflat',
'0.40<=M<=0.60',
False, path=path)
47 kFit(
'K_S0:inRoe', 0.01, path=path)
50def flavorTagger(particleLists, mode='Expert', working_dir='', uniqueIdentifier='TFlaT_MC16rd_light_2601_hyperion',
51 target='qrCombined', overwrite=False,
55 Interfacing for the Transformer FlavorTagger (TFlat). This function can be used for preparation of
56 training datasets (``Sampler``) and inference (``Expert``).
58 This function requires reconstructed B meson signal particle list and where an RestOfEvent is built.
60 :param particleLists: string or list[string], particle list(s) of the reconstructed signal B meson
61 :param mode: string, valid modes are ``Expert`` (default), ``Sampler``
62 :param working_dir: string, working directory for the method
63 :param uniqueIdentifier: string, database identifier for the method
64 :param target: string, target variable
65 :param overwrite: bool, overwrite already (locally!) existing training
66 :param sampler_id: identifier of sampled file for parallel sampling
67 :param path: basf2 path obj
71 if isinstance(particleLists, str):
72 particleLists = [particleLists]
74 if mode
not in [
'Expert',
'Sampler']:
75 B2FATAL(f
'Invalid mode {mode}')
77 tree_name =
'tflat_variables'
80 config = utils.load_config(uniqueIdentifier)
83 TFLAT_mask = config[
'TFLAT_Mask']
84 maskName = TFLAT_mask[0]
85 for name
in particleLists:
86 ma.appendROEMasks(list_name=name, mask_tuples=[TFLAT_mask], path=path)
89 roe_path = basf2.create_path()
90 dead_end_path = basf2.create_path()
93 trk_variable_list = config[
'trk_variable_list']
94 ecl_variable_list = config[
'ecl_variable_list']
95 roe_variable_list = config[
'roe_variable_list']
97 features = utils.get_variables(
'pi+:tflat', rank_variable, trk_variable_list,
98 particleNumber=config[
'parameters'][
'num_trk'])
99 features += utils.get_variables(
'gamma:tflat', rank_variable, ecl_variable_list,
100 particleNumber=config[
'parameters'][
'num_ecl'])
101 features += utils.get_variables(
'pi+:tflat', rank_variable, roe_variable_list,
102 particleNumber=config[
'parameters'][
'num_roe'])
104 output_file_name = os.path.join(working_dir, uniqueIdentifier + f
'_training_data{sampler_id}.root')
105 if os.path.isfile(output_file_name)
and not overwrite:
106 B2FATAL(f
'Outputfile {output_file_name} already exists. Aborting writeout.')
109 ma.signalSideParticleListsFilter(
111 f
'nROE_Charged({maskName}, 0) > 0 and abs(qrCombined) == 1',
115 fill_particle_lists(config, maskName, roe_path)
117 ma.rankByHighest(
'pi+:tflat', rank_variable, path=roe_path)
118 ma.rankByHighest(
'gamma:tflat', rank_variable, path=roe_path)
120 vm.addAlias(
'refdx',
'getVariableByRank(pi+:tflat, p, dx, 1)')
121 vm.addAlias(
'dxdiff',
'formula(dx-refdx)')
122 vm.addAlias(
'refdy',
'getVariableByRank(pi+:tflat, p, dy, 1)')
123 vm.addAlias(
'dydiff',
'formula(dy-refdy)')
124 vm.addAlias(
'refdz',
'getVariableByRank(pi+:tflat, p, dz, 1)')
125 vm.addAlias(
'dzdiff',
'formula(dz-refdz)')
128 all_variables = features + [target]
131 ma.variablesToNtuple(
'', all_variables, tree_name, output_file_name, roe_path)
133 path.for_each(
'RestOfEvent',
'RestOfEvents', roe_path)
135 elif mode ==
'Expert':
138 ma.signalSideParticleListsFilter(
140 f
'nROE_Charged({maskName}, 0) > 0',
144 path.add_module(
'FlavorTaggerInfoBuilder')
146 fill_particle_lists(config, maskName, roe_path)
148 ma.rankByHighest(
'pi+:tflat', rank_variable, path=roe_path)
149 ma.rankByHighest(
'gamma:tflat', rank_variable, path=roe_path)
151 vm.addAlias(
'refdx',
'getVariableByRank(pi+:tflat, p, dx, 1)')
152 vm.addAlias(
'dxdiff',
'formula(dx-refdx)')
153 vm.addAlias(
'refdy',
'getVariableByRank(pi+:tflat, p, dy, 1)')
154 vm.addAlias(
'dydiff',
'formula(dy-refdy)')
155 vm.addAlias(
'refdz',
'getVariableByRank(pi+:tflat, p, dz, 1)')
156 vm.addAlias(
'dzdiff',
'formula(dz-refdz)')
158 expert_module = basf2.register_module(
'MVAExpert')
159 expert_module.param(
'listNames', particleLists)
160 expert_module.param(
'identifier', uniqueIdentifier)
161 expert_module.param(
'extraInfoName',
'tflat_output')
163 roe_path.add_module(expert_module)
165 flavorTaggerInfoFiller = basf2.register_module(
'FlavorTaggerInfoFiller')
166 flavorTaggerInfoFiller.param(
'TFLATnn',
True)
167 roe_path.add_module(flavorTaggerInfoFiller)
170 vm.addAlias(
'qrTFLAT',
'qrOutput(TFLAT)')
172 path.for_each(
'RestOfEvent',
'RestOfEvents', roe_path)