Belle II Software  light-2212-foldex
variable_importance.py
1 #!/usr/bin/env python3
2 
3 
10 
11 # We want to find out which variables/features are the most important
12 # There are three approaches
13 # 1. You can use the variable importance estimate outputted by the method itself,
14 # e.g. FastBDT and TMVA BDTs support calculating the variable importance using the information gain of each applied cut.
15 # 2. One can estimate the importance in a method-agnostic way, by training N times (where N is the number of variables),
16 # each time another variable is removed from the training. The loss in ROC AUC score is used to estimate the importance.
17 # This will underestimate the importance of variables whose information is highly correlated to other variables in the training.
18 # 3. One can estimate the importance in a method-agnostic way, by training N*N / 2 times (where N is the number of variables),
19 # the first step is approach 2, afterwards the most-important variable given by approach 2 is removed and approach 2 is run
20 # again on the remaining variables.
21 # This will take the correlations of variables into account, but takes some time
22 #
23 # Approach 2 and 3 can be done in parallel (by using the multiprocessing module of python, see below)
24 
25 import basf2_mva
26 import basf2_mva_util
27 import multiprocessing
28 import copy
29 
30 
31 if __name__ == "__main__":
32 
33  training_data = basf2_mva.vector("train.root")
34  test_data = basf2_mva.vector("test.root")
35 
36  variables = ['M', 'p', 'pt', 'pz',
37  'daughter(0, p)', 'daughter(0, pz)', 'daughter(0, pt)',
38  'daughter(1, p)', 'daughter(1, pz)', 'daughter(1, pt)',
39  'daughter(2, p)', 'daughter(2, pz)', 'daughter(2, pt)',
40  'chiProb', 'dr', 'dz',
41  'daughter(0, dr)', 'daughter(1, dr)',
42  'daughter(0, dz)', 'daughter(1, dz)',
43  'daughter(0, chiProb)', 'daughter(1, chiProb)', 'daughter(2, chiProb)',
44  'daughter(0, kaonID)', 'daughter(0, pionID)',
45  'daughterInvM(0, 1)', 'daughterInvM(0, 2)', 'daughterInvM(1, 2)']
46 
47  # Train model with default parameters
48  general_options = basf2_mva.GeneralOptions()
49  general_options.m_datafiles = training_data
50  general_options.m_treename = "tree"
51  general_options.m_identifier = "test.xml"
52  general_options.m_variables = basf2_mva.vector(*variables)
53  general_options.m_target_variable = "isSignal"
54 
55  fastbdt_options = basf2_mva.FastBDTOptions()
56  basf2_mva.teacher(general_options, fastbdt_options)
57 
58  def roc_for_variable_set(variables):
59  method = basf2_mva_util.Method(general_options.m_identifier)
60  options = copy.copy(general_options)
61  options.m_variables = basf2_mva.vector(*variables)
62  m = method.train_teacher(training_data, general_options.m_treename, general_options=options)
63  p, t = m.apply_expert(test_data, general_options.m_treename)
65 
66  method = basf2_mva_util.Method(general_options.m_identifier)
67  p, t = method.apply_expert(test_data, general_options.m_treename)
69 
70  # Approach 1: Read out the importance calculated by the method itself
71  print("Variable importances returned my method")
72  for variable in method.variables:
73  print(variable, method.importances.get(variable, 0.0))
74 
75  # Approach 2: Calculate the importance using the loss in AUC if a variable is removed
76  p = multiprocessing.Pool(None, maxtasksperchild=1)
77  results = p.map(roc_for_variable_set, [[v for v in method.variables if v != variable] for variable in method.variables])
78  sorted_variables_with_results = list(sorted(zip(method.variables, results), key=lambda x: x[1]))
79  print("Variable importances calculated using loss if variable is removed")
80  for variable, auc in sorted_variables_with_results:
81  print(variable, global_auc - auc)
82 
83  # Approach 3: Calculate the importance using the loss in AUC if a variable is removed recursively.
84  removed_variables_with_results = sorted_variables_with_results[:1]
85  remaining_variables = [v for v, r in sorted_variables_with_results[1:]]
86  while len(remaining_variables) > 1:
87  results = p.map(roc_for_variable_set,
88  [[v for v in remaining_variables if v != variable] for variable in remaining_variables])
89  sorted_variables_with_results = list(sorted(zip(remaining_variables, results), key=lambda x: x[1]))
90  removed_variables_with_results += sorted_variables_with_results[:1]
91  remaining_variables = [v for v, r in sorted_variables_with_results[1:]]
92  removed_variables_with_results += sorted_variables_with_results[1:]
93 
94  print("Variable importances calculated using loss if variables are recursively removed")
95  last_auc = global_auc
96  for variable, auc in removed_variables_with_results:
97  print(variable, last_auc - auc)
98  last_auc = auc
def calculate_auc_efficiency_vs_background_retention(p, t, w=None)