Belle II Software  light-2205-abys
basf2_mva_evaluate.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
11 
12 import basf2_mva_util
13 
14 from basf2_mva_evaluation import plotting
15 from basf2 import conditions
16 import argparse
17 import tempfile
18 
19 import numpy as np
20 from B2Tools import b2latex, format
21 from basf2 import B2INFO
22 
23 import os
24 import shutil
25 import collections
26 from typing import List, Any
27 
28 
29 def get_argument_parser() -> argparse.ArgumentParser:
30  """ Parses the command line options of the fei and returns the corresponding arguments. """
31  parser = argparse.ArgumentParser()
32  parser.add_argument('-id', '--identifiers', dest='identifiers', type=str, required=True, action='append', nargs='+',
33  help='DB Identifier or weightfile')
34  parser.add_argument('-train', '--train_datafiles', dest='train_datafiles', type=str, required=False, action='append', nargs='+',
35  help='Data file containing ROOT TTree used during training')
36  parser.add_argument('-data', '--datafiles', dest='datafiles', type=str, required=True, action='append', nargs='+',
37  help='Data file containing ROOT TTree with independent test data')
38  parser.add_argument('-tree', '--treename', dest='treename', type=str, default='tree', help='Treename in data file')
39  parser.add_argument('-out', '--outputfile', dest='outputfile', type=str, default='output.zip',
40  help='Name of the created .zip archive file if not compiling or a pdf file if compilation is successful.')
41  parser.add_argument('-w', '--working_directory', dest='working_directory', type=str, default='',
42  help="""Working directory where the created images and root files are stored,
43  default is to create a temporary directory.""")
44  parser.add_argument('-l', '--localdb', dest='localdb', type=str, action='append', nargs='+', required=False,
45  help="""path or list of paths to local database(s) containing the mvas of interest.
46  The testing payloads are preprended and take precedence over payloads in global tags.""")
47  parser.add_argument('-g', '--globaltag', dest='globaltag', type=str, action='append', nargs='+', required=False,
48  help='globaltag or list of globaltags containing the mvas of interest. The globaltags are prepended.')
49  parser.add_argument('-n', '--fillnan', dest='fillnan', action='store_true',
50  help='Fill nan and inf values with actual numbers')
51  parser.add_argument('-c', '--compile', dest='compile', action='store_true',
52  help='Compile latex to pdf directly')
53  return parser
54 
55 
56 def unique(input_list: List[Any]) -> List[Any]:
57  """
58  Returns a list containing only unique elements, keeps the original order of the list
59  @param input_list list containing the elements
60  """
61  output = []
62  for x in input_list:
63  if x not in output:
64  output.append(x)
65  return output
66 
67 
68 def flatten(input_list: List[List[Any]]) -> List[Any]:
69  """
70  Flattens a list of lists
71  @param input_list list of lists to be flattened
72  """
73  return [item for sublist in input_list for item in sublist]
74 
75 
76 def create_abbreviations(names, length=5):
77  count = dict()
78  for name in names:
79  abbreviation = name[:length]
80  if abbreviation not in count:
81  count[abbreviation] = 0
82  count[abbreviation] += 1
83  abbreviations = collections.OrderedDict()
84 
85  count2 = dict()
86  for name in names:
87  abbreviation = name[:length]
88  abbreviations[name] = abbreviation
89  if count[abbreviation] > 1:
90  if abbreviation not in count2:
91  count2[abbreviation] = 0
92  count2[abbreviation] += 1
93  abbreviations[name] += str(count2[abbreviation])
94  return abbreviations
95 
96 
97 if __name__ == '__main__':
98 
99  import ROOT # noqa
100  ROOT.PyConfig.IgnoreCommandLineOptions = True
101  ROOT.PyConfig.StartGuiThread = False
102  ROOT.gROOT.SetBatch(True)
103 
104  old_cwd = os.getcwd()
105  parser = get_argument_parser()
106  args = parser.parse_args()
107 
108  identifiers = flatten(args.identifiers)
109  identifier_abbreviations = create_abbreviations(identifiers)
110 
111  datafiles = flatten(args.datafiles)
112  if args.localdb is not None:
113  for localdb in flatten(args.localdb):
114  conditions.prepend_testing_payloads(localdb)
115 
116  if args.globaltag is not None:
117  for tag in flatten(args.globaltag):
118  conditions.prepend_globaltag(tag)
119 
120  print("Load methods")
121  methods = [basf2_mva_util.Method(identifier) for identifier in identifiers]
122 
123  print("Apply experts on independent data")
124  test_probability = {}
125  test_target = {}
126  for method in methods:
127  p, t = method.apply_expert(datafiles, args.treename)
128  test_probability[identifier_abbreviations[method.identifier]] = p
129  test_target[identifier_abbreviations[method.identifier]] = t
130 
131  print("Apply experts on training data")
132  train_probability = {}
133  train_target = {}
134  if args.train_datafiles is not None:
135  train_datafiles = sum(args.train_datafiles, [])
136  for method in methods:
137  p, t = method.apply_expert(train_datafiles, args.treename)
138  train_probability[identifier_abbreviations[method.identifier]] = p
139  train_target[identifier_abbreviations[method.identifier]] = t
140 
141  variables = unique(v for method in methods for v in method.variables)
142  variable_abbreviations = create_abbreviations(variables)
143  root_variables = unique(v for method in methods for v in method.root_variables)
144 
145  spectators = unique(v for method in methods for v in method.spectators)
146  spectator_abbreviations = create_abbreviations(spectators)
147  root_spectators = unique(v for method in methods for v in method.root_spectators)
148 
149  print("Load variables array")
150  rootchain = ROOT.TChain(args.treename)
151  for datafile in datafiles:
152  rootchain.Add(datafile)
153 
154  variables_data = basf2_mva_util.tree2dict(rootchain, root_variables, list(variable_abbreviations.values()))
155  spectators_data = basf2_mva_util.tree2dict(rootchain, root_spectators, list(spectator_abbreviations.values()))
156 
157  if args.fillnan:
158  for column in variable_abbreviations.values():
159  np.nan_to_num(variables_data[column], copy=False)
160 
161  for column in spectator_abbreviations.values():
162  np.nan_to_num(spectators_data[column], copy=False)
163 
164  print("Create latex file")
165  # Change working directory after experts run, because they might want to access
166  # a localdb in the current working directory.
167  with tempfile.TemporaryDirectory() as tempdir:
168  if args.working_directory == '':
169  os.chdir(tempdir)
170  else:
171  os.chdir(args.working_directory)
172 
173  o = b2latex.LatexFile()
174  o += b2latex.TitlePage(title='Automatic MVA Evaluation',
175  authors=[r'Thomas Keck\\ Moritz Gelb\\ Nils Braun'],
176  abstract='Evaluation plots',
177  add_table_of_contents=True).finish()
178 
179  o += b2latex.Section("Classifiers")
180  o += b2latex.String(r"""
181  This section contains the GeneralOptions and SpecificOptions of all classifiers represented by an XML tree.
182  The same information can be retrieved using the basf2\_mva\_info tool.
183  """)
184 
185  table = b2latex.LongTable(r"ll", "Abbreviations of identifiers", "{name} & {abbr}", r"Identifier & Abbreviation")
186  for identifier in identifiers:
187  table.add(name=format.string(identifier), abbr=format.string(identifier_abbreviations[identifier]))
188  o += table.finish()
189 
190  for method in methods:
191  o += b2latex.SubSection(format.string(method.identifier))
192  o += b2latex.Listing(language='XML').add(method.description).finish()
193 
194  o += b2latex.Section("Variables")
195  o += b2latex.String("""
196  This section contains an overview of the importance and correlation of the variables used by the classifiers.
197  And distribution plots of the variables on the independent dataset. The distributions are normed for signal and
198  background separately, and only the region +- 3 sigma around the mean is shown.
199  """)
200 
201  table = b2latex.LongTable(r"ll", "Abbreviations of variables", "{name} & {abbr}", r"Variable & Abbreviation")
202  for v in variables:
203  table.add(name=format.string(v), abbr=format.string(variable_abbreviations[v]))
204  o += table.finish()
205 
206  o += b2latex.SubSection("Importance")
207  graphics = b2latex.Graphics()
208  p = plotting.Importance()
209  p.add({identifier_abbreviations[i.identifier]: np.array([i.importances.get(v, 0.0) for v in variables]) for i in methods},
210  identifier_abbreviations.values(), variable_abbreviations.values())
211  p.finish()
212  p.save('importance.pdf')
213  graphics.add('importance.pdf', width=1.0)
214  o += graphics.finish()
215 
216  o += b2latex.SubSection("Correlation")
217  first_identifier_abbr = list(identifier_abbreviations.values())[0]
218  graphics = b2latex.Graphics()
220  p.add(variables_data, variable_abbreviations.values(),
221  test_target[first_identifier_abbr] == 1,
222  test_target[first_identifier_abbr] == 0)
223  p.finish()
224  p.save('correlation_plot.pdf')
225  graphics.add('correlation_plot.pdf', width=1.0)
226  o += graphics.finish()
227 
228  for v in variables:
229  variable_abbr = variable_abbreviations[v]
230  o += b2latex.SubSection(format.string(v))
231  graphics = b2latex.Graphics()
232  p = plotting.VerboseDistribution(normed=True, range_in_std=3)
233  p.add(variables_data, variable_abbr, test_target[first_identifier_abbr] == 1, label="Signal")
234  p.add(variables_data, variable_abbr, test_target[first_identifier_abbr] == 0, label="Background")
235  p.finish()
236  p.save('variable_{}.pdf'.format(hash(v)))
237  graphics.add('variable_{}.pdf'.format(hash(v)), width=1.0)
238  o += graphics.finish()
239 
240  o += b2latex.Section("Classifier Plot")
241  o += b2latex.String("This section contains the receiver operating characteristics (ROC), purity projection, ..."
242  "of the classifiers on training and independent data."
243  "The legend of each plot contains the shortened identifier and the area under the ROC curve"
244  "in parenthesis.")
245 
246  o += b2latex.Section("ROC Plot")
247  graphics = b2latex.Graphics()
249  for identifier in identifier_abbreviations.values():
250  p.add(test_probability, identifier, test_target[identifier] == 1, test_target[identifier] == 0)
251  p.finish()
252  p.axis.set_title("ROC Rejection Plot on independent data")
253  p.save('roc_plot_test.pdf')
254  graphics.add('roc_plot_test.pdf', width=1.0)
255  o += graphics.finish()
256 
257  if train_probability:
258  for i, identifier in enumerate(identifiers):
259  graphics = b2latex.Graphics()
261  identifier_abbr = identifier_abbreviations[identifier]
262  p.add(train_probability, identifier_abbr, train_target[identifier_abbr] == 1,
263  train_target[identifier_abbr] == 0, label='Train')
264  p.add(test_probability, identifier_abbr, test_target[identifier_abbr] == 1,
265  test_target[identifier_abbr] == 0, label='Test')
266  p.finish()
267  p.axis.set_title(identifier)
268  p.save('roc_test_{}.pdf'.format(hash(identifier)))
269  graphics.add('roc_test_{}.pdf'.format(hash(identifier)), width=1.0)
270  o += graphics.finish()
271 
272  o += b2latex.Section("Classification Results")
273 
274  for identifier in identifiers:
275  identifier_abbr = identifier_abbreviations[identifier]
276  o += b2latex.SubSection(format.string(identifier_abbr))
277  graphics = b2latex.Graphics()
279  p.add(0, test_probability, identifier_abbr, test_target[identifier_abbr] == 1,
280  test_target[identifier_abbr] == 0, normed=True)
281  p.sub_plots[0].axis.set_title("Classification result in test data for {identifier}".format(identifier=identifier))
282 
283  p.add(1, test_probability, identifier_abbr, test_target[identifier_abbr] == 1,
284  test_target[identifier_abbr] == 0, normed=False)
285  p.sub_plots[1].axis.set_title("Classification result in test data for {identifier}".format(identifier=identifier))
286  p.finish()
287 
288  p.save('classification_result_{identifier}.pdf'.format(identifier=hash(identifier)))
289  graphics.add('classification_result_{identifier}.pdf'.format(identifier=hash(identifier)), width=1)
290  o += graphics.finish()
291 
292  o += b2latex.Section("Diagonal Plot")
293  graphics = b2latex.Graphics()
294  p = plotting.Diagonal()
295  for identifier in identifiers:
296  o += b2latex.SubSection(format.string(identifier_abbr))
297  identifier_abbr = identifier_abbreviations[identifier]
298  p.add(test_probability, identifier_abbr, test_target[identifier_abbr] == 1, test_target[identifier_abbr] == 0)
299  p.finish()
300  p.axis.set_title("Diagonal plot on independent data")
301  p.save('diagonal_plot_test.pdf')
302  graphics.add('diagonal_plot_test.pdf', width=1.0)
303  o += graphics.finish()
304 
305  if train_probability:
306  o += b2latex.SubSection("Overtraining Plot")
307  for identifier in identifiers:
308  identifier_abbr = identifier_abbreviations[identifier]
309  probability = {identifier_abbr: np.r_[train_probability[identifier_abbr], test_probability[identifier_abbr]]}
310  target = np.r_[train_target[identifier_abbr], test_target[identifier_abbr]]
311  train_mask = np.r_[np.ones(len(train_target[identifier_abbr])), np.zeros(len(test_target[identifier_abbr]))]
312  graphics = b2latex.Graphics()
314  p.add(probability, identifier_abbr,
315  train_mask == 1, train_mask == 0,
316  target == 1, target == 0, )
317  p.finish()
318  p.axis.set_title("Overtraining check for {}".format(identifier))
319  p.save('overtraining_plot_{}.pdf'.format(hash(identifier)))
320  graphics.add('overtraining_plot_{}.pdf'.format(hash(identifier)), width=1.0)
321  o += graphics.finish()
322 
323  o += b2latex.Section("Spectators")
324  o += b2latex.String("This section contains the distribution and dependence on the"
325  "classifier outputs of all spectator variables.")
326 
327  table = b2latex.LongTable(r"ll", "Abbreviations of spectators", "{name} & {abbr}", r"Spectator & Abbreviation")
328  for s in spectators:
329  table.add(name=format.string(s), abbr=format.string(spectator_abbreviations[s]))
330  o += table.finish()
331 
332  for spectator in spectators:
333  spectator_abbr = spectator_abbreviations[spectator]
334  o += b2latex.SubSection(format.string(spectator))
335  graphics = b2latex.Graphics()
337  p.add(spectators_data, spectator_abbr, test_target[first_identifier_abbr] == 1, label="Signal")
338  p.add(spectators_data, spectator_abbr, test_target[first_identifier_abbr] == 0, label="Background")
339  p.finish()
340  p.save('spectator_{}.pdf'.format(hash(spectator)))
341  graphics.add('spectator_{}.pdf'.format(hash(spectator)), width=1.0)
342  o += graphics.finish()
343 
344  for identifier in identifiers:
345  o += b2latex.SubSubSection(format.string(spectator) + " with classifier " + format.string(identifier))
346  identifier_abbr = identifier_abbreviations[identifier]
347  data = {identifier_abbr: test_probability[identifier_abbr], spectator_abbr: spectators_data[spectator_abbr]}
348  graphics = b2latex.Graphics()
350  p.add(data, spectator_abbr, identifier_abbr, list(range(10, 100, 10)),
351  test_target[identifier_abbr] == 1,
352  test_target[identifier_abbr] == 0)
353  p.finish()
354  p.save('correlation_plot_{}_{}.pdf'.format(hash(spectator), hash(identifier)))
355  graphics.add('correlation_plot_{}_{}.pdf'.format(hash(spectator), hash(identifier)), width=1.0)
356  o += graphics.finish()
357 
358  if args.compile:
359  B2INFO(f"Creating a PDF file at {args.outputfile}. Please remove the '-c' switch if this fails.")
360  o.save('latex.tex', compile=True)
361  else:
362  B2INFO(f"Creating a .zip archive containing plots and a TeX file at {args.outputfile}."
363  f"Please unpack the archive and compile the latex.tex file with pdflatex.")
364  o.save('latex.tex', compile=False)
365 
366  os.chdir(old_cwd)
367  if args.working_directory == '':
368  working_directory = tempdir
369  else:
370  working_directory = args.working_directory
371 
372  if args.compile:
373  shutil.copy(os.path.join(working_directory, 'latex.pdf'), args.outputfile)
374  else:
375  base_name = os.path.join(old_cwd, args.outputfile.rsplit('.', 1)[0])
376  shutil.make_archive(base_name, 'zip', working_directory)
def tree2dict(tree, tree_columns, dict_columns=None)