20 return logging.getLogger(__name__)
23 def get_truth_name(variable_names):
24 """Selects the shortest variable name form containing the 'truth'."""
25 truth_names = [name
for name
in variable_names
if "truth" in name]
29 truth_name = min(truth_names, key=len)
31 raise ValueError(
"variable_names='%s' does not contain a truth variable" % variable_names)
37 """Class to generate overview plots for the classification power of various variables from a TTree.
39 In order to get an overview, which variables from a diverse set, generated from a recording filter
40 or some other sort of validation, perform well in classification task.
72 """Main method feed with a TTree containing the truth variable and the variables to be investigated.
74 Branches that contain "truth" in the name are considered to directly contain information about
75 true classification target and are not analysed here.
78 input_tree (ROOT.TTree) : Tree containing the variables to be investigated
79 as well as the classification target.
80 truth_name (str, optional) : Name of the branch of the classification target.
81 If not given the Branch with the shortest name containing "truth" is selected.
91 column_names = [leave.GetName()
for leave
in input_tree.GetListOfLeaves()]
93 tree_name = input_tree.GetName()
97 if truth_name
is None:
98 truth_name = get_truth_name(column_names)
100 if truth_name
not in column_names:
101 raise KeyError(
"Truth column {truth} not in tree {tree}".format(truth=truth_name,
103 variable_names = [name
for name
in column_names
if name != truth_name]
106 select = self.
selectselect
112 variable_names = [name
for name
in variable_names
if name
in select]
115 variable_names = [name
for name
in variable_names
if name
not in exclude]
118 variable_names = [name
for name
in variable_names
if name
not in filters]
121 variable_names = [name
for name
123 if "truth" not in name
or name
in select]
125 print(
"Truth name", truth_name)
126 print(
"Variable names", variable_names)
129 print(
"Loading tree")
130 branch_names = {*variable_names, truth_name, *groupbys, *auxiliaries, *filters}
131 branch_names = [name
for name
in branch_names
if name]
132 input_array = root_numpy.tree2array(input_tree, branches=branch_names)
133 input_record_array = input_array.view(np.recarray)
136 for filter
in filters:
137 filter_values = input_record_array[filter]
138 input_record_array = input_record_array[np.nonzero(filter_values)]
141 truths = input_record_array[truth_name]
146 for groupby
in groupbys:
147 if groupby
is None or groupby ==
"":
148 groupby_parts = [(
None,
slice(
None))]
151 groupby_values = input_record_array[groupby]
152 unique_values, indices = np.unique(groupby_values, return_inverse=
True)
153 for idx, value
in enumerate(unique_values):
154 groupby_parts.append((value, indices == idx))
156 for groupby_value, groupby_select
in groupby_parts:
158 groupby_folder_name =
'.'
160 groupby_folder_name =
"groupby_{name}_{value}".format(name=groupby, value=groupby_value)
162 with root_cd(groupby_folder_name)
as tdirectory:
163 for variable_name
in variable_names:
164 print(
'Analyse', variable_name,
'groupby', groupby,
'=', groupby_value)
166 if variable_name == groupby:
170 estimates = input_record_array[variable_name]
171 estimates[estimates == np.finfo(np.float32).max] = float(
"nan")
172 estimates[estimates == -np.finfo(np.float32).max] = -float(
"inf")
173 auxiliaries = {name: input_record_array[name][groupby_select]
for name
in self.
auxiliariesauxiliaries}
175 classification_analysis = classification.ClassificationAnalysis(
177 quantity_name=variable_name,
181 classification_analysis.analyse(
182 estimates[groupby_select],
183 truths[groupby_select],
184 auxiliaries=auxiliaries
187 with root_cd(variable_name)
as tdirectory:
188 classification_analysis.write(tdirectory)
195 print(
"Saved overviews completely")
truth_name
cached truth name
def train(self, input_tree)
classification_analyses
array of classification analyses
groupbys
cached groupby-specifier array
def __init__(self, output_file_name, truth_name=None, select=[], exclude=[], groupbys=[], auxiliaries=[], filters=[])
output_file_name
cached output filename
filters
cached filter-specifier array
exclude
cached exclusion-specifier array
select
cached selection-specifier array
auxiliaries
cached auxiliary-specifier array
std::vector< Atom > slice(std::vector< Atom > vec, int s, int e)
Slice the vector to contain only elements with indexes s .. e (included)