Belle II Software development
VXDQE_teacher.py
1#!/usr/bin/env python3
2
3
10
11
16
17
18import sys
19import subprocess
20import argparse
21
22import tracking.utilities as utilities
23import tracking.root_utils as root_utils
24
25import logging
26
27
28def main():
29 argument_parser = utilities.DefaultHelpArgumentParser()
30
31 argument_parser.add_argument(
32 "records_file_path",
33 help="ROOT file containing the TTree of records on which to train a boosted decision tree.",
34 )
35
36 argument_parser.add_argument(
37 "-r",
38 "--treename",
39 default="tree",
40 help="Name of the input TTree in the ROOT file",
41 )
42
43 argument_parser.add_argument(
44 "-i",
45 "--identifier",
46 default=argparse.SUPPRESS,
47 help="Database identifier or name of weight file to be generated",
48 )
49
50 argument_parser.add_argument(
51 "-t",
52 "--truth",
53 type=str,
54 default="truth",
55 help="Name of the column containing the truth information."
56 )
57
58 argument_parser.add_argument(
59 "--variables",
60 default=None,
61 nargs='+',
62 help="Name of the column containing the variables to be used."
63 )
64
65 argument_parser.add_argument(
66 "-x",
67 "--variable_excludes",
68 default=None,
69 nargs='+',
70 help="Variables to be excluded"
71 )
72
73 argument_parser.add_argument(
74 "--method",
75 type=str,
76 default="FastBDT",
77 help="MVA Method [FastBDT], not implemented: [NeuroBayes|TMVA|XGBoost|Theano|Tensorflow|FANN|SKLearn]"
78 )
79
80 argument_parser.add_argument(
81 "-e",
82 "--evaluate",
83 action="store_true",
84 help="Evaluate the method after the training is finished"
85 )
86
87 argument_parser.add_argument(
88 "-n",
89 "--fillnan",
90 action="store_true",
91 help="Fill nan and inf values with actual numbers in evaluation"
92 )
93
94 arguments = argument_parser.parse_args()
95
96 records_file_path = arguments.records_file_path
97 treename = arguments.treename
98 feature_names = arguments.variables
99
100 excludes = arguments.variable_excludes
101 if excludes is None:
102 excludes = ["truth"]
103 elif "truth" not in excludes:
104 excludes.append("truth")
105
106 print('excludes: ', excludes)
107
108 # Figure out feature variables
109 if feature_names is None:
110 with root_utils.root_open(records_file_path) as records_tfile:
111 input_tree = records_tfile.Get(treename)
112 feature_names = [leave.GetName() for leave in input_tree.GetListOfLeaves()]
113
114 truth = arguments.truth
115 method = arguments.method
116
117 identifier = vars(arguments).get("identifier", method + ".weights.xml")
118
119 # Remove the variables that have Monte Carlo truth information unless explicitly selected
120 truth_free_variable_names = [name for name
121 in feature_names
122 if name not in excludes]
123
124 weight_variable = ""
125 if "weight" in truth_free_variable_names:
126 truth_free_variable_names.remove("weight")
127 weight_variable = "weight"
128
129 elif "__weight__" in truth_free_variable_names:
130 truth_free_variable_names.remove("__weight__")
131 weight_variable = "__weight__"
132
133 else:
134 weight_variable = ""
135
136 cmd = [
137 "basf2_mva_teacher",
138 "--datafiles", records_file_path,
139 "--treename", treename,
140 "--identifier", identifier,
141 "--target_variable", truth,
142 "--method", method,
143 "--variables", *truth_free_variable_names,
144 "--weight_variable", weight_variable,
145 ]
146
147 print(cmd)
148 subprocess.call(cmd)
149
150 if arguments.evaluate:
151 evaluation_pdf = identifier.rsplit(".", 1)[0] + ".pdf"
152 cmd = [
153 "basf2_mva_evaluate.py",
154 "--identifier", identifier,
155 "-d", records_file_path,
156 "--treename", treename,
157 "-o", evaluation_pdf
158 ]
159 if arguments.fillnan:
160 cmd.append("-n")
161 print(cmd)
162 subprocess.call(cmd)
163
164
165if __name__ == "__main__":
166 logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(levelname)s:%(message)s')
167 main()
Definition: main.py:1