Belle II Software development
training_flipping_mvas.py
1#!/usr/bin/env python3
2
3
10
11# Use training in a basf2 path
12# Run basics/*.py before
13
14import basf2_mva
15import basf2_mva_util
16import numpy as np
17from basf2 import conditions, B2FATAL
18import argparse
19
20
21def get_argument_parser() -> argparse.ArgumentParser:
22 """ Parses the command line options of the flipping mva training and returns the corresponding arguments. """
23 parser = argparse.ArgumentParser()
24 parser.add_argument('-train', default='', type=str,
25 help='Data file containing ROOT TTree used during training. Default: \'\'.')
26 parser.add_argument('-data', default='', type=str,
27 help='Data file containing ROOT TTree with independent test data. Default: \'\'.')
28 parser.add_argument('-tree', default='', type=str,
29 help='Treename in data file. Default: \'\'.')
30 parser.add_argument('-mva', default=1, type=int,
31 help='index of mva to be trained. Default: 1')
32
33 return parser
34
35
36def get_variables(index=1):
37 var = []
38 if index == 1:
39 var = ["seed_pz_estimate",
40 "seed_pz_variance",
41 "seed_z_estimate",
42 "seed_tan_lambda_estimate",
43 "seed_pt_estimate",
44 "seed_x_estimate",
45 "seed_y_estimate",
46 "seed_py_variance",
47 "seed_d0_estimate",
48 "seed_omega_variance",
49 "svd_layer6_clsTime",
50 "seed_tan_lambda_variance",
51 "seed_z_variance",
52 "n_svd_hits",
53 "n_cdc_hits",
54 "svd_layer3_positionSigma",
55 "first_cdc_layer",
56 "last_cdc_layer",
57 "InOutArmTimeDifference",
58 "InOutArmTimeDifferenceError",
59 "inGoingArmTime",
60 "inGoingArmTimeError",
61 "outGoingArmTime",
62 "outGoingArmTimeError"]
63
64 if index == 2:
65 # training variables
66 var = [
67 "flipped_pz_estimate",
68 "tan_lambda_estimate",
69 "d0_variance",
70 "z_estimate",
71 "px_variance",
72 "p_value",
73 "pt_estimate",
74 "y_estimate",
75 "d0_estimate",
76 "x_estimate",
77 "pz_variance",
78 "omega_estimate",
79 "px_estimate",
80 "flipped_z_estimate",
81 "py_estimate",
82 "outGoingArmTime",
83 "quality_flip_indicator",
84 "inGoingArmTime"]
85
86 return var
87
88
89if __name__ == "__main__":
90
91 parser = get_argument_parser()
92 args = parser.parse_args()
93
94 if args.train == '' or args.data == '':
95 B2FATAL("Missing train or test samples. Terminating here.")
96
97 if args.tree == '':
98 B2FATAL("Missing Treename. Terminating here.")
99
100 if args.mva not in [1, 2]:
101 B2FATAL("MVA number must be either 1 or 2. Terminating here.")
102
103 conditions.testing_payloads = ['localdb/database.txt']
104
105 print(args.train)
106 training_data = basf2_mva.vector(args.train)
107 test_data = basf2_mva.vector(args.data)
108
109 # get the variables
110 variables = get_variables(args.mva)
111
112 general_options = basf2_mva.GeneralOptions()
113 general_options.m_datafiles = training_data
114 '''
115 # the official name of the weight file is `TRKTrackFlipAndRefit_MVA1_weightfile`
116 # But the evaluation scripts taking `Weightfile` as a default name, so that's why this line was commented
117 general_options.m_identifier = "TRKTrackFlipAndRefit_MVA1_weightfile"
118 or "TRKTrackFlipAndRefit_MVA2_weightfile" for second MVA
119 '''
120 general_options.m_identifier = f"TRKTrackFlipAndRefit_MVA{args.mva}_weightfile"
121 general_options.m_treename = args.tree
122 general_options.m_variables = basf2_mva.vector(*variables)
123 general_options.m_target_variable = "ismatched_WC"
124 general_options.m_max_events = 0
125
126 fastbdt_options = basf2_mva.FastBDTOptions()
127 if args.mva == 1:
128 # configurations for MVA1
129 fastbdt_options.m_nTrees = 150
130 fastbdt_options.m_nCuts = 18
131 fastbdt_options.m_nLevels = 4
132 fastbdt_options.m_shrinkage = 0.2
133 fastbdt_options.m_randRatio = 0.5
134 fastbdt_options.m_purityTransformation = False
135 fastbdt_options.m_sPlot = False
136
137 if args.mva == 2:
138 # configurations for MVA2
139 fastbdt_options.m_nTrees = 400
140 fastbdt_options.m_nCuts = 25
141 fastbdt_options.m_nLevels = 2
142 fastbdt_options.m_shrinkage = 0.6
143 fastbdt_options.m_randRatio = 0.5
144 fastbdt_options.m_purityTransformation = False
145 fastbdt_options.m_sPlot = False
146 basf2_mva.teacher(general_options, fastbdt_options)
147
148 m = basf2_mva_util.Method(general_options.m_identifier)
149 p, t = m.apply_expert(test_data, general_options.m_treename)
151
152 print(res)
153 print("Variable importances returned my method")
154
155 imp = np.array([m.importances.get(v, 0.0) for v in m.variables])
156 width = (np.max(imp) - np.min(imp))
157
158 for var in m.variables:
159 print(var, (m.importances.get(var, 0.0) - np.min(imp))/width * 100)
def calculate_auc_efficiency_vs_background_retention(p, t, w=None)