24 argument_parser = utilities.DefaultHelpArgumentParser()
26 argument_parser.add_argument(
28 help=
"ROOT file containing the TTree of records on which to train a boosted decision tree.",
31 argument_parser.add_argument(
35 help=
"Name of the input TTree in the ROOT file",
38 argument_parser.add_argument(
41 default=argparse.SUPPRESS,
42 help=
"Database identifier or name of weight file to be generated",
45 argument_parser.add_argument(
50 help=
"Name of the column containing the truth information."
53 argument_parser.add_argument(
57 help=
"Name of the column containing the variables to be used."
60 argument_parser.add_argument(
62 "--variable_excludes",
65 help=
"Variables to be excluded"
68 argument_parser.add_argument(
72 help=
"MVA Method [FastBDT], not implemented: [NeuroBayes|TMVA|XGBoost|Theano|Tensorflow|FANN|SKLearn]"
75 argument_parser.add_argument(
79 help=
"Evaluate the method after the training is finished"
82 argument_parser.add_argument(
86 help=
"Fill nan and inf values with actual numbers in evaluation"
89 arguments = argument_parser.parse_args()
91 records_file_path = arguments.records_file_path
92 treename = arguments.treename
93 feature_names = arguments.variables
95 excludes = arguments.variable_excludes
98 elif "truth" not in excludes:
99 excludes.append(
"truth")
101 print(
'excludes: ', excludes)
104 if feature_names
is None:
105 with root_utils.root_open(records_file_path)
as records_tfile:
106 input_tree = records_tfile.Get(treename)
107 feature_names = [leave.GetName()
for leave
in input_tree.GetListOfLeaves()]
109 truth = arguments.truth
110 method = arguments.method
112 identifier = vars(arguments).get(
"identifier", method +
".weights.xml")
115 truth_free_variable_names = [name
for name
117 if name
not in excludes]
120 if "weight" in truth_free_variable_names:
121 truth_free_variable_names.remove(
"weight")
122 weight_variable =
"weight"
124 elif "__weight__" in truth_free_variable_names:
125 truth_free_variable_names.remove(
"__weight__")
126 weight_variable =
"__weight__"
133 "--datafiles", records_file_path,
134 "--treename", treename,
135 "--identifier", identifier,
136 "--target_variable", truth,
138 "--variables", *truth_free_variable_names,
139 "--weight_variable", weight_variable,
145 if arguments.evaluate:
146 evaluation_pdf = identifier.rsplit(
".", 1)[0] +
".pdf"
148 "basf2_mva_evaluate.py",
149 "--identifier", identifier,
150 "-d", records_file_path,
151 "--treename", treename,
154 if arguments.fillnan:
160 if __name__ ==
"__main__":
161 logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=
'%(levelname)s:%(message)s')