Belle II Software  release-08-01-10
__init__.py
1 #!/usr/bin/env python3
2 
3 
10 
11 from pathlib import Path
12 import basf2_mva
13 import tracking.root_utils as root_utils
14 
15 
16 def my_basf2_mva_teacher(
17  records_files,
18  tree_name,
19  weightfile_identifier,
20  target_variable="truth",
21  exclude_variables=None,
22  fast_bdt_option=[200, 8, 3, 0.1]
23 ):
24  """
25  Custom wrapper for basf2 mva teacher. Adapted from code in ``trackfindingcdc_teacher``.
26 
27  :param records_files: List of files with collected ("recorded") variables to use as training data for the MVA.
28  :param tree_name: Name of the TTree in the ROOT file from the ``data_collection_task``
29  that contains the training data for the MVA teacher.
30  :param weightfile_identifier: Name of the weightfile that is created.
31  Should either end in ".xml" for local weightfiles or in ".root", when
32  the weightfile needs later to be uploaded as a payload to the conditions
33  database.
34  :param target_variable: Feature/variable to use as truth label in the quality estimator MVA classifier.
35  :param exclude_variables: List of collected variables to not use in the training of the QE MVA classifier.
36  In addition to variables containing the "truth" substring, which are excluded by default.
37  :param fast_bdt_option: specified fast BDT options, default: [200, 8, 3, 0.1] [nTrees, nCuts, nLevels, shrinkage]
38  """
39  if exclude_variables is None:
40  exclude_variables = []
41 
42  weightfile_extension = Path(weightfile_identifier).suffix
43  if weightfile_extension not in {".xml", ".root"}:
44  raise ValueError(f"Weightfile Identifier should end in .xml or .root, but ends in {weightfile_extension}")
45 
46  # extract names of all variables from one record file
47  with root_utils.root_open(records_files[0]) as records_tfile:
48  input_tree = records_tfile.Get(tree_name)
49  feature_names = [leave.GetName() for leave in input_tree.GetListOfLeaves()]
50 
51  # get list of variables to use for training without MC truth
52  truth_free_variable_names = [
53  name
54  for name in feature_names
55  if (
56  ("truth" not in name) and
57  (name != target_variable) and
58  (name not in exclude_variables)
59  )
60  ]
61  if "weight" in truth_free_variable_names:
62  truth_free_variable_names.remove("weight")
63  weight_variable = "weight"
64  elif "__weight__" in truth_free_variable_names:
65  truth_free_variable_names.remove("__weight__")
66  weight_variable = "__weight__"
67  else:
68  weight_variable = ""
69 
70  # Set options for MVA training
71  general_options = basf2_mva.GeneralOptions()
72  general_options.m_datafiles = basf2_mva.vector(*records_files)
73  general_options.m_treename = tree_name
74  general_options.m_weight_variable = weight_variable
75  general_options.m_identifier = weightfile_identifier
76  general_options.m_variables = basf2_mva.vector(*truth_free_variable_names)
77  general_options.m_target_variable = target_variable
78  fastbdt_options = basf2_mva.FastBDTOptions()
79 
80  fastbdt_options.m_nTrees = fast_bdt_option[0]
81  fastbdt_options.m_nCuts = fast_bdt_option[1]
82  fastbdt_options.m_nLevels = fast_bdt_option[2]
83  fastbdt_options.m_shrinkage = fast_bdt_option[3]
84  # Train a MVA method and store the weightfile (MVAFastBDT.root) locally.
85  basf2_mva.teacher(general_options, fastbdt_options)
86 
87 
88 def create_fbdt_option_string(fast_bdt_option):
89  """
90  Returns a readable string created by the ``fast_bdt_option`` array.
91 
92  :param fast_bdt_option: List containing the FastBDT options that should be converted to a human readable string
93  """
94  return f"_nTrees{fast_bdt_option[0]}_nCuts{fast_bdt_option[1]}"\
95  f"_nLevels{fast_bdt_option[2]}_shrin{int(round(100*fast_bdt_option[3], 0))}"