Belle II Software  release-08-01-10
training_flipping_mvas.py
1 #!/usr/bin/env python3
2 
3 
10 
11 # Use training in a basf2 path
12 # Run basics/*.py before
13 
14 import basf2_mva
15 import basf2_mva_util
16 import numpy as np
17 from basf2 import conditions, B2FATAL
18 import argparse
19 
20 
21 def get_argument_parser() -> argparse.ArgumentParser:
22  """ Parses the command line options of the fliping mva training and returns the corresponding arguments. """
23  parser = argparse.ArgumentParser()
24  parser.add_argument('-train', default='', type=str,
25  help='Data file containing ROOT TTree used during training. Default: \'\'.')
26  parser.add_argument('-data', default='', type=str,
27  help='Data file containing ROOT TTree with independent test data. Default: \'\'.')
28  parser.add_argument('-tree', default='', type=str,
29  help='Treename in data file. Default: \'\'.')
30  parser.add_argument('-mva', default=1, type=int,
31  help='index of mva to be trainned. Default: 1')
32 
33  return parser
34 
35 
36 def get_variables(index=1):
37  var = []
38  if index == 1:
39  var = ["seed_pz_estimate",
40  "seed_pz_variance",
41  "seed_z_estimate",
42  "seed_tan_lambda_estimate",
43  "seed_pt_estimate",
44  "seed_x_estimate",
45  "seed_y_estimate",
46  "seed_py_variance",
47  "seed_d0_estimate",
48  "seed_omega_variance",
49  "svd_layer6_clsTime",
50  "seed_tan_lambda_variance",
51  "seed_z_variance",
52  "n_svd_hits",
53  "n_cdc_hits",
54  "svd_layer3_positionSigma",
55  "first_cdc_layer",
56  "last_cdc_layer",
57  "InOutArmTimeDifference",
58  "InOutArmTimeDifferenceError",
59  "inGoingArmTime",
60  "inGoingArmTimeError",
61  "outGoingArmTime",
62  "outGoingArmTimeError"]
63 
64  if index == 2:
65  # training variables
66  var = [
67  "flipped_pz_estimate",
68  "tan_lambda_estimate",
69  "d0_variance",
70  "z_estimate",
71  "px_variance",
72  "p_value",
73  "pt_estimate",
74  "y_estimate",
75  "d0_estimate",
76  "x_estimate",
77  "pz_variance",
78  "omega_estimate",
79  "px_estimate",
80  "flipped_z_estimate",
81  "py_estimate",
82  "outGoingArmTime",
83  "quality_flip_indicator",
84  "inGoingArmTime"]
85 
86  return var
87 
88 
89 if __name__ == "__main__":
90 
91  parser = get_argument_parser()
92  args = parser.parse_args()
93 
94  if args.train == '' or args.data == '':
95  B2FATAL("Missing train or test samples. Terminating here.")
96 
97  if args.tree == '':
98  B2FATAL("Missing Treename. Terminating here.")
99 
100  if args.mva not in [1, 2]:
101  B2FATAL("MVA number must be either 1 or 2. Terminating here.")
102 
103  conditions.testing_payloads = ['localdb/database.txt']
104 
105  print(args.train)
106  training_data = basf2_mva.vector(args.train)
107  test_data = basf2_mva.vector(args.data)
108 
109  # get the variables
110  variables = get_variables(args.mva)
111 
112  general_options = basf2_mva.GeneralOptions()
113  general_options.m_datafiles = training_data
114  '''
115  # the official name of the weight file is `TRKTrackFlipAndRefit_MVA1_weightfile`
116  # But the evaluation scripts taking `Weightfile` as a default name, so that's why this line was commented
117  general_options.m_identifier = "TRKTrackFlipAndRefit_MVA1_weightfile"
118  or "TRKTrackFlipAndRefit_MVA2_weightfile" for second MVA
119  '''
120  general_options.m_identifier = f"TRKTrackFlipAndRefit_MVA{args.mva}_weightfile"
121  general_options.m_treename = args.tree
122  general_options.m_variables = basf2_mva.vector(*variables)
123  general_options.m_target_variable = "ismatched_WC"
124  general_options.m_max_events = 0
125 
126  fastbdt_options = basf2_mva.FastBDTOptions()
127  if args.mva == 1:
128  # configurations for MVA1
129  fastbdt_options.m_nTrees = 150
130  fastbdt_options.m_nCuts = 18
131  fastbdt_options.m_nLevels = 4
132  fastbdt_options.m_shrinkage = 0.2
133  fastbdt_options.m_randRatio = 0.5
134  fastbdt_options.m_purityTransformation = False
135  fastbdt_options.m_sPlot = False
136 
137  if args.mva == 2:
138  # configurations for MVA2
139  fastbdt_options.m_nTrees = 400
140  fastbdt_options.m_nCuts = 25
141  fastbdt_options.m_nLevels = 2
142  fastbdt_options.m_shrinkage = 0.6
143  fastbdt_options.m_randRatio = 0.5
144  fastbdt_options.m_purityTransformation = False
145  fastbdt_options.m_sPlot = False
146  basf2_mva.teacher(general_options, fastbdt_options)
147 
148  m = basf2_mva_util.Method(general_options.m_identifier)
149  p, t = m.apply_expert(test_data, general_options.m_treename)
151 
152  print(res)
153  print("Variable importances returned my method")
154 
155  imp = np.array([m.importances.get(v, 0.0) for v in m.variables])
156  width = (np.max(imp) - np.min(imp))
157 
158  for var in m.variables:
159  print(var, (m.importances.get(var, 0.0) - np.min(imp))/width * 100)
def calculate_auc_efficiency_vs_background_retention(p, t, w=None)