Belle II Software  release-05-01-25
grid_search.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 # Thomas Keck 2017
5 
6 import basf2_mva
7 import basf2_mva_util
8 import multiprocessing
9 import itertools
10 
11 if __name__ == "__main__":
12  training_data = basf2_mva.vector("train.root")
13  test_data = basf2_mva.vector("test.root")
14 
15  # Train model with default parameters
16  general_options = basf2_mva.GeneralOptions()
17  general_options.m_datafiles = training_data
18  general_options.m_treename = "tree"
19  general_options.m_identifier = "test.xml"
20  general_options.m_variables = basf2_mva.vector('p', 'pz', 'daughter(0, kaonID)', 'chiProb', 'M')
21  general_options.m_target_variable = "isSignal"
22 
23  fastbdt_options = basf2_mva.FastBDTOptions()
24  basf2_mva.teacher(general_options, fastbdt_options)
25 
26  # Load the model and train it again searching for the best hyperparameters
27  def grid_search(hyperparameters):
28  nTrees, depth = hyperparameters
29  method = basf2_mva_util.Method(general_options.m_identifier)
30  options = basf2_mva.FastBDTOptions()
31  options.m_nTrees = nTrees
32  options.m_nLevels = depth
33  m = method.train_teacher(training_data, general_options.m_treename, specific_options=options)
34  p, t = m.apply_expert(test_data, general_options.m_treename)
35  return hyperparameters, basf2_mva_util.calculate_roc_auc(p, t)
36 
37  p = multiprocessing.Pool(None, maxtasksperchild=1)
38  results = p.map(grid_search, itertools.product([10, 50, 100, 500, 1000], [2, 4, 6]))
39  for hyperparameters, auc in results:
40  print("Hyperparameters", hyperparameters, "AUC", auc)
grid_search
Definition: grid_search.py:1
basf2_mva_util.calculate_roc_auc
def calculate_roc_auc(p, t)
Definition: basf2_mva_util.py:39
basf2_mva_util.Method
Definition: basf2_mva_util.py:81