Belle II Software  release-06-01-15
basf2_mva_util.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
11 
12 import basf2_mva
13 
14 import tempfile
15 
16 import numpy as np
17 
18 import ROOT
19 from ROOT import Belle2
20 
21 
22 def tree2dict(tree, tree_columns, dict_columns=None):
23  """
24  Convert a ROOT.TTree into a dictionary of np.arrays
25  @param tree the ROOT.TTree
26  @param tree_columns the column (or branch) names in the tree
27  @param dict_columns the corresponding column names in the dictionary
28  """
29  if len(tree_columns) == 0:
30  return dict()
31  if dict_columns is None:
32  dict_columns = tree_columns
33  try:
34  import root_numpy
35  d = root_numpy.tree2array(tree, branches=tree_columns)
36  d.dtype.names = dict_columns
37  except ImportError:
38  d = {column: np.zeros((tree.GetEntries(),)) for column in dict_columns}
39  for iEvent, event in enumerate(tree):
40  for dict_column, tree_column in zip(dict_columns, tree_columns):
41  d[dict_column][iEvent] = getattr(event, tree_column)
42  return d
43 
44 
45 def calculate_roc_auc(p, t):
46  """
47  Calculates the area under the receiver oeprating characteristic curve (AUC ROC)
48  @param p np.array filled with the probability output of a classifier
49  @param t np.array filled with the target (0 or 1)
50  """
51  N = len(t)
52  T = np.sum(t)
53  index = np.argsort(p)
54  efficiency = (T - np.cumsum(t[index])) / float(T)
55  purity = (T - np.cumsum(t[index])) / (N - np.cumsum(np.ones(N)))
56  purity = np.where(np.isnan(purity), 0, purity)
57  return np.abs(np.trapz(purity, efficiency))
58 
59 
60 def calculate_flatness(f, p, w=None):
61  """
62  Calculates the flatness of a feature under cuts on a signal probability
63  @param f the feature values
64  @param p the probability values
65  @param w optional weights
66  @return the mean standard deviation between the local and global cut selection efficiency
67  """
68  quantiles = list(range(101))
69  binning_feature = np.unique(np.percentile(f, q=quantiles))
70  binning_probability = np.unique(np.percentile(p, q=quantiles))
71  if len(binning_feature) < 2:
72  binning_feature = np.array([np.min(f) - 1, np.max(f) + 1])
73  if len(binning_probability) < 2:
74  binning_probability = np.array([np.min(p) - 1, np.max(p) + 1])
75  hist_n, _ = np.histogramdd(np.c_[p, f],
76  bins=[binning_probability, binning_feature],
77  weights=w)
78  hist_inc = hist_n.sum(axis=1)
79  hist_inc /= hist_inc.sum(axis=0)
80  hist_n /= hist_n.sum(axis=0)
81  hist_n = hist_n.cumsum(axis=0)
82  hist_inc = hist_inc.cumsum(axis=0)
83  diff = (hist_n.T - hist_inc)**2
84  return np.sqrt(diff.sum() / (100 * 99))
85 
86 
87 class Method(object):
88  """
89  Wrapper class providing an interface to the method stored under the given identifier.
90  It loads the Options, can apply the expert and train new ones using the current as a prototype.
91  This class is used by the basf_mva_evaluation tools
92  """
93 
94  def __init__(self, identifier):
95  """
96  Load a method stored under the given identifier
97  @param identifier identifying the method
98  """
99 
100  self.identifieridentifier = identifier
101 
102  self.weightfileweightfile = basf2_mva.Weightfile.load(self.identifieridentifier)
103 
104  self.general_optionsgeneral_options = basf2_mva.GeneralOptions()
105  self.general_optionsgeneral_options.load(self.weightfileweightfile.getXMLTree())
106 
107  # This piece of code should be correct but leads to random segmentation faults
108  # inside python, llvm or pyroot, therefore we use the more dirty code below
109  # Ideas why this is happening:
110  # 1. Ownership of the unique_ptr returned by getOptions()
111  # 2. Some kind of object slicing, although pyroot identifies the correct type
112  # 3. Bug in pyroot
113  # interfaces = basf2_mva.AbstractInterface.getSupportedInterfaces()
114  # self.interface = interfaces[self.general_options.m_method]
115  # self.specific_options = self.interface.getOptions()
116 
117 
118  self.specific_optionsspecific_options = None
119  if self.general_optionsgeneral_options.m_method == "FastBDT":
120  self.specific_optionsspecific_options = basf2_mva.FastBDTOptions()
121  elif self.general_optionsgeneral_options.m_method == "TMVAClassification":
122  self.specific_optionsspecific_options = basf2_mva.TMVAOptionsClassification()
123  elif self.general_optionsgeneral_options.m_method == "TMVARegression":
124  self.specific_optionsspecific_options = basf2_mva.TMVAOptionsRegression()
125  elif self.general_optionsgeneral_options.m_method == "FANN":
126  self.specific_optionsspecific_options = basf2_mva.FANNOptions()
127  elif self.general_optionsgeneral_options.m_method == "Python":
128  self.specific_optionsspecific_options = basf2_mva.PythonOptions()
129  elif self.general_optionsgeneral_options.m_method == "PDF":
130  self.specific_optionsspecific_options = basf2_mva.PDFOptions()
131  elif self.general_optionsgeneral_options.m_method == "Combination":
132  self.specific_optionsspecific_options = basf2_mva.CombinationOptions()
133  elif self.general_optionsgeneral_options.m_method == "Reweighter":
134  self.specific_optionsspecific_options = basf2_mva.ReweighterOptions()
135  elif self.general_optionsgeneral_options.m_method == "Trivial":
136  self.specific_optionsspecific_options = basf2_mva.TrivialOptions()
137  else:
138  raise RuntimeError("Unknown method " + self.general_optionsgeneral_options.m_method)
139 
140  self.specific_optionsspecific_options.load(self.weightfileweightfile.getXMLTree())
141 
142  variables = [str(v) for v in self.general_optionsgeneral_options.m_variables]
143  importances = self.weightfileweightfile.getFeatureImportance()
144 
145 
146  self.importancesimportances = {k: importances[k] for k in variables}
147 
148  self.variablesvariables = list(sorted(variables, key=lambda v: self.importancesimportances.get(v, 0.0)))
149 
150  self.root_variablesroot_variables = [Belle2.makeROOTCompatible(v) for v in self.variablesvariables]
151 
152  self.root_importancesroot_importances = {k: importances[k] for k in self.root_variablesroot_variables}
153 
154  self.descriptiondescription = str(basf2_mva.info(self.identifieridentifier))
155 
156  self.spectatorsspectators = [str(v) for v in self.general_optionsgeneral_options.m_spectators]
157 
158  self.root_spectatorsroot_spectators = [Belle2.makeROOTCompatible(v) for v in self.spectatorsspectators]
159 
160  def train_teacher(self, datafiles, treename, general_options=None, specific_options=None):
161  """
162  Train a new method using this method as a prototype
163  @param datafiles the training datafiles
164  @param treename the name of the tree containing the training data
165  @param general_options general options given to basf2_mva.teacher (if None the options of this method are used)
166  @param specific_options specific options given to basf2_mva.teacher (if None the options of this method are used)
167  """
168  if isinstance(datafiles, str):
169  datafiles = [datafiles]
170  if general_options is None:
171  general_options = self.general_optionsgeneral_options
172  if specific_options is None:
173  specific_options = self.specific_optionsspecific_options
174 
175  with tempfile.TemporaryDirectory() as tempdir:
176  identifier = tempdir + "/weightfile.xml"
177 
178  general_options.m_datafiles = basf2_mva.vector(*datafiles)
179  general_options.m_identifier = identifier
180 
181  basf2_mva.teacher(general_options, specific_options)
182 
183  method = Method(identifier)
184  return method
185 
186  def apply_expert(self, datafiles, treename):
187  """
188  Apply the expert of the method to data and return the calculated probability and the target
189  @param datafiles the datafiles
190  @param treename the name of the tree containing the data
191  """
192  if isinstance(datafiles, str):
193  datafiles = [datafiles]
194  with tempfile.TemporaryDirectory() as tempdir:
195  identifier = tempdir + "/weightfile.xml"
196  basf2_mva.Weightfile.save(self.weightfileweightfile, identifier)
197 
198  rootfilename = tempdir + '/expert.root'
199  basf2_mva.expert(basf2_mva.vector(identifier),
200  basf2_mva.vector(*datafiles),
201  treename,
202  rootfilename)
203  rootfile = ROOT.TFile(rootfilename, "UPDATE")
204  roottree = rootfile.Get("variables")
205 
206  expert_target = identifier + '_' + self.general_optionsgeneral_options.m_target_variable
207  stripped_expert_target = self.identifieridentifier + '_' + self.general_optionsgeneral_options.m_target_variable
208  d = tree2dict(roottree,
209  [Belle2.makeROOTCompatible(identifier), Belle2.makeROOTCompatible(expert_target)],
210  [self.identifieridentifier, stripped_expert_target])
211  return d[self.identifieridentifier], d[stripped_expert_target]
specific_options
Specific options of the method.
def apply_expert(self, datafiles, treename)
description
Description of the method as a xml string returned by basf2_mva.info.
importances
Dictionary of the variable importances calculated by the method.
root_importances
Dictionary of the variables sorted by their importance but with root compatoble variable names.
variables
List of variables sorted by their importance.
def __init__(self, identifier)
weightfile
Weightfile of the method.
root_spectators
List of spectators with root compatible names.
spectators
List of spectators.
def train_teacher(self, datafiles, treename, general_options=None, specific_options=None)
root_variables
List of the variable importances calculated by the method, but with the root compatible variable name...
general_options
General options of the method.
identifier
Identifier of the method.
std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.