Belle II Software light-2406-ragdoll
variable_importance.py
1#!/usr/bin/env python3
2
3
10
11# We want to find out which variables/features are the most important
12# There are three approaches
13# 1. You can use the variable importance estimate outputted by the method itself,
14# e.g. FastBDT and TMVA BDTs support calculating the variable importance using the information gain of each applied cut.
15# 2. One can estimate the importance in a method-agnostic way, by training N times (where N is the number of variables),
16# each time another variable is removed from the training. The loss in ROC AUC score is used to estimate the importance.
17# This will underestimate the importance of variables whose information is highly correlated to other variables in the training.
18# 3. One can estimate the importance in a method-agnostic way, by training N*N / 2 times (where N is the number of variables),
19# the first step is approach 2, afterwards the most-important variable given by approach 2 is removed and approach 2 is run
20# again on the remaining variables.
21# This will take the correlations of variables into account, but takes some time
22#
23# Approach 2 and 3 can be done in parallel (by using the multiprocessing module of python, see below)
24
25from basf2 import find_file
26import basf2_mva
27import basf2_mva_util
28import multiprocessing
29import copy
30
31
32if __name__ == "__main__":
33
34 train_file = find_file("mva/train_D0toKpipi.root", "examples")
35 test_file = find_file("mva/test_D0toKpipi.root", "examples")
36
37 training_data = basf2_mva.vector(train_file)
38 testing_data = basf2_mva.vector(test_file)
39
40 variables = ['M', 'p', 'pt', 'pz',
41 'daughter(0, p)', 'daughter(0, pz)', 'daughter(0, pt)',
42 'daughter(1, p)', 'daughter(1, pz)', 'daughter(1, pt)',
43 'daughter(2, p)', 'daughter(2, pz)', 'daughter(2, pt)',
44 'chiProb', 'dr', 'dz',
45 'daughter(0, dr)', 'daughter(1, dr)',
46 'daughter(0, dz)', 'daughter(1, dz)',
47 'daughter(0, chiProb)', 'daughter(1, chiProb)', 'daughter(2, chiProb)',
48 'daughter(0, kaonID)', 'daughter(0, pionID)',
49 'daughterInvM(0, 1)', 'daughterInvM(0, 2)', 'daughterInvM(1, 2)']
50
51 # Train model with default parameters
52 general_options = basf2_mva.GeneralOptions()
53 general_options.m_datafiles = training_data
54 general_options.m_treename = "tree"
55 general_options.m_identifier = "test.xml"
56 general_options.m_variables = basf2_mva.vector(*variables)
57 general_options.m_target_variable = "isSignal"
58
59 fastbdt_options = basf2_mva.FastBDTOptions()
60 basf2_mva.teacher(general_options, fastbdt_options)
61
62 def roc_for_variable_set(variables):
63 method = basf2_mva_util.Method(general_options.m_identifier)
64 options = copy.copy(general_options)
65 options.m_variables = basf2_mva.vector(*variables)
66 m = method.train_teacher(training_data, general_options.m_treename, general_options=options)
67 p, t = m.apply_expert(testing_data, general_options.m_treename)
69
70 method = basf2_mva_util.Method(general_options.m_identifier)
71 p, t = method.apply_expert(testing_data, general_options.m_treename)
73
74 # Approach 1: Read out the importance calculated by the method itself
75 print("Variable importances returned my method")
76 for variable in method.variables:
77 print(variable, method.importances.get(variable, 0.0))
78
79 # Approach 2: Calculate the importance using the loss in AUC if a variable is removed
80 p = multiprocessing.Pool(None, maxtasksperchild=1)
81 results = p.map(roc_for_variable_set, [[v for v in method.variables if v != variable] for variable in method.variables])
82 sorted_variables_with_results = list(sorted(zip(method.variables, results), key=lambda x: x[1]))
83 print("Variable importances calculated using loss if variable is removed")
84 for variable, auc in sorted_variables_with_results:
85 print(variable, global_auc - auc)
86
87 # Approach 3: Calculate the importance using the loss in AUC if a variable is removed recursively.
88 removed_variables_with_results = sorted_variables_with_results[:1]
89 remaining_variables = [v for v, r in sorted_variables_with_results[1:]]
90 while len(remaining_variables) > 1:
91 results = p.map(roc_for_variable_set,
92 [[v for v in remaining_variables if v != variable] for variable in remaining_variables])
93 sorted_variables_with_results = list(sorted(zip(remaining_variables, results), key=lambda x: x[1]))
94 removed_variables_with_results += sorted_variables_with_results[:1]
95 remaining_variables = [v for v, r in sorted_variables_with_results[1:]]
96 removed_variables_with_results += sorted_variables_with_results[1:]
97
98 print("Variable importances calculated using loss if variables are recursively removed")
99 last_auc = global_auc
100 for variable, auc in removed_variables_with_results:
101 print(variable, last_auc - auc)
102 last_auc = auc
def calculate_auc_efficiency_vs_background_retention(p, t, w=None)