14 from ROOT
import gSystem
15 gSystem.Load(
'libanalysis.so')
16 import modularAnalysis
as ma
18 from ROOT
import Belle2
20 from basf2
import B2ERROR, B2FATAL
31 def get_variables(particle_list, ranked_variable, variables=None, particleNumber=1):
32 """ creates variable name pattern requested by the basf2 variable getVariableByRank()
34 :param ranked_variable:
36 :param particleNumber:
41 for i_num
in range(1, particleNumber + 1):
42 var_list.append(
'getVariableByRank(' + particle_list +
', ' + ranked_variable +
', ' + var +
', ' +
47 def construct_default_variable_names(particle_lists=None, ranked_variable='p', variables=None, particleNumber=5):
48 """ construct default variables (that are sorted by charge and ranked by momentum)
49 :param particle_lists:
50 :param ranked_variable:
52 :param particleNumber:
55 if particle_lists
is None:
56 particle_lists = [
'pi+:pos_charged',
'pi+:neg_charged']
59 for p_list
in particle_lists:
60 variable_names += get_variables(p_list, ranked_variable, variables, particleNumber)
63 root_compatible_list = []
64 for var
in variable_names:
67 return root_compatible_list
70 def DeepFlavorTagger(particle_lists, mode='expert', working_dir='', uniqueIdentifier='standard', variable_list=None,
71 target='qrCombined', overwrite=False,
72 transform_to_probability=False, signal_fraction=-1.0, classifier_args=None,
73 train_valid_fraction=.92, mva_steering_file='analysis/scripts/dft/tensorflow_dnn_interface.py',
77 Interfacing for the DeepFlavorTagger. This function can be used for training (``teacher``), preparation of
78 training datasets (``sampler``) and inference (``expert``).
80 This function requires reconstructed B meson signal particle list and where an RestOfEvent is built.
82 :param particle_lists: string or list[string], particle list(s) of the reconstructed signal B meson
83 :param mode: string, valid modes are ``expert`` (default), ``teacher``, ``sampler``
84 :param working_dir: string, working directory for the method
85 :param uniqueIdentifier: string, database identifier for the method
86 :param variable_list: list[string], name of the basf2 variables used for discrimination
87 :param target: string, target variable
88 :param overwrite: bool, overwrite already (locally!) existing training
89 :param transform_to_probability: bool, enable a purity transformation to compensate potential over-training,
90 can only be set during training
91 :param signal_fraction: float, (experimental) signal fraction override,
92 transform to output to a probability if an uneven signal/background fraction is used in the training data,
93 can only be set during training
94 :param classifier_args: dictionary, costumized arguments for the mlp
95 possible attributes of the dictionary are:
96 lr_dec_rate: learning rate decay rate
97 lr_init: learning rate initial value
98 mom_init: momentum initial value
99 min_epochs: minimal number of epochs
100 max_epochs: maximal number of epochs
101 stop_epochs: epochs to stop without improvements on the validation set for early stopping
102 batch_size: batch size
103 seed: random seed for tensorflow
104 layers: [[layer name, activation function, input_width, output_width, init_bias, init_weights],..]
105 wd_coeffs: weight decay coefficients, length of layers
106 cuda_visible_devices: selection of cuda devices
107 tensorboard_dir: addition directory for logging the training process
108 :param train_valid_fraction: float, train-valid fraction (.92). If transform to probability is
109 enabled, train valid fraction will be splitted to a test set (.5)
110 :param maskName: get ROE particles from a specified ROE mask
111 :param path: basf2 path obj
115 if isinstance(particle_lists, str):
116 particle_lists = [particle_lists]
118 if mode
not in [
'expert',
'teacher',
'sampler']:
119 B2FATAL(
'Invalid mode %s' % mode)
121 if variable_list
is None and mode
in [
'sampler',
'teacher']:
122 variable_list = [
'useCMSFrame(p)',
'useCMSFrame(cosTheta)',
'useCMSFrame(phi)',
'Kid',
'eid',
'muid',
'prid',
123 'nCDCHits',
'nPXDHits',
'nSVDHits',
'dz',
'dr',
'chiProb']
125 if variable_list
is not None and mode
is 'expert':
126 B2ERROR(
'DFT: Variables from identifier file are used. Input variables will be ignored.')
128 if classifier_args
is None:
131 assert isinstance(classifier_args, dict)
133 classifier_args[
'transform_to_prob'] = transform_to_probability
135 output_file_name = os.path.join(working_dir, uniqueIdentifier +
'_training_data.root')
138 roe_path = basf2.create_path()
139 dead_end_path = basf2.create_path()
142 extension = particle_lists[0].replace(
':',
'_to_')
143 roe_particle_list_cut =
''
144 roe_particle_list =
'pi+:dft' +
'_' + extension
146 tree_name =
'dft_variables'
149 ma.signalSideParticleListsFilter(particle_lists,
'hasRestOfEventTracks > 0', roe_path, dead_end_path)
154 ma.fillParticleList(roe_particle_list, roe_particle_list_cut, path=roe_path)
156 dft_particle_lists = [
'pi+:pos_charged',
'pi+:neg_charged']
158 pos_cut =
'charge > 0 and isInRestOfEvent == 1 and passesROEMask(' + maskName +
') > 0.5 and p < infinity'
159 neg_cut =
'charge < 0 and isInRestOfEvent == 1 and passesROEMask(' + maskName +
') > 0.5 and p < infinity'
161 ma.cutAndCopyList(dft_particle_lists[0], roe_particle_list, pos_cut, writeOut=
True, path=roe_path)
162 ma.cutAndCopyList(dft_particle_lists[1], roe_particle_list, neg_cut, writeOut=
True, path=roe_path)
169 if mode
is not 'expert':
170 features = get_variables(dft_particle_lists[0], rank_variable, variable_list, particleNumber=5)
171 features += get_variables(dft_particle_lists[1], rank_variable, variable_list, particleNumber=5)
173 for particles
in dft_particle_lists:
174 ma.rankByHighest(particles, rank_variable, path=roe_path)
176 if mode
is 'sampler':
177 if os.path.isfile(output_file_name)
and not overwrite:
178 B2FATAL(
'Outputfile %s already exists. Aborting writeout.' % output_file_name)
181 all_variables = features + [target]
184 ma.variablesToNtuple(
'', all_variables, tree_name, output_file_name, roe_path)
187 extern_command =
'basf2_mva_teacher --datafile {output_file_name} --treename {tree_name}' \
188 ' --identifier {identifier} --variables "{variables_string}" --target_variable {target}' \
189 ' --method Python --training_fraction {fraction}' \
190 " --config '{classifier_args}' --framework tensorflow" \
191 ' --steering_file {steering_file}'\
192 ''.format(output_file_name=output_file_name, tree_name=tree_name,
193 identifier=uniqueIdentifier,
194 variables_string=
'" "'.join(features), target=target,
195 classifier_args=json.dumps(classifier_args), fraction=train_valid_fraction,
196 steering_file=mva_steering_file)
198 with open(os.path.join(working_dir, uniqueIdentifier +
'_teacher_command'),
'w')
as f:
199 f.write(extern_command)
201 elif mode
is 'teacher':
202 if not os.path.isfile(output_file_name):
203 B2FATAL(
'There is no training data file available. Run flavor tagger in sampler mode first.')
204 general_options = basf2_mva.GeneralOptions()
205 general_options.m_datafiles = basf2_mva.vector(output_file_name)
207 general_options.m_treename = tree_name
208 general_options.m_target_variable = target
209 general_options.m_variables = basf2_mva.vector(*features)
211 general_options.m_identifier = uniqueIdentifier
213 specific_options = basf2_mva.PythonOptions()
214 specific_options.m_framework =
'tensorflow'
215 specific_options.m_steering_file = mva_steering_file
216 specific_options.m_training_fraction = train_valid_fraction
218 specific_options.m_config = json.dumps(classifier_args)
220 basf2_mva.teacher(general_options, specific_options)
222 elif mode
is 'expert':
232 expert_module = basf2.register_module(
'MVAExpert')
233 expert_module.param(
'listNames', particle_lists)
234 expert_module.param(
'identifier', uniqueIdentifier)
236 expert_module.param(
'extraInfoName',
'dnn_output')
237 expert_module.param(
'signalFraction', signal_fraction)
239 roe_path.add_module(expert_module)
242 vu._variablemanager.addAlias(
'DNN_qrCombined',
'formula(2*extraInfo(dnn_output) - 1)')
244 path.for_each(
'RestOfEvent',
'RestOfEvents', roe_path)