Belle II Software  release-05-02-19
basf2_mva_util.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 # Thomas Keck 2016
5 
6 import basf2_mva
7 
8 import tempfile
9 
10 import numpy as np
11 
12 import ROOT
13 from ROOT import Belle2
14 
15 
16 def tree2dict(tree, tree_columns, dict_columns=None):
17  """
18  Convert a ROOT.TTree into a dictionary of np.arrays
19  @param tree the ROOT.TTree
20  @param tree_columns the column (or branch) names in the tree
21  @param dict_columns the corresponding column names in the dictionary
22  """
23  if len(tree_columns) == 0:
24  return dict()
25  if dict_columns is None:
26  dict_columns = tree_columns
27  try:
28  import root_numpy
29  d = root_numpy.tree2array(tree, branches=tree_columns)
30  d.dtype.names = dict_columns
31  except ImportError:
32  d = {column: np.zeros((tree.GetEntries(),)) for column in dict_columns}
33  for iEvent, event in enumerate(tree):
34  for dict_column, tree_column in zip(dict_columns, tree_columns):
35  d[dict_column][iEvent] = getattr(event, tree_column)
36  return d
37 
38 
39 def calculate_roc_auc(p, t):
40  """
41  Calculates the area under the receiver oeprating characteristic curve (AUC ROC)
42  @param p np.array filled with the probability output of a classifier
43  @param t np.array filled with the target (0 or 1)
44  """
45  N = len(t)
46  T = np.sum(t)
47  index = np.argsort(p)
48  efficiency = (T - np.cumsum(t[index])) / float(T)
49  purity = (T - np.cumsum(t[index])) / (N - np.cumsum(np.ones(N)))
50  purity = np.where(np.isnan(purity), 0, purity)
51  return np.abs(np.trapz(purity, efficiency))
52 
53 
54 def calculate_flatness(f, p, w=None):
55  """
56  Calculates the flatness of a feature under cuts on a signal probability
57  @param f the feature values
58  @param p the probability values
59  @param w optional weights
60  @return the mean standard deviation between the local and global cut selection efficiency
61  """
62  quantiles = list(range(101))
63  binning_feature = np.unique(np.percentile(f, q=quantiles))
64  binning_probability = np.unique(np.percentile(p, q=quantiles))
65  if len(binning_feature) < 2:
66  binning_feature = np.array([np.min(f) - 1, np.max(f) + 1])
67  if len(binning_probability) < 2:
68  binning_probability = np.array([np.min(p) - 1, np.max(p) + 1])
69  hist_n, _ = np.histogramdd(np.c_[p, f],
70  bins=[binning_probability, binning_feature],
71  weights=w)
72  hist_inc = hist_n.sum(axis=1)
73  hist_inc /= hist_inc.sum(axis=0)
74  hist_n /= hist_n.sum(axis=0)
75  hist_n = hist_n.cumsum(axis=0)
76  hist_inc = hist_inc.cumsum(axis=0)
77  diff = (hist_n.T - hist_inc)**2
78  return np.sqrt(diff.sum() / (100 * 99))
79 
80 
81 class Method(object):
82  """
83  Wrapper class providing an interface to the method stored under the given identifier.
84  It loads the Options, can apply the expert and train new ones using the current as a prototype.
85  This class is used by the basf_mva_evaluation tools
86  """
87 
88  def __init__(self, identifier):
89  """
90  Load a method stored under the given identifier
91  @param identifier identifying the method
92  """
93 
94  self.identifier = identifier
95 
96  self.weightfile = basf2_mva.Weightfile.load(self.identifier)
97 
98  self.general_options = basf2_mva.GeneralOptions()
99  self.general_options.load(self.weightfile.getXMLTree())
100 
101  # This piece of code should be correct but leads to random segmentation faults
102  # inside python, llvm or pyroot, therefore we use the more dirty code below
103  # Ideas why this is happening:
104  # 1. Ownership of the unique_ptr returned by getOptions()
105  # 2. Some kind of object slicing, although pyroot identifies the correct type
106  # 3. Bug in pyroot
107  # interfaces = basf2_mva.AbstractInterface.getSupportedInterfaces()
108  # self.interface = interfaces[self.general_options.m_method]
109  # self.specific_options = self.interface.getOptions()
110 
111 
112  self.specific_options = None
113  if self.general_options.m_method == "FastBDT":
114  self.specific_options = basf2_mva.FastBDTOptions()
115  elif self.general_options.m_method == "TMVAClassification":
116  self.specific_options = basf2_mva.TMVAOptionsClassification()
117  elif self.general_options.m_method == "TMVARegression":
118  self.specific_options = basf2_mva.TMVAOptionsRegression()
119  elif self.general_options.m_method == "FANN":
120  self.specific_options = basf2_mva.FANNOptions()
121  elif self.general_options.m_method == "Python":
122  self.specific_options = basf2_mva.PythonOptions()
123  elif self.general_options.m_method == "PDF":
124  self.specific_options = basf2_mva.PDFOptions()
125  elif self.general_options.m_method == "Combination":
126  self.specific_options = basf2_mva.CombinationOptions()
127  elif self.general_options.m_method == "Reweighter":
128  self.specific_options = basf2_mva.ReweighterOptions()
129  elif self.general_options.m_method == "Trivial":
130  self.specific_options = basf2_mva.TrivialOptions()
131  else:
132  raise RuntimeError("Unkown method " + self.general_options.m_method)
133 
134  self.specific_options.load(self.weightfile.getXMLTree())
135 
136  variables = [str(v) for v in self.general_options.m_variables]
137  importances = self.weightfile.getFeatureImportance()
138 
139 
140  self.importances = {k: importances[k] for k in variables}
141 
142  self.variables = list(sorted(variables, key=lambda v: self.importances.get(v, 0.0)))
143 
145 
146  self.root_importances = {k: importances[k] for k in self.root_variables}
147 
148  self.description = str(basf2_mva.info(self.identifier))
149 
150  self.spectators = [str(v) for v in self.general_options.m_spectators]
151 
153 
154  def train_teacher(self, datafiles, treename, general_options=None, specific_options=None):
155  """
156  Train a new method using this method as a prototype
157  @param datafiles the training datafiles
158  @param treename the name of the tree containing the training data
159  @param general_options general options given to basf2_mva.teacher (if None the options of this method are used)
160  @param specific_options specific options given to basf2_mva.teacher (if None the options of this method are used)
161  """
162  if isinstance(datafiles, str):
163  datafiles = [datafiles]
164  if general_options is None:
165  general_options = self.general_options
166  if specific_options is None:
167  specific_options = self.specific_options
168 
169  with tempfile.TemporaryDirectory() as tempdir:
170  identifier = tempdir + "/weightfile.xml"
171 
172  general_options.m_datafiles = basf2_mva.vector(*datafiles)
173  general_options.m_identifier = identifier
174 
175  basf2_mva.teacher(general_options, specific_options)
176 
177  method = Method(identifier)
178  return method
179 
180  def apply_expert(self, datafiles, treename):
181  """
182  Apply the expert of the method to data and return the calculated probability and the target
183  @param datafiles the datafiles
184  @param treename the name of the tree containing the data
185  """
186  if isinstance(datafiles, str):
187  datafiles = [datafiles]
188  with tempfile.TemporaryDirectory() as tempdir:
189  identifier = tempdir + "/weightfile.xml"
190  basf2_mva.Weightfile.save(self.weightfile, identifier)
191 
192  rootfilename = tempdir + '/expert.root'
193  basf2_mva.expert(basf2_mva.vector(identifier),
194  basf2_mva.vector(*datafiles),
195  treename,
196  rootfilename)
197  rootfile = ROOT.TFile(rootfilename, "UPDATE")
198  roottree = rootfile.Get("variables")
199 
200  expert_target = identifier + '_' + self.general_options.m_target_variable
201  stripped_expert_target = self.identifier + '_' + self.general_options.m_target_variable
202  d = tree2dict(roottree,
203  [Belle2.makeROOTCompatible(identifier), Belle2.makeROOTCompatible(expert_target)],
204  [self.identifier, stripped_expert_target])
205  return d[self.identifier], d[stripped_expert_target]
basf2_mva_util.Method.identifier
identifier
Identifier of the method.
Definition: basf2_mva_util.py:94
basf2_mva_util.Method.__init__
def __init__(self, identifier)
Definition: basf2_mva_util.py:88
basf2_mva_util.Method.root_spectators
root_spectators
List of spectators with root compatible names.
Definition: basf2_mva_util.py:152
basf2_mva_util.Method.root_importances
root_importances
Dictionary of the variables sorted by their importance but with root compatoble variable names.
Definition: basf2_mva_util.py:146
basf2_mva_util.Method.spectators
spectators
List of spectators.
Definition: basf2_mva_util.py:150
basf2_mva_util.Method
Definition: basf2_mva_util.py:81
basf2_mva_util.Method.root_variables
root_variables
List of the variable importances calculated by the method, but with the root compatible variable name...
Definition: basf2_mva_util.py:144
Belle2::makeROOTCompatible
std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
Definition: MakeROOTCompatible.cc:74
basf2_mva_util.Method.train_teacher
def train_teacher(self, datafiles, treename, general_options=None, specific_options=None)
Definition: basf2_mva_util.py:154
basf2_mva_util.Method.variables
variables
List of variables sorted by their importance.
Definition: basf2_mva_util.py:142
basf2_mva_util.Method.general_options
general_options
General options of the method.
Definition: basf2_mva_util.py:98
basf2_mva_util.Method.specific_options
specific_options
Specific options of the method.
Definition: basf2_mva_util.py:112
basf2_mva_util.Method.importances
importances
Dictionary of the variable importances calculated by the mehtod.
Definition: basf2_mva_util.py:140
basf2_mva_util.Method.apply_expert
def apply_expert(self, datafiles, treename)
Definition: basf2_mva_util.py:180
basf2_mva_util.Method.weightfile
weightfile
Weightfile of the method.
Definition: basf2_mva_util.py:96
basf2_mva_util.Method.description
description
Description of the method as a xml string returned by basf2_mva.info.
Definition: basf2_mva_util.py:148