Belle II Software light-2505-deimos
basf2_mva_util.py
1
8
9import tempfile
10import numpy as np
11
12from basf2 import B2WARNING
13import basf2_mva
14
15
16def chain2dict(chain, tree_columns, dict_columns=None, max_entries=None):
17 """
18 Convert a ROOT.TChain into a dictionary of np.arrays
19 @param chain the ROOT.TChain
20 @param tree_columns the column (or branch) names in the tree
21 @param dict_columns the corresponding column names in the dictionary
22 """
23 if len(tree_columns) == 0:
24 return dict()
25 if dict_columns is None:
26 dict_columns = tree_columns
27 try:
28 from ROOT import RDataFrame
29 rdf = RDataFrame(chain)
30 if max_entries is not None:
31 nEntries = rdf.Count().GetValue()
32 if nEntries > max_entries:
33 B2WARNING(
34 "basf2_mva_util (chain2dict): Number of entries in the chain is larger than the maximum allowed entries: " +
35 str(nEntries) +
36 " > " +
37 str(max_entries))
38 skip = nEntries // max_entries
39 rdf_subset = rdf.Filter("rdfentry_ % " + str(skip) + " == 0")
40 rdf = rdf_subset
41
42 d = np.column_stack(list(rdf.AsNumpy(tree_columns).values()))
43 d = np.core.records.fromarrays(d.transpose(), names=dict_columns)
44 except ImportError:
45 d = {column: np.zeros((chain.GetEntries(),)) for column in dict_columns}
46 for iEvent, event in enumerate(chain):
47 for dict_column, tree_column in zip(dict_columns, tree_columns):
48 d[dict_column][iEvent] = getattr(event, tree_column)
49 return d
50
51
52def calculate_roc_auc(p, t):
53 """
54 Deprecated name of ``calculate_auc_efficiency_vs_purity``
55
56 @param p np.array filled with the probability output of a classifier
57 @param t np.array filled with the target (0 or 1)
58 """
59 B2WARNING(
60 "\033[93mcalculate_roc_auc\033[00m has been deprecated and will be removed in future.\n"
61 "This change has been made as calculate_roc_auc returned the area under the efficiency-purity curve\n"
62 "not the efficiency-background retention curve as expected by users.\n"
63 "Please replace calculate_roc_auc with:\n\n"
64 "\033[96mcalculate_auc_efficiency_vs_purity(probability, target[, weight])\033[00m:"
65 " the current definition of calculate_roc_auc\n"
66 "\033[96mcalculate_auc_efficiency_vs_background_retention(probability, target[, weight])\033[00m:"
67 " what is commonly known as roc auc\n")
68 return calculate_auc_efficiency_vs_purity(p, t)
69
70
71def calculate_auc_efficiency_vs_purity(p, t, w=None):
72 """
73 Calculates the area under the efficiency-purity curve
74 @param p np.array filled with the probability output of a classifier
75 @param t np.array filled with the target (0 or 1)
76 @param w None or np.array filled with weights
77 """
78 if w is None:
79 w = np.ones(t.shape)
80
81 wt = w * t
82
83 N = np.sum(w)
84 T = np.sum(wt)
85
86 index = np.argsort(p)
87 efficiency = (T - np.cumsum(wt[index])) / float(T)
88 purity = (T - np.cumsum(wt[index])) / (N - np.cumsum(w[index]))
89 purity = np.where(np.isnan(purity), 0, purity)
90 return np.abs(np.trapz(purity, efficiency))
91
92
93def calculate_auc_efficiency_vs_background_retention(p, t, w=None):
94 """
95 Calculates the area under the efficiency-background_retention curve (AUC ROC)
96 @param p np.array filled with the probability output of a classifier
97 @param t np.array filled with the target (0 or 1)
98 @param w None or np.array filled with weights
99 """
100 if w is None:
101 w = np.ones(t.shape)
102
103 wt = w * t
104
105 N = np.sum(w)
106 T = np.sum(wt)
107
108 index = np.argsort(p)
109 efficiency = (T - np.cumsum(wt[index])) / float(T)
110 background_retention = (N - T - np.cumsum((np.abs(1 - t) * w)[index])) / float(N - T)
111 return np.abs(np.trapz(efficiency, background_retention))
112
113
114def calculate_flatness(f, p, w=None):
115 """
116 Calculates the flatness of a feature under cuts on a signal probability
117 @param f the feature values
118 @param p the probability values
119 @param w optional weights
120 @return the mean standard deviation between the local and global cut selection efficiency
121 """
122 quantiles = list(range(101))
123 binning_feature = np.unique(np.percentile(f, q=quantiles))
124 binning_probability = np.unique(np.percentile(p, q=quantiles))
125 if len(binning_feature) < 2:
126 binning_feature = np.array([np.min(f) - 1, np.max(f) + 1])
127 if len(binning_probability) < 2:
128 binning_probability = np.array([np.min(p) - 1, np.max(p) + 1])
129 hist_n, _ = np.histogramdd(np.c_[p, f],
130 bins=[binning_probability, binning_feature],
131 weights=w)
132 hist_inc = hist_n.sum(axis=1)
133 hist_inc /= hist_inc.sum(axis=0)
134 hist_n /= hist_n.sum(axis=0)
135 hist_n = hist_n.cumsum(axis=0)
136 hist_inc = hist_inc.cumsum(axis=0)
137 diff = (hist_n.T - hist_inc)**2
138 return np.sqrt(diff.sum() / (100 * 99))
139
140
141class Method:
142 """
143 Wrapper class providing an interface to the method stored under the given identifier.
144 It loads the Options, can apply the expert and train new ones using the current as a prototype.
145 This class is used by the basf_mva_evaluation tools
146 """
147
148 def __init__(self, identifier):
149 """
150 Load a method stored under the given identifier
151 @param identifier identifying the method
152 """
153 # Always avoid the top-level 'import ROOT'.
154 import ROOT # noqa
155 # Initialize all the available interfaces
156 ROOT.Belle2.MVA.AbstractInterface.initSupportedInterfaces()
157
158 self.identifier = identifier
159
160 self.weightfile = ROOT.Belle2.MVA.Weightfile.load(self.identifier)
161
162 self.general_options = basf2_mva.GeneralOptions()
163 self.general_options.load(self.weightfile.getXMLTree())
164
165 # This piece of code should be correct but leads to random segmentation faults
166 # inside python, llvm or pyroot, therefore we use the more dirty code below
167 # Ideas why this is happening:
168 # 1. Ownership of the unique_ptr returned by getOptions()
169 # 2. Some kind of object slicing, although pyroot identifies the correct type
170 # 3. Bug in pyroot
171 # interfaces = ROOT.Belle2.MVA.AbstractInterface.getSupportedInterfaces()
172 # self.interface = interfaces[self.general_options.m_method]
173 # self.specific_options = self.interface.getOptions()
174
175
177 if self.general_options.m_method == "FastBDT":
178 self.specific_options = basf2_mva.FastBDTOptions()
179 elif self.general_options.m_method == "TMVAClassification":
180 self.specific_options = basf2_mva.TMVAOptionsClassification()
181 elif self.general_options.m_method == "TMVARegression":
182 self.specific_options = basf2_mva.TMVAOptionsRegression()
183 elif self.general_options.m_method == "FANN":
184 self.specific_options = basf2_mva.FANNOptions()
185 elif self.general_options.m_method == "Python":
186 self.specific_options = basf2_mva.PythonOptions()
187 elif self.general_options.m_method == "PDF":
188 self.specific_options = basf2_mva.PDFOptions()
189 elif self.general_options.m_method == "Combination":
190 self.specific_options = basf2_mva.CombinationOptions()
191 elif self.general_options.m_method == "Reweighter":
192 self.specific_options = basf2_mva.ReweighterOptions()
193 elif self.general_options.m_method == "Trivial":
194 self.specific_options = basf2_mva.TrivialOptions()
195 else:
196 raise RuntimeError("Unknown method " + self.general_options.m_method)
197
198 self.specific_options.load(self.weightfile.getXMLTree())
199
200 variables = [str(v) for v in self.general_options.m_variables]
201 importances = self.weightfile.getFeatureImportance()
202
203
204 self.importances = {k: importances[k] for k in variables}
205
206 self.variables = list(sorted(variables, key=lambda v: self.importances.get(v, 0.0)))
207
208 self.root_variables = [ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(v) for v in self.variables]
209
210 self.root_importances = {k: importances[k] for k in self.root_variables}
211
212 self.description = str(basf2_mva.info(self.identifier))
213
214 self.spectators = [str(v) for v in self.general_options.m_spectators]
215
216 self.root_spectators = [ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(v) for v in self.spectators]
217
218 def train_teacher(self, datafiles, treename, general_options=None, specific_options=None):
219 """
220 Train a new method using this method as a prototype
221 @param datafiles the training datafiles
222 @param treename the name of the tree containing the training data
223 @param general_options general options given to basf2_mva.teacher
224 (if None the options of this method are used)
225 @param specific_options specific options given to basf2_mva.teacher
226 (if None the options of this method are used)
227 """
228 # Always avoid the top-level 'import ROOT'.
229 import ROOT # noqa
230 if isinstance(datafiles, str):
231 datafiles = [datafiles]
232 if general_options is None:
233 general_options = self.general_options
234 if specific_options is None:
235 specific_options = self.specific_options
236
237 with tempfile.TemporaryDirectory() as tempdir:
238 identifier = tempdir + "/weightfile.xml"
239
240 general_options.m_datafiles = basf2_mva.vector(*datafiles)
241 general_options.m_identifier = identifier
242
243 basf2_mva.teacher(general_options, specific_options)
244
245 method = Method(identifier)
246 return method
247
248 def apply_expert(self, datafiles, treename):
249 """
250 Apply the expert of the method to data and return the calculated probability and the target
251 @param datafiles the datafiles
252 @param treename the name of the tree containing the data
253 """
254 import ROOT # noqa
255 if isinstance(datafiles, str):
256 datafiles = [datafiles]
257 with tempfile.TemporaryDirectory() as tempdir:
258 identifier = tempdir + "/weightfile.xml"
259 ROOT.Belle2.MVA.Weightfile.save(self.weightfile, identifier)
260
261 rootfilename = tempdir + '/expert.root'
262 basf2_mva.expert(basf2_mva.vector(identifier),
263 basf2_mva.vector(*datafiles),
264 treename,
265 rootfilename)
266 chain = ROOT.TChain("variables")
267 chain.Add(rootfilename)
268
269 expert_target = identifier + '_' + self.general_options.m_target_variable
270 stripped_expert_target = self.identifier + '_' + self.general_options.m_target_variable
271
272 output_names = [self.identifier]
273 branch_names = [
274 ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(identifier),
275 ]
276 if self.general_options.m_nClasses > 2:
277 output_names = [self.identifier+f'_{i}' for i in range(self.general_options.m_nClasses)]
278 branch_names = [
279 ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(
280 identifier +
281 f'_{i}') for i in range(
282 self.general_options.m_nClasses)]
283
284 d = chain2dict(
285 chain,
286 [*branch_names, ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(expert_target)],
287 [*output_names, stripped_expert_target])
288
289 return (d[str(self.identifier)] if self.general_options.m_nClasses <= 2 else np.array([d[x]
290 for x in output_names]).T), d[stripped_expert_target]
dict importances
Dictionary of the variable importances calculated by the method.
specific_options
Specific options of the method.
description
Description of the method as a xml string returned by basf2_mva.info.
__init__(self, identifier)
dict root_importances
Dictionary of the variables sorted by their importance but with root compatoble variable names.
variables
List of variables sorted by their importance.
weightfile
Weightfile of the method.
list root_spectators
List of spectators with root compatible names.
list root_variables
List of the variable importances calculated by the method, but with the root compatible variable name...
list spectators
List of spectators.
train_teacher(self, datafiles, treename, general_options=None, specific_options=None)
general_options
General options of the method.
apply_expert(self, datafiles, treename)
identifier
Identifier of the method.