Belle II Software  release-08-01-10
basf2_mva_evaluate.py
1 #!/usr/bin/env python3
2 
3 
10 
11 import basf2_mva_util
12 
13 from basf2_mva_evaluation import plotting
14 from basf2 import conditions
15 import argparse
16 import tempfile
17 
18 import numpy as np
19 from B2Tools import b2latex, format
20 from basf2 import B2INFO
21 
22 import os
23 import shutil
24 import collections
25 from typing import List, Any
26 
27 
28 def get_argument_parser() -> argparse.ArgumentParser:
29  """ Parses the command line options of the fei and returns the corresponding arguments. """
30  parser = argparse.ArgumentParser()
31  parser.add_argument('-id', '--identifiers', dest='identifiers', type=str, required=True, action='append', nargs='+',
32  help='DB Identifier or weightfile')
33  parser.add_argument('-train', '--train_datafiles', dest='train_datafiles', type=str, required=False, action='append', nargs='+',
34  help='Data file containing ROOT TTree used during training')
35  parser.add_argument('-data', '--datafiles', dest='datafiles', type=str, required=True, action='append', nargs='+',
36  help='Data file containing ROOT TTree with independent test data')
37  parser.add_argument('-tree', '--treename', dest='treename', type=str, default='tree', help='Treename in data file')
38  parser.add_argument('-out', '--outputfile', dest='outputfile', type=str, default='output.zip',
39  help='Name of the created .zip archive file if not compiling or a pdf file if compilation is successful.')
40  parser.add_argument('-w', '--working_directory', dest='working_directory', type=str, default='',
41  help="""Working directory where the created images and root files are stored,
42  default is to create a temporary directory.""")
43  parser.add_argument('-l', '--localdb', dest='localdb', type=str, action='append', nargs='+', required=False,
44  help="""path or list of paths to local database(s) containing the mvas of interest.
45  The testing payloads are preprended and take precedence over payloads in global tags.""")
46  parser.add_argument('-g', '--globaltag', dest='globaltag', type=str, action='append', nargs='+', required=False,
47  help='globaltag or list of globaltags containing the mvas of interest. The globaltags are prepended.')
48  parser.add_argument('-n', '--fillnan', dest='fillnan', action='store_true',
49  help='Fill nan and inf values with actual numbers')
50  parser.add_argument('-c', '--compile', dest='compile', action='store_true',
51  help='Compile latex to pdf directly')
52  parser.add_argument('-a', '--abbreviation_length', dest='abbreviation_length',
53  action='store', type=int, default=5,
54  help='Number of characters to which variable names are abbreviated.')
55  return parser
56 
57 
58 def unique(input_list: List[Any]) -> List[Any]:
59  """
60  Returns a list containing only unique elements, keeps the original order of the list
61  @param input_list list containing the elements
62  """
63  output = []
64  for x in input_list:
65  if x not in output:
66  output.append(x)
67  return output
68 
69 
70 def flatten(input_list: List[List[Any]]) -> List[Any]:
71  """
72  Flattens a list of lists
73  @param input_list list of lists to be flattened
74  """
75  return [item for sublist in input_list for item in sublist]
76 
77 
78 def create_abbreviations(names, length=5):
79  count = dict()
80  for name in names:
81  abbreviation = name[:length]
82  if abbreviation not in count:
83  count[abbreviation] = 0
84  count[abbreviation] += 1
85  abbreviations = collections.OrderedDict()
86 
87  count2 = dict()
88  for name in names:
89  abbreviation = name[:length]
90  abbreviations[name] = abbreviation
91  if count[abbreviation] > 1:
92  if abbreviation not in count2:
93  count2[abbreviation] = 0
94  count2[abbreviation] += 1
95  abbreviations[name] += str(count2[abbreviation])
96  return abbreviations
97 
98 
99 if __name__ == '__main__':
100 
101  import ROOT # noqa
102  ROOT.PyConfig.IgnoreCommandLineOptions = True
103  ROOT.PyConfig.StartGuiThread = False
104  ROOT.gROOT.SetBatch(True)
105 
106  old_cwd = os.getcwd()
107  parser = get_argument_parser()
108  args = parser.parse_args()
109 
110  identifiers = flatten(args.identifiers)
111  identifier_abbreviations = create_abbreviations(identifiers, args.abbreviation_length)
112 
113  datafiles = flatten(args.datafiles)
114  if args.localdb is not None:
115  for localdb in flatten(args.localdb):
116  conditions.prepend_testing_payloads(localdb)
117 
118  if args.globaltag is not None:
119  for tag in flatten(args.globaltag):
120  conditions.prepend_globaltag(tag)
121 
122  print("Load methods")
123  methods = [basf2_mva_util.Method(identifier) for identifier in identifiers]
124 
125  print("Apply experts on independent data")
126  test_probability = {}
127  test_target = {}
128  for method in methods:
129  p, t = method.apply_expert(datafiles, args.treename)
130  test_probability[identifier_abbreviations[method.identifier]] = p
131  test_target[identifier_abbreviations[method.identifier]] = t
132 
133  print("Apply experts on training data")
134  train_probability = {}
135  train_target = {}
136  if args.train_datafiles is not None:
137  train_datafiles = sum(args.train_datafiles, [])
138  for method in methods:
139  p, t = method.apply_expert(train_datafiles, args.treename)
140  train_probability[identifier_abbreviations[method.identifier]] = p
141  train_target[identifier_abbreviations[method.identifier]] = t
142 
143  variables = unique(v for method in methods for v in method.variables)
144  variable_abbreviations = create_abbreviations(variables, args.abbreviation_length)
145  root_variables = unique(v for method in methods for v in method.root_variables)
146 
147  spectators = unique(v for method in methods for v in method.spectators)
148  spectator_abbreviations = create_abbreviations(spectators, args.abbreviation_length)
149  root_spectators = unique(v for method in methods for v in method.root_spectators)
150 
151  print("Load variables array")
152  rootchain = ROOT.TChain(args.treename)
153  for datafile in datafiles:
154  rootchain.Add(datafile)
155 
156  variables_data = basf2_mva_util.chain2dict(rootchain, root_variables, list(variable_abbreviations.values()))
157  spectators_data = basf2_mva_util.chain2dict(rootchain, root_spectators, list(spectator_abbreviations.values()))
158 
159  if args.fillnan:
160  for column in variable_abbreviations.values():
161  np.nan_to_num(variables_data[column], copy=False)
162 
163  for column in spectator_abbreviations.values():
164  np.nan_to_num(spectators_data[column], copy=False)
165 
166  print("Create latex file")
167  # Change working directory after experts run, because they might want to access
168  # a localdb in the current working directory.
169  with tempfile.TemporaryDirectory() as tempdir:
170  if args.working_directory == '':
171  os.chdir(tempdir)
172  else:
173  os.chdir(args.working_directory)
174 
175  with open('abbreviations.txt', 'w') as f:
176  f.write('Identifier Abbreviation : Identifier \n')
177  for name, abbrev in identifier_abbreviations.items():
178  f.write(f'\t{abbrev} : {name}\n')
179  f.write('\n\n\nVariable Abbreviation : Variable \n')
180  for name, abbrev in variable_abbreviations.items():
181  f.write(f'\t{abbrev} : {name}\n')
182  f.write('\n\n\nSpectator Abbreviation : Spectator \n')
183  for name, abbrev in spectator_abbreviations.items():
184  f.write(f'\t{abbrev} : {name}\n')
185 
186  o = b2latex.LatexFile()
187  o += b2latex.TitlePage(title='Automatic MVA Evaluation',
188  authors=[r'Thomas Keck\\ Moritz Gelb\\ Nils Braun'],
189  abstract='Evaluation plots',
190  add_table_of_contents=True).finish()
191 
192  o += b2latex.Section("Classifiers")
193  o += b2latex.String(r"""
194  This section contains the GeneralOptions and SpecificOptions of all classifiers represented by an XML tree.
195  The same information can be retrieved using the basf2\_mva\_info tool.
196  """)
197 
198  table = b2latex.LongTable(r"ll", "Abbreviations of identifiers", "{name} & {abbr}", r"Identifier & Abbreviation")
199  for identifier in identifiers:
200  table.add(name=format.string(identifier), abbr=format.string(identifier_abbreviations[identifier]))
201  o += table.finish()
202 
203  for method in methods:
204  o += b2latex.SubSection(format.string(method.identifier))
205  o += b2latex.Listing(language='XML').add(method.description).finish()
206 
207  o += b2latex.Section("Variables")
208  o += b2latex.String("""
209  This section contains an overview of the importance and correlation of the variables used by the classifiers.
210  And distribution plots of the variables on the independent dataset. The distributions are normed for signal and
211  background separately, and only the region +- 3 sigma around the mean is shown.
212 
213  The importance scores shown are based on the variable importance as estimated by each MVA method internally.
214  This means the variable with the lowest importance will have score 0, and the variable
215  with the highest importance will have score 100. If the method does not provide such a ranking, all
216  importances will be 0.
217  """)
218 
219  table = b2latex.LongTable(r"ll", "Abbreviations of variables", "{name} & {abbr}", r"Variable & Abbreviation")
220  for v in variables:
221  table.add(name=format.string(v), abbr=format.string(variable_abbreviations[v]))
222  o += table.finish()
223 
224  o += b2latex.SubSection("Importance")
225  graphics = b2latex.Graphics()
226  p = plotting.Importance()
227  p.add({identifier_abbreviations[i.identifier]: np.array([i.importances.get(v, 0.0) for v in variables]) for i in methods},
228  identifier_abbreviations.values(), variable_abbreviations.values())
229  p.finish()
230  p.save('importance.pdf')
231  graphics.add('importance.pdf', width=1.0)
232  o += graphics.finish()
233 
234  o += b2latex.SubSection("Correlation")
235  first_identifier_abbr = list(identifier_abbreviations.values())[0]
236  graphics = b2latex.Graphics()
238  p.add(variables_data, variable_abbreviations.values(),
239  test_target[first_identifier_abbr] == 1,
240  test_target[first_identifier_abbr] == 0)
241  p.finish()
242  p.save('correlation_plot.pdf')
243  graphics.add('correlation_plot.pdf', width=1.0)
244  o += graphics.finish()
245 
246  for v in variables:
247  variable_abbr = variable_abbreviations[v]
248  o += b2latex.SubSection(format.string(v))
249  graphics = b2latex.Graphics()
250  p = plotting.VerboseDistribution(normed=True, range_in_std=3)
251  p.add(variables_data, variable_abbr, test_target[first_identifier_abbr] == 1, label="Signal")
252  p.add(variables_data, variable_abbr, test_target[first_identifier_abbr] == 0, label="Background")
253  p.finish()
254  p.save(f'variable_{hash(v)}.pdf')
255  graphics.add(f'variable_{hash(v)}.pdf', width=1.0)
256  o += graphics.finish()
257 
258  o += b2latex.Section("Classifier Plot")
259  o += b2latex.String("This section contains the receiver operating characteristics (ROC), purity projection, ..."
260  "of the classifiers on training and independent data."
261  "The legend of each plot contains the shortened identifier and the area under the ROC curve"
262  "in parenthesis.")
263 
264  o += b2latex.Section("ROC Plot")
265  graphics = b2latex.Graphics()
267  for identifier in identifier_abbreviations.values():
268  p.add(test_probability, identifier, test_target[identifier] == 1, test_target[identifier] == 0)
269  p.finish()
270  p.axis.set_title("ROC Rejection Plot on independent data")
271  p.save('roc_plot_test.pdf')
272  graphics.add('roc_plot_test.pdf', width=1.0)
273  o += graphics.finish()
274 
275  if train_probability:
276  for i, identifier in enumerate(identifiers):
277  graphics = b2latex.Graphics()
279  identifier_abbr = identifier_abbreviations[identifier]
280  p.add(train_probability, identifier_abbr, train_target[identifier_abbr] == 1,
281  train_target[identifier_abbr] == 0, label='Train')
282  p.add(test_probability, identifier_abbr, test_target[identifier_abbr] == 1,
283  test_target[identifier_abbr] == 0, label='Test')
284  p.finish()
285  p.axis.set_title(identifier)
286  p.save(f'roc_test_{hash(identifier)}.pdf')
287  graphics.add(f'roc_test_{hash(identifier)}.pdf', width=1.0)
288  o += graphics.finish()
289 
290  o += b2latex.Section("Classification Results")
291 
292  for identifier in identifiers:
293  identifier_abbr = identifier_abbreviations[identifier]
294  o += b2latex.SubSection(format.string(identifier_abbr))
295  graphics = b2latex.Graphics()
297  p.add(0, test_probability, identifier_abbr, test_target[identifier_abbr] == 1,
298  test_target[identifier_abbr] == 0, normed=True)
299  p.sub_plots[0].axis.set_title(f"Classification result in test data for {identifier}")
300 
301  p.add(1, test_probability, identifier_abbr, test_target[identifier_abbr] == 1,
302  test_target[identifier_abbr] == 0, normed=False)
303  p.sub_plots[1].axis.set_title(f"Classification result in test data for {identifier}")
304  p.finish()
305 
306  p.save(f'classification_result_{hash(identifier)}.pdf')
307  graphics.add(f'classification_result_{hash(identifier)}.pdf', width=1)
308  o += graphics.finish()
309 
310  o += b2latex.Section("Diagonal Plot")
311  graphics = b2latex.Graphics()
312  p = plotting.Diagonal()
313  for identifier in identifiers:
314  o += b2latex.SubSection(format.string(identifier_abbr))
315  identifier_abbr = identifier_abbreviations[identifier]
316  p.add(test_probability, identifier_abbr, test_target[identifier_abbr] == 1, test_target[identifier_abbr] == 0)
317  p.finish()
318  p.axis.set_title("Diagonal plot on independent data")
319  p.save('diagonal_plot_test.pdf')
320  graphics.add('diagonal_plot_test.pdf', width=1.0)
321  o += graphics.finish()
322 
323  if train_probability:
324  o += b2latex.SubSection("Overtraining Plot")
325  for identifier in identifiers:
326  identifier_abbr = identifier_abbreviations[identifier]
327  probability = {identifier_abbr: np.r_[train_probability[identifier_abbr], test_probability[identifier_abbr]]}
328  target = np.r_[train_target[identifier_abbr], test_target[identifier_abbr]]
329  train_mask = np.r_[np.ones(len(train_target[identifier_abbr])), np.zeros(len(test_target[identifier_abbr]))]
330  graphics = b2latex.Graphics()
332  p.add(probability, identifier_abbr,
333  train_mask == 1, train_mask == 0,
334  target == 1, target == 0, )
335  p.finish()
336  p.axis.set_title(f"Overtraining check for {identifier}")
337  p.save(f'overtraining_plot_{hash(identifier)}.pdf')
338  graphics.add(f'overtraining_plot_{hash(identifier)}.pdf', width=1.0)
339  o += graphics.finish()
340 
341  o += b2latex.Section("Spectators")
342  o += b2latex.String("This section contains the distribution and dependence on the"
343  "classifier outputs of all spectator variables.")
344 
345  table = b2latex.LongTable(r"ll", "Abbreviations of spectators", "{name} & {abbr}", r"Spectator & Abbreviation")
346  for s in spectators:
347  table.add(name=format.string(s), abbr=format.string(spectator_abbreviations[s]))
348  o += table.finish()
349 
350  for spectator in spectators:
351  spectator_abbr = spectator_abbreviations[spectator]
352  o += b2latex.SubSection(format.string(spectator))
353  graphics = b2latex.Graphics()
355  p.add(spectators_data, spectator_abbr, test_target[first_identifier_abbr] == 1, label="Signal")
356  p.add(spectators_data, spectator_abbr, test_target[first_identifier_abbr] == 0, label="Background")
357  p.finish()
358  p.save(f'spectator_{hash(spectator)}.pdf')
359  graphics.add(f'spectator_{hash(spectator)}.pdf', width=1.0)
360  o += graphics.finish()
361 
362  for identifier in identifiers:
363  o += b2latex.SubSubSection(format.string(spectator) + " with classifier " + format.string(identifier))
364  identifier_abbr = identifier_abbreviations[identifier]
365  data = {identifier_abbr: test_probability[identifier_abbr], spectator_abbr: spectators_data[spectator_abbr]}
366  graphics = b2latex.Graphics()
368  p.add(data, spectator_abbr, identifier_abbr, list(range(10, 100, 10)),
369  test_target[identifier_abbr] == 1,
370  test_target[identifier_abbr] == 0)
371  p.finish()
372  p.save(f'correlation_plot_{hash(spectator)}_{hash(identifier)}.pdf')
373  graphics.add(f'correlation_plot_{hash(spectator)}_{hash(identifier)}.pdf', width=1.0)
374  o += graphics.finish()
375 
376  if args.compile:
377  B2INFO(f"Creating a PDF file at {args.outputfile}. Please remove the '-c' switch if this fails.")
378  o.save('latex.tex', compile=True)
379  else:
380  B2INFO(f"Creating a .zip archive containing plots and a TeX file at {args.outputfile}."
381  f"Please unpack the archive and compile the latex.tex file with pdflatex.")
382  o.save('latex.tex', compile=False)
383 
384  os.chdir(old_cwd)
385  if args.working_directory == '':
386  working_directory = tempdir
387  else:
388  working_directory = args.working_directory
389 
390  if args.compile:
391  shutil.copy(os.path.join(working_directory, 'latex.pdf'), args.outputfile)
392  else:
393  base_name = os.path.join(old_cwd, args.outputfile.rsplit('.', 1)[0])
394  shutil.make_archive(base_name, 'zip', working_directory)
def chain2dict(chain, tree_columns, dict_columns=None)