Belle II Software  release-05-01-25
variable_importance.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 # Thomas Keck 2017
5 
6 # We want to find out which variables/features are the most important
7 # There are three approaches
8 # 1. You can use the variable importance estimate outputted by the method itself,
9 # e.g. FastBDT and TMVA BDTs support calculating the variable importance using the information gain of each applied cut.
10 # 2. One can estimate the importance in a method-agnostic way, by training N times (where N is the number of variables),
11 # each time another variable is removed from the training. The loss in ROC AUC score is used to estimate the importance.
12 # This will underestimate the importance of variables whose information is highly correlated to other variables in the training.
13 # 3. One can estimate the importance in a method-agnostic way, by training N*N / 2 times (where N is the number of variables),
14 # the first step is approach 2, afterwards the most-important variable given by approach 2 is removed and approach 2 is run
15 # again on the remaining variables.
16 # This will take the correlations of variables into account, but takes some time
17 #
18 # Approach 2 and 3 can be done in parallel (by using the multiprocessing module of python, see below)
19 
20 import basf2_mva
21 import basf2_mva_util
22 import multiprocessing
23 import copy
24 
25 
26 if __name__ == "__main__":
27  training_data = basf2_mva.vector("train.root")
28  test_data = basf2_mva.vector("test.root")
29 
30  variables = ['M', 'p', 'pt', 'pz',
31  'daughter(0, p)', 'daughter(0, pz)', 'daughter(0, pt)',
32  'daughter(1, p)', 'daughter(1, pz)', 'daughter(1, pt)',
33  'daughter(2, p)', 'daughter(2, pz)', 'daughter(2, pt)',
34  'chiProb', 'dr', 'dz',
35  'daughter(0, dr)', 'daughter(1, dr)',
36  'daughter(0, dz)', 'daughter(1, dz)',
37  'daughter(0, chiProb)', 'daughter(1, chiProb)', 'daughter(2, chiProb)',
38  'daughter(0, kaonID)', 'daughter(0, pionID)',
39  'daughterInvariantMass(0, 1)', 'daughterInvariantMass(0, 2)', 'daughterInvariantMass(1, 2)']
40 
41  # Train model with default parameters
42  general_options = basf2_mva.GeneralOptions()
43  general_options.m_datafiles = training_data
44  general_options.m_treename = "tree"
45  general_options.m_identifier = "test.xml"
46  general_options.m_variables = basf2_mva.vector(*variables)
47  general_options.m_target_variable = "isSignal"
48 
49  fastbdt_options = basf2_mva.FastBDTOptions()
50  basf2_mva.teacher(general_options, fastbdt_options)
51 
52  def roc_for_variable_set(variables):
53  method = basf2_mva_util.Method(general_options.m_identifier)
54  options = copy.copy(general_options)
55  options.m_variables = basf2_mva.vector(*variables)
56  m = method.train_teacher(training_data, general_options.m_treename, general_options=options)
57  p, t = m.apply_expert(test_data, general_options.m_treename)
59 
60  method = basf2_mva_util.Method(general_options.m_identifier)
61  p, t = method.apply_expert(test_data, general_options.m_treename)
62  global_auc = basf2_mva_util.calculate_roc_auc(p, t)
63 
64  # Approach 1: Read out the importance calculted by the method itself
65  print("Variable importances returned my method")
66  for variable in method.variables:
67  print(variable, method.importances.get(variable, 0.0))
68 
69  # Approach 2: Calculate the importance using the loss in AUC if a variable is removed
70  p = multiprocessing.Pool(None, maxtasksperchild=1)
71  results = p.map(roc_for_variable_set, [[v for v in method.variables if v != variable] for variable in method.variables])
72  sorted_variables_with_results = list(sorted(zip(method.variables, results), key=lambda x: x[1]))
73  print("Variable importances calculated using loss if variable is removed")
74  for variable, auc in sorted_variables_with_results:
75  print(variable, global_auc - auc)
76 
77  # Approach 3: Calculate the importance using the loss in AUC if a variable is removed recursively.
78  removed_variables_with_results = sorted_variables_with_results[:1]
79  remaining_variables = [v for v, r in sorted_variables_with_results[1:]]
80  while len(remaining_variables) > 1:
81  results = p.map(roc_for_variable_set,
82  [[v for v in remaining_variables if v != variable] for variable in remaining_variables])
83  sorted_variables_with_results = list(sorted(zip(remaining_variables, results), key=lambda x: x[1]))
84  removed_variables_with_results += sorted_variables_with_results[:1]
85  remaining_variables = [v for v, r in sorted_variables_with_results[1:]]
86  removed_variables_with_results += sorted_variables_with_results[1:]
87 
88  print("Variable importances calculated using loss if variables are recursively removed")
89  last_auc = global_auc
90  for variable, auc in removed_variables_with_results:
91  print(variable, last_auc - auc)
92  last_auc = auc
basf2_mva_util.calculate_roc_auc
def calculate_roc_auc(p, t)
Definition: basf2_mva_util.py:39
basf2_mva_util.Method
Definition: basf2_mva_util.py:81