Belle II Software development
basf2_mva_evaluate.py
1#!/usr/bin/env python3
2
3
10
11import basf2_mva_util
12
13from basf2_mva_evaluation import plotting
14from basf2 import conditions
15import argparse
16import tempfile
17
18import numpy as np
19from B2Tools import b2latex, format
20from basf2 import B2INFO
21
22import os
23import shutil
24import collections
25from typing import List, Any
26
27
28def get_argument_parser() -> argparse.ArgumentParser:
29 """ Parses the command line options of the fei and returns the corresponding arguments. """
30 parser = argparse.ArgumentParser()
31 parser.add_argument('-id', '--identifiers', dest='identifiers', type=str, required=True, action='append', nargs='+',
32 help='DB Identifier or weightfile')
33 parser.add_argument('-train', '--train_datafiles', dest='train_datafiles', type=str, required=False, action='append', nargs='+',
34 help='Data file containing ROOT TTree used during training')
35 parser.add_argument('-data', '--datafiles', dest='datafiles', type=str, required=True, action='append', nargs='+',
36 help='Data file containing ROOT TTree with independent test data')
37 parser.add_argument('-tree', '--treename', dest='treename', type=str, default='tree', help='Treename in data file')
38 parser.add_argument('-out', '--outputfile', dest='outputfile', type=str, default='output.zip',
39 help='Name of the created .zip archive file if not compiling or a pdf file if compilation is successful.')
40 parser.add_argument('-w', '--working_directory', dest='working_directory', type=str, default='',
41 help="""Working directory where the created images and root files are stored,
42 default is to create a temporary directory.""")
43 parser.add_argument('-l', '--localdb', dest='localdb', type=str, action='append', nargs='+', required=False,
44 help="""path or list of paths to local database(s) containing the mvas of interest.
45 The testing payloads are preprended and take precedence over payloads in global tags.""")
46 parser.add_argument('-g', '--globaltag', dest='globaltag', type=str, action='append', nargs='+', required=False,
47 help='globaltag or list of globaltags containing the mvas of interest. The globaltags are prepended.')
48 parser.add_argument('-n', '--fillnan', dest='fillnan', action='store_true',
49 help='Fill nan and inf values with actual numbers')
50 parser.add_argument('-c', '--compile', dest='compile', action='store_true',
51 help='Compile latex to pdf directly')
52 parser.add_argument('-a', '--abbreviation_length', dest='abbreviation_length',
53 action='store', type=int, default=5,
54 help='Number of characters to which variable names are abbreviated.')
55 return parser
56
57
58def unique(input_list: List[Any]) -> List[Any]:
59 """
60 Returns a list containing only unique elements, keeps the original order of the list
61 @param input_list list containing the elements
62 """
63 output = []
64 for x in input_list:
65 if x not in output:
66 output.append(x)
67 return output
68
69
70def flatten(input_list: List[List[Any]]) -> List[Any]:
71 """
72 Flattens a list of lists
73 @param input_list list of lists to be flattened
74 """
75 return [item for sublist in input_list for item in sublist]
76
77
78def create_abbreviations(names, length=5):
79 count = dict()
80 for name in names:
81 abbreviation = name[:length]
82 if abbreviation not in count:
83 count[abbreviation] = 0
84 count[abbreviation] += 1
85 abbreviations = collections.OrderedDict()
86
87 count2 = dict()
88 for name in names:
89 abbreviation = name[:length]
90 abbreviations[name] = abbreviation
91 if count[abbreviation] > 1:
92 if abbreviation not in count2:
93 count2[abbreviation] = 0
94 count2[abbreviation] += 1
95 abbreviations[name] += str(count2[abbreviation])
96 return abbreviations
97
98
99if __name__ == '__main__':
100
101 import ROOT # noqa
102 ROOT.PyConfig.IgnoreCommandLineOptions = True
103 ROOT.PyConfig.StartGuiThread = False
104 ROOT.gROOT.SetBatch(True)
105
106 old_cwd = os.getcwd()
107 parser = get_argument_parser()
108 args = parser.parse_args()
109
110 identifiers = flatten(args.identifiers)
111 identifier_abbreviations = create_abbreviations(identifiers, args.abbreviation_length)
112
113 datafiles = flatten(args.datafiles)
114 if args.localdb is not None:
115 for localdb in flatten(args.localdb):
116 conditions.prepend_testing_payloads(localdb)
117
118 if args.globaltag is not None:
119 for tag in flatten(args.globaltag):
120 conditions.prepend_globaltag(tag)
121
122 print("Load methods")
123 methods = [basf2_mva_util.Method(identifier) for identifier in identifiers]
124
125 print("Apply experts on independent data")
126 test_probability = {}
127 test_target = {}
128 for method in methods:
129 p, t = method.apply_expert(datafiles, args.treename)
130 test_probability[identifier_abbreviations[method.identifier]] = p
131 test_target[identifier_abbreviations[method.identifier]] = t
132
133 print("Apply experts on training data")
134 train_probability = {}
135 train_target = {}
136 if args.train_datafiles is not None:
137 train_datafiles = sum(args.train_datafiles, [])
138 for method in methods:
139 p, t = method.apply_expert(train_datafiles, args.treename)
140 train_probability[identifier_abbreviations[method.identifier]] = p
141 train_target[identifier_abbreviations[method.identifier]] = t
142
143 variables = unique(v for method in methods for v in method.variables)
144 variable_abbreviations = create_abbreviations(variables, args.abbreviation_length)
145 root_variables = unique(v for method in methods for v in method.root_variables)
146
147 spectators = unique(v for method in methods for v in method.spectators)
148 spectator_abbreviations = create_abbreviations(spectators, args.abbreviation_length)
149 root_spectators = unique(v for method in methods for v in method.root_spectators)
150
151 print("Load variables array")
152 rootchain = ROOT.TChain(args.treename)
153 for datafile in datafiles:
154 rootchain.Add(datafile)
155
156 variables_data = basf2_mva_util.chain2dict(rootchain, root_variables, list(variable_abbreviations.values()))
157 spectators_data = basf2_mva_util.chain2dict(rootchain, root_spectators, list(spectator_abbreviations.values()))
158
159 if args.fillnan:
160 for column in variable_abbreviations.values():
161 np.nan_to_num(variables_data[column], copy=False)
162
163 for column in spectator_abbreviations.values():
164 np.nan_to_num(spectators_data[column], copy=False)
165
166 print("Create latex file")
167 # Change working directory after experts run, because they might want to access
168 # a localdb in the current working directory.
169 with tempfile.TemporaryDirectory() as tempdir:
170 if args.working_directory == '':
171 os.chdir(tempdir)
172 else:
173 os.chdir(args.working_directory)
174
175 with open('abbreviations.txt', 'w') as f:
176 f.write('Identifier Abbreviation : Identifier \n')
177 for name, abbrev in identifier_abbreviations.items():
178 f.write(f'\t{abbrev} : {name}\n')
179 f.write('\n\n\nVariable Abbreviation : Variable \n')
180 for name, abbrev in variable_abbreviations.items():
181 f.write(f'\t{abbrev} : {name}\n')
182 f.write('\n\n\nSpectator Abbreviation : Spectator \n')
183 for name, abbrev in spectator_abbreviations.items():
184 f.write(f'\t{abbrev} : {name}\n')
185
186 o = b2latex.LatexFile()
187 o += b2latex.TitlePage(title='Automatic MVA Evaluation',
188 authors=[r'Thomas Keck\\ Moritz Gelb\\ Nils Braun'],
189 abstract='Evaluation plots',
190 add_table_of_contents=True).finish()
191
192 o += b2latex.Section("Classifiers")
193 o += b2latex.String(r"""
194 This section contains the GeneralOptions and SpecificOptions of all classifiers represented by an XML tree.
195 The same information can be retrieved using the basf2\_mva\_info tool.
196 """)
197
198 table = b2latex.LongTable(r"ll", "Abbreviations of identifiers", "{name} & {abbr}", r"Identifier & Abbreviation")
199 for identifier in identifiers:
200 table.add(name=format.string(identifier), abbr=format.string(identifier_abbreviations[identifier]))
201 o += table.finish()
202
203 for method in methods:
204 o += b2latex.SubSection(format.string(method.identifier))
205 o += b2latex.Listing(language='XML').add(method.description).finish()
206
207 o += b2latex.Section("Variables")
208 o += b2latex.String("""
209 This section contains an overview of the importance and correlation of the variables used by the classifiers.
210 And distribution plots of the variables on the independent dataset. The distributions are normed for signal and
211 background separately, and only the region +- 3 sigma around the mean is shown.
212
213 The importance scores shown are based on the variable importance as estimated by each MVA method internally.
214 This means the variable with the lowest importance will have score 0, and the variable
215 with the highest importance will have score 100. If the method does not provide such a ranking, all
216 importances will be 0.
217 """)
218
219 table = b2latex.LongTable(r"ll", "Abbreviations of variables", "{name} & {abbr}", r"Variable & Abbreviation")
220 for v in variables:
221 table.add(name=format.string(v), abbr=format.string(variable_abbreviations[v]))
222 o += table.finish()
223
224 o += b2latex.SubSection("Importance")
225 graphics = b2latex.Graphics()
227 p.add({identifier_abbreviations[i.identifier]: np.array([i.importances.get(v, 0.0) for v in variables]) for i in methods},
228 identifier_abbreviations.values(), variable_abbreviations.values())
229 p.finish()
230 p.save('importance.pdf')
231 graphics.add('importance.pdf', width=1.0)
232 o += graphics.finish()
233
234 o += b2latex.SubSection("Correlation")
235 first_identifier_abbr = list(identifier_abbreviations.values())[0]
236 graphics = b2latex.Graphics()
238 p.add(variables_data, variable_abbreviations.values(),
239 test_target[first_identifier_abbr] == 1,
240 test_target[first_identifier_abbr] == 0)
241 p.finish()
242 p.save('correlation_plot.pdf')
243 graphics.add('correlation_plot.pdf', width=1.0)
244 o += graphics.finish()
245
246 for v in variables:
247 variable_abbr = variable_abbreviations[v]
248 o += b2latex.SubSection(format.string(v))
249 graphics = b2latex.Graphics()
250 p = plotting.VerboseDistribution(normed=True, range_in_std=3)
251 p.add(variables_data, variable_abbr, test_target[first_identifier_abbr] == 1, label="Signal")
252 p.add(variables_data, variable_abbr, test_target[first_identifier_abbr] == 0, label="Background")
253 p.finish()
254 p.save(f'variable_{hash(v)}.pdf')
255 graphics.add(f'variable_{hash(v)}.pdf', width=1.0)
256 o += graphics.finish()
257
258 o += b2latex.Section("Classifier Plot")
259 o += b2latex.String("This section contains the receiver operating characteristics (ROC), purity projection, ..."
260 "of the classifiers on training and independent data."
261 "The legend of each plot contains the shortened identifier and the area under the ROC curve"
262 "in parenthesis.")
263
264 o += b2latex.Section("ROC Plot")
265 graphics = b2latex.Graphics()
267 for identifier in identifier_abbreviations.values():
268 p.add(test_probability, identifier, test_target[identifier] == 1, test_target[identifier] == 0)
269 p.finish()
270 p.axis.set_title("ROC Rejection Plot on independent data")
271 p.save('roc_plot_test.pdf')
272 graphics.add('roc_plot_test.pdf', width=1.0)
273 o += graphics.finish()
274
275 if train_probability:
276 for i, identifier in enumerate(identifiers):
277 graphics = b2latex.Graphics()
279 identifier_abbr = identifier_abbreviations[identifier]
280 p.add(train_probability, identifier_abbr, train_target[identifier_abbr] == 1,
281 train_target[identifier_abbr] == 0, label='Train')
282 p.add(test_probability, identifier_abbr, test_target[identifier_abbr] == 1,
283 test_target[identifier_abbr] == 0, label='Test')
284 p.finish()
285 p.axis.set_title(identifier)
286 p.save(f'roc_test_{hash(identifier)}.pdf')
287 graphics.add(f'roc_test_{hash(identifier)}.pdf', width=1.0)
288 o += graphics.finish()
289
290 o += b2latex.Section("Classification Results")
291
292 for identifier in identifiers:
293 identifier_abbr = identifier_abbreviations[identifier]
294 o += b2latex.SubSection(format.string(identifier_abbr))
295 graphics = b2latex.Graphics()
297 p.add(0, test_probability, identifier_abbr, test_target[identifier_abbr] == 1,
298 test_target[identifier_abbr] == 0, normed=True)
299 p.sub_plots[0].axis.set_title(f"Classification result in test data for {identifier}")
300
301 p.add(1, test_probability, identifier_abbr, test_target[identifier_abbr] == 1,
302 test_target[identifier_abbr] == 0, normed=False)
303 p.sub_plots[1].axis.set_title(f"Classification result in test data for {identifier}")
304 p.finish()
305
306 p.save(f'classification_result_{hash(identifier)}.pdf')
307 graphics.add(f'classification_result_{hash(identifier)}.pdf', width=1)
308 o += graphics.finish()
309
310 o += b2latex.Section("Diagonal Plot")
311 graphics = b2latex.Graphics()
313 for identifier in identifiers:
314 o += b2latex.SubSection(format.string(identifier_abbr))
315 identifier_abbr = identifier_abbreviations[identifier]
316 p.add(test_probability, identifier_abbr, test_target[identifier_abbr] == 1, test_target[identifier_abbr] == 0)
317 p.finish()
318 p.axis.set_title("Diagonal plot on independent data")
319 p.save('diagonal_plot_test.pdf')
320 graphics.add('diagonal_plot_test.pdf', width=1.0)
321 o += graphics.finish()
322
323 if train_probability:
324 o += b2latex.SubSection("Overtraining Plot")
325 for identifier in identifiers:
326 identifier_abbr = identifier_abbreviations[identifier]
327 probability = {identifier_abbr: np.r_[train_probability[identifier_abbr], test_probability[identifier_abbr]]}
328 target = np.r_[train_target[identifier_abbr], test_target[identifier_abbr]]
329 train_mask = np.r_[np.ones(len(train_target[identifier_abbr])), np.zeros(len(test_target[identifier_abbr]))]
330 graphics = b2latex.Graphics()
332 p.add(probability, identifier_abbr,
333 train_mask == 1, train_mask == 0,
334 target == 1, target == 0, )
335 p.finish()
336 p.axis.set_title(f"Overtraining check for {identifier}")
337 p.save(f'overtraining_plot_{hash(identifier)}.pdf')
338 graphics.add(f'overtraining_plot_{hash(identifier)}.pdf', width=1.0)
339 o += graphics.finish()
340
341 o += b2latex.Section("Spectators")
342 o += b2latex.String("This section contains the distribution and dependence on the"
343 "classifier outputs of all spectator variables.")
344
345 table = b2latex.LongTable(r"ll", "Abbreviations of spectators", "{name} & {abbr}", r"Spectator & Abbreviation")
346 for s in spectators:
347 table.add(name=format.string(s), abbr=format.string(spectator_abbreviations[s]))
348 o += table.finish()
349
350 for spectator in spectators:
351 spectator_abbr = spectator_abbreviations[spectator]
352 o += b2latex.SubSection(format.string(spectator))
353 graphics = b2latex.Graphics()
355 p.add(spectators_data, spectator_abbr, test_target[first_identifier_abbr] == 1, label="Signal")
356 p.add(spectators_data, spectator_abbr, test_target[first_identifier_abbr] == 0, label="Background")
357 p.finish()
358 p.save(f'spectator_{hash(spectator)}.pdf')
359 graphics.add(f'spectator_{hash(spectator)}.pdf', width=1.0)
360 o += graphics.finish()
361
362 for identifier in identifiers:
363 o += b2latex.SubSubSection(format.string(spectator) + " with classifier " + format.string(identifier))
364 identifier_abbr = identifier_abbreviations[identifier]
365 data = {identifier_abbr: test_probability[identifier_abbr], spectator_abbr: spectators_data[spectator_abbr]}
366 graphics = b2latex.Graphics()
368 p.add(data, spectator_abbr, identifier_abbr, list(range(10, 100, 10)),
369 test_target[identifier_abbr] == 1,
370 test_target[identifier_abbr] == 0)
371 p.finish()
372 p.save(f'correlation_plot_{hash(spectator)}_{hash(identifier)}.pdf')
373 graphics.add(f'correlation_plot_{hash(spectator)}_{hash(identifier)}.pdf', width=1.0)
374 o += graphics.finish()
375
376 if args.compile:
377 B2INFO(f"Creating a PDF file at {args.outputfile}. Please remove the '-c' switch if this fails.")
378 o.save('latex.tex', compile=True)
379 else:
380 B2INFO(f"Creating a .zip archive containing plots and a TeX file at {args.outputfile}."
381 f"Please unpack the archive and compile the latex.tex file with pdflatex.")
382 o.save('latex.tex', compile=False)
383
384 os.chdir(old_cwd)
385 if args.working_directory == '':
386 working_directory = tempdir
387 else:
388 working_directory = args.working_directory
389
390 if args.compile:
391 shutil.copy(os.path.join(working_directory, 'latex.pdf'), args.outputfile)
392 else:
393 base_name = os.path.join(old_cwd, args.outputfile.rsplit('.', 1)[0])
394 shutil.make_archive(base_name, 'zip', working_directory)
def chain2dict(chain, tree_columns, dict_columns=None)