30 argument_parser = utilities.DefaultHelpArgumentParser()
32 argument_parser.add_argument(
34 help=
"ROOT file containing the TTree of records on which to train a boosted decision tree.",
37 argument_parser.add_argument(
41 help=
"Name of the input TTree in the ROOT file",
44 argument_parser.add_argument(
47 default=argparse.SUPPRESS,
48 help=
"Database identifier or name of weight file to be generated",
51 argument_parser.add_argument(
56 help=
"Name of the column containing the truth information."
59 argument_parser.add_argument(
63 help=
"Name of the column containing the variables to be used."
66 argument_parser.add_argument(
68 "--variable_excludes",
71 help=
"Variables to be excluded"
74 argument_parser.add_argument(
78 help=
"MVA Method [FastBDT], not implemented: [NeuroBayes|TMVA|XGBoost|Theano|Tensorflow|FANN|SKLearn]"
81 argument_parser.add_argument(
85 help=
"Evaluate the method after the training is finished"
88 argument_parser.add_argument(
92 help=
"Fill nan and inf values with actual numbers in evaluation"
95 arguments = argument_parser.parse_args()
97 records_file_path = arguments.records_file_path
98 treename = arguments.treename
99 feature_names = arguments.variables
101 excludes = arguments.variable_excludes
104 elif "truth" not in excludes:
105 excludes.append(
"truth")
107 print(
'excludes: ', excludes)
110 if feature_names
is None:
111 with root_utils.root_open(records_file_path)
as records_tfile:
112 input_tree = records_tfile.Get(treename)
113 feature_names = [leave.GetName()
for leave
in input_tree.GetListOfLeaves()]
115 truth = arguments.truth
116 method = arguments.method
118 identifier = vars(arguments).get(
"identifier", method +
".weights.xml")
121 truth_free_variable_names = [name
for name
123 if name
not in excludes]
126 if "weight" in truth_free_variable_names:
127 truth_free_variable_names.remove(
"weight")
128 weight_variable =
"weight"
130 elif "__weight__" in truth_free_variable_names:
131 truth_free_variable_names.remove(
"__weight__")
132 weight_variable =
"__weight__"
139 "--datafiles", records_file_path,
140 "--treename", treename,
141 "--identifier", identifier,
142 "--target_variable", truth,
144 "--variables", *truth_free_variable_names,
145 "--weight_variable", weight_variable,
151 if arguments.evaluate:
152 evaluation_pdf = identifier.rsplit(
".", 1)[0] +
".pdf"
154 "basf2_mva_evaluate.py",
155 "--identifier", identifier,
156 "-d", records_file_path,
157 "--treename", treename,
160 if arguments.fillnan:
166 if __name__ ==
"__main__":
167 logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=
'%(levelname)s:%(message)s')
int main(int argc, char **argv)
Run all tests.