Belle II Software  release-08-01-10
VXDQE_teacher.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 
11 
12 
17 
18 
19 import sys
20 import subprocess
21 import argparse
22 
23 import tracking.utilities as utilities
24 import tracking.root_utils as root_utils
25 
26 import logging
27 
28 
29 def main():
30  argument_parser = utilities.DefaultHelpArgumentParser()
31 
32  argument_parser.add_argument(
33  "records_file_path",
34  help="ROOT file containing the TTree of records on which to train a boosted decision tree.",
35  )
36 
37  argument_parser.add_argument(
38  "-r",
39  "--treename",
40  default="tree",
41  help="Name of the input TTree in the ROOT file",
42  )
43 
44  argument_parser.add_argument(
45  "-i",
46  "--identifier",
47  default=argparse.SUPPRESS,
48  help="Database identifier or name of weight file to be generated",
49  )
50 
51  argument_parser.add_argument(
52  "-t",
53  "--truth",
54  type=str,
55  default="truth",
56  help="Name of the column containing the truth information."
57  )
58 
59  argument_parser.add_argument(
60  "--variables",
61  default=None,
62  nargs='+',
63  help="Name of the column containing the variables to be used."
64  )
65 
66  argument_parser.add_argument(
67  "-x",
68  "--variable_excludes",
69  default=None,
70  nargs='+',
71  help="Variables to be excluded"
72  )
73 
74  argument_parser.add_argument(
75  "--method",
76  type=str,
77  default="FastBDT",
78  help="MVA Method [FastBDT], not implemented: [NeuroBayes|TMVA|XGBoost|Theano|Tensorflow|FANN|SKLearn]"
79  )
80 
81  argument_parser.add_argument(
82  "-e",
83  "--evaluate",
84  action="store_true",
85  help="Evaluate the method after the training is finished"
86  )
87 
88  argument_parser.add_argument(
89  "-n",
90  "--fillnan",
91  action="store_true",
92  help="Fill nan and inf values with actual numbers in evaluation"
93  )
94 
95  arguments = argument_parser.parse_args()
96 
97  records_file_path = arguments.records_file_path
98  treename = arguments.treename
99  feature_names = arguments.variables
100 
101  excludes = arguments.variable_excludes
102  if excludes is None:
103  excludes = ["truth"]
104  elif "truth" not in excludes:
105  excludes.append("truth")
106 
107  print('excludes: ', excludes)
108 
109  # Figure out feature variables
110  if feature_names is None:
111  with root_utils.root_open(records_file_path) as records_tfile:
112  input_tree = records_tfile.Get(treename)
113  feature_names = [leave.GetName() for leave in input_tree.GetListOfLeaves()]
114 
115  truth = arguments.truth
116  method = arguments.method
117 
118  identifier = vars(arguments).get("identifier", method + ".weights.xml")
119 
120  # Remove the variables that have Monte Carlo truth information unless explicitly selected
121  truth_free_variable_names = [name for name
122  in feature_names
123  if name not in excludes]
124 
125  weight_variable = ""
126  if "weight" in truth_free_variable_names:
127  truth_free_variable_names.remove("weight")
128  weight_variable = "weight"
129 
130  elif "__weight__" in truth_free_variable_names:
131  truth_free_variable_names.remove("__weight__")
132  weight_variable = "__weight__"
133 
134  else:
135  weight_variable = ""
136 
137  cmd = [
138  "basf2_mva_teacher",
139  "--datafiles", records_file_path,
140  "--treename", treename,
141  "--identifier", identifier,
142  "--target_variable", truth,
143  "--method", method,
144  "--variables", *truth_free_variable_names,
145  "--weight_variable", weight_variable,
146  ]
147 
148  print(cmd)
149  subprocess.call(cmd)
150 
151  if arguments.evaluate:
152  evaluation_pdf = identifier.rsplit(".", 1)[0] + ".pdf"
153  cmd = [
154  "basf2_mva_evaluate.py",
155  "--identifier", identifier,
156  "-d", records_file_path,
157  "--treename", treename,
158  "-o", evaluation_pdf
159  ]
160  if arguments.fillnan:
161  cmd.append("-n")
162  print(cmd)
163  subprocess.call(cmd)
164 
165 
166 if __name__ == "__main__":
167  logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(levelname)s:%(message)s')
168  main()
Definition: main.py:1
int main(int argc, char **argv)
Run all tests.
Definition: test_main.cc:91