13from basf2
import conditions
16def my_basf2_mva_teacher(
19 weightfile_identifier,
20 target_variable="truth",
21 exclude_variables=None,
22 fast_bdt_option=[200, 8, 3, 0.1]
25 conditions.testing_payloads = [
26 'localdb/database.txt'
30 Custom wrapper for basf2 mva teacher. Adapted
from code
in ``trackfindingcdc_teacher``.
32 :param records_files: List of files
with collected (
"recorded") variables to use
as training data
for the MVA.
33 :param tree_name: Name of the TTree
in the ROOT file
from the ``data_collection_task``
34 that contains the training data
for the MVA teacher.
35 :param weightfile_identifier: Name of the weightfile that
is created.
36 Must
not end
in .xml nor
in .root since the payload will be later downloaded to a local database.
37 :param target_variable: Feature/variable to use
as truth label
in the quality estimator MVA classifier.
38 :param exclude_variables: List of collected variables to
not use
in the training of the QE MVA classifier.
39 In addition to variables containing the
"truth" substring, which are excluded by default.
40 :param fast_bdt_option: specified fast BDT options, default: [200, 8, 3, 0.1] [nTrees, nCuts, nLevels, shrinkage]
42 if exclude_variables
is None:
43 exclude_variables = []
46 with root_utils.root_open(records_files[0])
as records_tfile:
47 input_tree = records_tfile.Get(tree_name)
48 feature_names = [leave.GetName()
for leave
in input_tree.GetListOfLeaves()]
51 truth_free_variable_names = [
53 for name
in feature_names
55 (
"truth" not in name)
and
56 (name != target_variable)
and
57 (name
not in exclude_variables)
60 if "weight" in truth_free_variable_names:
61 truth_free_variable_names.remove(
"weight")
62 weight_variable =
"weight"
63 elif "__weight__" in truth_free_variable_names:
64 truth_free_variable_names.remove(
"__weight__")
65 weight_variable =
"__weight__"
70 general_options = basf2_mva.GeneralOptions()
71 general_options.m_datafiles = basf2_mva.vector(*records_files)
72 general_options.m_treename = tree_name
73 general_options.m_weight_variable = weight_variable
74 general_options.m_identifier = weightfile_identifier
75 general_options.m_variables = basf2_mva.vector(*truth_free_variable_names)
76 general_options.m_target_variable = target_variable
77 fastbdt_options = basf2_mva.FastBDTOptions()
79 fastbdt_options.m_nTrees = fast_bdt_option[0]
80 fastbdt_options.m_nCuts = fast_bdt_option[1]
81 fastbdt_options.m_nLevels = fast_bdt_option[2]
82 fastbdt_options.m_shrinkage = fast_bdt_option[3]
84 basf2_mva.teacher(general_options, fastbdt_options)
87def create_fbdt_option_string(fast_bdt_option):
89 Returns a readable string created by the ``fast_bdt_option`` array.
91 :param fast_bdt_option: List containing the FastBDT options that should be converted to a human readable string
93 return f
"_nTrees{fast_bdt_option[0]}_nCuts{fast_bdt_option[1]}"\
94 f
"_nLevels{fast_bdt_option[2]}_shrin{int(round(100*fast_bdt_option[3], 0))}"