14from basf2
import B2ERROR, B2FATAL
16from variables
import variables
as vm
17import modularAnalysis
as ma
20def get_variables(particle_list, ranked_variable, variables=None, particleNumber=1):
21 """ creates variable name pattern requested by the basf2 variable getVariableByRank()
23 :param ranked_variable:
25 :param particleNumber:
30 for i_num
in range(1, particleNumber + 1):
31 var_list.append(
'getVariableByRank(' + particle_list +
', ' + ranked_variable +
', ' + var +
', ' +
36def construct_default_variable_names(particle_lists=None, ranked_variable='p', variables=None, particleNumber=5):
37 """ construct default variables (that are sorted by charge and ranked by momentum)
38 :param particle_lists:
39 :param ranked_variable:
41 :param particleNumber:
44 from ROOT
import Belle2
45 if particle_lists
is None:
46 particle_lists = [
'pi+:pos_charged',
'pi+:neg_charged']
49 for p_list
in particle_lists:
50 variable_names += get_variables(p_list, ranked_variable, variables, particleNumber)
53 root_compatible_list = []
54 for var
in variable_names:
57 return root_compatible_list
60def DeepFlavorTagger(particle_lists, mode='expert', working_dir='', uniqueIdentifier='standard', variable_list=None,
61 target='qrCombined', overwrite=False,
62 transform_to_probability=False, signal_fraction=-1.0, classifier_args=None,
63 train_valid_fraction=.92, mva_steering_file='analysis/scripts/dft/tensorflow_dnn_interface.py',
67 Interfacing for the DeepFlavorTagger. This function can be used
for training (``teacher``), preparation of
68 training datasets (``sampler``)
and inference (``expert``).
70 This function requires reconstructed B meson signal particle list
and where an RestOfEvent
is built.
72 :param particle_lists: string
or list[string], particle list(s) of the reconstructed signal B meson
73 :param mode: string, valid modes are ``expert`` (default), ``teacher``, ``sampler``
74 :param working_dir: string, working directory
for the method
75 :param uniqueIdentifier: string, database identifier
for the method
76 :param variable_list: list[string], name of the basf2 variables used
for discrimination
77 :param target: string, target variable
78 :param overwrite: bool, overwrite already (locally!) existing training
79 :param transform_to_probability: bool, enable a purity transformation to compensate potential over-training,
80 can only be set during training
81 :param signal_fraction: float, (experimental) signal fraction override,
82 transform to output to a probability
if an uneven signal/background fraction
is used
in the training data,
83 can only be set during training
84 :param classifier_args: dictionary, customized arguments
for the mlp
85 possible attributes of the dictionary are:
86 lr_dec_rate: learning rate decay rate
87 lr_init: learning rate initial value
88 mom_init: momentum initial value
89 min_epochs: minimal number of epochs
90 max_epochs: maximal number of epochs
91 stop_epochs: epochs to stop without improvements on the validation set
for early stopping
92 batch_size: batch size
93 seed: random seed
for tensorflow
94 layers: [[layer name, activation function, input_width, output_width, init_bias, init_weights],..]
95 wd_coeffs: weight decay coefficients, length of layers
96 cuda_visible_devices: selection of cuda devices
97 tensorboard_dir: addition directory
for logging the training process
98 :param train_valid_fraction: float, train-valid fraction (.92). If transform to probability
is
99 enabled, train valid fraction will be split into a test set (.5)
100 :param maskName: get ROE particles
from a specified ROE mask
101 :param path: basf2 path obj
105 if isinstance(particle_lists, str):
106 particle_lists = [particle_lists]
108 if mode
not in [
'expert',
'teacher',
'sampler']:
109 B2FATAL(f
'Invalid mode {mode}')
111 if variable_list
is None and mode
in [
'sampler',
'teacher']:
114 'useCMSFrame(cosTheta)',
127 if variable_list
is not None and mode ==
'expert':
128 B2ERROR(
'DFT: Variables from identifier file are used. Input variables will be ignored.')
130 if classifier_args
is None:
133 assert isinstance(classifier_args, dict)
135 classifier_args[
'transform_to_prob'] = transform_to_probability
137 output_file_name = os.path.join(working_dir, uniqueIdentifier +
'_training_data.root')
140 roe_path = basf2.create_path()
141 dead_end_path = basf2.create_path()
144 extension = particle_lists[0].replace(
':',
'_to_')
145 roe_particle_list_cut =
''
146 roe_particle_list =
'pi+:dft' +
'_' + extension
148 tree_name =
'dft_variables'
151 ma.signalSideParticleListsFilter(particle_lists,
'hasRestOfEventTracks > 0', roe_path, dead_end_path)
156 ma.fillParticleList(roe_particle_list, roe_particle_list_cut, path=roe_path)
158 dft_particle_lists = [
'pi+:pos_charged',
'pi+:neg_charged']
160 pos_cut =
'charge > 0 and isInRestOfEvent == 1 and passesROEMask(' + maskName +
') > 0.5 and p < infinity'
161 neg_cut =
'charge < 0 and isInRestOfEvent == 1 and passesROEMask(' + maskName +
') > 0.5 and p < infinity'
163 ma.cutAndCopyList(dft_particle_lists[0], roe_particle_list, pos_cut, writeOut=
True, path=roe_path)
164 ma.cutAndCopyList(dft_particle_lists[1], roe_particle_list, neg_cut, writeOut=
True, path=roe_path)
172 features = get_variables(dft_particle_lists[0], rank_variable, variable_list, particleNumber=5)
173 features += get_variables(dft_particle_lists[1], rank_variable, variable_list, particleNumber=5)
175 for particles
in dft_particle_lists:
176 ma.rankByHighest(particles, rank_variable, path=roe_path)
178 if mode ==
'sampler':
179 if os.path.isfile(output_file_name)
and not overwrite:
180 B2FATAL(f
'Outputfile {output_file_name} already exists. Aborting writeout.')
183 all_variables = features + [target]
186 ma.variablesToNtuple(
'', all_variables, tree_name, output_file_name, roe_path)
189 extern_command = f
'basf2_mva_teacher --datafile {output_file_name} --treename {tree_name}' + \
190 f
' --identifier {uniqueIdentifier} ' + \
191 '--variables "{}" '.format(
'" "'.join(features)) + \
192 f
'--target_variable {target}' + \
193 f
' --method Python --training_fraction {train_valid_fraction}' + \
194 f
" --config '{json.dumps(classifier_args)}' --framework tensorflow" + \
195 f
' --steering_file {mva_steering_file}'
197 with open(os.path.join(working_dir, uniqueIdentifier +
'_teacher_command'),
'w')
as f:
198 f.write(extern_command)
200 elif mode ==
'teacher':
201 if not os.path.isfile(output_file_name):
202 B2FATAL(
'There is no training data file available. Run flavor tagger in sampler mode first.')
203 general_options = basf2_mva.GeneralOptions()
204 general_options.m_datafiles = basf2_mva.vector(output_file_name)
206 general_options.m_treename = tree_name
207 general_options.m_target_variable = target
208 general_options.m_variables = basf2_mva.vector(*features)
210 general_options.m_identifier = uniqueIdentifier
212 specific_options = basf2_mva.PythonOptions()
213 specific_options.m_framework =
'tensorflow'
214 specific_options.m_steering_file = mva_steering_file
215 specific_options.m_training_fraction = train_valid_fraction
217 specific_options.m_config = json.dumps(classifier_args)
219 basf2_mva.teacher(general_options, specific_options)
221 elif mode ==
'expert':
223 flavorTaggerInfoBuilder = basf2.register_module(
'FlavorTaggerInfoBuilder')
224 path.add_module(flavorTaggerInfoBuilder)
226 expert_module = basf2.register_module(
'MVAExpert')
227 expert_module.param(
'listNames', particle_lists)
228 expert_module.param(
'identifier', uniqueIdentifier)
230 expert_module.param(
'extraInfoName',
'dnn_output')
231 expert_module.param(
'signalFraction', signal_fraction)
233 roe_path.add_module(expert_module)
235 flavorTaggerInfoFiller = basf2.register_module(
'FlavorTaggerInfoFiller')
236 flavorTaggerInfoFiller.param(
'DNNmlp',
True)
237 roe_path.add_module(flavorTaggerInfoFiller)
240 vm.addAlias(
'DNN_qrCombined',
'qrOutput(DNN)')
242 path.for_each(
'RestOfEvent',
'RestOfEvents', roe_path)
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.