Belle II Software development
__init__.py
1#!/usr/bin/env python3
2
3
10
11import basf2_mva
12import tracking.root_utils as root_utils
13from basf2 import conditions
14
15
16def my_basf2_mva_teacher(
17 records_files,
18 tree_name,
19 weightfile_identifier,
20 target_variable="truth",
21 exclude_variables=None,
22 fast_bdt_option=[200, 8, 3, 0.1]
23):
24
25 conditions.testing_payloads = [
26 'localdb/database.txt'
27 ]
28
29 """
30 Custom wrapper for basf2 mva teacher. Adapted from code in ``trackfindingcdc_teacher``.
31
32 :param records_files: List of files with collected ("recorded") variables to use as training data for the MVA.
33 :param tree_name: Name of the TTree in the ROOT file from the ``data_collection_task``
34 that contains the training data for the MVA teacher.
35 :param weightfile_identifier: Name of the weightfile that is created.
36 Must not end in .xml nor in .root since the payload will be later downloaded to a local database.
37 :param target_variable: Feature/variable to use as truth label in the quality estimator MVA classifier.
38 :param exclude_variables: List of collected variables to not use in the training of the QE MVA classifier.
39 In addition to variables containing the "truth" substring, which are excluded by default.
40 :param fast_bdt_option: specified fast BDT options, default: [200, 8, 3, 0.1] [nTrees, nCuts, nLevels, shrinkage]
41 """
42 if exclude_variables is None:
43 exclude_variables = []
44
45 # extract names of all variables from one record file
46 with root_utils.root_open(records_files[0]) as records_tfile:
47 input_tree = records_tfile.Get(tree_name)
48 feature_names = [leave.GetName() for leave in input_tree.GetListOfLeaves()]
49
50 # get list of variables to use for training without MC truth
51 truth_free_variable_names = [
52 name
53 for name in feature_names
54 if (
55 ("truth" not in name) and
56 (name != target_variable) and
57 (name not in exclude_variables)
58 )
59 ]
60 if "weight" in truth_free_variable_names:
61 truth_free_variable_names.remove("weight")
62 weight_variable = "weight"
63 elif "__weight__" in truth_free_variable_names:
64 truth_free_variable_names.remove("__weight__")
65 weight_variable = "__weight__"
66 else:
67 weight_variable = ""
68
69 # Set options for MVA training
70 general_options = basf2_mva.GeneralOptions()
71 general_options.m_datafiles = basf2_mva.vector(*records_files)
72 general_options.m_treename = tree_name
73 general_options.m_weight_variable = weight_variable
74 general_options.m_identifier = weightfile_identifier
75 general_options.m_variables = basf2_mva.vector(*truth_free_variable_names)
76 general_options.m_target_variable = target_variable
77 fastbdt_options = basf2_mva.FastBDTOptions()
78
79 fastbdt_options.m_nTrees = fast_bdt_option[0]
80 fastbdt_options.m_nCuts = fast_bdt_option[1]
81 fastbdt_options.m_nLevels = fast_bdt_option[2]
82 fastbdt_options.m_shrinkage = fast_bdt_option[3]
83 # Train a MVA method.
84 basf2_mva.teacher(general_options, fastbdt_options)
85
86
87def create_fbdt_option_string(fast_bdt_option):
88 """
89 Returns a readable string created by the ``fast_bdt_option`` array.
90
91 :param fast_bdt_option: List containing the FastBDT options that should be converted to a human readable string
92 """
93 return f"_nTrees{fast_bdt_option[0]}_nCuts{fast_bdt_option[1]}"\
94 f"_nLevels{fast_bdt_option[2]}_shrin{int(round(100*fast_bdt_option[3], 0))}"