Belle II Software  light-2212-foldex
basf2_mva_evaluate.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
11 
12 import basf2_mva_util
13 
14 from basf2_mva_evaluation import plotting
15 from basf2 import conditions
16 import argparse
17 import tempfile
18 
19 import numpy as np
20 from B2Tools import b2latex, format
21 from basf2 import B2INFO
22 
23 import os
24 import shutil
25 import collections
26 from typing import List, Any
27 
28 
29 def get_argument_parser() -> argparse.ArgumentParser:
30  """ Parses the command line options of the fei and returns the corresponding arguments. """
31  parser = argparse.ArgumentParser()
32  parser.add_argument('-id', '--identifiers', dest='identifiers', type=str, required=True, action='append', nargs='+',
33  help='DB Identifier or weightfile')
34  parser.add_argument('-train', '--train_datafiles', dest='train_datafiles', type=str, required=False, action='append', nargs='+',
35  help='Data file containing ROOT TTree used during training')
36  parser.add_argument('-data', '--datafiles', dest='datafiles', type=str, required=True, action='append', nargs='+',
37  help='Data file containing ROOT TTree with independent test data')
38  parser.add_argument('-tree', '--treename', dest='treename', type=str, default='tree', help='Treename in data file')
39  parser.add_argument('-out', '--outputfile', dest='outputfile', type=str, default='output.zip',
40  help='Name of the created .zip archive file if not compiling or a pdf file if compilation is successful.')
41  parser.add_argument('-w', '--working_directory', dest='working_directory', type=str, default='',
42  help="""Working directory where the created images and root files are stored,
43  default is to create a temporary directory.""")
44  parser.add_argument('-l', '--localdb', dest='localdb', type=str, action='append', nargs='+', required=False,
45  help="""path or list of paths to local database(s) containing the mvas of interest.
46  The testing payloads are preprended and take precedence over payloads in global tags.""")
47  parser.add_argument('-g', '--globaltag', dest='globaltag', type=str, action='append', nargs='+', required=False,
48  help='globaltag or list of globaltags containing the mvas of interest. The globaltags are prepended.')
49  parser.add_argument('-n', '--fillnan', dest='fillnan', action='store_true',
50  help='Fill nan and inf values with actual numbers')
51  parser.add_argument('-c', '--compile', dest='compile', action='store_true',
52  help='Compile latex to pdf directly')
53  return parser
54 
55 
56 def unique(input_list: List[Any]) -> List[Any]:
57  """
58  Returns a list containing only unique elements, keeps the original order of the list
59  @param input_list list containing the elements
60  """
61  output = []
62  for x in input_list:
63  if x not in output:
64  output.append(x)
65  return output
66 
67 
68 def flatten(input_list: List[List[Any]]) -> List[Any]:
69  """
70  Flattens a list of lists
71  @param input_list list of lists to be flattened
72  """
73  return [item for sublist in input_list for item in sublist]
74 
75 
76 def create_abbreviations(names, length=5):
77  count = dict()
78  for name in names:
79  abbreviation = name[:length]
80  if abbreviation not in count:
81  count[abbreviation] = 0
82  count[abbreviation] += 1
83  abbreviations = collections.OrderedDict()
84 
85  count2 = dict()
86  for name in names:
87  abbreviation = name[:length]
88  abbreviations[name] = abbreviation
89  if count[abbreviation] > 1:
90  if abbreviation not in count2:
91  count2[abbreviation] = 0
92  count2[abbreviation] += 1
93  abbreviations[name] += str(count2[abbreviation])
94  return abbreviations
95 
96 
97 if __name__ == '__main__':
98 
99  import ROOT # noqa
100  ROOT.PyConfig.IgnoreCommandLineOptions = True
101  ROOT.PyConfig.StartGuiThread = False
102  ROOT.gROOT.SetBatch(True)
103 
104  old_cwd = os.getcwd()
105  parser = get_argument_parser()
106  args = parser.parse_args()
107 
108  identifiers = flatten(args.identifiers)
109  identifier_abbreviations = create_abbreviations(identifiers)
110 
111  datafiles = flatten(args.datafiles)
112  if args.localdb is not None:
113  for localdb in flatten(args.localdb):
114  conditions.prepend_testing_payloads(localdb)
115 
116  if args.globaltag is not None:
117  for tag in flatten(args.globaltag):
118  conditions.prepend_globaltag(tag)
119 
120  print("Load methods")
121  methods = [basf2_mva_util.Method(identifier) for identifier in identifiers]
122 
123  print("Apply experts on independent data")
124  test_probability = {}
125  test_target = {}
126  for method in methods:
127  p, t = method.apply_expert(datafiles, args.treename)
128  test_probability[identifier_abbreviations[method.identifier]] = p
129  test_target[identifier_abbreviations[method.identifier]] = t
130 
131  print("Apply experts on training data")
132  train_probability = {}
133  train_target = {}
134  if args.train_datafiles is not None:
135  train_datafiles = sum(args.train_datafiles, [])
136  for method in methods:
137  p, t = method.apply_expert(train_datafiles, args.treename)
138  train_probability[identifier_abbreviations[method.identifier]] = p
139  train_target[identifier_abbreviations[method.identifier]] = t
140 
141  variables = unique(v for method in methods for v in method.variables)
142  variable_abbreviations = create_abbreviations(variables)
143  root_variables = unique(v for method in methods for v in method.root_variables)
144 
145  spectators = unique(v for method in methods for v in method.spectators)
146  spectator_abbreviations = create_abbreviations(spectators)
147  root_spectators = unique(v for method in methods for v in method.root_spectators)
148 
149  print("Load variables array")
150  rootchain = ROOT.TChain(args.treename)
151  for datafile in datafiles:
152  rootchain.Add(datafile)
153 
154  variables_data = basf2_mva_util.tree2dict(rootchain, root_variables, list(variable_abbreviations.values()))
155  spectators_data = basf2_mva_util.tree2dict(rootchain, root_spectators, list(spectator_abbreviations.values()))
156 
157  if args.fillnan:
158  for column in variable_abbreviations.values():
159  np.nan_to_num(variables_data[column], copy=False)
160 
161  for column in spectator_abbreviations.values():
162  np.nan_to_num(spectators_data[column], copy=False)
163 
164  print("Create latex file")
165  # Change working directory after experts run, because they might want to access
166  # a localdb in the current working directory.
167  with tempfile.TemporaryDirectory() as tempdir:
168  if args.working_directory == '':
169  os.chdir(tempdir)
170  else:
171  os.chdir(args.working_directory)
172 
173  o = b2latex.LatexFile()
174  o += b2latex.TitlePage(title='Automatic MVA Evaluation',
175  authors=[r'Thomas Keck\\ Moritz Gelb\\ Nils Braun'],
176  abstract='Evaluation plots',
177  add_table_of_contents=True).finish()
178 
179  o += b2latex.Section("Classifiers")
180  o += b2latex.String(r"""
181  This section contains the GeneralOptions and SpecificOptions of all classifiers represented by an XML tree.
182  The same information can be retrieved using the basf2\_mva\_info tool.
183  """)
184 
185  table = b2latex.LongTable(r"ll", "Abbreviations of identifiers", "{name} & {abbr}", r"Identifier & Abbreviation")
186  for identifier in identifiers:
187  table.add(name=format.string(identifier), abbr=format.string(identifier_abbreviations[identifier]))
188  o += table.finish()
189 
190  for method in methods:
191  o += b2latex.SubSection(format.string(method.identifier))
192  o += b2latex.Listing(language='XML').add(method.description).finish()
193 
194  o += b2latex.Section("Variables")
195  o += b2latex.String("""
196  This section contains an overview of the importance and correlation of the variables used by the classifiers.
197  And distribution plots of the variables on the independent dataset. The distributions are normed for signal and
198  background separately, and only the region +- 3 sigma around the mean is shown.
199 
200  The importance scores shown are based on the variable importance as estimated by each MVA method internally.
201  This means the variable with the lowest importance will have score 0, and the variable
202  with the highest importance will have score 100. If the method does not provide such a ranking, all
203  importances will be 0.
204  """)
205 
206  table = b2latex.LongTable(r"ll", "Abbreviations of variables", "{name} & {abbr}", r"Variable & Abbreviation")
207  for v in variables:
208  table.add(name=format.string(v), abbr=format.string(variable_abbreviations[v]))
209  o += table.finish()
210 
211  o += b2latex.SubSection("Importance")
212  graphics = b2latex.Graphics()
213  p = plotting.Importance()
214  p.add({identifier_abbreviations[i.identifier]: np.array([i.importances.get(v, 0.0) for v in variables]) for i in methods},
215  identifier_abbreviations.values(), variable_abbreviations.values())
216  p.finish()
217  p.save('importance.pdf')
218  graphics.add('importance.pdf', width=1.0)
219  o += graphics.finish()
220 
221  o += b2latex.SubSection("Correlation")
222  first_identifier_abbr = list(identifier_abbreviations.values())[0]
223  graphics = b2latex.Graphics()
225  p.add(variables_data, variable_abbreviations.values(),
226  test_target[first_identifier_abbr] == 1,
227  test_target[first_identifier_abbr] == 0)
228  p.finish()
229  p.save('correlation_plot.pdf')
230  graphics.add('correlation_plot.pdf', width=1.0)
231  o += graphics.finish()
232 
233  for v in variables:
234  variable_abbr = variable_abbreviations[v]
235  o += b2latex.SubSection(format.string(v))
236  graphics = b2latex.Graphics()
237  p = plotting.VerboseDistribution(normed=True, range_in_std=3)
238  p.add(variables_data, variable_abbr, test_target[first_identifier_abbr] == 1, label="Signal")
239  p.add(variables_data, variable_abbr, test_target[first_identifier_abbr] == 0, label="Background")
240  p.finish()
241  p.save('variable_{}.pdf'.format(hash(v)))
242  graphics.add('variable_{}.pdf'.format(hash(v)), width=1.0)
243  o += graphics.finish()
244 
245  o += b2latex.Section("Classifier Plot")
246  o += b2latex.String("This section contains the receiver operating characteristics (ROC), purity projection, ..."
247  "of the classifiers on training and independent data."
248  "The legend of each plot contains the shortened identifier and the area under the ROC curve"
249  "in parenthesis.")
250 
251  o += b2latex.Section("ROC Plot")
252  graphics = b2latex.Graphics()
254  for identifier in identifier_abbreviations.values():
255  p.add(test_probability, identifier, test_target[identifier] == 1, test_target[identifier] == 0)
256  p.finish()
257  p.axis.set_title("ROC Rejection Plot on independent data")
258  p.save('roc_plot_test.pdf')
259  graphics.add('roc_plot_test.pdf', width=1.0)
260  o += graphics.finish()
261 
262  if train_probability:
263  for i, identifier in enumerate(identifiers):
264  graphics = b2latex.Graphics()
266  identifier_abbr = identifier_abbreviations[identifier]
267  p.add(train_probability, identifier_abbr, train_target[identifier_abbr] == 1,
268  train_target[identifier_abbr] == 0, label='Train')
269  p.add(test_probability, identifier_abbr, test_target[identifier_abbr] == 1,
270  test_target[identifier_abbr] == 0, label='Test')
271  p.finish()
272  p.axis.set_title(identifier)
273  p.save('roc_test_{}.pdf'.format(hash(identifier)))
274  graphics.add('roc_test_{}.pdf'.format(hash(identifier)), width=1.0)
275  o += graphics.finish()
276 
277  o += b2latex.Section("Classification Results")
278 
279  for identifier in identifiers:
280  identifier_abbr = identifier_abbreviations[identifier]
281  o += b2latex.SubSection(format.string(identifier_abbr))
282  graphics = b2latex.Graphics()
284  p.add(0, test_probability, identifier_abbr, test_target[identifier_abbr] == 1,
285  test_target[identifier_abbr] == 0, normed=True)
286  p.sub_plots[0].axis.set_title("Classification result in test data for {identifier}".format(identifier=identifier))
287 
288  p.add(1, test_probability, identifier_abbr, test_target[identifier_abbr] == 1,
289  test_target[identifier_abbr] == 0, normed=False)
290  p.sub_plots[1].axis.set_title("Classification result in test data for {identifier}".format(identifier=identifier))
291  p.finish()
292 
293  p.save('classification_result_{identifier}.pdf'.format(identifier=hash(identifier)))
294  graphics.add('classification_result_{identifier}.pdf'.format(identifier=hash(identifier)), width=1)
295  o += graphics.finish()
296 
297  o += b2latex.Section("Diagonal Plot")
298  graphics = b2latex.Graphics()
299  p = plotting.Diagonal()
300  for identifier in identifiers:
301  o += b2latex.SubSection(format.string(identifier_abbr))
302  identifier_abbr = identifier_abbreviations[identifier]
303  p.add(test_probability, identifier_abbr, test_target[identifier_abbr] == 1, test_target[identifier_abbr] == 0)
304  p.finish()
305  p.axis.set_title("Diagonal plot on independent data")
306  p.save('diagonal_plot_test.pdf')
307  graphics.add('diagonal_plot_test.pdf', width=1.0)
308  o += graphics.finish()
309 
310  if train_probability:
311  o += b2latex.SubSection("Overtraining Plot")
312  for identifier in identifiers:
313  identifier_abbr = identifier_abbreviations[identifier]
314  probability = {identifier_abbr: np.r_[train_probability[identifier_abbr], test_probability[identifier_abbr]]}
315  target = np.r_[train_target[identifier_abbr], test_target[identifier_abbr]]
316  train_mask = np.r_[np.ones(len(train_target[identifier_abbr])), np.zeros(len(test_target[identifier_abbr]))]
317  graphics = b2latex.Graphics()
319  p.add(probability, identifier_abbr,
320  train_mask == 1, train_mask == 0,
321  target == 1, target == 0, )
322  p.finish()
323  p.axis.set_title("Overtraining check for {}".format(identifier))
324  p.save('overtraining_plot_{}.pdf'.format(hash(identifier)))
325  graphics.add('overtraining_plot_{}.pdf'.format(hash(identifier)), width=1.0)
326  o += graphics.finish()
327 
328  o += b2latex.Section("Spectators")
329  o += b2latex.String("This section contains the distribution and dependence on the"
330  "classifier outputs of all spectator variables.")
331 
332  table = b2latex.LongTable(r"ll", "Abbreviations of spectators", "{name} & {abbr}", r"Spectator & Abbreviation")
333  for s in spectators:
334  table.add(name=format.string(s), abbr=format.string(spectator_abbreviations[s]))
335  o += table.finish()
336 
337  for spectator in spectators:
338  spectator_abbr = spectator_abbreviations[spectator]
339  o += b2latex.SubSection(format.string(spectator))
340  graphics = b2latex.Graphics()
342  p.add(spectators_data, spectator_abbr, test_target[first_identifier_abbr] == 1, label="Signal")
343  p.add(spectators_data, spectator_abbr, test_target[first_identifier_abbr] == 0, label="Background")
344  p.finish()
345  p.save('spectator_{}.pdf'.format(hash(spectator)))
346  graphics.add('spectator_{}.pdf'.format(hash(spectator)), width=1.0)
347  o += graphics.finish()
348 
349  for identifier in identifiers:
350  o += b2latex.SubSubSection(format.string(spectator) + " with classifier " + format.string(identifier))
351  identifier_abbr = identifier_abbreviations[identifier]
352  data = {identifier_abbr: test_probability[identifier_abbr], spectator_abbr: spectators_data[spectator_abbr]}
353  graphics = b2latex.Graphics()
355  p.add(data, spectator_abbr, identifier_abbr, list(range(10, 100, 10)),
356  test_target[identifier_abbr] == 1,
357  test_target[identifier_abbr] == 0)
358  p.finish()
359  p.save('correlation_plot_{}_{}.pdf'.format(hash(spectator), hash(identifier)))
360  graphics.add('correlation_plot_{}_{}.pdf'.format(hash(spectator), hash(identifier)), width=1.0)
361  o += graphics.finish()
362 
363  if args.compile:
364  B2INFO(f"Creating a PDF file at {args.outputfile}. Please remove the '-c' switch if this fails.")
365  o.save('latex.tex', compile=True)
366  else:
367  B2INFO(f"Creating a .zip archive containing plots and a TeX file at {args.outputfile}."
368  f"Please unpack the archive and compile the latex.tex file with pdflatex.")
369  o.save('latex.tex', compile=False)
370 
371  os.chdir(old_cwd)
372  if args.working_directory == '':
373  working_directory = tempdir
374  else:
375  working_directory = args.working_directory
376 
377  if args.compile:
378  shutil.copy(os.path.join(working_directory, 'latex.pdf'), args.outputfile)
379  else:
380  base_name = os.path.join(old_cwd, args.outputfile.rsplit('.', 1)[0])
381  shutil.make_archive(base_name, 'zip', working_directory)
def tree2dict(tree, tree_columns, dict_columns=None)