Belle II Software  release-06-01-15
grid_search.py
1 #!/usr/bin/env python3
2 
3 
10 
11 import basf2_mva
12 import basf2_mva_util
13 import multiprocessing
14 import itertools
15 
16 if __name__ == "__main__":
17  training_data = basf2_mva.vector("train.root")
18  test_data = basf2_mva.vector("test.root")
19 
20  # Train model with default parameters
21  general_options = basf2_mva.GeneralOptions()
22  general_options.m_datafiles = training_data
23  general_options.m_treename = "tree"
24  general_options.m_identifier = "test.xml"
25  general_options.m_variables = basf2_mva.vector('p', 'pz', 'daughter(0, kaonID)', 'chiProb', 'M')
26  general_options.m_target_variable = "isSignal"
27 
28  fastbdt_options = basf2_mva.FastBDTOptions()
29  basf2_mva.teacher(general_options, fastbdt_options)
30 
31  # Load the model and train it again searching for the best hyperparameters
32  def grid_search(hyperparameters):
33  nTrees, depth = hyperparameters
34  method = basf2_mva_util.Method(general_options.m_identifier)
35  options = basf2_mva.FastBDTOptions()
36  options.m_nTrees = nTrees
37  options.m_nLevels = depth
38  m = method.train_teacher(training_data, general_options.m_treename, specific_options=options)
39  p, t = m.apply_expert(test_data, general_options.m_treename)
40  return hyperparameters, basf2_mva_util.calculate_roc_auc(p, t)
41 
42  p = multiprocessing.Pool(None, maxtasksperchild=1)
43  results = p.map(grid_search, itertools.product([10, 50, 100, 500, 1000], [2, 4, 6]))
44  for hyperparameters, auc in results:
45  print("Hyperparameters", hyperparameters, "AUC", auc)
def calculate_roc_auc(p, t)