Belle II Software development
basf2_mva_evaluate.py
1#!/usr/bin/env python3
2
3
10
11import basf2_mva_util
12
13from basf2_mva_evaluation import plotting
14from basf2 import conditions
15import argparse
16import tempfile
17
18import numpy as np
19from B2Tools import b2latex, format
20from basf2 import B2INFO
21
22import os
23import shutil
24import collections
25from typing import List, Any
26
27
28def get_argument_parser() -> argparse.ArgumentParser:
29 """ Parses the command line options of the fei and returns the corresponding arguments. """
30 parser = argparse.ArgumentParser()
31 parser.add_argument('-id', '--identifiers', dest='identifiers', type=str, required=True, action='append', nargs='+',
32 help='DB Identifier or weightfile')
33 parser.add_argument('-train', '--train_datafiles', dest='train_datafiles', type=str, required=False, action='append', nargs='+',
34 help='Data file containing ROOT TTree used during training')
35 parser.add_argument('-data', '--datafiles', dest='datafiles', type=str, required=True, action='append', nargs='+',
36 help='Data file containing ROOT TTree with independent test data')
37 parser.add_argument('-tree', '--treename', dest='treename', type=str, default='tree', help='Treename in data file')
38 parser.add_argument('-out', '--outputfile', dest='outputfile', type=str, default='output.zip',
39 help='Name of the created .zip archive file if not compiling or a pdf file if compilation is successful.')
40 parser.add_argument('-w', '--working_directory', dest='working_directory', type=str, default='',
41 help="""Working directory where the created images and root files are stored,
42 default is to create a temporary directory.""")
43 parser.add_argument('-l', '--localdb', dest='localdb', type=str, action='append', nargs='+', required=False,
44 help="""path or list of paths to local database(s) containing the mvas of interest.
45 The testing payloads are preprended and take precedence over payloads in global tags.""")
46 parser.add_argument('-g', '--globaltag', dest='globaltag', type=str, action='append', nargs='+', required=False,
47 help='globaltag or list of globaltags containing the mvas of interest. The globaltags are prepended.')
48 parser.add_argument('-n', '--fillnan', dest='fillnan', action='store_true',
49 help='Fill nan and inf values with actual numbers')
50 parser.add_argument('-c', '--compile', dest='compile', action='store_true',
51 help='Compile latex to pdf directly')
52 parser.add_argument('-a', '--abbreviation_length', dest='abbreviation_length',
53 action='store', type=int, default=10,
54 help='Number of characters to which variable names are abbreviated.')
55 return parser
56
57
58def unique(input_list: List[Any]) -> List[Any]:
59 """
60 Returns a list containing only unique elements, keeps the original order of the list
61 @param input_list list containing the elements
62 """
63 output = []
64 for x in input_list:
65 if x not in output:
66 output.append(x)
67 return output
68
69
70def flatten(input_list: List[List[Any]]) -> List[Any]:
71 """
72 Flattens a list of lists
73 @param input_list list of lists to be flattened
74 """
75 return [item for sublist in input_list for item in sublist]
76
77
78def smart_abbreviation(name):
79 shortName = name
80 shortName = shortName.replace("daughter", "d")
81 shortName = shortName.replace("Angle", "Ang")
82 shortName = shortName.replace("useCMSFrame", "")
83 shortName = shortName.replace("useLabFrame", "")
84 shortName = shortName.replace("useRestFrame", "")
85 shortName = shortName.replace("formula", "")
86 shortName = shortName.replace("(", "")
87 shortName = shortName.replace(")", "")
88 shortName = shortName.replace("conditionalVariableSelector", "")
89 shortName = shortName.replace(",", "")
90 shortName = shortName.replace(" ", "")
91 return shortName
92
93
94def create_abbreviations(names, length=5):
95 count = dict()
96 for name in names:
97 abbreviation = smart_abbreviation(name)[:length]
98 if abbreviation not in count:
99 count[abbreviation] = 0
100 count[abbreviation] += 1
101 abbreviations = collections.OrderedDict()
102
103 count2 = dict()
104 for name in names:
105 abbreviation = smart_abbreviation(name)[:length]
106 abbreviations[name] = abbreviation
107 if count[abbreviation] > 1:
108 if abbreviation not in count2:
109 count2[abbreviation] = 0
110 count2[abbreviation] += 1
111 abbreviations[name] += str(count2[abbreviation])
112 return abbreviations
113
114
115if __name__ == '__main__':
116
117 import ROOT # noqa
118 ROOT.PyConfig.IgnoreCommandLineOptions = True
119 ROOT.PyConfig.StartGuiThread = False
120 ROOT.gROOT.SetBatch(True)
121
122 old_cwd = os.getcwd()
123 parser = get_argument_parser()
124 args = parser.parse_args()
125
126 identifiers = flatten(args.identifiers)
127 identifier_abbreviations = create_abbreviations(identifiers, args.abbreviation_length)
128
129 datafiles = flatten(args.datafiles)
130 if args.localdb is not None:
131 for localdb in flatten(args.localdb):
132 conditions.prepend_testing_payloads(localdb)
133
134 if args.globaltag is not None:
135 for tag in flatten(args.globaltag):
136 conditions.prepend_globaltag(tag)
137
138 print("Load methods")
139 methods = [basf2_mva_util.Method(identifier) for identifier in identifiers]
140
141 print("Apply experts on independent data")
142 test_probability = {}
143 test_target = {}
144 for method in methods:
145 p, t = method.apply_expert(datafiles, args.treename)
146 test_probability[identifier_abbreviations[method.identifier]] = p
147 test_target[identifier_abbreviations[method.identifier]] = t
148
149 print("Apply experts on training data")
150 train_probability = {}
151 train_target = {}
152 if args.train_datafiles is not None:
153 train_datafiles = sum(args.train_datafiles, [])
154 for method in methods:
155 p, t = method.apply_expert(train_datafiles, args.treename)
156 train_probability[identifier_abbreviations[method.identifier]] = p
157 train_target[identifier_abbreviations[method.identifier]] = t
158
159 variables = unique(v for method in methods for v in method.variables)
160 variable_abbreviations = create_abbreviations(variables, args.abbreviation_length)
161 root_variables = unique(v for method in methods for v in method.root_variables)
162
163 spectators = unique(v for method in methods for v in method.spectators)
164 spectator_abbreviations = create_abbreviations(spectators, args.abbreviation_length)
165 root_spectators = unique(v for method in methods for v in method.root_spectators)
166
167 print("Load variables array")
168 rootchain = ROOT.TChain(args.treename)
169 rootchain_spec = ROOT.TChain(args.treename)
170 for datafile in datafiles:
171 rootchain.Add(datafile)
172 rootchain_spec.Add(datafile)
173
174 variables_data = basf2_mva_util.chain2dict(rootchain, root_variables, list(variable_abbreviations.values()))
175 spectators_data = basf2_mva_util.chain2dict(rootchain, root_spectators, list(spectator_abbreviations.values()))
176 varSpec_data = basf2_mva_util.chain2dict(
177 rootchain_spec,
178 root_variables +
179 root_spectators,
180 list(
181 variable_abbreviations.values()) +
182 list(
183 spectator_abbreviations.values()))
184
185 if train_probability:
186 rootchain_train = ROOT.TChain(args.treename)
187 rootchain_train_spec = ROOT.TChain(args.treename)
188 for train_datafile in train_datafiles:
189 rootchain_train.Add(train_datafile)
190 rootchain_train_spec.Add(train_datafile)
191 variables_train_data = basf2_mva_util.chain2dict(rootchain_train, root_variables, list(variable_abbreviations.values()))
192 spectators_train_data = basf2_mva_util.chain2dict(rootchain_train, root_spectators, list(spectator_abbreviations.values()))
193 varSpec_train_data = basf2_mva_util.chain2dict(
194 rootchain_train_spec,
195 root_variables +
196 root_spectators,
197 list(
198 variable_abbreviations.values()) +
199 list(
200 spectator_abbreviations.values()))
201
202 if args.fillnan:
203 for column in variable_abbreviations.values():
204 np.nan_to_num(variables_data[column], copy=False)
205 np.nan_to_num(varSpec_data[column], copy=False)
206 if train_probability:
207 np.nan_to_num(variables_train_data[column], copy=False)
208 np.nan_to_num(varSpec_train_data[column], copy=False)
209
210 for column in spectator_abbreviations.values():
211 np.nan_to_num(spectators_data[column], copy=False)
212 np.nan_to_num(varSpec_data[column], copy=False)
213 if train_probability:
214 np.nan_to_num(spectators_train_data[column], copy=False)
215 np.nan_to_num(varSpec_train_data[column], copy=False)
216
217 print("Create latex file")
218 # Change working directory after experts run, because they might want to access
219 # a localdb in the current working directory.
220 with tempfile.TemporaryDirectory() as tempdir:
221 if args.working_directory == '':
222 os.chdir(tempdir)
223 else:
224 os.chdir(args.working_directory)
225
226 with open('abbreviations.txt', 'w') as f:
227 f.write('Identifier Abbreviation : Identifier \n')
228 for name, abbrev in identifier_abbreviations.items():
229 f.write(f'\t{abbrev} : {name}\n')
230 f.write('\n\n\nVariable Abbreviation : Variable \n')
231 for name, abbrev in variable_abbreviations.items():
232 f.write(f'\t{abbrev} : {name}\n')
233 f.write('\n\n\nSpectator Abbreviation : Spectator \n')
234 for name, abbrev in spectator_abbreviations.items():
235 f.write(f'\t{abbrev} : {name}\n')
236
237 o = b2latex.LatexFile()
238 o += b2latex.TitlePage(title='Automatic MVA Evaluation',
239 authors=[r'Thomas Keck\\ Moritz Gelb\\ Nils Braun'],
240 abstract='Evaluation plots',
241 add_table_of_contents=True).finish()
242
243 o += b2latex.Section("Classifiers")
244 o += b2latex.String(r"""
245 This section contains the GeneralOptions and SpecificOptions of all classifiers represented by an XML tree.
246 The same information can be retrieved using the basf2\_mva\_info tool.
247 """)
248
249 table = b2latex.LongTable(r"ll", "Abbreviations of identifiers", "{name} & {abbr}", r"Identifier & Abbreviation")
250 for identifier in identifiers:
251 table.add(name=format.string(identifier), abbr=format.string(identifier_abbreviations[identifier]))
252 o += table.finish()
253
254 for method in methods:
255 o += b2latex.SubSection(format.string(method.identifier))
256 o += b2latex.Listing(language='XML').add(method.description).finish()
257
258 o += b2latex.Section("Variables")
259 o += b2latex.String("""
260 This section contains an overview of the importance and correlation of the variables used by the classifiers.
261 And distribution plots of the variables on the independent dataset. The distributions are normed for signal and
262 background separately, and only the region +- 3 sigma around the mean is shown.
263
264 The importance scores shown are based on the variable importance as estimated by each MVA method internally.
265 This means the variable with the lowest importance will have score 0, and the variable
266 with the highest importance will have score 100. If the method does not provide such a ranking, all
267 importances will be 0.
268 """)
269
270 table = b2latex.LongTable(r"ll", "Abbreviations of variables", "{name} & {abbr}", r"Variable & Abbreviation")
271 for v in variables:
272 table.add(name=format.string(v), abbr=format.string(variable_abbreviations[v]))
273 o += table.finish()
274
275 o += b2latex.SubSection("Importance")
276 graphics = b2latex.Graphics()
278 p.add({identifier_abbreviations[i.identifier]: np.array([i.importances.get(v, 0.0) for v in variables]) for i in methods},
279 identifier_abbreviations.values(), variable_abbreviations.values())
280 p.finish()
281 p.save('importance.pdf')
282 graphics.add('importance.pdf', width=1.0)
283 o += graphics.finish()
284
285 o += b2latex.SubSection("Correlation")
286 first_identifier_abbr = list(identifier_abbreviations.values())[0]
287 graphics = b2latex.Graphics()
289 p.add(variables_data, variable_abbreviations.values(),
290 test_target[first_identifier_abbr] == 1,
291 test_target[first_identifier_abbr] == 0)
292 p.finish()
293 p.save('correlation_plot.pdf')
294 graphics.add('correlation_plot.pdf', width=1.0)
295 o += graphics.finish()
296
297 if train_probability:
298 o += b2latex.SubSection("Correlation on Training Data")
299 graphics = b2latex.Graphics()
301 p.add(variables_train_data, variable_abbreviations.values(),
302 train_target[first_identifier_abbr] == 1,
303 train_target[first_identifier_abbr] == 0)
304 p.finish()
305 p.save('correlation_plot_train.pdf')
306 graphics.add('correlation_plot_train.pdf', width=1.0)
307 o += graphics.finish()
308
309 for v in variables:
310 variable_abbr = variable_abbreviations[v]
311 o += b2latex.SubSection(format.string(v))
312 graphics = b2latex.Graphics()
313 p = plotting.VerboseDistribution(normed=True, range_in_std=3, x_axis_label=v)
314 p.add(variables_data, variable_abbr, test_target[first_identifier_abbr] == 1, label="Sig")
315 p.add(variables_data, variable_abbr, test_target[first_identifier_abbr] == 0, label="Bkg")
316 if train_probability:
317 p.add(variables_train_data, variable_abbr, train_target[first_identifier_abbr] == 1, label="Sig_train")
318 p.add(variables_train_data, variable_abbr, train_target[first_identifier_abbr] == 0, label="Bkg_train")
319 p.finish()
320 p.save(f'variable_{variable_abbr}_{hash(v)}.pdf')
321 graphics.add(f'variable_{variable_abbr}_{hash(v)}.pdf', width=1.0)
322 o += graphics.finish()
323
324 o += b2latex.Section("Classifier Plot")
325 o += b2latex.String("This section contains the receiver operating characteristics (ROC), purity projection, ..."
326 "of the classifiers on training and independent data."
327 "The legend of each plot contains the shortened identifier and the area under the ROC curve"
328 "in parenthesis.")
329 plot_classes = [
335 ]
336
337 for plot_class in plot_classes:
338 # Start section for each plot
339 o += b2latex.Section(f"{plot_class.__name__} Plot")
340
341 graphics = b2latex.Graphics()
342 p = plot_class()
343 for i, identifier in enumerate(identifiers):
344 identifier_abbr = identifier_abbreviations[identifier]
345 p.add(
346 test_probability,
347 identifier_abbr,
348 test_target[identifier_abbr] == 1,
349 test_target[identifier_abbr] == 0,
350 label=identifier_abbr)
351 p.finish()
352 p.axis.set_title(f"{plot_class.__name__} Plot on independent data")
353 p.save(f'{plot_class.__name__.lower()}_plot_test.pdf')
354 graphics.add(f'{plot_class.__name__.lower()}_plot_test.pdf', width=1.0)
355 o += graphics.finish()
356
357 if train_probability:
358 for i, identifier in enumerate(identifiers):
359 graphics = b2latex.Graphics()
360 p = plot_class()
361 identifier_abbr = identifier_abbreviations[identifier]
362 p.add(train_probability, identifier_abbr, train_target[identifier_abbr] == 1,
363 train_target[identifier_abbr] == 0, label=f'Train {identifier_abbr}')
364 p.add(test_probability, identifier_abbr, test_target[identifier_abbr] == 1,
365 test_target[identifier_abbr] == 0, label=f'Test {identifier_abbr}')
366 p.finish()
367 p.axis.set_title(f"{plot_class.__name__} Plot for \n" + identifier)
368 p.save(f'{plot_class.__name__.lower()}_plot_{hash(identifier)}.pdf')
369 graphics.add(f'{plot_class.__name__.lower()}_plot_{hash(identifier)}.pdf', width=1.0)
370 o += graphics.finish()
371
372 o += b2latex.Section("Classification Results")
373 for identifier in identifiers:
374 identifier_abbr = identifier_abbreviations[identifier]
375 o += b2latex.SubSection(format.string(identifier_abbr))
376 graphics = b2latex.Graphics()
377 if train_probability:
379 else:
381 p.add(0, test_probability, identifier_abbr, test_target[identifier_abbr] == 1,
382 test_target[identifier_abbr] == 0, normed=True)
383 p.sub_plots[0].axis.set_title(f"Classification result in test data for \n{identifier}")
384
385 p.add(1, test_probability, identifier_abbr, test_target[identifier_abbr] == 1,
386 test_target[identifier_abbr] == 0, normed=False)
387 p.sub_plots[1].axis.set_title(f"Classification result in test data for \n{identifier}")
388
389 if train_probability:
390 p.add(2, train_probability, identifier_abbr, train_target[identifier_abbr] == 1,
391 train_target[identifier_abbr] == 0, normed=True)
392 p.sub_plots[2].axis.set_title(f"Classification result in training data for \n{identifier}")
393
394 p.add(3, train_probability, identifier_abbr, train_target[identifier_abbr] == 1,
395 train_target[identifier_abbr] == 0, normed=False)
396 p.sub_plots[3].axis.set_title(f"Classification result in training data for \n{identifier}")
397
398 p.figure.subplots_adjust(wspace=0.3, hspace=0.3)
399 p.finish()
400 p.save(f'classification_result_{hash(identifier)}.pdf')
401 graphics.add(f'classification_result_{hash(identifier)}.pdf', width=1)
402 o += graphics.finish()
403
404 o += b2latex.Section("Diagonal Plot")
405 graphics = b2latex.Graphics()
407 for identifier in identifiers:
408 identifier_abbr = identifier_abbreviations[identifier]
409 p.add(
410 test_probability,
411 identifier_abbr,
412 test_target[identifier_abbr] == 1,
413 test_target[identifier_abbr] == 0,
414 label=identifier_abbr)
415 p.finish()
416 p.axis.set_title("Diagonal plot on independent data")
417 p.save('diagonal_plot_test.pdf')
418 graphics.add('diagonal_plot_test.pdf', width=1.0)
419 o += graphics.finish()
420
421 if train_probability:
422 for identifier in identifiers:
423 identifier_abbr = identifier_abbreviations[identifier]
424 o += b2latex.SubSection(format.string(identifier_abbr))
425 graphics = b2latex.Graphics()
427
428 p.add(
429 train_probability,
430 identifier_abbr,
431 train_target[identifier_abbr] == 1,
432 train_target[identifier_abbr] == 0,
433 label='Train')
434 p.add(
435 test_probability,
436 identifier_abbr,
437 test_target[identifier_abbr] == 1,
438 test_target[identifier_abbr] == 0,
439 label='Test')
440
441 p.finish()
442 p.axis.set_title("Diagonal plot for \n" + identifier)
443 p.save(f'diagonal_plot_{hash(identifier)}.pdf')
444 graphics.add(f'diagonal_plot_{hash(identifier)}.pdf', width=1.0)
445 o += graphics.finish()
446
447 if train_probability:
448 o += b2latex.Section("Overtraining Plot")
449 for identifier in identifiers:
450 identifier_abbr = identifier_abbreviations[identifier]
451 probability = {identifier_abbr: np.r_[train_probability[identifier_abbr], test_probability[identifier_abbr]]}
452 target = np.r_[train_target[identifier_abbr], test_target[identifier_abbr]]
453 train_mask = np.r_[np.ones(len(train_target[identifier_abbr])), np.zeros(len(test_target[identifier_abbr]))]
454 graphics = b2latex.Graphics()
456 p.add(probability, identifier_abbr,
457 train_mask == 1, train_mask == 0,
458 target == 1, target == 0, )
459 p.finish()
460 p.axis.set_title(f"Overtraining check for \n{identifier}")
461 p.save(f'overtraining_plot_{hash(identifier)}.pdf')
462 graphics.add(f'overtraining_plot_{hash(identifier)}.pdf', width=1.0)
463 o += graphics.finish()
464
465 o += b2latex.Section("Spectators")
466 o += b2latex.String("This section contains the distribution and dependence on the"
467 "classifier outputs of all spectator variables.")
468
469 table = b2latex.LongTable(r"ll", "Abbreviations of spectators", "{name} & {abbr}", r"Spectator & Abbreviation")
470 for s in spectators:
471 table.add(name=format.string(s), abbr=format.string(spectator_abbreviations[s]))
472 o += table.finish()
473
474 for spectator in spectators:
475 spectator_abbr = spectator_abbreviations[spectator]
476 o += b2latex.SubSection(format.string(spectator))
477 graphics = b2latex.Graphics()
479 p.add(spectators_data, spectator_abbr, test_target[first_identifier_abbr] == 1, label="Sig")
480 p.add(spectators_data, spectator_abbr, test_target[first_identifier_abbr] == 0, label="Bkg")
481 if train_probability:
482 p.add(spectators_train_data, spectator_abbr, train_target[first_identifier_abbr] == 1, label="Sig_train")
483 p.add(spectators_train_data, spectator_abbr, train_target[first_identifier_abbr] == 0, label="Bkg_train")
484 p.finish()
485 p.save(f'spectator_{spectator_abbr}_{hash(spectator)}.pdf')
486 graphics.add(f'spectator_{spectator_abbr}_{hash(spectator)}.pdf', width=1.0)
487 o += graphics.finish()
488
489 for identifier in identifiers:
490 o += b2latex.SubSubSection(format.string(spectator) + " with classifier " + format.string(identifier))
491 identifier_abbr = identifier_abbreviations[identifier]
492 data = {identifier_abbr: test_probability[identifier_abbr], spectator_abbr: spectators_data[spectator_abbr]}
493 graphics = b2latex.Graphics()
495 p.add(data, spectator_abbr, identifier_abbr, list(range(10, 100, 10)),
496 test_target[identifier_abbr] == 1,
497 test_target[identifier_abbr] == 0)
498 p.figure.subplots_adjust(hspace=0.5)
499 p.finish()
500 p.save(f'correlation_plot_{spectator_abbr}_{hash(spectator)}_{hash(identifier)}.pdf')
501 graphics.add(f'correlation_plot_{spectator_abbr}_{hash(spectator)}_{hash(identifier)}.pdf', width=1.0)
502 o += graphics.finish()
503
504 if train_probability:
505 o += b2latex.SubSubSection(format.string(spectator) + " with classifier " +
506 format.string(identifier) + " on training data")
507 data = {identifier_abbr: train_probability[identifier_abbr],
508 spectator_abbr: spectators_train_data[spectator_abbr]}
509 graphics = b2latex.Graphics()
511 p.add(data, spectator_abbr, identifier_abbr, list(range(10, 100, 10)),
512 train_target[identifier_abbr] == 1,
513 train_target[identifier_abbr] == 0)
514 p.figure.subplots_adjust(hspace=0.5)
515 p.finish()
516 p.save(f'correlation_plot_{spectator_abbr}_{hash(spectator)}_{hash(identifier)}_train.pdf')
517 graphics.add(f'correlation_plot_{spectator_abbr}_{hash(spectator)}_{hash(identifier)}_train.pdf', width=1.0)
518 o += graphics.finish()
519
520 if len(spectators) > 0:
521 o += b2latex.SubSection("Correlation of Spectators")
522 first_identifier_abbr = list(identifier_abbreviations.values())[0]
523 graphics = b2latex.Graphics()
525 p.add(
526 varSpec_data,
527 list(variable_abbreviations.values()) + list(spectator_abbreviations.values()),
528 test_target[first_identifier_abbr] == 1,
529 test_target[first_identifier_abbr] == 0
530 )
531 p.finish()
532 p.save('correlation_spec_plot.pdf')
533 graphics.add('correlation_spec_plot.pdf', width=1.0)
534 o += graphics.finish()
535
536 if train_probability:
537 o += b2latex.SubSection("Correlation of Spectators on Training Data")
538 graphics = b2latex.Graphics()
540 p.add(
541 varSpec_train_data,
542 list(variable_abbreviations.values()) + list(spectator_abbreviations.values()),
543 train_target[first_identifier_abbr] == 1,
544 train_target[first_identifier_abbr] == 0
545 )
546 p.finish()
547 p.save('correlation_spec_plot_train.pdf')
548 graphics.add('correlation_spec_plot_train.pdf', width=1.0)
549 o += graphics.finish()
550
551 if args.compile:
552 B2INFO(f"Creating a PDF file at {args.outputfile}. Please remove the '-c' switch if this fails.")
553 o.save('latex.tex', compile=True)
554 else:
555 B2INFO(f"Creating a .zip archive containing plots and a TeX file at {args.outputfile}."
556 f"Please unpack the archive and compile the latex.tex file with pdflatex.")
557 o.save('latex.tex', compile=False)
558
559 os.chdir(old_cwd)
560 if args.working_directory == '':
561 working_directory = tempdir
562 else:
563 working_directory = args.working_directory
564
565 if args.compile:
566 shutil.copy(os.path.join(working_directory, 'latex.pdf'), args.outputfile)
567 else:
568 base_name = os.path.join(old_cwd, args.outputfile.rsplit('.', 1)[0])
569 shutil.make_archive(base_name, 'zip', working_directory)
chain2dict(chain, tree_columns, dict_columns=None, max_entries=None)