17from basf2
import conditions, B2FATAL
21def get_argument_parser() -> argparse.ArgumentParser:
22 """ Parses the command line options of the flipping mva training and returns the corresponding arguments. """
23 parser = argparse.ArgumentParser()
24 parser.add_argument(
'-train', default=
'', type=str,
25 help=
'Data file containing ROOT TTree used during training. Default: \'\'.')
26 parser.add_argument(
'-data', default=
'', type=str,
27 help=
'Data file containing ROOT TTree with independent test data. Default: \'\'.')
28 parser.add_argument(
'-tree', default=
'', type=str,
29 help=
'Treename in data file. Default: \'\'.')
30 parser.add_argument(
'-mva', default=1, type=int,
31 help=
'index of mva to be trained. Default: 1')
36def get_variables(index=1):
39 var = [
"seed_pz_estimate",
42 "seed_tan_lambda_estimate",
48 "seed_omega_variance",
50 "seed_tan_lambda_variance",
54 "svd_layer3_positionSigma",
57 "InOutArmTimeDifference",
58 "InOutArmTimeDifferenceError",
60 "inGoingArmTimeError",
62 "outGoingArmTimeError"]
67 "flipped_pz_estimate",
68 "tan_lambda_estimate",
83 "quality_flip_indicator",
89if __name__ ==
"__main__":
91 parser = get_argument_parser()
92 args = parser.parse_args()
94 if args.train ==
'' or args.data ==
'':
95 B2FATAL(
"Missing train or test samples. Terminating here.")
98 B2FATAL(
"Missing Treename. Terminating here.")
100 if args.mva
not in [1, 2]:
101 B2FATAL(
"MVA number must be either 1 or 2. Terminating here.")
103 conditions.testing_payloads = [
'localdb/database.txt']
106 training_data = basf2_mva.vector(args.train)
107 test_data = basf2_mva.vector(args.data)
110 variables = get_variables(args.mva)
112 general_options = basf2_mva.GeneralOptions()
113 general_options.m_datafiles = training_data
117 general_options.m_identifier =
"TRKTrackFlipAndRefit_MVA1_weightfile"
118 or "TRKTrackFlipAndRefit_MVA2_weightfile" for second MVA
120 general_options.m_identifier = f"TRKTrackFlipAndRefit_MVA{args.mva}_weightfile"
121 general_options.m_treename = args.tree
122 general_options.m_variables = basf2_mva.vector(*variables)
123 general_options.m_target_variable =
"ismatched_WC"
124 general_options.m_max_events = 0
126 fastbdt_options = basf2_mva.FastBDTOptions()
129 fastbdt_options.m_nTrees = 150
130 fastbdt_options.m_nCuts = 18
131 fastbdt_options.m_nLevels = 4
132 fastbdt_options.m_shrinkage = 0.2
133 fastbdt_options.m_randRatio = 0.5
134 fastbdt_options.m_purityTransformation =
False
135 fastbdt_options.m_sPlot =
False
139 fastbdt_options.m_nTrees = 400
140 fastbdt_options.m_nCuts = 25
141 fastbdt_options.m_nLevels = 2
142 fastbdt_options.m_shrinkage = 0.6
143 fastbdt_options.m_randRatio = 0.5
144 fastbdt_options.m_purityTransformation =
False
145 fastbdt_options.m_sPlot =
False
146 basf2_mva.teacher(general_options, fastbdt_options)
149 p, t = m.apply_expert(test_data, general_options.m_treename)
153 print(
"Variable importances returned my method")
155 imp = np.array([m.importances.get(v, 0.0)
for v
in m.variables])
156 width = (np.max(imp) - np.min(imp))
158 for var
in m.variables:
159 print(var, (m.importances.get(var, 0.0) - np.min(imp))/width * 100)
def calculate_auc_efficiency_vs_background_retention(p, t, w=None)