Belle II Software  release-05-01-25
basf2_mva_variable_importance.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 import basf2_mva
5 import basf2_mva_util
6 from basf2_mva_evaluation import plotting
7 import argparse
8 import tempfile
9 
10 from ROOT import Belle2
11 import numpy as np
12 from B2Tools import b2latex
13 
14 import os
15 
16 
17 def getCommandLineOptions():
18  """ Parses the command line options of the fei and returns the corresponding arguments. """
19  parser = argparse.ArgumentParser()
20  parser.add_argument('-id', '--identifiers', dest='identifiers', type=str, required=True, action='append', nargs='+',
21  help='DB Identifier or weightfile. Does at the moment only work with one id.')
22  parser.add_argument('-train', '--train_datafiles', dest='train_datafiles', type=str, required=True, action='append', nargs='+',
23  help='Data file containing ROOT TTree used during training')
24  parser.add_argument('-test', '--test_datafiles', dest='test_datafiles', type=str, required=True, action='append', nargs='+',
25  help='Data file containing ROOT TTree with independent test data')
26  parser.add_argument('-tree', '--treename', dest='treename', type=str, default='tree', help='Treename in data file')
27  parser.add_argument('-out', '--outputfile', dest='outputfile', type=str, default='output.pdf',
28  help='Name of the outputted pdf file')
29  parser.add_argument('-weightfile', '--weightfile', dest='weightfile', action='store_true',
30  help='Read feature importances from weightfile')
31  parser.add_argument('-iterative', '--iterative', dest='iterative', action='store_true',
32  help='Improve the importance estimation by iteratively'
33  'leaving one variable out and retrain. Needs O(NFeatures) Trainings!')
34  parser.add_argument('-recursive', '--recursive', dest='recursive', action='store_true',
35  help='Improve the importance estimation by recursively'
36  'remove the most important variable. Needs O(NFeatures**2) Trainings!')
37  args = parser.parse_args()
38  return args
39 
40 
41 def get_importances(method, train_datafiles, test_datafiles, treename, variables, global_auc):
42  """
43  Calculate the importance of the variables of a method by retraning the method without
44  one of the variables at a time and comparing the auc to the global_auc
45  @param method the method object
46  @param train_datafiles data used to retrain the method
47  @param test_datafiles data used to evaluate the method and calculate the new auc
48  @param treename the name of the tree containing the data
49  @param variables list of variables which are considered for the trainings
50  @param global_auc the auc of the training with all variables
51  """
52  importances = {}
53  classifiers = {}
54  for variable in variables:
55  general_options = method.general_options
56  general_options.m_variables = basf2_mva.vector(*[v for v in variables if v != variable])
57  m = method.train_teacher(train_datafiles, treename, general_options)
58  auc = basf2_mva_util.calculate_roc_auc(*m.apply_expert(test_datafiles, treename))
59  importances[variable] = global_auc - auc
60  classifiers[variable] = m
61  return importances, classifiers
62 
63 
64 def get_importances_recursive(method, train_datfiles, test_datafiles, treename, variables, global_auc):
65  """
66  Calculate the importance of the variables of a method by retraning the method without
67  one of the variables at a time. Then the best variable (the one which leads to the lowest auc
68  if it is left out) is removed and the importance of the remaining variables is calculated recursively
69  @param method the method object
70  @param train_datafiles data used to retrain the method
71  @param test_datafiles data used to evaluate the method and calculate the new auc
72  @param treename the name of the tree containing the data
73  @param variables list of variables which are considered for the trainings
74  @param global_auc the auc of the training with all variables
75  """
76  imp, cla = get_importances(method, train_datfiles, test_datafiles, treename, variables, global_auc)
77  most_important = max(imp.keys(), key=lambda x: imp[x])
78  remaining_variables = [v for v in variables if v != most_important]
79 
80  if len(remaining_variables) == 1:
81  return imp, cla
82 
83  importances = {most_important: imp[most_important]}
84  classifiers = {most_important: cla[most_important]}
85  rest, subcla = get_importances_recursive(method, train_datfiles, test_datafiles, treename,
86  remaining_variables, global_auc - imp[most_important])
87  importances.update(rest)
88  classifiers.update(subcla)
89  return importances, classifiers
90 
91 
92 if __name__ == '__main__':
93 
94  print("WARNING This tool is deprecated, use mva/examples/advanced/variable_importance.py instead and adapt it to your needs.")
95  print("In fact adapting the example is easier than using this general tool, and it is also easier to automatise")
96  print("Therefore this tool will be removed in the future")
97 
98  old_cwd = os.getcwd()
99  args = getCommandLineOptions()
100 
101  identifiers = sum(args.identifiers, [])
102  train_datafiles = sum(args.train_datafiles, [])
103  test_datafiles = sum(args.test_datafiles, [])
104 
105  methods = [basf2_mva_util.Method(identifier) for identifier in identifiers]
106 
107  labels = []
108  importances = []
109  iterative_classifiers = []
110  recursive_classifiers = []
111  all_variables = []
112  cla_dict = {}
113  for method in methods:
114  global_auc = basf2_mva_util.calculate_roc_auc(*method.apply_expert(test_datafiles, args.treename))
115  print(" in method", method)
116  for variable in method.variables:
117  all_variables.append(variable)
118  if args.recursive:
119  imp, cla = get_importances_recursive(method, train_datafiles, test_datafiles,
120  args.treename, method.variables, global_auc)
121  importances.append(imp)
122  recursive_classifiers.append(cla)
123  labels.append(method.identifier + '\n (recursive)')
124 
125  cla_dict[method.identifier + '_recursive'] = cla
126  elif args.iterative:
127  imp, cla = get_importances(method, train_datafiles, test_datafiles,
128  args.treename, method.variables, global_auc)
129  importances.append(imp)
130  iterative_classifiers.append(cla)
131  labels.append(method.identifier + '\n (iterative)')
132  cla_dict[method.identifier + '_iterative'] = cla
133  if args.weightfile:
134  importances.append(method.importances)
135  labels.append(method.identifier + '\n (weightfile)')
136 
137  all_variables = list(sorted(all_variables, key=lambda v: importances[0].get(v, 0.0)))
138 
139  importances_dict = {}
140  for i, label in enumerate(labels):
141  importances_dict[label] = np.array([importances[i].get(v, 0.0) for v in all_variables])
142 
143  # todo: distinguish between iterative & recursive
144  print("Apply experts on independent data")
145  test_probability = {}
146  test_target = {}
147  for method in methods:
148  ps = {}
149  ts = {}
150  for classifier in cla:
151  p, t = cla[classifier].apply_expert(test_datafiles, args.treename)
152  ps[classifier] = p
153  ts[classifier] = t
154  test_probability[method.identifier] = ps
155  test_target[method.identifier] = ts
156 
157  print("Apply experts on training data")
158  train_probability = {}
159  train_target = {}
160  if args.train_datafiles is not None:
161  train_datafiles = sum(args.train_datafiles, [])
162  for method in methods:
163  ps = {}
164  ts = {}
165  for classifier in cla:
166  p, t = cla[classifier].apply_expert(train_datafiles, args.treename)
167  ps[classifier] = p
168  ts[classifier] = t
169  train_probability[method.identifier] = ps
170  train_target[method.identifier] = ts
171 
172  # Change working directory after experts run, because they might want to access
173  # a locadb in the current working directory
174  import shutil
175 
176  with tempfile.TemporaryDirectory() as tempdir:
177  os.chdir(tempdir)
178 
179  o = b2latex.LatexFile()
180  o += b2latex.TitlePage(title='Automatic Feature Importance Report',
181  authors=[r'Thomas Keck', 'Markus Prim', 'Moritz Gelb'],
182  abstract='Feature importance calculation by leaving one variable out and retrain.',
183  add_table_of_contents=False,
184  clearpage=False).finish()
185  o += b2latex.Section("General Feature Importance")
186  graphics = b2latex.Graphics()
187  p = plotting.Importance()
188  read_root_var = [Belle2.invertMakeROOTCompatible(v) for v in all_variables]
189  print("\n")
190  print(importances_dict)
191  print(labels)
192  print(read_root_var)
193  print("\n")
194  p.add(importances_dict, labels, read_root_var)
195  p.finish()
196  p.save('importance.png')
197  graphics.add('importance.png', width=1.0)
198  o += graphics.finish()
199 
200  for identifier in identifiers:
201  if args.recursive:
202  o += b2latex.Section("Recursive Feature Importance")
203  o += b2latex.String("""
204  Calculate the importance of the variables of a method by retraining the method without
205  one of the variables at a time. Then the best variable (the one which leads to the lowest
206  area under the curve if it is left out) is removed and the importance of the remaining
207  variables is calculated recursively.
208  """)
209  if args.iterative:
210  o += b2latex.Section("Iterative Feature Importance")
211  o += b2latex.String("""
212  Calculate the importance of the variables of a method by retraining the method
213  without one of the variables at a time and comparing the auc to the global
214  area under the curve.
215  """)
216  if args.weightfile:
217  o += b2latex.Section("Feature Importance")
218  o += b2latex.String("""
219  Read feature importances from weightfile.
220  """)
221 
222  for variable in reversed(all_variables):
223  if args.weightfile:
224  pass
225  for classifier in cla:
226  if classifier == variable:
227  if args.recursive:
228  o += b2latex.SubSection(
229  "Without variable {} (and variables above)".format(Belle2.invertMakeROOTCompatible(classifier)))
230  elif args.iterative:
231  o += b2latex.SubSection(
232  "Without variable {}".format(Belle2.invertMakeROOTCompatible(classifier)))
233  probability = {classifier: np.r_[train_probability[identifier][classifier],
234  test_probability[identifier][classifier]]}
235  target = np.r_[train_target[identifier][classifier], test_target[identifier][classifier]]
236  train_mask = np.r_[np.ones(len(train_target[identifier][classifier])),
237  np.zeros(len(test_target[identifier][classifier]))]
238  graphics = b2latex.Graphics()
240  p.add(probability, classifier, train_mask == 1, train_mask == 0, target == 1, target == 0, )
241  p.finish()
242 
243  p.axis.set_title(
244  "Overtraining check for {} without variable {} ".format(
245  identifier, Belle2.invertMakeROOTCompatible(classifier)))
246  p.save('overtraining_plot_{}_wo_{}.png'.format(hash(identifier), classifier))
247  graphics.add('overtraining_plot_{}_wo_{}.png'.format(hash(identifier), classifier), width=1.0)
248  o += graphics.finish()
249 
250  o.save('latex.tex', compile=True)
251  os.chdir(old_cwd)
252  shutil.copy(tempdir + '/latex.pdf', args.outputfile)
basf2_mva_util.calculate_roc_auc
def calculate_roc_auc(p, t)
Definition: basf2_mva_util.py:39
Belle2::invertMakeROOTCompatible
std::string invertMakeROOTCompatible(std::string str)
Invert makeROOTCompatible operation.
Definition: MakeROOTCompatible.cc:89
basf2_mva_util.Method
Definition: basf2_mva_util.py:81
plotting.Importance
Definition: plotting.py:1130
plotting.Overtraining
Definition: plotting.py:813
basf2_mva_util.Method.apply_expert
def apply_expert(self, datafiles, treename)
Definition: basf2_mva_util.py:180