Belle II Software  release-06-00-14
train.py
1 
8 
9 import ROOT
10 
11 import numpy as np
12 
13 import tracking.validation.classification as classification
14 from tracking.root_utils import root_cd
15 
16 import logging
17 
18 
19 def get_logger():
20  return logging.getLogger(__name__)
21 
22 
23 def get_truth_name(variable_names):
24  """Selects the shortest variable name form containing the 'truth'."""
25  truth_names = [name for name in variable_names if "truth" in name]
26 
27  # select the shortest
28  try:
29  truth_name = min(truth_names, key=len)
30  except ValueError:
31  raise ValueError("variable_names='%s' does not contain a truth variable" % variable_names)
32  else:
33  return truth_name
34 
35 
37  """Class to generate overview plots for the classification power of various variables from a TTree.
38 
39  In order to get an overview, which variables from a diverse set, generated from a recording filter
40  or some other sort of validation, perform well in classification task.
41 
42  """
43 
44  def __init__(self,
45  output_file_name,
46  truth_name=None,
47  select=[],
48  exclude=[],
49  groupbys=[],
50  auxiliaries=[],
51  filters=[]):
52  """Constructor"""
53 
54  self.output_file_nameoutput_file_name = output_file_name
55 
56  self.truth_nametruth_name = truth_name
57 
58  self.selectselect = select
59 
60  self.excludeexclude = exclude
61 
62  self.groupbysgroupbys = groupbys
63 
64  self.auxiliariesauxiliaries = auxiliaries
65 
66  self.filtersfilters = filters
67 
68 
69  self.classification_analysesclassification_analyses = []
70 
71  def train(self, input_tree):
72  """Main method feed with a TTree containing the truth variable and the variables to be investigated.
73 
74  Branches that contain "truth" in the name are considered to directly contain information about
75  true classification target and are not analysed here.
76 
77  Args:
78  input_tree (ROOT.TTree) : Tree containing the variables to be investigated
79  as well as the classification target.
80  truth_name (str, optional) : Name of the branch of the classification target.
81  If not given the Branch with the shortest name containing "truth" is selected.
82  """
83 
84  if isinstance(self.output_file_nameoutput_file_name, str):
85  output_file = ROOT.TFile(self.output_file_nameoutput_file_name, "RECREATE")
86  else:
87  output_file = self.output_file_nameoutput_file_name
88 
89  output_file.cd()
90 
91  column_names = [leave.GetName() for leave in input_tree.GetListOfLeaves()]
92 
93  tree_name = input_tree.GetName()
94 
95  truth_name = self.truth_nametruth_name
96 
97  if truth_name is None:
98  truth_name = get_truth_name(column_names)
99 
100  if truth_name not in column_names:
101  raise KeyError("Truth column {truth} not in tree {tree}".format(truth=truth_name,
102  tree=tree_name))
103  variable_names = [name for name in column_names if name != truth_name]
104 
105  exclude = self.excludeexclude
106  select = self.selectselect
107  groupbys = self.groupbysgroupbys
108  auxiliaries = self.auxiliariesauxiliaries
109  filters = self.filtersfilters
110 
111  if select:
112  variable_names = [name for name in variable_names if name in select]
113 
114  if exclude:
115  variable_names = [name for name in variable_names if name not in exclude]
116 
117  if filters:
118  variable_names = [name for name in variable_names if name not in filters]
119 
120  # Remove the variables that have Monte Carlo truth information unless explicitly selected
121  variable_names = [name for name
122  in variable_names
123  if "truth" not in name or name in select]
124 
125  print("Truth name", truth_name)
126  print("Variable names", variable_names)
127 
128  import root_numpy
129  print("Loading tree")
130  branch_names = {*variable_names, truth_name, *groupbys, *auxiliaries, *filters}
131  branch_names = [name for name in branch_names if name]
132  input_array = root_numpy.tree2array(input_tree, branches=branch_names)
133  input_record_array = input_array.view(np.recarray)
134 
135  if filters:
136  for filter in filters:
137  filter_values = input_record_array[filter]
138  input_record_array = input_record_array[np.nonzero(filter_values)]
139 
140  print("Loaded tree")
141  truths = input_record_array[truth_name]
142 
143  if not groupbys:
144  groupbys = [None]
145 
146  for groupby in groupbys:
147  if groupby is None or groupby == "":
148  groupby_parts = [(None, slice(None))]
149  else:
150  groupby_parts = []
151  groupby_values = input_record_array[groupby]
152  unique_values, indices = np.unique(groupby_values, return_inverse=True)
153  for idx, value in enumerate(unique_values):
154  groupby_parts.append((value, indices == idx))
155 
156  for groupby_value, groupby_select in groupby_parts:
157  if groupby is None:
158  groupby_folder_name = '.'
159  else:
160  groupby_folder_name = "groupby_{name}_{value}".format(name=groupby, value=groupby_value)
161 
162  with root_cd(groupby_folder_name) as tdirectory:
163  for variable_name in variable_names:
164  print('Analyse', variable_name, 'groupby', groupby, '=', groupby_value)
165 
166  if variable_name == groupby:
167  continue
168 
169  # Get the truths as a numpy array
170  estimates = input_record_array[variable_name]
171  estimates[estimates == np.finfo(np.float32).max] = float("nan")
172  estimates[estimates == -np.finfo(np.float32).max] = -float("inf")
173  auxiliaries = {name: input_record_array[name][groupby_select] for name in self.auxiliariesauxiliaries}
174 
175  classification_analysis = classification.ClassificationAnalysis(
176  contact="",
177  quantity_name=variable_name,
178  outlier_z_score=5.0,
179  allow_discrete=True,
180  )
181  classification_analysis.analyse(
182  estimates[groupby_select],
183  truths[groupby_select],
184  auxiliaries=auxiliaries
185  )
186 
187  with root_cd(variable_name) as tdirectory:
188  classification_analysis.write(tdirectory)
189 
190  self.classification_analysesclassification_analyses.append(classification_analysis)
191 
192  if isinstance(self.output_file_nameoutput_file_name, str):
193  output_file.Close()
194 
195  print("Saved overviews completely")
truth_name
cached truth name
Definition: train.py:56
def train(self, input_tree)
Definition: train.py:71
classification_analyses
array of classification analyses
Definition: train.py:69
groupbys
cached groupby-specifier array
Definition: train.py:62
def __init__(self, output_file_name, truth_name=None, select=[], exclude=[], groupbys=[], auxiliaries=[], filters=[])
Definition: train.py:51
output_file_name
cached output filename
Definition: train.py:54
filters
cached filter-specifier array
Definition: train.py:66
exclude
cached exclusion-specifier array
Definition: train.py:60
select
cached selection-specifier array
Definition: train.py:58
auxiliaries
cached auxiliary-specifier array
Definition: train.py:64
std::vector< Atom > slice(std::vector< Atom > vec, int s, int e)
Slice the vector to contain only elements with indexes s .. e (included)
Definition: Splitter.h:85