Belle II Software  release-08-01-10
variable_importance.py
1 #!/usr/bin/env python3
2 
3 
10 
11 # We want to find out which variables/features are the most important
12 # There are three approaches
13 # 1. You can use the variable importance estimate outputted by the method itself,
14 # e.g. FastBDT and TMVA BDTs support calculating the variable importance using the information gain of each applied cut.
15 # 2. One can estimate the importance in a method-agnostic way, by training N times (where N is the number of variables),
16 # each time another variable is removed from the training. The loss in ROC AUC score is used to estimate the importance.
17 # This will underestimate the importance of variables whose information is highly correlated to other variables in the training.
18 # 3. One can estimate the importance in a method-agnostic way, by training N*N / 2 times (where N is the number of variables),
19 # the first step is approach 2, afterwards the most-important variable given by approach 2 is removed and approach 2 is run
20 # again on the remaining variables.
21 # This will take the correlations of variables into account, but takes some time
22 #
23 # Approach 2 and 3 can be done in parallel (by using the multiprocessing module of python, see below)
24 
25 from basf2 import find_file
26 import basf2_mva
27 import basf2_mva_util
28 import multiprocessing
29 import copy
30 
31 
32 if __name__ == "__main__":
33 
34  train_file = find_file("mva/train_D0toKpipi.root", "examples")
35  test_file = find_file("mva/test_D0toKpipi.root", "examples")
36 
37  training_data = basf2_mva.vector(train_file)
38  testing_data = basf2_mva.vector(test_file)
39 
40  variables = ['M', 'p', 'pt', 'pz',
41  'daughter(0, p)', 'daughter(0, pz)', 'daughter(0, pt)',
42  'daughter(1, p)', 'daughter(1, pz)', 'daughter(1, pt)',
43  'daughter(2, p)', 'daughter(2, pz)', 'daughter(2, pt)',
44  'chiProb', 'dr', 'dz',
45  'daughter(0, dr)', 'daughter(1, dr)',
46  'daughter(0, dz)', 'daughter(1, dz)',
47  'daughter(0, chiProb)', 'daughter(1, chiProb)', 'daughter(2, chiProb)',
48  'daughter(0, kaonID)', 'daughter(0, pionID)',
49  'daughterInvM(0, 1)', 'daughterInvM(0, 2)', 'daughterInvM(1, 2)']
50 
51  # Train model with default parameters
52  general_options = basf2_mva.GeneralOptions()
53  general_options.m_datafiles = training_data
54  general_options.m_treename = "tree"
55  general_options.m_identifier = "test.xml"
56  general_options.m_variables = basf2_mva.vector(*variables)
57  general_options.m_target_variable = "isSignal"
58 
59  fastbdt_options = basf2_mva.FastBDTOptions()
60  basf2_mva.teacher(general_options, fastbdt_options)
61 
62  def roc_for_variable_set(variables):
63  method = basf2_mva_util.Method(general_options.m_identifier)
64  options = copy.copy(general_options)
65  options.m_variables = basf2_mva.vector(*variables)
66  m = method.train_teacher(training_data, general_options.m_treename, general_options=options)
67  p, t = m.apply_expert(testing_data, general_options.m_treename)
69 
70  method = basf2_mva_util.Method(general_options.m_identifier)
71  p, t = method.apply_expert(testing_data, general_options.m_treename)
73 
74  # Approach 1: Read out the importance calculated by the method itself
75  print("Variable importances returned my method")
76  for variable in method.variables:
77  print(variable, method.importances.get(variable, 0.0))
78 
79  # Approach 2: Calculate the importance using the loss in AUC if a variable is removed
80  p = multiprocessing.Pool(None, maxtasksperchild=1)
81  results = p.map(roc_for_variable_set, [[v for v in method.variables if v != variable] for variable in method.variables])
82  sorted_variables_with_results = list(sorted(zip(method.variables, results), key=lambda x: x[1]))
83  print("Variable importances calculated using loss if variable is removed")
84  for variable, auc in sorted_variables_with_results:
85  print(variable, global_auc - auc)
86 
87  # Approach 3: Calculate the importance using the loss in AUC if a variable is removed recursively.
88  removed_variables_with_results = sorted_variables_with_results[:1]
89  remaining_variables = [v for v, r in sorted_variables_with_results[1:]]
90  while len(remaining_variables) > 1:
91  results = p.map(roc_for_variable_set,
92  [[v for v in remaining_variables if v != variable] for variable in remaining_variables])
93  sorted_variables_with_results = list(sorted(zip(remaining_variables, results), key=lambda x: x[1]))
94  removed_variables_with_results += sorted_variables_with_results[:1]
95  remaining_variables = [v for v, r in sorted_variables_with_results[1:]]
96  removed_variables_with_results += sorted_variables_with_results[1:]
97 
98  print("Variable importances calculated using loss if variables are recursively removed")
99  last_auc = global_auc
100  for variable, auc in removed_variables_with_results:
101  print(variable, last_auc - auc)
102  last_auc = auc
def calculate_auc_efficiency_vs_background_retention(p, t, w=None)