Belle II Software release-09-00-00
__init__.py
1#!/usr/bin/env python3
2
3
10
11from pathlib import Path
12import basf2_mva
13import tracking.root_utils as root_utils
14
15
16def my_basf2_mva_teacher(
17 records_files,
18 tree_name,
19 weightfile_identifier,
20 target_variable="truth",
21 exclude_variables=None,
22 fast_bdt_option=[200, 8, 3, 0.1]
23):
24 """
25 Custom wrapper for basf2 mva teacher. Adapted from code in ``trackfindingcdc_teacher``.
26
27 :param records_files: List of files with collected ("recorded") variables to use as training data for the MVA.
28 :param tree_name: Name of the TTree in the ROOT file from the ``data_collection_task``
29 that contains the training data for the MVA teacher.
30 :param weightfile_identifier: Name of the weightfile that is created.
31 Should either end in ".xml" for local weightfiles or in ".root", when
32 the weightfile needs later to be uploaded as a payload to the conditions
33 database.
34 :param target_variable: Feature/variable to use as truth label in the quality estimator MVA classifier.
35 :param exclude_variables: List of collected variables to not use in the training of the QE MVA classifier.
36 In addition to variables containing the "truth" substring, which are excluded by default.
37 :param fast_bdt_option: specified fast BDT options, default: [200, 8, 3, 0.1] [nTrees, nCuts, nLevels, shrinkage]
38 """
39 if exclude_variables is None:
40 exclude_variables = []
41
42 weightfile_extension = Path(weightfile_identifier).suffix
43 if weightfile_extension not in {".xml", ".root"}:
44 raise ValueError(f"Weightfile Identifier should end in .xml or .root, but ends in {weightfile_extension}")
45
46 # extract names of all variables from one record file
47 with root_utils.root_open(records_files[0]) as records_tfile:
48 input_tree = records_tfile.Get(tree_name)
49 feature_names = [leave.GetName() for leave in input_tree.GetListOfLeaves()]
50
51 # get list of variables to use for training without MC truth
52 truth_free_variable_names = [
53 name
54 for name in feature_names
55 if (
56 ("truth" not in name) and
57 (name != target_variable) and
58 (name not in exclude_variables)
59 )
60 ]
61 if "weight" in truth_free_variable_names:
62 truth_free_variable_names.remove("weight")
63 weight_variable = "weight"
64 elif "__weight__" in truth_free_variable_names:
65 truth_free_variable_names.remove("__weight__")
66 weight_variable = "__weight__"
67 else:
68 weight_variable = ""
69
70 # Set options for MVA training
71 general_options = basf2_mva.GeneralOptions()
72 general_options.m_datafiles = basf2_mva.vector(*records_files)
73 general_options.m_treename = tree_name
74 general_options.m_weight_variable = weight_variable
75 general_options.m_identifier = weightfile_identifier
76 general_options.m_variables = basf2_mva.vector(*truth_free_variable_names)
77 general_options.m_target_variable = target_variable
78 fastbdt_options = basf2_mva.FastBDTOptions()
79
80 fastbdt_options.m_nTrees = fast_bdt_option[0]
81 fastbdt_options.m_nCuts = fast_bdt_option[1]
82 fastbdt_options.m_nLevels = fast_bdt_option[2]
83 fastbdt_options.m_shrinkage = fast_bdt_option[3]
84 # Train a MVA method and store the weightfile (MVAFastBDT.root) locally.
85 basf2_mva.teacher(general_options, fastbdt_options)
86
87
88def create_fbdt_option_string(fast_bdt_option):
89 """
90 Returns a readable string created by the ``fast_bdt_option`` array.
91
92 :param fast_bdt_option: List containing the FastBDT options that should be converted to a human readable string
93 """
94 return f"_nTrees{fast_bdt_option[0]}_nCuts{fast_bdt_option[1]}"\
95 f"_nLevels{fast_bdt_option[2]}_shrin{int(round(100*fast_bdt_option[3], 0))}"