71 def train(self, input_tree):
72 """Main method feed with a TTree containing the truth variable and the variables to be investigated.
74 Branches that contain "truth" in the name are considered to directly contain information about
75 true classification target and are not analysed here.
78 input_tree (ROOT.TTree) : Tree containing the variables to be investigated
79 as well as the classification target.
80 truth_name (str, optional) : Name of the branch of the classification target.
81 If not given the Branch with the shortest name containing "truth" is selected.
84 if isinstance(self.output_file_name, str):
85 output_file = ROOT.TFile(self.output_file_name, "RECREATE")
87 output_file = self.output_file_name
91 column_names = [leave.GetName() for leave in input_tree.GetListOfLeaves()]
93 tree_name = input_tree.GetName()
95 truth_name = self.truth_name
97 if truth_name is None:
98 truth_name = get_truth_name(column_names)
100 if truth_name not in column_names:
101 raise KeyError(f"Truth column {truth_name} not in tree {tree_name}")
102 variable_names = [name for name in column_names if name != truth_name]
104 exclude = self.exclude
106 groupbys = self.groupbys
107 auxiliaries = self.auxiliaries
108 filters = self.filters
111 variable_names = [name for name in variable_names if name in select]
114 variable_names = [name for name in variable_names if name not in exclude]
117 variable_names = [name for name in variable_names if name not in filters]
119 # Remove the variables that have Monte Carlo truth information unless explicitly selected
120 variable_names = [name for name
122 if "truth" not in name or name in select]
124 print("Truth name", truth_name)
125 print("Variable names", variable_names)
127 print("Loading tree")
128 branch_names = {*variable_names, truth_name, *groupbys, *auxiliaries, *filters}
129 branch_names = [name for name in branch_names if name]
130 if isinstance(self.output_file_name, str):
131 rdf = ROOT.RDataFrame(input_tree, self.output_file_name)
133 rdf = ROOT.RDataFrame(input_tree, self.output_file_name.GetName())
134 input_array = np.column_stack(list(rdf.AsNumpy(branch_names).values()))
135 input_record_array = input_array.view(np.recarray)
138 for filter in filters:
139 filter_values = input_record_array[filter]
140 input_record_array = input_record_array[np.nonzero(filter_values)]
143 truths = input_record_array[truth_name]
148 for groupby in groupbys:
149 if groupby is None or groupby == "":
150 groupby_parts = [(None, slice(None))]
153 groupby_values = input_record_array[groupby]
154 unique_values, indices = np.unique(groupby_values, return_inverse=True)
155 for idx, value in enumerate(unique_values):
156 groupby_parts.append((value, indices == idx))
158 for groupby_value, groupby_select in groupby_parts:
160 groupby_folder_name = '.'
162 groupby_folder_name = f"groupby_{groupby}_{groupby_value}"
164 with root_cd(groupby_folder_name) as tdirectory:
165 for variable_name in variable_names:
166 print('Analyse', variable_name, 'groupby', groupby, '=', groupby_value)
168 if variable_name == groupby:
171 # Get the truths as a numpy array
172 estimates = input_record_array[variable_name]
173 estimates[estimates == np.finfo(np.float32).max] = float("nan")
174 estimates[estimates == -np.finfo(np.float32).max] = -float("inf")
175 auxiliaries = {name: input_record_array[name][groupby_select] for name in self.auxiliaries}
177 classification_analysis = classification.ClassificationAnalysis(
179 quantity_name=variable_name,
183 classification_analysis.analyse(
184 estimates[groupby_select],
185 truths[groupby_select],
186 auxiliaries=auxiliaries
189 with root_cd(variable_name) as tdirectory:
190 classification_analysis.write(tdirectory)
192 self.classification_analyses.append(classification_analysis)
194 if isinstance(self.output_file_name, str):
197 print("Saved overviews completely")