Belle II Software  release-08-01-10
train.py
1 
8 
9 import ROOT
10 
11 import numpy as np
12 
13 import tracking.validation.classification as classification
14 from tracking.root_utils import root_cd
15 
16 import logging
17 
18 
19 def get_logger():
20  return logging.getLogger(__name__)
21 
22 
23 def get_truth_name(variable_names):
24  """Selects the shortest variable name form containing the 'truth'."""
25  truth_names = [name for name in variable_names if "truth" in name]
26 
27  # select the shortest
28  try:
29  truth_name = min(truth_names, key=len)
30  except ValueError:
31  raise ValueError("variable_names='%s' does not contain a truth variable" % variable_names)
32  else:
33  return truth_name
34 
35 
37  """Class to generate overview plots for the classification power of various variables from a TTree.
38 
39  In order to get an overview, which variables from a diverse set, generated from a recording filter
40  or some other sort of validation, perform well in classification task.
41 
42  """
43 
44  def __init__(self,
45  output_file_name,
46  truth_name=None,
47  select=[],
48  exclude=[],
49  groupbys=[],
50  auxiliaries=[],
51  filters=[]):
52  """Constructor"""
53 
54  self.output_file_nameoutput_file_name = output_file_name
55 
56  self.truth_nametruth_name = truth_name
57 
58  self.selectselect = select
59 
60  self.excludeexclude = exclude
61 
62  self.groupbysgroupbys = groupbys
63 
64  self.auxiliariesauxiliaries = auxiliaries
65 
66  self.filtersfilters = filters
67 
68 
69  self.classification_analysesclassification_analyses = []
70 
71  def train(self, input_tree):
72  """Main method feed with a TTree containing the truth variable and the variables to be investigated.
73 
74  Branches that contain "truth" in the name are considered to directly contain information about
75  true classification target and are not analysed here.
76 
77  Args:
78  input_tree (ROOT.TTree) : Tree containing the variables to be investigated
79  as well as the classification target.
80  truth_name (str, optional) : Name of the branch of the classification target.
81  If not given the Branch with the shortest name containing "truth" is selected.
82  """
83 
84  if isinstance(self.output_file_nameoutput_file_name, str):
85  output_file = ROOT.TFile(self.output_file_nameoutput_file_name, "RECREATE")
86  else:
87  output_file = self.output_file_nameoutput_file_name
88 
89  output_file.cd()
90 
91  column_names = [leave.GetName() for leave in input_tree.GetListOfLeaves()]
92 
93  tree_name = input_tree.GetName()
94 
95  truth_name = self.truth_nametruth_name
96 
97  if truth_name is None:
98  truth_name = get_truth_name(column_names)
99 
100  if truth_name not in column_names:
101  raise KeyError("Truth column {truth} not in tree {tree}".format(truth=truth_name,
102  tree=tree_name))
103  variable_names = [name for name in column_names if name != truth_name]
104 
105  exclude = self.excludeexclude
106  select = self.selectselect
107  groupbys = self.groupbysgroupbys
108  auxiliaries = self.auxiliariesauxiliaries
109  filters = self.filtersfilters
110 
111  if select:
112  variable_names = [name for name in variable_names if name in select]
113 
114  if exclude:
115  variable_names = [name for name in variable_names if name not in exclude]
116 
117  if filters:
118  variable_names = [name for name in variable_names if name not in filters]
119 
120  # Remove the variables that have Monte Carlo truth information unless explicitly selected
121  variable_names = [name for name
122  in variable_names
123  if "truth" not in name or name in select]
124 
125  print("Truth name", truth_name)
126  print("Variable names", variable_names)
127 
128  print("Loading tree")
129  branch_names = {*variable_names, truth_name, *groupbys, *auxiliaries, *filters}
130  branch_names = [name for name in branch_names if name]
131  if isinstance(self.output_file_nameoutput_file_name, str):
132  rdf = ROOT.RDataFrame(input_tree, self.output_file_nameoutput_file_name)
133  else:
134  rdf = ROOT.RDataFrame(input_tree, self.output_file_nameoutput_file_name.GetName())
135  input_array = np.column_stack(list(rdf.AsNumpy(branch_names).values()))
136  input_record_array = input_array.view(np.recarray)
137 
138  if filters:
139  for filter in filters:
140  filter_values = input_record_array[filter]
141  input_record_array = input_record_array[np.nonzero(filter_values)]
142 
143  print("Loaded tree")
144  truths = input_record_array[truth_name]
145 
146  if not groupbys:
147  groupbys = [None]
148 
149  for groupby in groupbys:
150  if groupby is None or groupby == "":
151  groupby_parts = [(None, slice(None))]
152  else:
153  groupby_parts = []
154  groupby_values = input_record_array[groupby]
155  unique_values, indices = np.unique(groupby_values, return_inverse=True)
156  for idx, value in enumerate(unique_values):
157  groupby_parts.append((value, indices == idx))
158 
159  for groupby_value, groupby_select in groupby_parts:
160  if groupby is None:
161  groupby_folder_name = '.'
162  else:
163  groupby_folder_name = "groupby_{name}_{value}".format(name=groupby, value=groupby_value)
164 
165  with root_cd(groupby_folder_name) as tdirectory:
166  for variable_name in variable_names:
167  print('Analyse', variable_name, 'groupby', groupby, '=', groupby_value)
168 
169  if variable_name == groupby:
170  continue
171 
172  # Get the truths as a numpy array
173  estimates = input_record_array[variable_name]
174  estimates[estimates == np.finfo(np.float32).max] = float("nan")
175  estimates[estimates == -np.finfo(np.float32).max] = -float("inf")
176  auxiliaries = {name: input_record_array[name][groupby_select] for name in self.auxiliariesauxiliaries}
177 
178  classification_analysis = classification.ClassificationAnalysis(
179  contact="",
180  quantity_name=variable_name,
181  outlier_z_score=5.0,
182  allow_discrete=True,
183  )
184  classification_analysis.analyse(
185  estimates[groupby_select],
186  truths[groupby_select],
187  auxiliaries=auxiliaries
188  )
189 
190  with root_cd(variable_name) as tdirectory:
191  classification_analysis.write(tdirectory)
192 
193  self.classification_analysesclassification_analyses.append(classification_analysis)
194 
195  if isinstance(self.output_file_nameoutput_file_name, str):
196  output_file.Close()
197 
198  print("Saved overviews completely")
truth_name
cached truth name
Definition: train.py:56
def train(self, input_tree)
Definition: train.py:71
classification_analyses
array of classification analyses
Definition: train.py:69
groupbys
cached groupby-specifier array
Definition: train.py:62
def __init__(self, output_file_name, truth_name=None, select=[], exclude=[], groupbys=[], auxiliaries=[], filters=[])
Definition: train.py:51
output_file_name
cached output filename
Definition: train.py:54
filters
cached filter-specifier array
Definition: train.py:66
exclude
cached exclusion-specifier array
Definition: train.py:60
select
cached selection-specifier array
Definition: train.py:58
auxiliaries
cached auxiliary-specifier array
Definition: train.py:64
std::vector< Atom > slice(std::vector< Atom > vec, int s, int e)
Slice the vector to contain only elements with indexes s .. e (included)
Definition: Splitter.h:85