Belle II Software  release-05-02-19
DeepFlavorTagger.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
13 
14 import json
15 import os
16 import basf2_mva
17 from basf2 import B2ERROR, B2FATAL
18 import basf2
19 from ROOT import Belle2
20 import variables.utils as vu
21 import modularAnalysis as ma
22 from ROOT import gSystem
23 gSystem.Load('libanalysis.so')
24 
25 # make ROOT compatible available
28 
29 
30 def get_variables(particle_list, ranked_variable, variables=None, particleNumber=1):
31  """ creates variable name pattern requested by the basf2 variable getVariableByRank()
32  :param particle_list:
33  :param ranked_variable:
34  :param variables:
35  :param particleNumber:
36  :return:
37  """
38  var_list = []
39  for var in variables:
40  for i_num in range(1, particleNumber + 1):
41  var_list.append('getVariableByRank(' + particle_list + ', ' + ranked_variable + ', ' + var + ', ' +
42  str(i_num) + ')')
43  return var_list
44 
45 
46 def construct_default_variable_names(particle_lists=None, ranked_variable='p', variables=None, particleNumber=5):
47  """ construct default variables (that are sorted by charge and ranked by momentum)
48  :param particle_lists:
49  :param ranked_variable:
50  :param variables:
51  :param particleNumber:
52  :return:
53  """
54  if particle_lists is None:
55  particle_lists = ['pi+:pos_charged', 'pi+:neg_charged']
56 
57  variable_names = []
58  for p_list in particle_lists:
59  variable_names += get_variables(p_list, ranked_variable, variables, particleNumber)
60 
61  # make root compatible
62  root_compatible_list = []
63  for var in variable_names:
64  root_compatible_list.append(Belle2.makeROOTCompatible(var))
65 
66  return root_compatible_list
67 
68 
69 def DeepFlavorTagger(particle_lists, mode='expert', working_dir='', uniqueIdentifier='standard', variable_list=None,
70  target='qrCombined', overwrite=False,
71  transform_to_probability=False, signal_fraction=-1.0, classifier_args=None,
72  train_valid_fraction=.92, mva_steering_file='analysis/scripts/dft/tensorflow_dnn_interface.py',
73  maskName='',
74  path=None):
75  """
76  Interfacing for the DeepFlavorTagger. This function can be used for training (``teacher``), preparation of
77  training datasets (``sampler``) and inference (``expert``).
78 
79  This function requires reconstructed B meson signal particle list and where an RestOfEvent is built.
80 
81  :param particle_lists: string or list[string], particle list(s) of the reconstructed signal B meson
82  :param mode: string, valid modes are ``expert`` (default), ``teacher``, ``sampler``
83  :param working_dir: string, working directory for the method
84  :param uniqueIdentifier: string, database identifier for the method
85  :param variable_list: list[string], name of the basf2 variables used for discrimination
86  :param target: string, target variable
87  :param overwrite: bool, overwrite already (locally!) existing training
88  :param transform_to_probability: bool, enable a purity transformation to compensate potential over-training,
89  can only be set during training
90  :param signal_fraction: float, (experimental) signal fraction override,
91  transform to output to a probability if an uneven signal/background fraction is used in the training data,
92  can only be set during training
93  :param classifier_args: dictionary, costumized arguments for the mlp
94  possible attributes of the dictionary are:
95  lr_dec_rate: learning rate decay rate
96  lr_init: learning rate initial value
97  mom_init: momentum initial value
98  min_epochs: minimal number of epochs
99  max_epochs: maximal number of epochs
100  stop_epochs: epochs to stop without improvements on the validation set for early stopping
101  batch_size: batch size
102  seed: random seed for tensorflow
103  layers: [[layer name, activation function, input_width, output_width, init_bias, init_weights],..]
104  wd_coeffs: weight decay coefficients, length of layers
105  cuda_visible_devices: selection of cuda devices
106  tensorboard_dir: addition directory for logging the training process
107  :param train_valid_fraction: float, train-valid fraction (.92). If transform to probability is
108  enabled, train valid fraction will be splitted to a test set (.5)
109  :param maskName: get ROE particles from a specified ROE mask
110  :param path: basf2 path obj
111  :return: None
112  """
113 
114  if isinstance(particle_lists, str):
115  particle_lists = [particle_lists]
116 
117  if mode not in ['expert', 'teacher', 'sampler']:
118  B2FATAL('Invalid mode %s' % mode)
119 
120  if variable_list is None and mode in ['sampler', 'teacher']:
121  variable_list = ['useCMSFrame(p)', 'useCMSFrame(cosTheta)', 'useCMSFrame(phi)', 'Kid', 'eid', 'muid', 'prid',
122  'nCDCHits', 'nPXDHits', 'nSVDHits', 'dz', 'dr', 'chiProb']
123 
124  if variable_list is not None and mode is 'expert':
125  B2ERROR('DFT: Variables from identifier file are used. Input variables will be ignored.')
126 
127  if classifier_args is None:
128  classifier_args = {}
129  else:
130  assert isinstance(classifier_args, dict)
131 
132  classifier_args['transform_to_prob'] = transform_to_probability
133 
134  output_file_name = os.path.join(working_dir, uniqueIdentifier + '_training_data.root')
135 
136  # create roe specific paths
137  roe_path = basf2.create_path()
138  dead_end_path = basf2.create_path()
139 
140  # define dft specific lists to enable multiple calls, if someone really wants to do that
141  extension = particle_lists[0].replace(':', '_to_')
142  roe_particle_list_cut = ''
143  roe_particle_list = 'pi+:dft' + '_' + extension
144 
145  tree_name = 'dft_variables'
146 
147  # filter rest of events only for specific particle list
148  ma.signalSideParticleListsFilter(particle_lists, 'hasRestOfEventTracks > 0', roe_path, dead_end_path)
149 
150  # TODO: particles with empty rest of events seems not to show up in efficiency statistics anymore
151 
152  # create final state particle lists
153  ma.fillParticleList(roe_particle_list, roe_particle_list_cut, path=roe_path)
154 
155  dft_particle_lists = ['pi+:pos_charged', 'pi+:neg_charged']
156 
157  pos_cut = 'charge > 0 and isInRestOfEvent == 1 and passesROEMask(' + maskName + ') > 0.5 and p < infinity'
158  neg_cut = 'charge < 0 and isInRestOfEvent == 1 and passesROEMask(' + maskName + ') > 0.5 and p < infinity'
159 
160  ma.cutAndCopyList(dft_particle_lists[0], roe_particle_list, pos_cut, writeOut=True, path=roe_path)
161  ma.cutAndCopyList(dft_particle_lists[1], roe_particle_list, neg_cut, writeOut=True, path=roe_path)
162 
163  # sort pattern for tagging specific variables
164  rank_variable = 'p'
165  # rank_variable = 'useCMSFrame(p)'
166 
167  # create tagging specific variables
168  if mode is not 'expert':
169  features = get_variables(dft_particle_lists[0], rank_variable, variable_list, particleNumber=5)
170  features += get_variables(dft_particle_lists[1], rank_variable, variable_list, particleNumber=5)
171 
172  for particles in dft_particle_lists:
173  ma.rankByHighest(particles, rank_variable, path=roe_path)
174 
175  if mode is 'sampler':
176  if os.path.isfile(output_file_name) and not overwrite:
177  B2FATAL('Outputfile %s already exists. Aborting writeout.' % output_file_name)
178 
179  # and add target
180  all_variables = features + [target]
181 
182  # write to ntuples
183  ma.variablesToNtuple('', all_variables, tree_name, output_file_name, roe_path)
184 
185  # write the command line output for the extern teacher to a file
186  extern_command = 'basf2_mva_teacher --datafile {output_file_name} --treename {tree_name}' \
187  ' --identifier {identifier} --variables "{variables_string}" --target_variable {target}' \
188  ' --method Python --training_fraction {fraction}' \
189  " --config '{classifier_args}' --framework tensorflow" \
190  ' --steering_file {steering_file}'\
191  ''.format(output_file_name=output_file_name, tree_name=tree_name,
192  identifier=uniqueIdentifier,
193  variables_string='" "'.join(features), target=target,
194  classifier_args=json.dumps(classifier_args), fraction=train_valid_fraction,
195  steering_file=mva_steering_file)
196 
197  with open(os.path.join(working_dir, uniqueIdentifier + '_teacher_command'), 'w') as f:
198  f.write(extern_command)
199 
200  elif mode is 'teacher':
201  if not os.path.isfile(output_file_name):
202  B2FATAL('There is no training data file available. Run flavor tagger in sampler mode first.')
203  general_options = basf2_mva.GeneralOptions()
204  general_options.m_datafiles = basf2_mva.vector(output_file_name)
205 
206  general_options.m_treename = tree_name
207  general_options.m_target_variable = target
208  general_options.m_variables = basf2_mva.vector(*features)
209 
210  general_options.m_identifier = uniqueIdentifier
211 
212  specific_options = basf2_mva.PythonOptions()
213  specific_options.m_framework = 'tensorflow'
214  specific_options.m_steering_file = mva_steering_file
215  specific_options.m_training_fraction = train_valid_fraction
216 
217  specific_options.m_config = json.dumps(classifier_args)
218 
219  basf2_mva.teacher(general_options, specific_options)
220 
221  elif mode is 'expert':
222 
223  flavorTaggerInfoBuilder = basf2.register_module('FlavorTaggerInfoBuilder')
224  path.add_module(flavorTaggerInfoBuilder)
225 
226  expert_module = basf2.register_module('MVAExpert')
227  expert_module.param('listNames', particle_lists)
228  expert_module.param('identifier', uniqueIdentifier)
229 
230  expert_module.param('extraInfoName', 'dnn_output')
231  expert_module.param('signalFraction', signal_fraction)
232 
233  roe_path.add_module(expert_module)
234 
235  flavorTaggerInfoFiller = basf2.register_module('FlavorTaggerInfoFiller')
236  flavorTaggerInfoFiller.param('DNNmlp', True)
237  roe_path.add_module(flavorTaggerInfoFiller)
238 
239  # Create standard alias for the output of the flavor tagger
240  vu._variablemanager.addAlias('DNN_qrCombined', 'qrOutput(DNN)')
241 
242  path.for_each('RestOfEvent', 'RestOfEvents', roe_path)
Belle2::makeROOTCompatible
std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
Definition: MakeROOTCompatible.cc:74
variables.utils
Definition: utils.py:1
Belle2::Variable::Manager
Global list of available variables.
Definition: Manager.h:108
Belle2::Variable::Manager::Instance
static Manager & Instance()
get singleton instance.
Definition: Manager.cc:27