Belle II Software  release-05-01-25
train.py
1 import basf2
2 import ROOT
3 
4 import numpy as np
5 
6 import tracking.validation.classification as classification
7 from tracking.root_utils import root_cd
8 
9 import logging
10 
11 
12 def get_logger():
13  return logging.getLogger(__name__)
14 
15 
16 def get_truth_name(variable_names):
17  """Selects the shortest variable name form containing the 'truth'."""
18  truth_names = [name for name in variable_names if "truth" in name]
19 
20  # select the shortest
21  try:
22  truth_name = min(truth_names, key=len)
23  except ValueError:
24  raise ValueError("variable_names='%s' does not contain a truth variable" % variable_names)
25  else:
26  return truth_name
27 
28 
30  """Class to generate overview plots for the classification power of various variables from a TTree.
31 
32  In order to get an overview, which variables from a diverse set, generated from a recording filter
33  or some other sort of validation, perform well in classification task.
34 
35  """
36 
37  def __init__(self,
38  output_file_name,
39  truth_name=None,
40  select=[],
41  exclude=[],
42  groupbys=[],
43  auxiliaries=[],
44  filters=[]):
45  """Constructor"""
46 
47  self.output_file_name = output_file_name
48 
49  self.truth_name = truth_name
50 
51  self.select = select
52 
53  self.exclude = exclude
54 
55  self.groupbys = groupbys
56 
57  self.auxiliaries = auxiliaries
58 
59  self.filters = filters
60 
61 
62  self.classification_analyses = []
63 
64  def train(self, input_tree):
65  """Main method feed with a TTree containing the truth variable and the variables to be investigated.
66 
67  Branches that contain "truth" in the name are considered to directly contain information about
68  true classification target and are not analysed here.
69 
70  Args:
71  input_tree (ROOT.TTree) : Tree containing the variables to be investigated
72  as well as the classification target.
73  truth_name (str, optional) : Name of the branch of the classification target.
74  If not given the Branch with the shortest name containing "truth" is selected.
75  """
76 
77  if isinstance(self.output_file_name, str):
78  output_file = ROOT.TFile(self.output_file_name, "RECREATE")
79  else:
80  output_file = self.output_file_name
81 
82  output_file.cd()
83 
84  column_names = [leave.GetName() for leave in input_tree.GetListOfLeaves()]
85 
86  tree_name = input_tree.GetName()
87 
88  truth_name = self.truth_name
89 
90  if truth_name is None:
91  truth_name = get_truth_name(column_names)
92 
93  if truth_name not in column_names:
94  raise KeyError("Truth column {truth} not in tree {tree}".format(truth=truth_name,
95  tree=tree_name))
96  variable_names = [name for name in column_names if name != truth_name]
97 
98  exclude = self.exclude
99  select = self.select
100  groupbys = self.groupbys
101  auxiliaries = self.auxiliaries
102  filters = self.filters
103 
104  if select:
105  variable_names = [name for name in variable_names if name in select]
106 
107  if exclude:
108  variable_names = [name for name in variable_names if name not in exclude]
109 
110  if filters:
111  variable_names = [name for name in variable_names if name not in filters]
112 
113  # Remove the variables that have Monte Carlo truth information unless explicitly selected
114  variable_names = [name for name
115  in variable_names
116  if "truth" not in name or name in select]
117 
118  print("Truth name", truth_name)
119  print("Variable names", variable_names)
120 
121  import root_numpy
122  print("Loading tree")
123  branch_names = {*variable_names, truth_name, *groupbys, *auxiliaries, *filters}
124  branch_names = [name for name in branch_names if name]
125  input_array = root_numpy.tree2array(input_tree, branches=branch_names)
126  input_record_array = input_array.view(np.recarray)
127 
128  if filters:
129  for filter in filters:
130  filter_values = input_record_array[filter]
131  input_record_array = input_record_array[np.nonzero(filter_values)]
132 
133  print("Loaded tree")
134  truths = input_record_array[truth_name]
135 
136  if not groupbys:
137  groupbys = [None]
138 
139  for groupby in groupbys:
140  if groupby is None or groupby == "":
141  groupby_parts = [(None, slice(None))]
142  else:
143  groupby_parts = []
144  groupby_values = input_record_array[groupby]
145  unique_values, indices = np.unique(groupby_values, return_inverse=True)
146  for idx, value in enumerate(unique_values):
147  groupby_parts.append((value, indices == idx))
148 
149  for groupby_value, groupby_select in groupby_parts:
150  if groupby is None:
151  groupby_folder_name = '.'
152  else:
153  groupby_folder_name = "groupby_{name}_{value}".format(name=groupby, value=groupby_value)
154 
155  with root_cd(groupby_folder_name) as tdirectory:
156  for variable_name in variable_names:
157  print('Analyse', variable_name, 'groupby', groupby, '=', groupby_value)
158 
159  if variable_name == groupby:
160  continue
161 
162  # Get the truths as a numpy array
163  estimates = input_record_array[variable_name]
164  estimates[estimates == np.finfo(np.float32).max] = float("nan")
165  estimates[estimates == -np.finfo(np.float32).max] = -float("inf")
166  auxiliaries = {name: input_record_array[name][groupby_select] for name in self.auxiliaries}
167 
168  classification_analysis = classification.ClassificationAnalysis(
169  contact="",
170  quantity_name=variable_name,
171  outlier_z_score=5.0,
172  allow_discrete=True,
173  )
174  classification_analysis.analyse(
175  estimates[groupby_select],
176  truths[groupby_select],
177  auxiliaries=auxiliaries
178  )
179 
180  with root_cd(variable_name) as tdirectory:
181  classification_analysis.write(tdirectory)
182 
183  self.classification_analyses.append(classification_analysis)
184 
185  if isinstance(self.output_file_name, str):
186  output_file.Close()
187 
188  print("Saved overviews completely")
train.ClassificationOverview.auxiliaries
auxiliaries
cached auxiliary-specifier array
Definition: train.py:50
tracking.validation.classification
Definition: classification.py:1
train.ClassificationOverview
Definition: train.py:29
train.ClassificationOverview.truth_name
truth_name
cached truth name
Definition: train.py:42
train.ClassificationOverview.groupbys
groupbys
cached groupby-specifier array
Definition: train.py:48
train.ClassificationOverview.select
select
cached selection-specifier array
Definition: train.py:44
train.ClassificationOverview.classification_analyses
classification_analyses
array of classification analyses
Definition: train.py:55
train.ClassificationOverview.filters
filters
cached filter-specifier array
Definition: train.py:52
tracking.root_utils
Definition: root_utils.py:1
train.ClassificationOverview.__init__
def __init__(self, output_file_name, truth_name=None, select=[], exclude=[], groupbys=[], auxiliaries=[], filters=[])
Definition: train.py:37
Belle2::slice
std::vector< Atom > slice(std::vector< Atom > vec, int s, int e)
Slice the vector to contain only elements with indexes s .. e (included)
Definition: Splitter.h:89
train.ClassificationOverview.train
def train(self, input_tree)
Definition: train.py:64
train.ClassificationOverview.exclude
exclude
cached exclusion-specifier array
Definition: train.py:46
train.ClassificationOverview.output_file_name
output_file_name
cached output filename
Definition: train.py:40