29 argument_parser = utilities.DefaultHelpArgumentParser()
31 argument_parser.add_argument(
33 help=
"ROOT file containing the TTree of records on which to train a boosted decision tree.",
36 argument_parser.add_argument(
40 help=
"Name of the input TTree in the ROOT file",
43 argument_parser.add_argument(
46 default=argparse.SUPPRESS,
47 help=
"Database identifier or name of weight file to be generated",
50 argument_parser.add_argument(
55 help=
"Name of the column containing the truth information."
58 argument_parser.add_argument(
62 help=
"Name of the column containing the variables to be used."
65 argument_parser.add_argument(
67 "--variable_excludes",
70 help=
"Variables to be excluded"
73 argument_parser.add_argument(
77 help=
"MVA Method [FastBDT], not implemented: [NeuroBayes|TMVA|XGBoost|Theano|Tensorflow|FANN|SKLearn]"
80 argument_parser.add_argument(
84 help=
"Evaluate the method after the training is finished"
87 argument_parser.add_argument(
91 help=
"Fill nan and inf values with actual numbers in evaluation"
94 arguments = argument_parser.parse_args()
96 records_file_path = arguments.records_file_path
97 treename = arguments.treename
98 feature_names = arguments.variables
100 excludes = arguments.variable_excludes
103 elif "truth" not in excludes:
104 excludes.append(
"truth")
106 print(
'excludes: ', excludes)
109 if feature_names
is None:
110 with root_utils.root_open(records_file_path)
as records_tfile:
111 input_tree = records_tfile.Get(treename)
112 feature_names = [leave.GetName()
for leave
in input_tree.GetListOfLeaves()]
114 truth = arguments.truth
115 method = arguments.method
117 identifier = vars(arguments).get(
"identifier", method +
".weights.xml")
120 truth_free_variable_names = [name
for name
122 if name
not in excludes]
125 if "weight" in truth_free_variable_names:
126 truth_free_variable_names.remove(
"weight")
127 weight_variable =
"weight"
129 elif "__weight__" in truth_free_variable_names:
130 truth_free_variable_names.remove(
"__weight__")
131 weight_variable =
"__weight__"
138 "--datafiles", records_file_path,
139 "--treename", treename,
140 "--identifier", identifier,
141 "--target_variable", truth,
143 "--variables", *truth_free_variable_names,
144 "--weight_variable", weight_variable,
150 if arguments.evaluate:
151 evaluation_pdf = identifier.rsplit(
".", 1)[0] +
".pdf"
153 "basf2_mva_evaluate.py",
154 "--identifier", identifier,
155 "-d", records_file_path,
156 "--treename", treename,
159 if arguments.fillnan:
165if __name__ ==
"__main__":
166 logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=
'%(levelname)s:%(message)s')