20 return logging.getLogger(__name__)
23 def get_truth_name(variable_names):
24 """Selects the shortest variable name form containing the 'truth'."""
25 truth_names = [name
for name
in variable_names
if "truth" in name]
29 truth_name = min(truth_names, key=len)
31 raise ValueError(
"variable_names='%s' does not contain a truth variable" % variable_names)
37 """Class to generate overview plots for the classification power of various variables from a TTree.
39 In order to get an overview, which variables from a diverse set, generated from a recording filter
40 or some other sort of validation, perform well in classification task.
72 """Main method feed with a TTree containing the truth variable and the variables to be investigated.
74 Branches that contain "truth" in the name are considered to directly contain information about
75 true classification target and are not analysed here.
78 input_tree (ROOT.TTree) : Tree containing the variables to be investigated
79 as well as the classification target.
80 truth_name (str, optional) : Name of the branch of the classification target.
81 If not given the Branch with the shortest name containing "truth" is selected.
91 column_names = [leave.GetName()
for leave
in input_tree.GetListOfLeaves()]
93 tree_name = input_tree.GetName()
97 if truth_name
is None:
98 truth_name = get_truth_name(column_names)
100 if truth_name
not in column_names:
101 raise KeyError(
"Truth column {truth} not in tree {tree}".format(truth=truth_name,
103 variable_names = [name
for name
in column_names
if name != truth_name]
106 select = self.
selectselect
112 variable_names = [name
for name
in variable_names
if name
in select]
115 variable_names = [name
for name
in variable_names
if name
not in exclude]
118 variable_names = [name
for name
in variable_names
if name
not in filters]
121 variable_names = [name
for name
123 if "truth" not in name
or name
in select]
125 print(
"Truth name", truth_name)
126 print(
"Variable names", variable_names)
128 print(
"Loading tree")
129 branch_names = {*variable_names, truth_name, *groupbys, *auxiliaries, *filters}
130 branch_names = [name
for name
in branch_names
if name]
134 rdf = ROOT.RDataFrame(input_tree, self.
output_file_nameoutput_file_name.GetName())
135 input_array = np.column_stack(list(rdf.AsNumpy(branch_names).values()))
136 input_record_array = input_array.view(np.recarray)
139 for filter
in filters:
140 filter_values = input_record_array[filter]
141 input_record_array = input_record_array[np.nonzero(filter_values)]
144 truths = input_record_array[truth_name]
149 for groupby
in groupbys:
150 if groupby
is None or groupby ==
"":
151 groupby_parts = [(
None,
slice(
None))]
154 groupby_values = input_record_array[groupby]
155 unique_values, indices = np.unique(groupby_values, return_inverse=
True)
156 for idx, value
in enumerate(unique_values):
157 groupby_parts.append((value, indices == idx))
159 for groupby_value, groupby_select
in groupby_parts:
161 groupby_folder_name =
'.'
163 groupby_folder_name =
"groupby_{name}_{value}".format(name=groupby, value=groupby_value)
165 with root_cd(groupby_folder_name)
as tdirectory:
166 for variable_name
in variable_names:
167 print(
'Analyse', variable_name,
'groupby', groupby,
'=', groupby_value)
169 if variable_name == groupby:
173 estimates = input_record_array[variable_name]
174 estimates[estimates == np.finfo(np.float32).max] = float(
"nan")
175 estimates[estimates == -np.finfo(np.float32).max] = -float(
"inf")
176 auxiliaries = {name: input_record_array[name][groupby_select]
for name
in self.
auxiliariesauxiliaries}
178 classification_analysis = classification.ClassificationAnalysis(
180 quantity_name=variable_name,
184 classification_analysis.analyse(
185 estimates[groupby_select],
186 truths[groupby_select],
187 auxiliaries=auxiliaries
190 with root_cd(variable_name)
as tdirectory:
191 classification_analysis.write(tdirectory)
198 print(
"Saved overviews completely")
truth_name
cached truth name
def train(self, input_tree)
classification_analyses
array of classification analyses
groupbys
cached groupby-specifier array
def __init__(self, output_file_name, truth_name=None, select=[], exclude=[], groupbys=[], auxiliaries=[], filters=[])
output_file_name
cached output filename
filters
cached filter-specifier array
exclude
cached exclusion-specifier array
select
cached selection-specifier array
auxiliaries
cached auxiliary-specifier array
std::vector< Atom > slice(std::vector< Atom > vec, int s, int e)
Slice the vector to contain only elements with indexes s .. e (included)