11 from pathlib
import Path
16 def my_basf2_mva_teacher(
19 weightfile_identifier,
20 target_variable="truth",
21 exclude_variables=None,
22 fast_bdt_option=[200, 8, 3, 0.1]
25 Custom wrapper for basf2 mva teacher. Adapted from code in ``trackfindingcdc_teacher``.
27 :param records_files: List of files with collected ("recorded") variables to use as training data for the MVA.
28 :param tree_name: Name of the TTree in the ROOT file from the ``data_collection_task``
29 that contains the training data for the MVA teacher.
30 :param weightfile_identifier: Name of the weightfile that is created.
31 Should either end in ".xml" for local weightfiles or in ".root", when
32 the weightfile needs later to be uploaded as a payload to the conditions
34 :param target_variable: Feature/variable to use as truth label in the quality estimator MVA classifier.
35 :param exclude_variables: List of collected variables to not use in the training of the QE MVA classifier.
36 In addition to variables containing the "truth" substring, which are excluded by default.
37 :param fast_bdt_option: specified fast BDT options, default: [200, 8, 3, 0.1] [nTrees, nCuts, nLevels, shrinkage]
39 if exclude_variables
is None:
40 exclude_variables = []
42 weightfile_extension = Path(weightfile_identifier).suffix
43 if weightfile_extension
not in {
".xml",
".root"}:
44 raise ValueError(f
"Weightfile Identifier should end in .xml or .root, but ends in {weightfile_extension}")
47 with root_utils.root_open(records_files[0])
as records_tfile:
48 input_tree = records_tfile.Get(tree_name)
49 feature_names = [leave.GetName()
for leave
in input_tree.GetListOfLeaves()]
52 truth_free_variable_names = [
54 for name
in feature_names
56 (
"truth" not in name)
and
57 (name != target_variable)
and
58 (name
not in exclude_variables)
61 if "weight" in truth_free_variable_names:
62 truth_free_variable_names.remove(
"weight")
63 weight_variable =
"weight"
64 elif "__weight__" in truth_free_variable_names:
65 truth_free_variable_names.remove(
"__weight__")
66 weight_variable =
"__weight__"
71 general_options = basf2_mva.GeneralOptions()
72 general_options.m_datafiles = basf2_mva.vector(*records_files)
73 general_options.m_treename = tree_name
74 general_options.m_weight_variable = weight_variable
75 general_options.m_identifier = weightfile_identifier
76 general_options.m_variables = basf2_mva.vector(*truth_free_variable_names)
77 general_options.m_target_variable = target_variable
78 fastbdt_options = basf2_mva.FastBDTOptions()
80 fastbdt_options.m_nTrees = fast_bdt_option[0]
81 fastbdt_options.m_nCuts = fast_bdt_option[1]
82 fastbdt_options.m_nLevels = fast_bdt_option[2]
83 fastbdt_options.m_shrinkage = fast_bdt_option[3]
85 basf2_mva.teacher(general_options, fastbdt_options)
88 def create_fbdt_option_string(fast_bdt_option):
90 Returns a readable string created by the ``fast_bdt_option`` array.
92 :param fast_bdt_option: List containing the FastBDT options that should be converted to a human readable string
94 return f
"_nTrees{fast_bdt_option[0]}_nCuts{fast_bdt_option[1]}"\
95 f
"_nLevels{fast_bdt_option[2]}_shrin{int(round(100*fast_bdt_option[3], 0))}"