Belle II Software development
basf2_mva_util.py
1
8
9import tempfile
10import numpy as np
11
12from basf2 import B2WARNING
13import basf2_mva
14
15
16def chain2dict(chain, tree_columns, dict_columns=None):
17 """
18 Convert a ROOT.TChain into a dictionary of np.arrays
19 @param chain the ROOT.TChain
20 @param tree_columns the column (or branch) names in the tree
21 @param dict_columns the corresponding column names in the dictionary
22 """
23 if len(tree_columns) == 0:
24 return dict()
25 if dict_columns is None:
26 dict_columns = tree_columns
27 try:
28 from ROOT import RDataFrame
29 rdf = RDataFrame(chain)
30 d = np.column_stack(list(rdf.AsNumpy(tree_columns).values()))
31 d = np.core.records.fromarrays(d.transpose(), names=dict_columns)
32 except ImportError:
33 d = {column: np.zeros((chain.GetEntries(),)) for column in dict_columns}
34 for iEvent, event in enumerate(chain):
35 for dict_column, tree_column in zip(dict_columns, tree_columns):
36 d[dict_column][iEvent] = getattr(event, tree_column)
37 return d
38
39
40def calculate_roc_auc(p, t):
41 """
42 Deprecated name of ``calculate_auc_efficiency_vs_purity``
43
44 @param p np.array filled with the probability output of a classifier
45 @param t np.array filled with the target (0 or 1)
46 """
47 B2WARNING(
48 "\033[93mcalculate_roc_auc\033[00m has been deprecated and will be removed in future.\n"
49 "This change has been made as calculate_roc_auc returned the area under the efficiency-purity curve\n"
50 "not the efficiency-background retention curve as expected by users.\n"
51 "Please replace calculate_roc_auc with:\n\n"
52 "\033[96mcalculate_auc_efficiency_vs_purity(probability, target[, weight])\033[00m:"
53 " the current definition of calculate_roc_auc\n"
54 "\033[96mcalculate_auc_efficiency_vs_background_retention(probability, target[, weight])\033[00m:"
55 " what is commonly known as roc auc\n")
56 return calculate_auc_efficiency_vs_purity(p, t)
57
58
59def calculate_auc_efficiency_vs_purity(p, t, w=None):
60 """
61 Calculates the area under the efficiency-purity curve
62 @param p np.array filled with the probability output of a classifier
63 @param t np.array filled with the target (0 or 1)
64 @param w None or np.array filled with weights
65 """
66 if w is None:
67 w = np.ones(t.shape)
68
69 wt = w * t
70
71 N = np.sum(w)
72 T = np.sum(wt)
73
74 index = np.argsort(p)
75 efficiency = (T - np.cumsum(wt[index])) / float(T)
76 purity = (T - np.cumsum(wt[index])) / (N - np.cumsum(w[index]))
77 purity = np.where(np.isnan(purity), 0, purity)
78 return np.abs(np.trapz(purity, efficiency))
79
80
81def calculate_auc_efficiency_vs_background_retention(p, t, w=None):
82 """
83 Calculates the area under the efficiency-background_retention curve (AUC ROC)
84 @param p np.array filled with the probability output of a classifier
85 @param t np.array filled with the target (0 or 1)
86 @param w None or np.array filled with weights
87 """
88 if w is None:
89 w = np.ones(t.shape)
90
91 wt = w * t
92
93 N = np.sum(w)
94 T = np.sum(wt)
95
96 index = np.argsort(p)
97 efficiency = (T - np.cumsum(wt[index])) / float(T)
98 background_retention = (N - T - np.cumsum((np.abs(1 - t) * w)[index])) / float(N - T)
99 return np.abs(np.trapz(efficiency, background_retention))
100
101
102def calculate_flatness(f, p, w=None):
103 """
104 Calculates the flatness of a feature under cuts on a signal probability
105 @param f the feature values
106 @param p the probability values
107 @param w optional weights
108 @return the mean standard deviation between the local and global cut selection efficiency
109 """
110 quantiles = list(range(101))
111 binning_feature = np.unique(np.percentile(f, q=quantiles))
112 binning_probability = np.unique(np.percentile(p, q=quantiles))
113 if len(binning_feature) < 2:
114 binning_feature = np.array([np.min(f) - 1, np.max(f) + 1])
115 if len(binning_probability) < 2:
116 binning_probability = np.array([np.min(p) - 1, np.max(p) + 1])
117 hist_n, _ = np.histogramdd(np.c_[p, f],
118 bins=[binning_probability, binning_feature],
119 weights=w)
120 hist_inc = hist_n.sum(axis=1)
121 hist_inc /= hist_inc.sum(axis=0)
122 hist_n /= hist_n.sum(axis=0)
123 hist_n = hist_n.cumsum(axis=0)
124 hist_inc = hist_inc.cumsum(axis=0)
125 diff = (hist_n.T - hist_inc)**2
126 return np.sqrt(diff.sum() / (100 * 99))
127
128
129class Method:
130 """
131 Wrapper class providing an interface to the method stored under the given identifier.
132 It loads the Options, can apply the expert and train new ones using the current as a prototype.
133 This class is used by the basf_mva_evaluation tools
134 """
135
136 def __init__(self, identifier):
137 """
138 Load a method stored under the given identifier
139 @param identifier identifying the method
140 """
141 # Always avoid the top-level 'import ROOT'.
142 import ROOT # noqa
143 # Initialize all the available interfaces
144 ROOT.Belle2.MVA.AbstractInterface.initSupportedInterfaces()
145
146 self.identifier = identifier
147
148 self.weightfile = ROOT.Belle2.MVA.Weightfile.load(self.identifier)
149
150 self.general_options = basf2_mva.GeneralOptions()
151 self.general_options.load(self.weightfile.getXMLTree())
152
153 # This piece of code should be correct but leads to random segmentation faults
154 # inside python, llvm or pyroot, therefore we use the more dirty code below
155 # Ideas why this is happening:
156 # 1. Ownership of the unique_ptr returned by getOptions()
157 # 2. Some kind of object slicing, although pyroot identifies the correct type
158 # 3. Bug in pyroot
159 # interfaces = ROOT.Belle2.MVA.AbstractInterface.getSupportedInterfaces()
160 # self.interface = interfaces[self.general_options.m_method]
161 # self.specific_options = self.interface.getOptions()
162
163
165 if self.general_options.m_method == "FastBDT":
166 self.specific_options = basf2_mva.FastBDTOptions()
167 elif self.general_options.m_method == "TMVAClassification":
168 self.specific_options = basf2_mva.TMVAOptionsClassification()
169 elif self.general_options.m_method == "TMVARegression":
170 self.specific_options = basf2_mva.TMVAOptionsRegression()
171 elif self.general_options.m_method == "FANN":
172 self.specific_options = basf2_mva.FANNOptions()
173 elif self.general_options.m_method == "Python":
174 self.specific_options = basf2_mva.PythonOptions()
175 elif self.general_options.m_method == "PDF":
176 self.specific_options = basf2_mva.PDFOptions()
177 elif self.general_options.m_method == "Combination":
178 self.specific_options = basf2_mva.CombinationOptions()
179 elif self.general_options.m_method == "Reweighter":
180 self.specific_options = basf2_mva.ReweighterOptions()
181 elif self.general_options.m_method == "Trivial":
182 self.specific_options = basf2_mva.TrivialOptions()
183 else:
184 raise RuntimeError("Unknown method " + self.general_options.m_method)
185
186 self.specific_options.load(self.weightfile.getXMLTree())
187
188 variables = [str(v) for v in self.general_options.m_variables]
189 importances = self.weightfile.getFeatureImportance()
190
191
192 self.importances = {k: importances[k] for k in variables}
193
194 self.variables = list(sorted(variables, key=lambda v: self.importances.get(v, 0.0)))
195
196 self.root_variables = [ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(v) for v in self.variables]
197
198 self.root_importances = {k: importances[k] for k in self.root_variables}
199
200 self.description = str(basf2_mva.info(self.identifier))
201
202 self.spectators = [str(v) for v in self.general_options.m_spectators]
203
204 self.root_spectators = [ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(v) for v in self.spectators]
205
206 def train_teacher(self, datafiles, treename, general_options=None, specific_options=None):
207 """
208 Train a new method using this method as a prototype
209 @param datafiles the training datafiles
210 @param treename the name of the tree containing the training data
211 @param general_options general options given to basf2_mva.teacher
212 (if None the options of this method are used)
213 @param specific_options specific options given to basf2_mva.teacher
214 (if None the options of this method are used)
215 """
216 # Always avoid the top-level 'import ROOT'.
217 import ROOT # noqa
218 if isinstance(datafiles, str):
219 datafiles = [datafiles]
220 if general_options is None:
221 general_options = self.general_options
222 if specific_options is None:
223 specific_options = self.specific_options
224
225 with tempfile.TemporaryDirectory() as tempdir:
226 identifier = tempdir + "/weightfile.xml"
227
228 general_options.m_datafiles = basf2_mva.vector(*datafiles)
229 general_options.m_identifier = identifier
230
231 basf2_mva.teacher(general_options, specific_options)
232
233 method = Method(identifier)
234 return method
235
236 def apply_expert(self, datafiles, treename):
237 """
238 Apply the expert of the method to data and return the calculated probability and the target
239 @param datafiles the datafiles
240 @param treename the name of the tree containing the data
241 """
242 import ROOT # noqa
243 if isinstance(datafiles, str):
244 datafiles = [datafiles]
245 with tempfile.TemporaryDirectory() as tempdir:
246 identifier = tempdir + "/weightfile.xml"
247 ROOT.Belle2.MVA.Weightfile.save(self.weightfile, identifier)
248
249 rootfilename = tempdir + '/expert.root'
250 basf2_mva.expert(basf2_mva.vector(identifier),
251 basf2_mva.vector(*datafiles),
252 treename,
253 rootfilename)
254 chain = ROOT.TChain("variables")
255 chain.Add(rootfilename)
256
257 expert_target = identifier + '_' + self.general_options.m_target_variable
258 stripped_expert_target = self.identifier + '_' + self.general_options.m_target_variable
259
260 output_names = [self.identifier]
261 branch_names = [
262 ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(identifier),
263 ]
264 if self.general_options.m_nClasses > 2:
265 output_names = [self.identifier+f'_{i}' for i in range(self.general_options.m_nClasses)]
266 branch_names = [
267 ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(
268 identifier +
269 f'_{i}') for i in range(
270 self.general_options.m_nClasses)]
271
272 d = chain2dict(
273 chain,
274 [*branch_names, ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(expert_target)],
275 [*output_names, stripped_expert_target])
276
277 return (d[self.identifier] if self.general_options.m_nClasses <= 2 else np.array([d[x]
278 for x in output_names]).T), d[stripped_expert_target]
specific_options
Specific options of the method.
def apply_expert(self, datafiles, treename)
description
Description of the method as a xml string returned by basf2_mva.info.
importances
Dictionary of the variable importances calculated by the method.
root_importances
Dictionary of the variables sorted by their importance but with root compatoble variable names.
variables
List of variables sorted by their importance.
def __init__(self, identifier)
weightfile
Weightfile of the method.
root_spectators
List of spectators with root compatible names.
spectators
List of spectators.
def train_teacher(self, datafiles, treename, general_options=None, specific_options=None)
root_variables
List of the variable importances calculated by the method, but with the root compatible variable name...
general_options
General options of the method.
identifier
Identifier of the method.
Definition: tools.py:1