Belle II Software development
basf2_mva_util.py
1
8
9import tempfile
10import numpy as np
11
12from basf2 import B2WARNING
13import basf2_mva
14
15
16def chain2dict(chain, tree_columns, dict_columns=None, max_entries=None):
17 """
18 Convert a ROOT.TChain into a dictionary of np.arrays
19 @param chain the ROOT.TChain
20 @param tree_columns the column (or branch) names in the tree
21 @param dict_columns the corresponding column names in the dictionary
22 """
23 if len(tree_columns) == 0:
24 return dict()
25 if dict_columns is None:
26 dict_columns = tree_columns
27 try:
28 from ROOT import RDataFrame
29 rdf = RDataFrame(chain)
30 if max_entries is not None:
31 nEntries = rdf.Count().GetValue()
32 if nEntries > max_entries:
33 B2WARNING(
34 "basf2_mva_util (chain2dict): Number of entries in the chain is larger than the maximum allowed entries: " +
35 str(nEntries) +
36 " > " +
37 str(max_entries))
38 skip = nEntries // max_entries
39 rdf_subset = rdf.Filter("rdfentry_ % " + str(skip) + " == 0")
40 rdf = rdf_subset
41
42 d = np.column_stack(list(rdf.AsNumpy(tree_columns).values()))
43 d = np.core.records.fromarrays(d.transpose(), names=dict_columns)
44 except ImportError:
45 d = {column: np.zeros((chain.GetEntries(),)) for column in dict_columns}
46 for iEvent, event in enumerate(chain):
47 for dict_column, tree_column in zip(dict_columns, tree_columns):
48 d[dict_column][iEvent] = getattr(event, tree_column)
49 return d
50
51
52def calculate_roc_auc(p, t):
53 """
54 Deprecated name of ``calculate_auc_efficiency_vs_purity``
55
56 @param p np.array filled with the probability output of a classifier
57 @param t np.array filled with the target (0 or 1)
58 """
59 B2WARNING(
60 "\033[93mcalculate_roc_auc\033[00m has been deprecated and will be removed in future.\n"
61 "This change has been made as calculate_roc_auc returned the area under the efficiency-purity curve\n"
62 "not the efficiency-background retention curve as expected by users.\n"
63 "Please replace calculate_roc_auc with:\n\n"
64 "\033[96mcalculate_auc_efficiency_vs_purity(probability, target[, weight])\033[00m:"
65 " the current definition of calculate_roc_auc\n"
66 "\033[96mcalculate_auc_efficiency_vs_background_retention(probability, target[, weight])\033[00m:"
67 " what is commonly known as roc auc\n")
68 return calculate_auc_efficiency_vs_purity(p, t)
69
70
71def calculate_auc_efficiency_vs_purity(p, t, w=None):
72 """
73 Calculates the area under the efficiency-purity curve
74 @param p np.array filled with the probability output of a classifier
75 @param t np.array filled with the target (0 or 1)
76 @param w None or np.array filled with weights
77 """
78 if w is None:
79 w = np.ones(t.shape)
80
81 wt = w * t
82
83 N = np.sum(w)
84 T = np.sum(wt)
85
86 index = np.argsort(p)
87 efficiency = (T - np.cumsum(wt[index])) / float(T)
88 purity = (T - np.cumsum(wt[index])) / (N - np.cumsum(w[index]))
89 purity = np.where(np.isnan(purity), 0, purity)
90 return np.abs(np.trapz(purity, efficiency))
91
92
93def calculate_auc_efficiency_vs_background_retention(p, t, w=None):
94 """
95 Calculates the area under the efficiency-background_retention curve (AUC ROC)
96 @param p np.array filled with the probability output of a classifier
97 @param t np.array filled with the target (0 or 1)
98 @param w None or np.array filled with weights
99 """
100 if w is None:
101 w = np.ones(t.shape)
102
103 wt = w * t
104
105 N = np.sum(w)
106 T = np.sum(wt)
107
108 index = np.argsort(p)
109 efficiency = (T - np.cumsum(wt[index])) / float(T)
110 background_retention = (N - T - np.cumsum((np.abs(1 - t) * w)[index])) / float(N - T)
111 return np.abs(np.trapz(efficiency, background_retention))
112
113
114def calculate_flatness(f, p, w=None):
115 """
116 Calculates the flatness of a feature under cuts on a signal probability
117 @param f the feature values
118 @param p the probability values
119 @param w optional weights
120 @return the mean standard deviation between the local and global cut selection efficiency
121 """
122 quantiles = list(range(101))
123 binning_feature = np.unique(np.percentile(f, q=quantiles))
124 binning_probability = np.unique(np.percentile(p, q=quantiles))
125 if len(binning_feature) < 2:
126 binning_feature = np.array([np.min(f) - 1, np.max(f) + 1])
127 if len(binning_probability) < 2:
128 binning_probability = np.array([np.min(p) - 1, np.max(p) + 1])
129 hist_n, _ = np.histogramdd(np.c_[p, f],
130 bins=[binning_probability, binning_feature],
131 weights=w)
132 hist_inc = hist_n.sum(axis=1)
133 hist_inc /= hist_inc.sum(axis=0)
134 hist_n /= hist_n.sum(axis=0)
135 hist_n = hist_n.cumsum(axis=0)
136 hist_inc = hist_inc.cumsum(axis=0)
137 diff = (hist_n.T - hist_inc)**2
138 return np.sqrt(diff.sum() / (100 * 99))
139
140
141class Method:
142 """
143 Wrapper class providing an interface to the method stored under the given identifier.
144 It loads the Options, can apply the expert and train new ones using the current as a prototype.
145 This class is used by the basf_mva_evaluation tools
146 """
147
148 def __init__(self, identifier):
149 """
150 Load a method stored under the given identifier
151 @param identifier identifying the method
152 """
153 # Always avoid the top-level 'import ROOT'.
154 import ROOT # noqa
155 # Initialize all the available interfaces
156 ROOT.Belle2.MVA.AbstractInterface.initSupportedInterfaces()
157
158 self.identifier = identifier
159
160 self.weightfile = ROOT.Belle2.MVA.Weightfile.load(self.identifier)
161
162 self.general_options = basf2_mva.GeneralOptions()
163 self.general_options.load(self.weightfile.getXMLTree())
164
165 # This piece of code should be correct but leads to random segmentation faults
166 # inside python, llvm or pyroot, therefore we use the more dirty code below
167 # Ideas why this is happening:
168 # 1. Ownership of the unique_ptr returned by getOptions()
169 # 2. Some kind of object slicing, although pyroot identifies the correct type
170 # 3. Bug in pyroot
171 # interfaces = ROOT.Belle2.MVA.AbstractInterface.getSupportedInterfaces()
172 # self.interface = interfaces[self.general_options.m_method]
173 # self.specific_options = self.interface.getOptions()
174
175
177 if self.general_options.m_method == "FastBDT":
178 self.specific_options = basf2_mva.FastBDTOptions()
179 elif self.general_options.m_method == "TMVAClassification":
180 self.specific_options = basf2_mva.TMVAOptionsClassification()
181 elif self.general_options.m_method == "TMVARegression":
182 self.specific_options = basf2_mva.TMVAOptionsRegression()
183 elif self.general_options.m_method == "FANN":
184 self.specific_options = basf2_mva.FANNOptions()
185 elif self.general_options.m_method == "Python":
186 self.specific_options = basf2_mva.PythonOptions()
187 elif self.general_options.m_method == "PDF":
188 self.specific_options = basf2_mva.PDFOptions()
189 elif self.general_options.m_method == "Combination":
190 self.specific_options = basf2_mva.CombinationOptions()
191 elif self.general_options.m_method == "Reweighter":
192 self.specific_options = basf2_mva.ReweighterOptions()
193 elif self.general_options.m_method == "Trivial":
194 self.specific_options = basf2_mva.TrivialOptions()
195 elif self.general_options.m_method == "ONNX":
196 self.specific_options = basf2_mva.ONNXOptions()
197 else:
198 raise RuntimeError("Unknown method " + self.general_options.m_method)
199
200 self.specific_options.load(self.weightfile.getXMLTree())
201
202 variables = [str(v) for v in self.general_options.m_variables]
203 importances = self.weightfile.getFeatureImportance()
204
205
206 self.importances = {k: importances[k] for k in variables}
207
208 self.variables = list(sorted(variables, key=lambda v: self.importances.get(v, 0.0)))
209
210 self.root_variables = [ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(v) for v in self.variables]
211
212 self.root_importances = {k: importances[k] for k in self.root_variables}
213
214 self.description = str(basf2_mva.info(self.identifier))
215
216 self.spectators = [str(v) for v in self.general_options.m_spectators]
217
218 self.root_spectators = [ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(v) for v in self.spectators]
219
220 def train_teacher(self, datafiles, treename, general_options=None, specific_options=None):
221 """
222 Train a new method using this method as a prototype
223 @param datafiles the training datafiles
224 @param treename the name of the tree containing the training data
225 @param general_options general options given to basf2_mva.teacher
226 (if None the options of this method are used)
227 @param specific_options specific options given to basf2_mva.teacher
228 (if None the options of this method are used)
229 """
230 # Always avoid the top-level 'import ROOT'.
231 import ROOT # noqa
232 if isinstance(datafiles, str):
233 datafiles = [datafiles]
234 if general_options is None:
235 general_options = self.general_options
236 if specific_options is None:
237 specific_options = self.specific_options
238
239 with tempfile.TemporaryDirectory() as tempdir:
240 identifier = tempdir + "/weightfile.xml"
241
242 general_options.m_datafiles = basf2_mva.vector(*datafiles)
243 general_options.m_identifier = identifier
244
245 basf2_mva.teacher(general_options, specific_options)
246
247 method = Method(identifier)
248 return method
249
250 def apply_expert(self, datafiles, treename):
251 """
252 Apply the expert of the method to data and return the calculated probability and the target
253 @param datafiles the datafiles
254 @param treename the name of the tree containing the data
255 """
256 import ROOT # noqa
257 if isinstance(datafiles, str):
258 datafiles = [datafiles]
259 with tempfile.TemporaryDirectory() as tempdir:
260 identifier = tempdir + "/weightfile.xml"
261 ROOT.Belle2.MVA.Weightfile.save(self.weightfile, identifier)
262
263 rootfilename = tempdir + '/expert.root'
264 basf2_mva.expert(basf2_mva.vector(identifier),
265 basf2_mva.vector(*datafiles),
266 treename,
267 rootfilename)
268 chain = ROOT.TChain("variables")
269 chain.Add(rootfilename)
270
271 expert_target = identifier + '_' + self.general_options.m_target_variable
272 stripped_expert_target = self.identifier + '_' + self.general_options.m_target_variable
273
274 output_names = [self.identifier]
275 branch_names = [
276 ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(identifier),
277 ]
278 if self.general_options.m_nClasses > 2:
279 output_names = [self.identifier+f'_{i}' for i in range(self.general_options.m_nClasses)]
280 branch_names = [
281 ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(
282 identifier +
283 f'_{i}') for i in range(
284 self.general_options.m_nClasses)]
285
286 d = chain2dict(
287 chain,
288 [*branch_names, ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(expert_target)],
289 [*output_names, stripped_expert_target])
290
291 return (d[str(self.identifier)] if self.general_options.m_nClasses <= 2 else np.array([d[x]
292 for x in output_names]).T), d[stripped_expert_target]
dict importances
Dictionary of the variable importances calculated by the method.
specific_options
Specific options of the method.
description
Description of the method as a xml string returned by basf2_mva.info.
__init__(self, identifier)
dict root_importances
Dictionary of the variables sorted by their importance but with root compatoble variable names.
variables
List of variables sorted by their importance.
weightfile
Weightfile of the method.
list root_spectators
List of spectators with root compatible names.
list root_variables
List of the variable importances calculated by the method, but with the root compatible variable name...
list spectators
List of spectators.
train_teacher(self, datafiles, treename, general_options=None, specific_options=None)
general_options
General options of the method.
apply_expert(self, datafiles, treename)
identifier
Identifier of the method.