Belle II Software  release-08-01-10
basf2_mva_util.py
1 
8 
9 import tempfile
10 import numpy as np
11 
12 from basf2 import B2WARNING
13 import basf2_mva
14 
15 
16 def chain2dict(chain, tree_columns, dict_columns=None):
17  """
18  Convert a ROOT.TChain into a dictionary of np.arrays
19  @param chain the ROOT.TChain
20  @param tree_columns the column (or branch) names in the tree
21  @param dict_columns the corresponding column names in the dictionary
22  """
23  if len(tree_columns) == 0:
24  return dict()
25  if dict_columns is None:
26  dict_columns = tree_columns
27  try:
28  from ROOT import RDataFrame
29  rdf = RDataFrame(chain)
30  d = np.column_stack(list(rdf.AsNumpy(tree_columns).values()))
31  d = np.core.records.fromarrays(d.transpose(), names=dict_columns)
32  except ImportError:
33  d = {column: np.zeros((chain.GetEntries(),)) for column in dict_columns}
34  for iEvent, event in enumerate(chain):
35  for dict_column, tree_column in zip(dict_columns, tree_columns):
36  d[dict_column][iEvent] = getattr(event, tree_column)
37  return d
38 
39 
40 def calculate_roc_auc(p, t):
41  """
42  Deprecated name of ``calculate_auc_efficiency_vs_purity``
43 
44  @param p np.array filled with the probability output of a classifier
45  @param t np.array filled with the target (0 or 1)
46  """
47  B2WARNING(
48  "\033[93mcalculate_roc_auc\033[00m has been deprecated and will be removed in future.\n"
49  "This change has been made as calculate_roc_auc returned the area under the efficiency-purity curve\n"
50  "not the efficiency-background retention curve as expected by users.\n"
51  "Please replace calculate_roc_auc with:\n\n"
52  "\033[96mcalculate_auc_efficiency_vs_purity(probability, target[, weight])\033[00m:"
53  " the current definition of calculate_roc_auc\n"
54  "\033[96mcalculate_auc_efficiency_vs_background_retention(probability, target[, weight])\033[00m:"
55  " what is commonly known as roc auc\n")
56  return calculate_auc_efficiency_vs_purity(p, t)
57 
58 
59 def calculate_auc_efficiency_vs_purity(p, t, w=None):
60  """
61  Calculates the area under the efficiency-purity curve
62  @param p np.array filled with the probability output of a classifier
63  @param t np.array filled with the target (0 or 1)
64  @param w None or np.array filled with weights
65  """
66  if w is None:
67  w = np.ones(t.shape)
68 
69  wt = w * t
70 
71  N = np.sum(w)
72  T = np.sum(wt)
73 
74  index = np.argsort(p)
75  efficiency = (T - np.cumsum(wt[index])) / float(T)
76  purity = (T - np.cumsum(wt[index])) / (N - np.cumsum(w[index]))
77  purity = np.where(np.isnan(purity), 0, purity)
78  return np.abs(np.trapz(purity, efficiency))
79 
80 
81 def calculate_auc_efficiency_vs_background_retention(p, t, w=None):
82  """
83  Calculates the area under the efficiency-background_retention curve (AUC ROC)
84  @param p np.array filled with the probability output of a classifier
85  @param t np.array filled with the target (0 or 1)
86  @param w None or np.array filled with weights
87  """
88  if w is None:
89  w = np.ones(t.shape)
90 
91  wt = w * t
92 
93  N = np.sum(w)
94  T = np.sum(wt)
95 
96  index = np.argsort(p)
97  efficiency = (T - np.cumsum(wt[index])) / float(T)
98  background_retention = (N - T - np.cumsum((np.abs(1 - t) * w)[index])) / float(N - T)
99  return np.abs(np.trapz(efficiency, background_retention))
100 
101 
102 def calculate_flatness(f, p, w=None):
103  """
104  Calculates the flatness of a feature under cuts on a signal probability
105  @param f the feature values
106  @param p the probability values
107  @param w optional weights
108  @return the mean standard deviation between the local and global cut selection efficiency
109  """
110  quantiles = list(range(101))
111  binning_feature = np.unique(np.percentile(f, q=quantiles))
112  binning_probability = np.unique(np.percentile(p, q=quantiles))
113  if len(binning_feature) < 2:
114  binning_feature = np.array([np.min(f) - 1, np.max(f) + 1])
115  if len(binning_probability) < 2:
116  binning_probability = np.array([np.min(p) - 1, np.max(p) + 1])
117  hist_n, _ = np.histogramdd(np.c_[p, f],
118  bins=[binning_probability, binning_feature],
119  weights=w)
120  hist_inc = hist_n.sum(axis=1)
121  hist_inc /= hist_inc.sum(axis=0)
122  hist_n /= hist_n.sum(axis=0)
123  hist_n = hist_n.cumsum(axis=0)
124  hist_inc = hist_inc.cumsum(axis=0)
125  diff = (hist_n.T - hist_inc)**2
126  return np.sqrt(diff.sum() / (100 * 99))
127 
128 
129 class Method:
130  """
131  Wrapper class providing an interface to the method stored under the given identifier.
132  It loads the Options, can apply the expert and train new ones using the current as a prototype.
133  This class is used by the basf_mva_evaluation tools
134  """
135 
136  def __init__(self, identifier):
137  """
138  Load a method stored under the given identifier
139  @param identifier identifying the method
140  """
141  # Always avoid the top-level 'import ROOT'.
142  import ROOT # noqa
143  # Initialize all the available interfaces
144  ROOT.Belle2.MVA.AbstractInterface.initSupportedInterfaces()
145 
146  self.identifieridentifier = identifier
147 
148  self.weightfileweightfile = ROOT.Belle2.MVA.Weightfile.load(self.identifieridentifier)
149 
150  self.general_optionsgeneral_options = basf2_mva.GeneralOptions()
151  self.general_optionsgeneral_options.load(self.weightfileweightfile.getXMLTree())
152 
153  # This piece of code should be correct but leads to random segmentation faults
154  # inside python, llvm or pyroot, therefore we use the more dirty code below
155  # Ideas why this is happening:
156  # 1. Ownership of the unique_ptr returned by getOptions()
157  # 2. Some kind of object slicing, although pyroot identifies the correct type
158  # 3. Bug in pyroot
159  # interfaces = ROOT.Belle2.MVA.AbstractInterface.getSupportedInterfaces()
160  # self.interface = interfaces[self.general_options.m_method]
161  # self.specific_options = self.interface.getOptions()
162 
163 
164  self.specific_optionsspecific_options = None
165  if self.general_optionsgeneral_options.m_method == "FastBDT":
166  self.specific_optionsspecific_options = basf2_mva.FastBDTOptions()
167  elif self.general_optionsgeneral_options.m_method == "TMVAClassification":
168  self.specific_optionsspecific_options = basf2_mva.TMVAOptionsClassification()
169  elif self.general_optionsgeneral_options.m_method == "TMVARegression":
170  self.specific_optionsspecific_options = basf2_mva.TMVAOptionsRegression()
171  elif self.general_optionsgeneral_options.m_method == "FANN":
172  self.specific_optionsspecific_options = basf2_mva.FANNOptions()
173  elif self.general_optionsgeneral_options.m_method == "Python":
174  self.specific_optionsspecific_options = basf2_mva.PythonOptions()
175  elif self.general_optionsgeneral_options.m_method == "PDF":
176  self.specific_optionsspecific_options = basf2_mva.PDFOptions()
177  elif self.general_optionsgeneral_options.m_method == "Combination":
178  self.specific_optionsspecific_options = basf2_mva.CombinationOptions()
179  elif self.general_optionsgeneral_options.m_method == "Reweighter":
180  self.specific_optionsspecific_options = basf2_mva.ReweighterOptions()
181  elif self.general_optionsgeneral_options.m_method == "Trivial":
182  self.specific_optionsspecific_options = basf2_mva.TrivialOptions()
183  else:
184  raise RuntimeError("Unknown method " + self.general_optionsgeneral_options.m_method)
185 
186  self.specific_optionsspecific_options.load(self.weightfileweightfile.getXMLTree())
187 
188  variables = [str(v) for v in self.general_optionsgeneral_options.m_variables]
189  importances = self.weightfileweightfile.getFeatureImportance()
190 
191 
192  self.importancesimportances = {k: importances[k] for k in variables}
193 
194  self.variablesvariables = list(sorted(variables, key=lambda v: self.importancesimportances.get(v, 0.0)))
195 
196  self.root_variablesroot_variables = [ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(v) for v in self.variablesvariables]
197 
198  self.root_importancesroot_importances = {k: importances[k] for k in self.root_variablesroot_variables}
199 
200  self.descriptiondescription = str(basf2_mva.info(self.identifieridentifier))
201 
202  self.spectatorsspectators = [str(v) for v in self.general_optionsgeneral_options.m_spectators]
203 
204  self.root_spectatorsroot_spectators = [ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(v) for v in self.spectatorsspectators]
205 
206  def train_teacher(self, datafiles, treename, general_options=None, specific_options=None):
207  """
208  Train a new method using this method as a prototype
209  @param datafiles the training datafiles
210  @param treename the name of the tree containing the training data
211  @param general_options general options given to basf2_mva.teacher
212  (if None the options of this method are used)
213  @param specific_options specific options given to basf2_mva.teacher
214  (if None the options of this method are used)
215  """
216  # Always avoid the top-level 'import ROOT'.
217  import ROOT # noqa
218  if isinstance(datafiles, str):
219  datafiles = [datafiles]
220  if general_options is None:
221  general_options = self.general_optionsgeneral_options
222  if specific_options is None:
223  specific_options = self.specific_optionsspecific_options
224 
225  with tempfile.TemporaryDirectory() as tempdir:
226  identifier = tempdir + "/weightfile.xml"
227 
228  general_options.m_datafiles = basf2_mva.vector(*datafiles)
229  general_options.m_identifier = identifier
230 
231  basf2_mva.teacher(general_options, specific_options)
232 
233  method = Method(identifier)
234  return method
235 
236  def apply_expert(self, datafiles, treename):
237  """
238  Apply the expert of the method to data and return the calculated probability and the target
239  @param datafiles the datafiles
240  @param treename the name of the tree containing the data
241  """
242  import ROOT # noqa
243  if isinstance(datafiles, str):
244  datafiles = [datafiles]
245  with tempfile.TemporaryDirectory() as tempdir:
246  identifier = tempdir + "/weightfile.xml"
247  ROOT.Belle2.MVA.Weightfile.save(self.weightfileweightfile, identifier)
248 
249  rootfilename = tempdir + '/expert.root'
250  basf2_mva.expert(basf2_mva.vector(identifier),
251  basf2_mva.vector(*datafiles),
252  treename,
253  rootfilename)
254  chain = ROOT.TChain("variables")
255  chain.Add(rootfilename)
256 
257  expert_target = identifier + '_' + self.general_optionsgeneral_options.m_target_variable
258  stripped_expert_target = self.identifieridentifier + '_' + self.general_optionsgeneral_options.m_target_variable
259 
260  output_names = [self.identifieridentifier]
261  branch_names = [
262  ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(identifier),
263  ]
264  if self.general_optionsgeneral_options.m_nClasses > 2:
265  output_names = [self.identifieridentifier+f'_{i}' for i in range(self.general_optionsgeneral_options.m_nClasses)]
266  branch_names = [
267  ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(
268  identifier +
269  f'_{i}') for i in range(
270  self.general_optionsgeneral_options.m_nClasses)]
271 
272  d = chain2dict(
273  chain,
274  [*branch_names, ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(expert_target)],
275  [*output_names, stripped_expert_target])
276 
277  return (d[self.identifieridentifier] if self.general_optionsgeneral_options.m_nClasses <= 2 else np.array([d[x]
278  for x in output_names]).T), d[stripped_expert_target]
specific_options
Specific options of the method.
def apply_expert(self, datafiles, treename)
description
Description of the method as a xml string returned by basf2_mva.info.
importances
Dictionary of the variable importances calculated by the method.
root_importances
Dictionary of the variables sorted by their importance but with root compatoble variable names.
variables
List of variables sorted by their importance.
def __init__(self, identifier)
weightfile
Weightfile of the method.
root_spectators
List of spectators with root compatible names.
spectators
List of spectators.
def train_teacher(self, datafiles, treename, general_options=None, specific_options=None)
root_variables
List of the variable importances calculated by the method, but with the root compatible variable name...
general_options
General options of the method.
identifier
Identifier of the method.