13 return logging.getLogger(__name__)
16 def get_truth_name(variable_names):
17 """Selects the shortest variable name form containing the 'truth'."""
18 truth_names = [name
for name
in variable_names
if "truth" in name]
22 truth_name = min(truth_names, key=len)
24 raise ValueError(
"variable_names='%s' does not contain a truth variable" % variable_names)
30 """Class to generate overview plots for the classification power of various variables from a TTree.
32 In order to get an overview, which variables from a diverse set, generated from a recording filter
33 or some other sort of validation, perform well in classification task.
65 """Main method feed with a TTree containing the truth variable and the variables to be investigated.
67 Branches that contain "truth" in the name are considered to directly contain information about
68 true classification target and are not analysed here.
71 input_tree (ROOT.TTree) : Tree containing the variables to be investigated
72 as well as the classification target.
73 truth_name (str, optional) : Name of the branch of the classification target.
74 If not given the Branch with the shortest name containing "truth" is selected.
84 column_names = [leave.GetName()
for leave
in input_tree.GetListOfLeaves()]
86 tree_name = input_tree.GetName()
90 if truth_name
is None:
91 truth_name = get_truth_name(column_names)
93 if truth_name
not in column_names:
94 raise KeyError(
"Truth column {truth} not in tree {tree}".format(truth=truth_name,
96 variable_names = [name
for name
in column_names
if name != truth_name]
105 variable_names = [name
for name
in variable_names
if name
in select]
108 variable_names = [name
for name
in variable_names
if name
not in exclude]
111 variable_names = [name
for name
in variable_names
if name
not in filters]
114 variable_names = [name
for name
116 if "truth" not in name
or name
in select]
118 print(
"Truth name", truth_name)
119 print(
"Variable names", variable_names)
122 print(
"Loading tree")
123 branch_names = {*variable_names, truth_name, *groupbys, *auxiliaries, *filters}
124 branch_names = [name
for name
in branch_names
if name]
125 input_array = root_numpy.tree2array(input_tree, branches=branch_names)
126 input_record_array = input_array.view(np.recarray)
129 for filter
in filters:
130 filter_values = input_record_array[filter]
131 input_record_array = input_record_array[np.nonzero(filter_values)]
134 truths = input_record_array[truth_name]
139 for groupby
in groupbys:
140 if groupby
is None or groupby ==
"":
141 groupby_parts = [(
None,
slice(
None))]
144 groupby_values = input_record_array[groupby]
145 unique_values, indices = np.unique(groupby_values, return_inverse=
True)
146 for idx, value
in enumerate(unique_values):
147 groupby_parts.append((value, indices == idx))
149 for groupby_value, groupby_select
in groupby_parts:
151 groupby_folder_name =
'.'
153 groupby_folder_name =
"groupby_{name}_{value}".format(name=groupby, value=groupby_value)
155 with root_cd(groupby_folder_name)
as tdirectory:
156 for variable_name
in variable_names:
157 print(
'Analyse', variable_name,
'groupby', groupby,
'=', groupby_value)
159 if variable_name == groupby:
163 estimates = input_record_array[variable_name]
164 estimates[estimates == np.finfo(np.float32).max] = float(
"nan")
165 estimates[estimates == -np.finfo(np.float32).max] = -float(
"inf")
166 auxiliaries = {name: input_record_array[name][groupby_select]
for name
in self.
auxiliaries}
168 classification_analysis = classification.ClassificationAnalysis(
170 quantity_name=variable_name,
174 classification_analysis.analyse(
175 estimates[groupby_select],
176 truths[groupby_select],
177 auxiliaries=auxiliaries
180 with root_cd(variable_name)
as tdirectory:
181 classification_analysis.write(tdirectory)
188 print(
"Saved overviews completely")