Belle II Software  release-05-01-25
VXDQE_teacher.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
11 
12 
13 import sys
14 import subprocess
15 import argparse
16 
17 import tracking.utilities as utilities
18 import tracking.root_utils as root_utils
19 
20 import logging
21 
22 
23 def main():
24  argument_parser = utilities.DefaultHelpArgumentParser()
25 
26  argument_parser.add_argument(
27  "records_file_path",
28  help="ROOT file containing the TTree of records on which to train a boosted decision tree.",
29  )
30 
31  argument_parser.add_argument(
32  "-r",
33  "--treename",
34  default="tree",
35  help="Name of the input TTree in the ROOT file",
36  )
37 
38  argument_parser.add_argument(
39  "-i",
40  "--identifier",
41  default=argparse.SUPPRESS,
42  help="Database identifier or name of weight file to be generated",
43  )
44 
45  argument_parser.add_argument(
46  "-t",
47  "--truth",
48  type=str,
49  default="truth",
50  help="Name of the column containing the truth information."
51  )
52 
53  argument_parser.add_argument(
54  "--variables",
55  default=None,
56  nargs='+',
57  help="Name of the column containing the variables to be used."
58  )
59 
60  argument_parser.add_argument(
61  "-x",
62  "--variable_excludes",
63  default=None,
64  nargs='+',
65  help="Variables to be excluded"
66  )
67 
68  argument_parser.add_argument(
69  "--method",
70  type=str,
71  default="FastBDT",
72  help="MVA Method [FastBDT], not implemented: [NeuroBayes|TMVA|XGBoost|Theano|Tensorflow|FANN|SKLearn]"
73  )
74 
75  argument_parser.add_argument(
76  "-e",
77  "--evaluate",
78  action="store_true",
79  help="Evaluate the method after the training is finished"
80  )
81 
82  argument_parser.add_argument(
83  "-n",
84  "--fillnan",
85  action="store_true",
86  help="Fill nan and inf values with actual numbers in evaluation"
87  )
88 
89  arguments = argument_parser.parse_args()
90 
91  records_file_path = arguments.records_file_path
92  treename = arguments.treename
93  feature_names = arguments.variables
94 
95  excludes = arguments.variable_excludes
96  if excludes is None:
97  excludes = ["truth"]
98  elif "truth" not in excludes:
99  excludes.append("truth")
100 
101  print('excludes: ', excludes)
102 
103  # Figure out feature variables
104  if feature_names is None:
105  with root_utils.root_open(records_file_path) as records_tfile:
106  input_tree = records_tfile.Get(treename)
107  feature_names = [leave.GetName() for leave in input_tree.GetListOfLeaves()]
108 
109  truth = arguments.truth
110  method = arguments.method
111 
112  identifier = vars(arguments).get("identifier", method + ".weights.xml")
113 
114  # Remove the variables that have Monte Carlo truth information unless explicitly selected
115  truth_free_variable_names = [name for name
116  in feature_names
117  if name not in excludes]
118 
119  weight_variable = ""
120  if "weight" in truth_free_variable_names:
121  truth_free_variable_names.remove("weight")
122  weight_variable = "weight"
123 
124  elif "__weight__" in truth_free_variable_names:
125  truth_free_variable_names.remove("__weight__")
126  weight_variable = "__weight__"
127 
128  else:
129  weight_variable = ""
130 
131  cmd = [
132  "basf2_mva_teacher",
133  "--datafiles", records_file_path,
134  "--treename", treename,
135  "--identifier", identifier,
136  "--target_variable", truth,
137  "--method", method,
138  "--variables", *truth_free_variable_names,
139  "--weight_variable", weight_variable,
140  ]
141 
142  print(cmd)
143  subprocess.call(cmd)
144 
145  if arguments.evaluate:
146  evaluation_pdf = identifier.rsplit(".", 1)[0] + ".pdf"
147  cmd = [
148  "basf2_mva_evaluate.py",
149  "--identifier", identifier,
150  "-d", records_file_path,
151  "--treename", treename,
152  "-o", evaluation_pdf
153  ]
154  if arguments.fillnan:
155  cmd.append("-n")
156  print(cmd)
157  subprocess.call(cmd)
158 
159 
160 if __name__ == "__main__":
161  logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(levelname)s:%(message)s')
162  main()
tracking.utilities
Definition: utilities.py:1
main
int main(int argc, char **argv)
Run all tests.
Definition: test_main.cc:77
tracking.root_utils
Definition: root_utils.py:1