Belle II Software  release-06-02-00
DeepFlavorTagger.py
1 #!/usr/bin/env python3
2 
3 
10 
11 import json
12 import os
13 import basf2_mva
14 from basf2 import B2ERROR, B2FATAL
15 import basf2
16 from ROOT import Belle2
17 import variables.utils as vu
18 import modularAnalysis as ma
19 from ROOT import gSystem
20 gSystem.Load('libanalysis.so')
21 
22 # make ROOT compatible available
25 
26 
27 def get_variables(particle_list, ranked_variable, variables=None, particleNumber=1):
28  """ creates variable name pattern requested by the basf2 variable getVariableByRank()
29  :param particle_list:
30  :param ranked_variable:
31  :param variables:
32  :param particleNumber:
33  :return:
34  """
35  var_list = []
36  for var in variables:
37  for i_num in range(1, particleNumber + 1):
38  var_list.append('getVariableByRank(' + particle_list + ', ' + ranked_variable + ', ' + var + ', ' +
39  str(i_num) + ')')
40  return var_list
41 
42 
43 def construct_default_variable_names(particle_lists=None, ranked_variable='p', variables=None, particleNumber=5):
44  """ construct default variables (that are sorted by charge and ranked by momentum)
45  :param particle_lists:
46  :param ranked_variable:
47  :param variables:
48  :param particleNumber:
49  :return:
50  """
51  if particle_lists is None:
52  particle_lists = ['pi+:pos_charged', 'pi+:neg_charged']
53 
54  variable_names = []
55  for p_list in particle_lists:
56  variable_names += get_variables(p_list, ranked_variable, variables, particleNumber)
57 
58  # make root compatible
59  root_compatible_list = []
60  for var in variable_names:
61  root_compatible_list.append(Belle2.makeROOTCompatible(var))
62 
63  return root_compatible_list
64 
65 
66 def DeepFlavorTagger(particle_lists, mode='expert', working_dir='', uniqueIdentifier='standard', variable_list=None,
67  target='qrCombined', overwrite=False,
68  transform_to_probability=False, signal_fraction=-1.0, classifier_args=None,
69  train_valid_fraction=.92, mva_steering_file='analysis/scripts/dft/tensorflow_dnn_interface.py',
70  maskName='',
71  path=None):
72  """
73  Interfacing for the DeepFlavorTagger. This function can be used for training (``teacher``), preparation of
74  training datasets (``sampler``) and inference (``expert``).
75 
76  This function requires reconstructed B meson signal particle list and where an RestOfEvent is built.
77 
78  :param particle_lists: string or list[string], particle list(s) of the reconstructed signal B meson
79  :param mode: string, valid modes are ``expert`` (default), ``teacher``, ``sampler``
80  :param working_dir: string, working directory for the method
81  :param uniqueIdentifier: string, database identifier for the method
82  :param variable_list: list[string], name of the basf2 variables used for discrimination
83  :param target: string, target variable
84  :param overwrite: bool, overwrite already (locally!) existing training
85  :param transform_to_probability: bool, enable a purity transformation to compensate potential over-training,
86  can only be set during training
87  :param signal_fraction: float, (experimental) signal fraction override,
88  transform to output to a probability if an uneven signal/background fraction is used in the training data,
89  can only be set during training
90  :param classifier_args: dictionary, customized arguments for the mlp
91  possible attributes of the dictionary are:
92  lr_dec_rate: learning rate decay rate
93  lr_init: learning rate initial value
94  mom_init: momentum initial value
95  min_epochs: minimal number of epochs
96  max_epochs: maximal number of epochs
97  stop_epochs: epochs to stop without improvements on the validation set for early stopping
98  batch_size: batch size
99  seed: random seed for tensorflow
100  layers: [[layer name, activation function, input_width, output_width, init_bias, init_weights],..]
101  wd_coeffs: weight decay coefficients, length of layers
102  cuda_visible_devices: selection of cuda devices
103  tensorboard_dir: addition directory for logging the training process
104  :param train_valid_fraction: float, train-valid fraction (.92). If transform to probability is
105  enabled, train valid fraction will be split into a test set (.5)
106  :param maskName: get ROE particles from a specified ROE mask
107  :param path: basf2 path obj
108  :return: None
109  """
110 
111  if isinstance(particle_lists, str):
112  particle_lists = [particle_lists]
113 
114  if mode not in ['expert', 'teacher', 'sampler']:
115  B2FATAL('Invalid mode %s' % mode)
116 
117  if variable_list is None and mode in ['sampler', 'teacher']:
118  variable_list = [
119  'useCMSFrame(p)',
120  'useCMSFrame(cosTheta)',
121  'useCMSFrame(phi)',
122  'kaonID',
123  'electronID',
124  'muonID',
125  'protonID',
126  'nCDCHits',
127  'nPXDHits',
128  'nSVDHits',
129  'dz',
130  'dr',
131  'chiProb']
132 
133  if variable_list is not None and mode == 'expert':
134  B2ERROR('DFT: Variables from identifier file are used. Input variables will be ignored.')
135 
136  if classifier_args is None:
137  classifier_args = {}
138  else:
139  assert isinstance(classifier_args, dict)
140 
141  classifier_args['transform_to_prob'] = transform_to_probability
142 
143  output_file_name = os.path.join(working_dir, uniqueIdentifier + '_training_data.root')
144 
145  # create roe specific paths
146  roe_path = basf2.create_path()
147  dead_end_path = basf2.create_path()
148 
149  # define dft specific lists to enable multiple calls, if someone really wants to do that
150  extension = particle_lists[0].replace(':', '_to_')
151  roe_particle_list_cut = ''
152  roe_particle_list = 'pi+:dft' + '_' + extension
153 
154  tree_name = 'dft_variables'
155 
156  # filter rest of events only for specific particle list
157  ma.signalSideParticleListsFilter(particle_lists, 'hasRestOfEventTracks > 0', roe_path, dead_end_path)
158 
159  # TODO: particles with empty rest of events seems not to show up in efficiency statistics anymore
160 
161  # create final state particle lists
162  ma.fillParticleList(roe_particle_list, roe_particle_list_cut, path=roe_path)
163 
164  dft_particle_lists = ['pi+:pos_charged', 'pi+:neg_charged']
165 
166  pos_cut = 'charge > 0 and isInRestOfEvent == 1 and passesROEMask(' + maskName + ') > 0.5 and p < infinity'
167  neg_cut = 'charge < 0 and isInRestOfEvent == 1 and passesROEMask(' + maskName + ') > 0.5 and p < infinity'
168 
169  ma.cutAndCopyList(dft_particle_lists[0], roe_particle_list, pos_cut, writeOut=True, path=roe_path)
170  ma.cutAndCopyList(dft_particle_lists[1], roe_particle_list, neg_cut, writeOut=True, path=roe_path)
171 
172  # sort pattern for tagging specific variables
173  rank_variable = 'p'
174  # rank_variable = 'useCMSFrame(p)'
175 
176  # create tagging specific variables
177  if mode != 'expert':
178  features = get_variables(dft_particle_lists[0], rank_variable, variable_list, particleNumber=5)
179  features += get_variables(dft_particle_lists[1], rank_variable, variable_list, particleNumber=5)
180 
181  for particles in dft_particle_lists:
182  ma.rankByHighest(particles, rank_variable, path=roe_path)
183 
184  if mode == 'sampler':
185  if os.path.isfile(output_file_name) and not overwrite:
186  B2FATAL('Outputfile %s already exists. Aborting writeout.' % output_file_name)
187 
188  # and add target
189  all_variables = features + [target]
190 
191  # write to ntuples
192  ma.variablesToNtuple('', all_variables, tree_name, output_file_name, roe_path)
193 
194  # write the command line output for the extern teacher to a file
195  extern_command = 'basf2_mva_teacher --datafile {output_file_name} --treename {tree_name}' \
196  ' --identifier {identifier} --variables "{variables_string}" --target_variable {target}' \
197  ' --method Python --training_fraction {fraction}' \
198  " --config '{classifier_args}' --framework tensorflow" \
199  ' --steering_file {steering_file}'\
200  ''.format(output_file_name=output_file_name, tree_name=tree_name,
201  identifier=uniqueIdentifier,
202  variables_string='" "'.join(features), target=target,
203  classifier_args=json.dumps(classifier_args), fraction=train_valid_fraction,
204  steering_file=mva_steering_file)
205 
206  with open(os.path.join(working_dir, uniqueIdentifier + '_teacher_command'), 'w') as f:
207  f.write(extern_command)
208 
209  elif mode == 'teacher':
210  if not os.path.isfile(output_file_name):
211  B2FATAL('There is no training data file available. Run flavor tagger in sampler mode first.')
212  general_options = basf2_mva.GeneralOptions()
213  general_options.m_datafiles = basf2_mva.vector(output_file_name)
214 
215  general_options.m_treename = tree_name
216  general_options.m_target_variable = target
217  general_options.m_variables = basf2_mva.vector(*features)
218 
219  general_options.m_identifier = uniqueIdentifier
220 
221  specific_options = basf2_mva.PythonOptions()
222  specific_options.m_framework = 'tensorflow'
223  specific_options.m_steering_file = mva_steering_file
224  specific_options.m_training_fraction = train_valid_fraction
225 
226  specific_options.m_config = json.dumps(classifier_args)
227 
228  basf2_mva.teacher(general_options, specific_options)
229 
230  elif mode == 'expert':
231 
232  flavorTaggerInfoBuilder = basf2.register_module('FlavorTaggerInfoBuilder')
233  path.add_module(flavorTaggerInfoBuilder)
234 
235  expert_module = basf2.register_module('MVAExpert')
236  expert_module.param('listNames', particle_lists)
237  expert_module.param('identifier', uniqueIdentifier)
238 
239  expert_module.param('extraInfoName', 'dnn_output')
240  expert_module.param('signalFraction', signal_fraction)
241 
242  roe_path.add_module(expert_module)
243 
244  flavorTaggerInfoFiller = basf2.register_module('FlavorTaggerInfoFiller')
245  flavorTaggerInfoFiller.param('DNNmlp', True)
246  roe_path.add_module(flavorTaggerInfoFiller)
247 
248  # Create standard alias for the output of the flavor tagger
249  vu._variablemanager.addAlias('DNN_qrCombined', 'qrOutput(DNN)')
250 
251  path.for_each('RestOfEvent', 'RestOfEvents', roe_path)
Global list of available variables.
Definition: Manager.h:98
static Manager & Instance()
get singleton instance.
Definition: Manager.cc:25
std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.