25 from basf2
import find_file
28 import multiprocessing
32 if __name__ ==
"__main__":
34 train_file = find_file(
"mva/train_D0toKpipi.root",
"examples")
35 test_file = find_file(
"mva/test_D0toKpipi.root",
"examples")
37 training_data = basf2_mva.vector(train_file)
38 testing_data = basf2_mva.vector(test_file)
40 variables = [
'M',
'p',
'pt',
'pz',
41 'daughter(0, p)',
'daughter(0, pz)',
'daughter(0, pt)',
42 'daughter(1, p)',
'daughter(1, pz)',
'daughter(1, pt)',
43 'daughter(2, p)',
'daughter(2, pz)',
'daughter(2, pt)',
44 'chiProb',
'dr',
'dz',
45 'daughter(0, dr)',
'daughter(1, dr)',
46 'daughter(0, dz)',
'daughter(1, dz)',
47 'daughter(0, chiProb)',
'daughter(1, chiProb)',
'daughter(2, chiProb)',
48 'daughter(0, kaonID)',
'daughter(0, pionID)',
49 'daughterInvM(0, 1)',
'daughterInvM(0, 2)',
'daughterInvM(1, 2)']
52 general_options = basf2_mva.GeneralOptions()
53 general_options.m_datafiles = training_data
54 general_options.m_treename =
"tree"
55 general_options.m_identifier =
"test.xml"
56 general_options.m_variables = basf2_mva.vector(*variables)
57 general_options.m_target_variable =
"isSignal"
59 fastbdt_options = basf2_mva.FastBDTOptions()
60 basf2_mva.teacher(general_options, fastbdt_options)
62 def roc_for_variable_set(variables):
64 options = copy.copy(general_options)
65 options.m_variables = basf2_mva.vector(*variables)
66 m = method.train_teacher(training_data, general_options.m_treename, general_options=options)
67 p, t = m.apply_expert(testing_data, general_options.m_treename)
71 p, t = method.apply_expert(testing_data, general_options.m_treename)
75 print(
"Variable importances returned my method")
76 for variable
in method.variables:
77 print(variable, method.importances.get(variable, 0.0))
80 p = multiprocessing.Pool(
None, maxtasksperchild=1)
81 results = p.map(roc_for_variable_set, [[v
for v
in method.variables
if v != variable]
for variable
in method.variables])
82 sorted_variables_with_results = list(sorted(zip(method.variables, results), key=
lambda x: x[1]))
83 print(
"Variable importances calculated using loss if variable is removed")
84 for variable, auc
in sorted_variables_with_results:
85 print(variable, global_auc - auc)
88 removed_variables_with_results = sorted_variables_with_results[:1]
89 remaining_variables = [v
for v, r
in sorted_variables_with_results[1:]]
90 while len(remaining_variables) > 1:
91 results = p.map(roc_for_variable_set,
92 [[v
for v
in remaining_variables
if v != variable]
for variable
in remaining_variables])
93 sorted_variables_with_results = list(sorted(zip(remaining_variables, results), key=
lambda x: x[1]))
94 removed_variables_with_results += sorted_variables_with_results[:1]
95 remaining_variables = [v
for v, r
in sorted_variables_with_results[1:]]
96 removed_variables_with_results += sorted_variables_with_results[1:]
98 print(
"Variable importances calculated using loss if variables are recursively removed")
100 for variable, auc
in removed_variables_with_results:
101 print(variable, last_auc - auc)
def calculate_auc_efficiency_vs_background_retention(p, t, w=None)