Belle II Software  release-05-01-25
DeepFlavorTagger.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
13 
14 from ROOT import gSystem
15 gSystem.Load('libanalysis.so')
16 import modularAnalysis as ma
17 import variables.utils as vu
18 from ROOT import Belle2
19 import basf2
20 from basf2 import B2ERROR, B2FATAL
21 import basf2_mva
22 
23 # make ROOT compatible available
26 
27 import os
28 import json
29 
30 
31 def get_variables(particle_list, ranked_variable, variables=None, particleNumber=1):
32  """ creates variable name pattern requested by the basf2 variable getVariableByRank()
33  :param particle_list:
34  :param ranked_variable:
35  :param variables:
36  :param particleNumber:
37  :return:
38  """
39  var_list = []
40  for var in variables:
41  for i_num in range(1, particleNumber + 1):
42  var_list.append('getVariableByRank(' + particle_list + ', ' + ranked_variable + ', ' + var + ', ' +
43  str(i_num) + ')')
44  return var_list
45 
46 
47 def construct_default_variable_names(particle_lists=None, ranked_variable='p', variables=None, particleNumber=5):
48  """ construct default variables (that are sorted by charge and ranked by momentum)
49  :param particle_lists:
50  :param ranked_variable:
51  :param variables:
52  :param particleNumber:
53  :return:
54  """
55  if particle_lists is None:
56  particle_lists = ['pi+:pos_charged', 'pi+:neg_charged']
57 
58  variable_names = []
59  for p_list in particle_lists:
60  variable_names += get_variables(p_list, ranked_variable, variables, particleNumber)
61 
62  # make root compatible
63  root_compatible_list = []
64  for var in variable_names:
65  root_compatible_list.append(Belle2.makeROOTCompatible(var))
66 
67  return root_compatible_list
68 
69 
70 def DeepFlavorTagger(particle_lists, mode='expert', working_dir='', uniqueIdentifier='standard', variable_list=None,
71  target='qrCombined', overwrite=False,
72  transform_to_probability=False, signal_fraction=-1.0, classifier_args=None,
73  train_valid_fraction=.92, mva_steering_file='analysis/scripts/dft/tensorflow_dnn_interface.py',
74  maskName='',
75  path=None):
76  """
77  Interfacing for the DeepFlavorTagger. This function can be used for training (``teacher``), preparation of
78  training datasets (``sampler``) and inference (``expert``).
79 
80  This function requires reconstructed B meson signal particle list and where an RestOfEvent is built.
81 
82  :param particle_lists: string or list[string], particle list(s) of the reconstructed signal B meson
83  :param mode: string, valid modes are ``expert`` (default), ``teacher``, ``sampler``
84  :param working_dir: string, working directory for the method
85  :param uniqueIdentifier: string, database identifier for the method
86  :param variable_list: list[string], name of the basf2 variables used for discrimination
87  :param target: string, target variable
88  :param overwrite: bool, overwrite already (locally!) existing training
89  :param transform_to_probability: bool, enable a purity transformation to compensate potential over-training,
90  can only be set during training
91  :param signal_fraction: float, (experimental) signal fraction override,
92  transform to output to a probability if an uneven signal/background fraction is used in the training data,
93  can only be set during training
94  :param classifier_args: dictionary, costumized arguments for the mlp
95  possible attributes of the dictionary are:
96  lr_dec_rate: learning rate decay rate
97  lr_init: learning rate initial value
98  mom_init: momentum initial value
99  min_epochs: minimal number of epochs
100  max_epochs: maximal number of epochs
101  stop_epochs: epochs to stop without improvements on the validation set for early stopping
102  batch_size: batch size
103  seed: random seed for tensorflow
104  layers: [[layer name, activation function, input_width, output_width, init_bias, init_weights],..]
105  wd_coeffs: weight decay coefficients, length of layers
106  cuda_visible_devices: selection of cuda devices
107  tensorboard_dir: addition directory for logging the training process
108  :param train_valid_fraction: float, train-valid fraction (.92). If transform to probability is
109  enabled, train valid fraction will be splitted to a test set (.5)
110  :param maskName: get ROE particles from a specified ROE mask
111  :param path: basf2 path obj
112  :return: None
113  """
114 
115  if isinstance(particle_lists, str):
116  particle_lists = [particle_lists]
117 
118  if mode not in ['expert', 'teacher', 'sampler']:
119  B2FATAL('Invalid mode %s' % mode)
120 
121  if variable_list is None and mode in ['sampler', 'teacher']:
122  variable_list = ['useCMSFrame(p)', 'useCMSFrame(cosTheta)', 'useCMSFrame(phi)', 'Kid', 'eid', 'muid', 'prid',
123  'nCDCHits', 'nPXDHits', 'nSVDHits', 'dz', 'dr', 'chiProb']
124 
125  if variable_list is not None and mode is 'expert':
126  B2ERROR('DFT: Variables from identifier file are used. Input variables will be ignored.')
127 
128  if classifier_args is None:
129  classifier_args = {}
130  else:
131  assert isinstance(classifier_args, dict)
132 
133  classifier_args['transform_to_prob'] = transform_to_probability
134 
135  output_file_name = os.path.join(working_dir, uniqueIdentifier + '_training_data.root')
136 
137  # create roe specific paths
138  roe_path = basf2.create_path()
139  dead_end_path = basf2.create_path()
140 
141  # define dft specific lists to enable multiple calls, if someone really wants to do that
142  extension = particle_lists[0].replace(':', '_to_')
143  roe_particle_list_cut = ''
144  roe_particle_list = 'pi+:dft' + '_' + extension
145 
146  tree_name = 'dft_variables'
147 
148  # filter rest of events only for specific particle list
149  ma.signalSideParticleListsFilter(particle_lists, 'hasRestOfEventTracks > 0', roe_path, dead_end_path)
150 
151  # TODO: particles with empty rest of events seems not to show up in efficiency statistics anymore
152 
153  # create final state particle lists
154  ma.fillParticleList(roe_particle_list, roe_particle_list_cut, path=roe_path)
155 
156  dft_particle_lists = ['pi+:pos_charged', 'pi+:neg_charged']
157 
158  pos_cut = 'charge > 0 and isInRestOfEvent == 1 and passesROEMask(' + maskName + ') > 0.5 and p < infinity'
159  neg_cut = 'charge < 0 and isInRestOfEvent == 1 and passesROEMask(' + maskName + ') > 0.5 and p < infinity'
160 
161  ma.cutAndCopyList(dft_particle_lists[0], roe_particle_list, pos_cut, writeOut=True, path=roe_path)
162  ma.cutAndCopyList(dft_particle_lists[1], roe_particle_list, neg_cut, writeOut=True, path=roe_path)
163 
164  # sort pattern for tagging specific variables
165  rank_variable = 'p'
166  # rank_variable = 'useCMSFrame(p)'
167 
168  # create tagging specific variables
169  if mode is not 'expert':
170  features = get_variables(dft_particle_lists[0], rank_variable, variable_list, particleNumber=5)
171  features += get_variables(dft_particle_lists[1], rank_variable, variable_list, particleNumber=5)
172 
173  for particles in dft_particle_lists:
174  ma.rankByHighest(particles, rank_variable, path=roe_path)
175 
176  if mode is 'sampler':
177  if os.path.isfile(output_file_name) and not overwrite:
178  B2FATAL('Outputfile %s already exists. Aborting writeout.' % output_file_name)
179 
180  # and add target
181  all_variables = features + [target]
182 
183  # write to ntuples
184  ma.variablesToNtuple('', all_variables, tree_name, output_file_name, roe_path)
185 
186  # write the command line output for the extern teacher to a file
187  extern_command = 'basf2_mva_teacher --datafile {output_file_name} --treename {tree_name}' \
188  ' --identifier {identifier} --variables "{variables_string}" --target_variable {target}' \
189  ' --method Python --training_fraction {fraction}' \
190  " --config '{classifier_args}' --framework tensorflow" \
191  ' --steering_file {steering_file}'\
192  ''.format(output_file_name=output_file_name, tree_name=tree_name,
193  identifier=uniqueIdentifier,
194  variables_string='" "'.join(features), target=target,
195  classifier_args=json.dumps(classifier_args), fraction=train_valid_fraction,
196  steering_file=mva_steering_file)
197 
198  with open(os.path.join(working_dir, uniqueIdentifier + '_teacher_command'), 'w') as f:
199  f.write(extern_command)
200 
201  elif mode is 'teacher':
202  if not os.path.isfile(output_file_name):
203  B2FATAL('There is no training data file available. Run flavor tagger in sampler mode first.')
204  general_options = basf2_mva.GeneralOptions()
205  general_options.m_datafiles = basf2_mva.vector(output_file_name)
206 
207  general_options.m_treename = tree_name
208  general_options.m_target_variable = target
209  general_options.m_variables = basf2_mva.vector(*features)
210 
211  general_options.m_identifier = uniqueIdentifier
212 
213  specific_options = basf2_mva.PythonOptions()
214  specific_options.m_framework = 'tensorflow'
215  specific_options.m_steering_file = mva_steering_file
216  specific_options.m_training_fraction = train_valid_fraction
217 
218  specific_options.m_config = json.dumps(classifier_args)
219 
220  basf2_mva.teacher(general_options, specific_options)
221 
222  elif mode is 'expert':
223  # TODO: implement filling flavor tagger info in the FlavorTaggerInfoMap
224 
225  # flavor tagger info
226  # mod_ft_info_builder = register_module('FlavorTaggerInfoBuilder')
227  # path.add_module(mod_ft_info_builder)
228 
229  # fill the flavor tagger info
230  # mod_ft_info_filler = register_module('FlavorTaggerInfoFiller')
231 
232  expert_module = basf2.register_module('MVAExpert')
233  expert_module.param('listNames', particle_lists)
234  expert_module.param('identifier', uniqueIdentifier)
235 
236  expert_module.param('extraInfoName', 'dnn_output')
237  expert_module.param('signalFraction', signal_fraction)
238 
239  roe_path.add_module(expert_module)
240 
241  # Create standard alias for the output of the flavor tagger
242  vu._variablemanager.addAlias('DNN_qrCombined', 'formula(2*extraInfo(dnn_output) - 1)')
243 
244  path.for_each('RestOfEvent', 'RestOfEvents', roe_path)
Belle2::makeROOTCompatible
std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
Definition: MakeROOTCompatible.cc:74
variables.utils
Definition: utils.py:1
Belle2::Variable::Manager
Global list of available variables.
Definition: Manager.h:108
Belle2::Variable::Manager::Instance
static Manager & Instance()
get singleton instance.
Definition: Manager.cc:27