Belle II Software development
train.py
1
8
9import ROOT
10
11import numpy as np
12
13import tracking.validation.classification as classification
14from tracking.root_utils import root_cd
15
16import logging
17
18
19def get_logger():
20 return logging.getLogger(__name__)
21
22
23def get_truth_name(variable_names):
24 """Selects the shortest variable name form containing the 'truth'."""
25 truth_names = [name for name in variable_names if "truth" in name]
26
27 # select the shortest
28 try:
29 truth_name = min(truth_names, key=len)
30 except ValueError:
31 raise ValueError(f"variable_names='{variable_names}' does not contain a truth variable")
32 else:
33 return truth_name
34
35
37 """Class to generate overview plots for the classification power of various variables from a TTree.
38
39 In order to get an overview, which variables from a diverse set, generated from a recording filter
40 or some other sort of validation, perform well in classification task.
41
42 """
43
44 def __init__(self,
45 output_file_name,
46 truth_name=None,
47 select=[],
48 exclude=[],
49 groupbys=[],
50 auxiliaries=[],
51 filters=[]):
52 """Constructor"""
53
54 self.output_file_name = output_file_name
55
56 self.truth_name = truth_name
57
58 self.select = select
59
60 self.exclude = exclude
61
62 self.groupbys = groupbys
63
64 self.auxiliaries = auxiliaries
65
66 self.filters = filters
67
68
70
71 def train(self, input_tree):
72 """Main method feed with a TTree containing the truth variable and the variables to be investigated.
73
74 Branches that contain "truth" in the name are considered to directly contain information about
75 true classification target and are not analysed here.
76
77 Args:
78 input_tree (ROOT.TTree) : Tree containing the variables to be investigated
79 as well as the classification target.
80 truth_name (str, optional) : Name of the branch of the classification target.
81 If not given the Branch with the shortest name containing "truth" is selected.
82 """
83
84 if isinstance(self.output_file_name, str):
85 output_file = ROOT.TFile(self.output_file_name, "RECREATE")
86 else:
87 output_file = self.output_file_name
88
89 output_file.cd()
90
91 column_names = [leave.GetName() for leave in input_tree.GetListOfLeaves()]
92
93 tree_name = input_tree.GetName()
94
95 truth_name = self.truth_name
96
97 if truth_name is None:
98 truth_name = get_truth_name(column_names)
99
100 if truth_name not in column_names:
101 raise KeyError(f"Truth column {truth_name} not in tree {tree_name}")
102 variable_names = [name for name in column_names if name != truth_name]
103
104 exclude = self.exclude
105 select = self.select
106 groupbys = self.groupbys
107 auxiliaries = self.auxiliaries
108 filters = self.filters
109
110 if select:
111 variable_names = [name for name in variable_names if name in select]
112
113 if exclude:
114 variable_names = [name for name in variable_names if name not in exclude]
115
116 if filters:
117 variable_names = [name for name in variable_names if name not in filters]
118
119 # Remove the variables that have Monte Carlo truth information unless explicitly selected
120 variable_names = [name for name
121 in variable_names
122 if "truth" not in name or name in select]
123
124 print("Truth name", truth_name)
125 print("Variable names", variable_names)
126
127 print("Loading tree")
128 branch_names = {*variable_names, truth_name, *groupbys, *auxiliaries, *filters}
129 branch_names = [name for name in branch_names if name]
130 if isinstance(self.output_file_name, str):
131 rdf = ROOT.RDataFrame(input_tree, self.output_file_name)
132 else:
133 rdf = ROOT.RDataFrame(input_tree, self.output_file_name.GetName())
134 input_array = np.column_stack(list(rdf.AsNumpy(branch_names).values()))
135 input_record_array = input_array.view(np.recarray)
136
137 if filters:
138 for filter in filters:
139 filter_values = input_record_array[filter]
140 input_record_array = input_record_array[np.nonzero(filter_values)]
141
142 print("Loaded tree")
143 truths = input_record_array[truth_name]
144
145 if not groupbys:
146 groupbys = [None]
147
148 for groupby in groupbys:
149 if groupby is None or groupby == "":
150 groupby_parts = [(None, slice(None))]
151 else:
152 groupby_parts = []
153 groupby_values = input_record_array[groupby]
154 unique_values, indices = np.unique(groupby_values, return_inverse=True)
155 for idx, value in enumerate(unique_values):
156 groupby_parts.append((value, indices == idx))
157
158 for groupby_value, groupby_select in groupby_parts:
159 if groupby is None:
160 groupby_folder_name = '.'
161 else:
162 groupby_folder_name = f"groupby_{groupby}_{groupby_value}"
163
164 with root_cd(groupby_folder_name) as tdirectory:
165 for variable_name in variable_names:
166 print('Analyse', variable_name, 'groupby', groupby, '=', groupby_value)
167
168 if variable_name == groupby:
169 continue
170
171 # Get the truths as a numpy array
172 estimates = input_record_array[variable_name]
173 estimates[estimates == np.finfo(np.float32).max] = float("nan")
174 estimates[estimates == -np.finfo(np.float32).max] = -float("inf")
175 auxiliaries = {name: input_record_array[name][groupby_select] for name in self.auxiliaries}
176
177 classification_analysis = classification.ClassificationAnalysis(
178 contact="",
179 quantity_name=variable_name,
180 outlier_z_score=5.0,
181 allow_discrete=True,
182 )
183 classification_analysis.analyse(
184 estimates[groupby_select],
185 truths[groupby_select],
186 auxiliaries=auxiliaries
187 )
188
189 with root_cd(variable_name) as tdirectory:
190 classification_analysis.write(tdirectory)
191
192 self.classification_analyses.append(classification_analysis)
193
194 if isinstance(self.output_file_name, str):
195 output_file.Close()
196
197 print("Saved overviews completely")
truth_name
cached truth name
Definition: train.py:56
classification_analyses
array of classification analyses
Definition: train.py:69
groupbys
cached groupby-specifier array
Definition: train.py:62
def __init__(self, output_file_name, truth_name=None, select=[], exclude=[], groupbys=[], auxiliaries=[], filters=[])
Definition: train.py:51
output_file_name
cached output filename
Definition: train.py:54
filters
cached filter-specifier array
Definition: train.py:66
exclude
cached exclusion-specifier array
Definition: train.py:60
select
cached selection-specifier array
Definition: train.py:58
auxiliaries
cached auxiliary-specifier array
Definition: train.py:64
Definition: train.py:1