19 from ROOT
import Belle2
22 def tree2dict(tree, tree_columns, dict_columns=None):
24 Convert a ROOT.TTree into a dictionary of np.arrays
25 @param tree the ROOT.TTree
26 @param tree_columns the column (or branch) names in the tree
27 @param dict_columns the corresponding column names in the dictionary
29 if len(tree_columns) == 0:
31 if dict_columns
is None:
32 dict_columns = tree_columns
35 d = root_numpy.tree2array(tree, branches=tree_columns)
36 d.dtype.names = dict_columns
38 d = {column: np.zeros((tree.GetEntries(),))
for column
in dict_columns}
39 for iEvent, event
in enumerate(tree):
40 for dict_column, tree_column
in zip(dict_columns, tree_columns):
41 d[dict_column][iEvent] = getattr(event, tree_column)
45 def calculate_roc_auc(p, t):
47 Calculates the area under the receiver oeprating characteristic curve (AUC ROC)
48 @param p np.array filled with the probability output of a classifier
49 @param t np.array filled with the target (0 or 1)
54 efficiency = (T - np.cumsum(t[index])) / float(T)
55 purity = (T - np.cumsum(t[index])) / (N - np.cumsum(np.ones(N)))
56 purity = np.where(np.isnan(purity), 0, purity)
57 return np.abs(np.trapz(purity, efficiency))
60 def calculate_flatness(f, p, w=None):
62 Calculates the flatness of a feature under cuts on a signal probability
63 @param f the feature values
64 @param p the probability values
65 @param w optional weights
66 @return the mean standard deviation between the local and global cut selection efficiency
68 quantiles = list(range(101))
69 binning_feature = np.unique(np.percentile(f, q=quantiles))
70 binning_probability = np.unique(np.percentile(p, q=quantiles))
71 if len(binning_feature) < 2:
72 binning_feature = np.array([np.min(f) - 1, np.max(f) + 1])
73 if len(binning_probability) < 2:
74 binning_probability = np.array([np.min(p) - 1, np.max(p) + 1])
75 hist_n, _ = np.histogramdd(np.c_[p, f],
76 bins=[binning_probability, binning_feature],
78 hist_inc = hist_n.sum(axis=1)
79 hist_inc /= hist_inc.sum(axis=0)
80 hist_n /= hist_n.sum(axis=0)
81 hist_n = hist_n.cumsum(axis=0)
82 hist_inc = hist_inc.cumsum(axis=0)
83 diff = (hist_n.T - hist_inc)**2
84 return np.sqrt(diff.sum() / (100 * 99))
89 Wrapper class providing an interface to the method stored under the given identifier.
90 It loads the Options, can apply the expert and train new ones using the current as a prototype.
91 This class is used by the basf_mva_evaluation tools
96 Load a method stored under the given identifier
97 @param identifier identifying the method
121 elif self.
general_optionsgeneral_options.m_method ==
"TMVAClassification":
122 self.
specific_optionsspecific_options = basf2_mva.TMVAOptionsClassification()
138 raise RuntimeError(
"Unknown method " + self.
general_optionsgeneral_options.m_method)
142 variables = [str(v)
for v
in self.
general_optionsgeneral_options.m_variables]
143 importances = self.
weightfileweightfile.getFeatureImportance()
146 self.
importancesimportances = {k: importances[k]
for k
in variables}
160 def train_teacher(self, datafiles, treename, general_options=None, specific_options=None):
162 Train a new method using this method as a prototype
163 @param datafiles the training datafiles
164 @param treename the name of the tree containing the training data
165 @param general_options general options given to basf2_mva.teacher (if None the options of this method are used)
166 @param specific_options specific options given to basf2_mva.teacher (if None the options of this method are used)
168 if isinstance(datafiles, str):
169 datafiles = [datafiles]
170 if general_options
is None:
172 if specific_options
is None:
175 with tempfile.TemporaryDirectory()
as tempdir:
176 identifier = tempdir +
"/weightfile.xml"
178 general_options.m_datafiles = basf2_mva.vector(*datafiles)
179 general_options.m_identifier = identifier
181 basf2_mva.teacher(general_options, specific_options)
183 method =
Method(identifier)
188 Apply the expert of the method to data and return the calculated probability and the target
189 @param datafiles the datafiles
190 @param treename the name of the tree containing the data
192 if isinstance(datafiles, str):
193 datafiles = [datafiles]
194 with tempfile.TemporaryDirectory()
as tempdir:
195 identifier = tempdir +
"/weightfile.xml"
196 basf2_mva.Weightfile.save(self.
weightfileweightfile, identifier)
198 rootfilename = tempdir +
'/expert.root'
199 basf2_mva.expert(basf2_mva.vector(identifier),
200 basf2_mva.vector(*datafiles),
203 rootfile = ROOT.TFile(rootfilename,
"UPDATE")
204 roottree = rootfile.Get(
"variables")
206 expert_target = identifier +
'_' + self.
general_optionsgeneral_options.m_target_variable
208 d = tree2dict(roottree,
210 [self.
identifieridentifier, stripped_expert_target])
211 return d[self.
identifieridentifier], d[stripped_expert_target]
specific_options
Specific options of the method.
def apply_expert(self, datafiles, treename)
description
Description of the method as a xml string returned by basf2_mva.info.
importances
Dictionary of the variable importances calculated by the method.
root_importances
Dictionary of the variables sorted by their importance but with root compatoble variable names.
variables
List of variables sorted by their importance.
def __init__(self, identifier)
weightfile
Weightfile of the method.
root_spectators
List of spectators with root compatible names.
spectators
List of spectators.
def train_teacher(self, datafiles, treename, general_options=None, specific_options=None)
root_variables
List of the variable importances calculated by the method, but with the root compatible variable name...
general_options
General options of the method.
identifier
Identifier of the method.
std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.