Belle II Software  release-06-01-15
variable_importance.py
1 #!/usr/bin/env python3
2 
3 
10 
11 # We want to find out which variables/features are the most important
12 # There are three approaches
13 # 1. You can use the variable importance estimate outputted by the method itself,
14 # e.g. FastBDT and TMVA BDTs support calculating the variable importance using the information gain of each applied cut.
15 # 2. One can estimate the importance in a method-agnostic way, by training N times (where N is the number of variables),
16 # each time another variable is removed from the training. The loss in ROC AUC score is used to estimate the importance.
17 # This will underestimate the importance of variables whose information is highly correlated to other variables in the training.
18 # 3. One can estimate the importance in a method-agnostic way, by training N*N / 2 times (where N is the number of variables),
19 # the first step is approach 2, afterwards the most-important variable given by approach 2 is removed and approach 2 is run
20 # again on the remaining variables.
21 # This will take the correlations of variables into account, but takes some time
22 #
23 # Approach 2 and 3 can be done in parallel (by using the multiprocessing module of python, see below)
24 
25 import basf2_mva
26 import basf2_mva_util
27 import multiprocessing
28 import copy
29 
30 
31 if __name__ == "__main__":
32  training_data = basf2_mva.vector("train.root")
33  test_data = basf2_mva.vector("test.root")
34 
35  variables = ['M', 'p', 'pt', 'pz',
36  'daughter(0, p)', 'daughter(0, pz)', 'daughter(0, pt)',
37  'daughter(1, p)', 'daughter(1, pz)', 'daughter(1, pt)',
38  'daughter(2, p)', 'daughter(2, pz)', 'daughter(2, pt)',
39  'chiProb', 'dr', 'dz',
40  'daughter(0, dr)', 'daughter(1, dr)',
41  'daughter(0, dz)', 'daughter(1, dz)',
42  'daughter(0, chiProb)', 'daughter(1, chiProb)', 'daughter(2, chiProb)',
43  'daughter(0, kaonID)', 'daughter(0, pionID)',
44  'daughterInvariantMass(0, 1)', 'daughterInvariantMass(0, 2)', 'daughterInvariantMass(1, 2)']
45 
46  # Train model with default parameters
47  general_options = basf2_mva.GeneralOptions()
48  general_options.m_datafiles = training_data
49  general_options.m_treename = "tree"
50  general_options.m_identifier = "test.xml"
51  general_options.m_variables = basf2_mva.vector(*variables)
52  general_options.m_target_variable = "isSignal"
53 
54  fastbdt_options = basf2_mva.FastBDTOptions()
55  basf2_mva.teacher(general_options, fastbdt_options)
56 
57  def roc_for_variable_set(variables):
58  method = basf2_mva_util.Method(general_options.m_identifier)
59  options = copy.copy(general_options)
60  options.m_variables = basf2_mva.vector(*variables)
61  m = method.train_teacher(training_data, general_options.m_treename, general_options=options)
62  p, t = m.apply_expert(test_data, general_options.m_treename)
64 
65  method = basf2_mva_util.Method(general_options.m_identifier)
66  p, t = method.apply_expert(test_data, general_options.m_treename)
67  global_auc = basf2_mva_util.calculate_roc_auc(p, t)
68 
69  # Approach 1: Read out the importance calculated by the method itself
70  print("Variable importances returned my method")
71  for variable in method.variables:
72  print(variable, method.importances.get(variable, 0.0))
73 
74  # Approach 2: Calculate the importance using the loss in AUC if a variable is removed
75  p = multiprocessing.Pool(None, maxtasksperchild=1)
76  results = p.map(roc_for_variable_set, [[v for v in method.variables if v != variable] for variable in method.variables])
77  sorted_variables_with_results = list(sorted(zip(method.variables, results), key=lambda x: x[1]))
78  print("Variable importances calculated using loss if variable is removed")
79  for variable, auc in sorted_variables_with_results:
80  print(variable, global_auc - auc)
81 
82  # Approach 3: Calculate the importance using the loss in AUC if a variable is removed recursively.
83  removed_variables_with_results = sorted_variables_with_results[:1]
84  remaining_variables = [v for v, r in sorted_variables_with_results[1:]]
85  while len(remaining_variables) > 1:
86  results = p.map(roc_for_variable_set,
87  [[v for v in remaining_variables if v != variable] for variable in remaining_variables])
88  sorted_variables_with_results = list(sorted(zip(remaining_variables, results), key=lambda x: x[1]))
89  removed_variables_with_results += sorted_variables_with_results[:1]
90  remaining_variables = [v for v, r in sorted_variables_with_results[1:]]
91  removed_variables_with_results += sorted_variables_with_results[1:]
92 
93  print("Variable importances calculated using loss if variables are recursively removed")
94  last_auc = global_auc
95  for variable, auc in removed_variables_with_results:
96  print(variable, last_auc - auc)
97  last_auc = auc
def calculate_roc_auc(p, t)