71 def train(self, input_tree):
72 """Main method feed with a TTree containing the truth variable and the variables to be investigated.
73
74 Branches that contain "truth" in the name are considered to directly contain information about
75 true classification target and are not analysed here.
76
77 Args:
78 input_tree (ROOT.TTree) : Tree containing the variables to be investigated
79 as well as the classification target.
80 truth_name (str, optional) : Name of the branch of the classification target.
81 If not given the Branch with the shortest name containing "truth" is selected.
82 """
83
84 if isinstance(self.output_file_name, str):
85 output_file = ROOT.TFile(self.output_file_name, "RECREATE")
86 else:
87 output_file = self.output_file_name
88
89 output_file.cd()
90
91 column_names = [leave.GetName() for leave in input_tree.GetListOfLeaves()]
92
93 tree_name = input_tree.GetName()
94
95 truth_name = self.truth_name
96
97 if truth_name is None:
98 truth_name = get_truth_name(column_names)
99
100 if truth_name not in column_names:
101 raise KeyError(f"Truth column {truth_name} not in tree {tree_name}")
102 variable_names = [name for name in column_names if name != truth_name]
103
104 exclude = self.exclude
105 select = self.select
106 groupbys = self.groupbys
107 auxiliaries = self.auxiliaries
108 filters = self.filters
109
110 if select:
111 variable_names = [name for name in variable_names if name in select]
112
113 if exclude:
114 variable_names = [name for name in variable_names if name not in exclude]
115
116 if filters:
117 variable_names = [name for name in variable_names if name not in filters]
118
119
120 variable_names = [name for name
121 in variable_names
122 if "truth" not in name or name in select]
123
124 print("Truth name", truth_name)
125 print("Variable names", variable_names)
126
127 print("Loading tree")
128 branch_names = {*variable_names, truth_name, *groupbys, *auxiliaries, *filters}
129 branch_names = [name for name in branch_names if name]
130 if isinstance(self.output_file_name, str):
131 rdf = ROOT.RDataFrame(input_tree, self.output_file_name)
132 else:
133 rdf = ROOT.RDataFrame(input_tree, self.output_file_name.GetName())
134 input_array = np.column_stack(list(rdf.AsNumpy(branch_names).values()))
135 input_record_array = input_array.view(np.recarray)
136
137 if filters:
138 for filter in filters:
139 filter_values = input_record_array[filter]
140 input_record_array = input_record_array[np.nonzero(filter_values)]
141
142 print("Loaded tree")
143 truths = input_record_array[truth_name]
144
145 if not groupbys:
146 groupbys = [None]
147
148 for groupby in groupbys:
149 if groupby is None or groupby == "":
150 groupby_parts = [(None, slice(None))]
151 else:
152 groupby_parts = []
153 groupby_values = input_record_array[groupby]
154 unique_values, indices = np.unique(groupby_values, return_inverse=True)
155 for idx, value in enumerate(unique_values):
156 groupby_parts.append((value, indices == idx))
157
158 for groupby_value, groupby_select in groupby_parts:
159 if groupby is None:
160 groupby_folder_name = '.'
161 else:
162 groupby_folder_name = f"groupby_{groupby}_{groupby_value}"
163
164 with root_cd(groupby_folder_name) as tdirectory:
165 for variable_name in variable_names:
166 print('Analyse', variable_name, 'groupby', groupby, '=', groupby_value)
167
168 if variable_name == groupby:
169 continue
170
171
172 estimates = input_record_array[variable_name]
173 estimates[estimates == np.finfo(np.float32).max] = float("nan")
174 estimates[estimates == -np.finfo(np.float32).max] = -float("inf")
175 auxiliaries = {name: input_record_array[name][groupby_select] for name in self.auxiliaries}
176
177 classification_analysis = classification.ClassificationAnalysis(
178 contact="",
179 quantity_name=variable_name,
180 outlier_z_score=5.0,
181 allow_discrete=True,
182 )
183 classification_analysis.analyse(
184 estimates[groupby_select],
185 truths[groupby_select],
186 auxiliaries=auxiliaries
187 )
188
189 with root_cd(variable_name) as tdirectory:
190 classification_analysis.write(tdirectory)
191
192 self.classification_analyses.append(classification_analysis)
193
194 if isinstance(self.output_file_name, str):
195 output_file.Close()
196
197 print("Saved overviews completely")