12from basf2
import B2WARNING
16def chain2dict(chain, tree_columns, dict_columns=None, max_entries=None):
18 Convert a ROOT.TChain into a dictionary of np.arrays
19 @param chain the ROOT.TChain
20 @param tree_columns the column (or branch) names in the tree
21 @param dict_columns the corresponding column names in the dictionary
23 if len(tree_columns) == 0:
25 if dict_columns
is None:
26 dict_columns = tree_columns
28 from ROOT
import RDataFrame
29 rdf = RDataFrame(chain)
30 if max_entries
is not None:
31 nEntries = rdf.Count().GetValue()
32 if nEntries > max_entries:
34 "basf2_mva_util (chain2dict): Number of entries in the chain is larger than the maximum allowed entries: " +
38 skip = nEntries // max_entries
39 rdf_subset = rdf.Filter(
"rdfentry_ % " + str(skip) +
" == 0")
42 d = np.column_stack(list(rdf.AsNumpy(tree_columns).values()))
43 d = np.core.records.fromarrays(d.transpose(), names=dict_columns)
45 d = {column: np.zeros((chain.GetEntries(),))
for column
in dict_columns}
46 for iEvent, event
in enumerate(chain):
47 for dict_column, tree_column
in zip(dict_columns, tree_columns):
48 d[dict_column][iEvent] = getattr(event, tree_column)
52def calculate_roc_auc(p, t):
54 Deprecated name of ``calculate_auc_efficiency_vs_purity``
56 @param p np.array filled with the probability output of a classifier
57 @param t np.array filled with the target (0 or 1)
60 "\033[93mcalculate_roc_auc\033[00m has been deprecated and will be removed in future.\n"
61 "This change has been made as calculate_roc_auc returned the area under the efficiency-purity curve\n"
62 "not the efficiency-background retention curve as expected by users.\n"
63 "Please replace calculate_roc_auc with:\n\n"
64 "\033[96mcalculate_auc_efficiency_vs_purity(probability, target[, weight])\033[00m:"
65 " the current definition of calculate_roc_auc\n"
66 "\033[96mcalculate_auc_efficiency_vs_background_retention(probability, target[, weight])\033[00m:"
67 " what is commonly known as roc auc\n")
68 return calculate_auc_efficiency_vs_purity(p, t)
71def calculate_auc_efficiency_vs_purity(p, t, w=None):
73 Calculates the area under the efficiency-purity curve
74 @param p np.array filled with the probability output of a classifier
75 @param t np.array filled with the target (0 or 1)
76 @param w None or np.array filled with weights
87 efficiency = (T - np.cumsum(wt[index])) / float(T)
88 purity = (T - np.cumsum(wt[index])) / (N - np.cumsum(w[index]))
89 purity = np.where(np.isnan(purity), 0, purity)
90 return np.abs(np.trapz(purity, efficiency))
93def calculate_auc_efficiency_vs_background_retention(p, t, w=None):
95 Calculates the area under the efficiency-background_retention curve (AUC ROC)
96 @param p np.array filled with the probability output of a classifier
97 @param t np.array filled with the target (0 or 1)
98 @param w None or np.array filled with weights
108 index = np.argsort(p)
109 efficiency = (T - np.cumsum(wt[index])) / float(T)
110 background_retention = (N - T - np.cumsum((np.abs(1 - t) * w)[index])) / float(N - T)
111 return np.abs(np.trapz(efficiency, background_retention))
114def calculate_flatness(f, p, w=None):
116 Calculates the flatness of a feature under cuts on a signal probability
117 @param f the feature values
118 @param p the probability values
119 @param w optional weights
120 @return the mean standard deviation between the local and global cut selection efficiency
122 quantiles = list(range(101))
123 binning_feature = np.unique(np.percentile(f, q=quantiles))
124 binning_probability = np.unique(np.percentile(p, q=quantiles))
125 if len(binning_feature) < 2:
126 binning_feature = np.array([np.min(f) - 1, np.max(f) + 1])
127 if len(binning_probability) < 2:
128 binning_probability = np.array([np.min(p) - 1, np.max(p) + 1])
129 hist_n, _ = np.histogramdd(np.c_[p, f],
130 bins=[binning_probability, binning_feature],
132 hist_inc = hist_n.sum(axis=1)
133 hist_inc /= hist_inc.sum(axis=0)
134 hist_n /= hist_n.sum(axis=0)
135 hist_n = hist_n.cumsum(axis=0)
136 hist_inc = hist_inc.cumsum(axis=0)
137 diff = (hist_n.T - hist_inc)**2
138 return np.sqrt(diff.sum() / (100 * 99))
143 Wrapper class providing an interface to the method stored under the given identifier.
144 It loads the Options, can apply the expert and train new ones using the current as a prototype.
145 This class is used by the basf_mva_evaluation tools
150 Load a method stored under the given identifier
151 @param identifier identifying the method
156 ROOT.Belle2.MVA.AbstractInterface.initSupportedInterfaces()
203 importances = self.
weightfile.getFeatureImportance()
220 def train_teacher(self, datafiles, treename, general_options=None, specific_options=None):
222 Train a new method using this method as a prototype
223 @param datafiles the training datafiles
224 @param treename the name of the tree containing the training data
225 @param general_options general options given to basf2_mva.teacher
226 (if None the options of this method are used)
227 @param specific_options specific options given to basf2_mva.teacher
228 (if None the options of this method are used)
232 if isinstance(datafiles, str):
233 datafiles = [datafiles]
234 if general_options
is None:
236 if specific_options
is None:
239 with tempfile.TemporaryDirectory()
as tempdir:
240 identifier = tempdir +
"/weightfile.xml"
242 general_options.m_datafiles = basf2_mva.vector(*datafiles)
243 general_options.m_identifier = identifier
245 basf2_mva.teacher(general_options, specific_options)
247 method =
Method(identifier)
252 Apply the expert of the method to data and return the calculated probability and the target
253 @param datafiles the datafiles
254 @param treename the name of the tree containing the data
257 if isinstance(datafiles, str):
258 datafiles = [datafiles]
259 with tempfile.TemporaryDirectory()
as tempdir:
260 identifier = tempdir +
"/weightfile.xml"
261 ROOT.Belle2.MVA.Weightfile.save(self.
weightfile, identifier)
263 rootfilename = tempdir +
'/expert.root'
264 basf2_mva.expert(basf2_mva.vector(identifier),
265 basf2_mva.vector(*datafiles),
268 chain = ROOT.TChain(
"variables")
269 chain.Add(rootfilename)
271 expert_target = identifier +
'_' + self.
general_options.m_target_variable
276 ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(identifier),
281 ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(
283 f
'_{i}')
for i
in range(
288 [*branch_names, ROOT.Belle2.MakeROOTCompatible.makeROOTCompatible(expert_target)],
289 [*output_names, stripped_expert_target])
292 for x
in output_names]).T), d[stripped_expert_target]
dict importances
Dictionary of the variable importances calculated by the method.
specific_options
Specific options of the method.
description
Description of the method as a xml string returned by basf2_mva.info.
__init__(self, identifier)
dict root_importances
Dictionary of the variables sorted by their importance but with root compatoble variable names.
variables
List of variables sorted by their importance.
weightfile
Weightfile of the method.
list root_spectators
List of spectators with root compatible names.
list root_variables
List of the variable importances calculated by the method, but with the root compatible variable name...
list spectators
List of spectators.
train_teacher(self, datafiles, treename, general_options=None, specific_options=None)
general_options
General options of the method.
apply_expert(self, datafiles, treename)
identifier
Identifier of the method.