Belle II Software development
basf2_mva_util.py
1
8
9import tempfile
10import numpy as np
11
12from basf2 import B2WARNING
13import basf2_mva
14
15
16def chain2dict(chain, tree_columns, dict_columns=None, max_entries=None):
17 """
18 Convert a ROOT.TChain into a dictionary of np.arrays
19 @param chain the ROOT.TChain
20 @param tree_columns the column (or branch) names in the tree
21 @param dict_columns the corresponding column names in the dictionary
22 """
23 if len(tree_columns) == 0:
24 return dict()
25 if dict_columns is None:
26 dict_columns = tree_columns
27 try:
28 from ROOT import RDataFrame
29 rdf = RDataFrame(chain)
30 if max_entries is not None:
31 nEntries = rdf.Count().GetValue()
32 if nEntries > max_entries:
33 B2WARNING(
34 "basf2_mva_util (chain2dict): Number of entries in the chain is larger than the maximum allowed entries: " +
35 str(nEntries) +
36 " > " +
37 str(max_entries))
38 skip = nEntries // max_entries
39 rdf_subset = rdf.Filter("rdfentry_ % " + str(skip) + " == 0")
40 rdf = rdf_subset
41
42 d = np.column_stack(list(rdf.AsNumpy(tree_columns).values()))
43 d = np.core.records.fromarrays(d.transpose(), names=dict_columns)
44 except ImportError:
45 d = {column: np.zeros((chain.GetEntries(),)) for column in dict_columns}
46 for iEvent, event in enumerate(chain):
47 for dict_column, tree_column in zip(dict_columns, tree_columns):
48 d[dict_column][iEvent] = getattr(event, tree_column)
49 return d
50
51
52def calculate_auc_efficiency_vs_purity(p, t, w=None):
53 """
54 Calculates the area under the efficiency-purity curve
55 @param p np.array filled with the probability output of a classifier
56 @param t np.array filled with the target (0 or 1)
57 @param w None or np.array filled with weights
58 """
59 if w is None:
60 w = np.ones(t.shape)
61
62 wt = w * t
63
64 N = np.sum(w)
65 T = np.sum(wt)
66
67 index = np.argsort(p)
68 efficiency = (T - np.cumsum(wt[index])) / float(T)
69 purity = (T - np.cumsum(wt[index])) / (N - np.cumsum(w[index]))
70 purity = np.where(np.isnan(purity), 0, purity)
71 return np.abs(np.trapz(purity, efficiency))
72
73
74def calculate_auc_efficiency_vs_background_retention(p, t, w=None):
75 """
76 Calculates the area under the efficiency-background_retention curve (AUC ROC)
77 @param p np.array filled with the probability output of a classifier
78 @param t np.array filled with the target (0 or 1)
79 @param w None or np.array filled with weights
80 """
81 if w is None:
82 w = np.ones(t.shape)
83
84 wt = w * t
85
86 N = np.sum(w)
87 T = np.sum(wt)
88
89 index = np.argsort(p)
90 efficiency = (T - np.cumsum(wt[index])) / float(T)
91 background_retention = (N - T - np.cumsum((np.abs(1 - t) * w)[index])) / float(N - T)
92 return np.abs(np.trapz(efficiency, background_retention))
93
94
95def calculate_flatness(f, p, w=None):
96 """
97 Calculates the flatness of a feature under cuts on a signal probability
98 @param f the feature values
99 @param p the probability values
100 @param w optional weights
101 @return the mean standard deviation between the local and global cut selection efficiency
102 """
103 quantiles = list(range(101))
104 binning_feature = np.unique(np.percentile(f, q=quantiles))
105 binning_probability = np.unique(np.percentile(p, q=quantiles))
106 if len(binning_feature) < 2:
107 binning_feature = np.array([np.min(f) - 1, np.max(f) + 1])
108 if len(binning_probability) < 2:
109 binning_probability = np.array([np.min(p) - 1, np.max(p) + 1])
110 hist_n, _ = np.histogramdd(np.c_[p, f],
111 bins=[binning_probability, binning_feature],
112 weights=w)
113 hist_inc = hist_n.sum(axis=1)
114 hist_inc /= hist_inc.sum(axis=0)
115 hist_n /= hist_n.sum(axis=0)
116 hist_n = hist_n.cumsum(axis=0)
117 hist_inc = hist_inc.cumsum(axis=0)
118 diff = (hist_n.T - hist_inc)**2
119 return np.sqrt(diff.sum() / (100 * 99))
120
121
122class Method:
123 """
124 Wrapper class providing an interface to the method stored under the given identifier.
125 It loads the Options, can apply the expert and train new ones using the current as a prototype.
126 This class is used by the basf_mva_evaluation tools
127 """
128
129 def __init__(self, identifier):
130 """
131 Load a method stored under the given identifier
132 @param identifier identifying the method
133 """
134 # Always avoid the top-level 'import ROOT'.
135 import ROOT # noqa
136 # Initialize all the available interfaces
137 ROOT.Belle2.MVA.AbstractInterface.initSupportedInterfaces()
138
139 self.identifier = identifier
140
141 self.weightfile = ROOT.Belle2.MVA.Weightfile.load(self.identifier)
142
143 self.general_options = basf2_mva.GeneralOptions()
144 self.general_options.load(self.weightfile.getXMLTree())
145
146 # This piece of code should be correct but leads to random segmentation faults
147 # inside python, llvm or pyroot, therefore we use the more dirty code below
148 # Ideas why this is happening:
149 # 1. Ownership of the unique_ptr returned by getOptions()
150 # 2. Some kind of object slicing, although pyroot identifies the correct type
151 # 3. Bug in pyroot
152 # interfaces = ROOT.Belle2.MVA.AbstractInterface.getSupportedInterfaces()
153 # self.interface = interfaces[self.general_options.m_method]
154 # self.specific_options = self.interface.getOptions()
155
156
158 if self.general_options.m_method == "FastBDT":
159 self.specific_options = basf2_mva.FastBDTOptions()
160 elif self.general_options.m_method == "TMVAClassification":
161 self.specific_options = basf2_mva.TMVAOptionsClassification()
162 elif self.general_options.m_method == "TMVARegression":
163 self.specific_options = basf2_mva.TMVAOptionsRegression()
164 elif self.general_options.m_method == "FANN":
165 self.specific_options = basf2_mva.FANNOptions()
166 elif self.general_options.m_method == "Python":
167 self.specific_options = basf2_mva.PythonOptions()
168 elif self.general_options.m_method == "PDF":
169 self.specific_options = basf2_mva.PDFOptions()
170 elif self.general_options.m_method == "Combination":
171 self.specific_options = basf2_mva.CombinationOptions()
172 elif self.general_options.m_method == "Reweighter":
173 self.specific_options = basf2_mva.ReweighterOptions()
174 elif self.general_options.m_method == "Trivial":
175 self.specific_options = basf2_mva.TrivialOptions()
176 elif self.general_options.m_method == "ONNX":
177 self.specific_options = basf2_mva.ONNXOptions()
178 else:
179 raise RuntimeError("Unknown method " + self.general_options.m_method)
180
181 self.specific_options.load(self.weightfile.getXMLTree())
182
183 variables = [str(v) for v in self.general_options.m_variables]
184 importances = self.weightfile.getFeatureImportance()
185
186
187 self.importances = {k: importances[k] for k in variables}
188
189 self.variables = list(sorted(variables, key=lambda v: self.importances.get(v, 0.0)))
190
191 self.root_variables = [ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(v) for v in self.variables]
192
193 self.root_importances = {k: importances[k] for k in self.root_variables}
194
195 self.description = str(basf2_mva.info(self.identifier))
196
197 self.spectators = [str(v) for v in self.general_options.m_spectators]
198
199 self.root_spectators = [ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(v) for v in self.spectators]
200
201 def train_teacher(self, datafiles, treename, general_options=None, specific_options=None):
202 """
203 Train a new method using this method as a prototype
204 @param datafiles the training datafiles
205 @param treename the name of the tree containing the training data
206 @param general_options general options given to basf2_mva.teacher
207 (if None the options of this method are used)
208 @param specific_options specific options given to basf2_mva.teacher
209 (if None the options of this method are used)
210 """
211 # Always avoid the top-level 'import ROOT'.
212 import ROOT # noqa
213 if isinstance(datafiles, str):
214 datafiles = [datafiles]
215 if general_options is None:
216 general_options = self.general_options
217 if specific_options is None:
218 specific_options = self.specific_options
219
220 with tempfile.TemporaryDirectory() as tempdir:
221 identifier = tempdir + "/weightfile.xml"
222
223 general_options.m_datafiles = basf2_mva.vector(*datafiles)
224 general_options.m_identifier = identifier
225
226 basf2_mva.teacher(general_options, specific_options)
227
228 method = Method(identifier)
229 return method
230
231 def apply_expert(self, datafiles, treename):
232 """
233 Apply the expert of the method to data and return the calculated probability and the target
234 @param datafiles the datafiles
235 @param treename the name of the tree containing the data
236 """
237 import ROOT # noqa
238 if isinstance(datafiles, str):
239 datafiles = [datafiles]
240 with tempfile.TemporaryDirectory() as tempdir:
241 identifier = tempdir + "/weightfile.xml"
242 ROOT.Belle2.MVA.Weightfile.save(self.weightfile, identifier)
243
244 rootfilename = tempdir + '/expert.root'
245 basf2_mva.expert(basf2_mva.vector(identifier),
246 basf2_mva.vector(*datafiles),
247 treename,
248 rootfilename)
249 chain = ROOT.TChain("variables")
250 chain.Add(rootfilename)
251
252 expert_target = identifier + '_' + self.general_options.m_target_variable
253 stripped_expert_target = self.identifier + '_' + self.general_options.m_target_variable
254
255 output_names = [self.identifier]
256 branch_names = [
257 ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(identifier),
258 ]
259 if self.general_options.m_nClasses > 2:
260 output_names = [self.identifier+f'_{i}' for i in range(self.general_options.m_nClasses)]
261 branch_names = [
262 ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(
263 identifier +
264 f'_{i}') for i in range(
265 self.general_options.m_nClasses)]
266
267 d = chain2dict(
268 chain,
269 [*branch_names, ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(expert_target)],
270 [*output_names, stripped_expert_target])
271
272 return (d[str(self.identifier)] if self.general_options.m_nClasses <= 2 else np.array([d[x]
273 for x in output_names]).T), d[stripped_expert_target]
274
275
276def create_onnx_mva_weightfile(onnx_model_path, **kwargs):
277 """
278 Create an MVA Weightfile for ONNX
279
280 Parameters
281 ----------
282 kwargs :
283 keyword arguments to set the options in the weightfile. They are
284 directly mapped to member variable names of the option classes with "m_"
285 added automatically. First, GeneralOptions are tried and the remaining
286 arguments are passed to ONNXOptions.
287
288 Returns
289 -------
290 weightfile :
291 Weightfile object containing the ONNX model and options
292
293 Example:
294 --------
295 >>> weightfile = create_onnx_mva_weightfile(
296 ... "model.onnx",
297 ... outputName="probabilities",
298 ... variables=["variable1", "variable2"],
299 ... target_variable="isSignal"
300 ...)
301 >>> weightfile.save("model.root")
302 """
303 general_options = basf2_mva.GeneralOptions()
304 onnx_options = basf2_mva.ONNXOptions()
305 general_options.m_method = onnx_options.getMethod()
306
307 # fill everything that exists in general options from kwargs
308 for k, v in list(kwargs.items()):
309 m_k = f"m_{k}"
310 if hasattr(general_options, m_k):
311 setattr(general_options, m_k, v)
312 kwargs.pop(k)
313
314 # for the rest try to set members of specific options
315 for k, v in list(kwargs.items()):
316 m_k = f"m_{k}"
317 if not hasattr(onnx_options, m_k):
318 raise AttributeError(f"No member named {m_k} in ONNXOptions.")
319 setattr(onnx_options, m_k, v)
320
321 w = basf2_mva.Weightfile()
322 w.addOptions(general_options)
323 w.addOptions(onnx_options)
324 w.addFile("ONNX_Modelfile", str(onnx_model_path))
325 return w
dict importances
Dictionary of the variable importances calculated by the method.
specific_options
Specific options of the method.
description
Description of the method as a xml string returned by basf2_mva.info.
__init__(self, identifier)
dict root_importances
Dictionary of the variables sorted by their importance but with root compatoble variable names.
variables
List of variables sorted by their importance.
weightfile
Weightfile of the method.
list root_spectators
List of spectators with root compatible names.
list root_variables
List of the variable importances calculated by the method, but with the root compatible variable name...
list spectators
List of spectators.
train_teacher(self, datafiles, treename, general_options=None, specific_options=None)
general_options
General options of the method.
apply_expert(self, datafiles, treename)
identifier
Identifier of the method.