Belle II Software  light-2212-foldex
basf2_mva_util.py
1 
8 
9 import tempfile
10 import numpy as np
11 
12 from basf2 import B2WARNING
13 import basf2_mva
14 
15 
16 def tree2dict(tree, tree_columns, dict_columns=None):
17  """
18  Convert a ROOT.TTree into a dictionary of np.arrays
19  @param tree the ROOT.TTree
20  @param tree_columns the column (or branch) names in the tree
21  @param dict_columns the corresponding column names in the dictionary
22  """
23  if len(tree_columns) == 0:
24  return dict()
25  if dict_columns is None:
26  dict_columns = tree_columns
27  try:
28  import root_numpy
29  d = root_numpy.tree2array(tree, branches=tree_columns)
30  d.dtype.names = dict_columns
31  except ImportError:
32  d = {column: np.zeros((tree.GetEntries(),)) for column in dict_columns}
33  for iEvent, event in enumerate(tree):
34  for dict_column, tree_column in zip(dict_columns, tree_columns):
35  d[dict_column][iEvent] = getattr(event, tree_column)
36  return d
37 
38 
39 def calculate_roc_auc(p, t):
40  """
41  Deprecated name of ``calculate_auc_efficiency_vs_purity``
42 
43  @param p np.array filled with the probability output of a classifier
44  @param t np.array filled with the target (0 or 1)
45  """
46  B2WARNING(
47  "\033[93mcalculate_roc_auc\033[00m has been deprecated and will be removed in future.\n"
48  "This change has been made as calculate_roc_auc returned the area under the efficiency-purity curve\n"
49  "not the efficiency-background retention curve as expected by users.\n"
50  "Please replace calculate_roc_auc with:\n\n"
51  "\033[96mcalculate_auc_efficiency_vs_purity(probability, target[, weight])\033[00m:"
52  " the current definition of calculate_roc_auc\n"
53  "\033[96mcalculate_auc_efficiency_vs_background_retention(probability, target[, weight])\033[00m:"
54  " what is commonly known as roc auc\n")
55  return calculate_auc_efficiency_vs_purity(p, t)
56 
57 
58 def calculate_auc_efficiency_vs_purity(p, t, w=None):
59  """
60  Calculates the area under the efficiency-purity curve
61  @param p np.array filled with the probability output of a classifier
62  @param t np.array filled with the target (0 or 1)
63  @param w None or np.array filled with weights
64  """
65  if w is None:
66  w = np.ones(t.shape)
67 
68  wt = w * t
69 
70  N = np.sum(w)
71  T = np.sum(wt)
72 
73  index = np.argsort(p)
74  efficiency = (T - np.cumsum(wt[index])) / float(T)
75  purity = (T - np.cumsum(wt[index])) / (N - np.cumsum(w[index]))
76  purity = np.where(np.isnan(purity), 0, purity)
77  return np.abs(np.trapz(purity, efficiency))
78 
79 
80 def calculate_auc_efficiency_vs_background_retention(p, t, w=None):
81  """
82  Calculates the area under the efficiency-background_retention curve (AUC ROC)
83  @param p np.array filled with the probability output of a classifier
84  @param t np.array filled with the target (0 or 1)
85  @param w None or np.array filled with weights
86  """
87  if w is None:
88  w = np.ones(t.shape)
89 
90  wt = w * t
91 
92  N = np.sum(w)
93  T = np.sum(wt)
94 
95  index = np.argsort(p)
96  efficiency = (T - np.cumsum(wt[index])) / float(T)
97  background_retention = (N - T - np.cumsum((np.abs(1 - t) * w)[index])) / float(N - T)
98  return np.abs(np.trapz(efficiency, background_retention))
99 
100 
101 def calculate_flatness(f, p, w=None):
102  """
103  Calculates the flatness of a feature under cuts on a signal probability
104  @param f the feature values
105  @param p the probability values
106  @param w optional weights
107  @return the mean standard deviation between the local and global cut selection efficiency
108  """
109  quantiles = list(range(101))
110  binning_feature = np.unique(np.percentile(f, q=quantiles))
111  binning_probability = np.unique(np.percentile(p, q=quantiles))
112  if len(binning_feature) < 2:
113  binning_feature = np.array([np.min(f) - 1, np.max(f) + 1])
114  if len(binning_probability) < 2:
115  binning_probability = np.array([np.min(p) - 1, np.max(p) + 1])
116  hist_n, _ = np.histogramdd(np.c_[p, f],
117  bins=[binning_probability, binning_feature],
118  weights=w)
119  hist_inc = hist_n.sum(axis=1)
120  hist_inc /= hist_inc.sum(axis=0)
121  hist_n /= hist_n.sum(axis=0)
122  hist_n = hist_n.cumsum(axis=0)
123  hist_inc = hist_inc.cumsum(axis=0)
124  diff = (hist_n.T - hist_inc)**2
125  return np.sqrt(diff.sum() / (100 * 99))
126 
127 
128 class Method(object):
129  """
130  Wrapper class providing an interface to the method stored under the given identifier.
131  It loads the Options, can apply the expert and train new ones using the current as a prototype.
132  This class is used by the basf_mva_evaluation tools
133  """
134 
135  def __init__(self, identifier):
136  """
137  Load a method stored under the given identifier
138  @param identifier identifying the method
139  """
140  # Always avoid the top-level 'import ROOT'.
141  import ROOT # noqa
142  # Initialize all the available interfaces
143  ROOT.Belle2.MVA.AbstractInterface.initSupportedInterfaces()
144 
145  self.identifieridentifier = identifier
146 
147  self.weightfileweightfile = ROOT.Belle2.MVA.Weightfile.load(self.identifieridentifier)
148 
149  self.general_optionsgeneral_options = basf2_mva.GeneralOptions()
150  self.general_optionsgeneral_options.load(self.weightfileweightfile.getXMLTree())
151 
152  # This piece of code should be correct but leads to random segmentation faults
153  # inside python, llvm or pyroot, therefore we use the more dirty code below
154  # Ideas why this is happening:
155  # 1. Ownership of the unique_ptr returned by getOptions()
156  # 2. Some kind of object slicing, although pyroot identifies the correct type
157  # 3. Bug in pyroot
158  # interfaces = ROOT.Belle2.MVA.AbstractInterface.getSupportedInterfaces()
159  # self.interface = interfaces[self.general_options.m_method]
160  # self.specific_options = self.interface.getOptions()
161 
162 
163  self.specific_optionsspecific_options = None
164  if self.general_optionsgeneral_options.m_method == "FastBDT":
165  self.specific_optionsspecific_options = basf2_mva.FastBDTOptions()
166  elif self.general_optionsgeneral_options.m_method == "TMVAClassification":
167  self.specific_optionsspecific_options = basf2_mva.TMVAOptionsClassification()
168  elif self.general_optionsgeneral_options.m_method == "TMVARegression":
169  self.specific_optionsspecific_options = basf2_mva.TMVAOptionsRegression()
170  elif self.general_optionsgeneral_options.m_method == "FANN":
171  self.specific_optionsspecific_options = basf2_mva.FANNOptions()
172  elif self.general_optionsgeneral_options.m_method == "Python":
173  self.specific_optionsspecific_options = basf2_mva.PythonOptions()
174  elif self.general_optionsgeneral_options.m_method == "PDF":
175  self.specific_optionsspecific_options = basf2_mva.PDFOptions()
176  elif self.general_optionsgeneral_options.m_method == "Combination":
177  self.specific_optionsspecific_options = basf2_mva.CombinationOptions()
178  elif self.general_optionsgeneral_options.m_method == "Reweighter":
179  self.specific_optionsspecific_options = basf2_mva.ReweighterOptions()
180  elif self.general_optionsgeneral_options.m_method == "Trivial":
181  self.specific_optionsspecific_options = basf2_mva.TrivialOptions()
182  else:
183  raise RuntimeError("Unknown method " + self.general_optionsgeneral_options.m_method)
184 
185  self.specific_optionsspecific_options.load(self.weightfileweightfile.getXMLTree())
186 
187  variables = [str(v) for v in self.general_optionsgeneral_options.m_variables]
188  importances = self.weightfileweightfile.getFeatureImportance()
189 
190 
191  self.importancesimportances = {k: importances[k] for k in variables}
192 
193  self.variablesvariables = list(sorted(variables, key=lambda v: self.importancesimportances.get(v, 0.0)))
194 
195  self.root_variablesroot_variables = [ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(v) for v in self.variablesvariables]
196 
197  self.root_importancesroot_importances = {k: importances[k] for k in self.root_variablesroot_variables}
198 
199  self.descriptiondescription = str(basf2_mva.info(self.identifieridentifier))
200 
201  self.spectatorsspectators = [str(v) for v in self.general_optionsgeneral_options.m_spectators]
202 
203  self.root_spectatorsroot_spectators = [ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(v) for v in self.spectatorsspectators]
204 
205  def train_teacher(self, datafiles, treename, general_options=None, specific_options=None):
206  """
207  Train a new method using this method as a prototype
208  @param datafiles the training datafiles
209  @param treename the name of the tree containing the training data
210  @param general_options general options given to basf2_mva.teacher
211  (if None the options of this method are used)
212  @param specific_options specific options given to basf2_mva.teacher
213  (if None the options of this method are used)
214  """
215  # Always avoid the top-level 'import ROOT'.
216  import ROOT # noqa
217  if isinstance(datafiles, str):
218  datafiles = [datafiles]
219  if general_options is None:
220  general_options = self.general_optionsgeneral_options
221  if specific_options is None:
222  specific_options = self.specific_optionsspecific_options
223 
224  with tempfile.TemporaryDirectory() as tempdir:
225  identifier = tempdir + "/weightfile.xml"
226 
227  general_options.m_datafiles = basf2_mva.vector(*datafiles)
228  general_options.m_identifier = identifier
229 
230  basf2_mva.teacher(general_options, specific_options)
231 
232  method = Method(identifier)
233  return method
234 
235  def apply_expert(self, datafiles, treename):
236  """
237  Apply the expert of the method to data and return the calculated probability and the target
238  @param datafiles the datafiles
239  @param treename the name of the tree containing the data
240  """
241  import ROOT # noqa
242  if isinstance(datafiles, str):
243  datafiles = [datafiles]
244  with tempfile.TemporaryDirectory() as tempdir:
245  identifier = tempdir + "/weightfile.xml"
246  ROOT.Belle2.MVA.Weightfile.save(self.weightfileweightfile, identifier)
247 
248  rootfilename = tempdir + '/expert.root'
249  basf2_mva.expert(basf2_mva.vector(identifier),
250  basf2_mva.vector(*datafiles),
251  treename,
252  rootfilename)
253  rootfile = ROOT.TFile(rootfilename, "UPDATE")
254  roottree = rootfile.Get("variables")
255 
256  expert_target = identifier + '_' + self.general_optionsgeneral_options.m_target_variable
257  stripped_expert_target = self.identifieridentifier + '_' + self.general_optionsgeneral_options.m_target_variable
258 
259  output_names = [self.identifieridentifier]
260  branch_names = [
261  ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(identifier),
262  ]
263  if self.general_optionsgeneral_options.m_nClasses > 2:
264  output_names = [self.identifieridentifier+f'_{i}' for i in range(self.general_optionsgeneral_options.m_nClasses)]
265  branch_names = [
266  ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(
267  identifier +
268  f'_{i}') for i in range(
269  self.general_optionsgeneral_options.m_nClasses)]
270 
271  d = tree2dict(
272  roottree,
273  [*branch_names, ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(expert_target)],
274  [*output_names, stripped_expert_target])
275 
276  return (d[self.identifieridentifier] if self.general_optionsgeneral_options.m_nClasses <= 2 else np.array([d[x]
277  for x in output_names]).T), d[stripped_expert_target]
specific_options
Specific options of the method.
def apply_expert(self, datafiles, treename)
description
Description of the method as a xml string returned by basf2_mva.info.
importances
Dictionary of the variable importances calculated by the method.
root_importances
Dictionary of the variables sorted by their importance but with root compatoble variable names.
variables
List of variables sorted by their importance.
def __init__(self, identifier)
weightfile
Weightfile of the method.
root_spectators
List of spectators with root compatible names.
spectators
List of spectators.
def train_teacher(self, datafiles, treename, general_options=None, specific_options=None)
root_variables
List of the variable importances calculated by the method, but with the root compatible variable name...
general_options
General options of the method.
identifier
Identifier of the method.