Belle II Software  release-08-01-10
grid_search.py
1 #!/usr/bin/env python3
2 
3 
10 
11 import basf2
12 import basf2_mva
13 import basf2_mva_util
14 import multiprocessing
15 import itertools
16 
17 if __name__ == "__main__":
18 
19  train_file = basf2.find_file("mva/train_D0toKpipi.root", "examples")
20  test_file = basf2.find_file("mva/test_D0toKpipi.root", "examples")
21 
22  training_data = basf2_mva.vector(train_file)
23  test_data = basf2_mva.vector(test_file)
24 
25  # Train model with default parameters
26  general_options = basf2_mva.GeneralOptions()
27  general_options.m_datafiles = training_data
28  general_options.m_treename = "tree"
29  general_options.m_identifier = "test.xml"
30  general_options.m_variables = basf2_mva.vector('p', 'pz', 'daughter(0, kaonID)', 'chiProb', 'M')
31  general_options.m_target_variable = "isSignal"
32 
33  fastbdt_options = basf2_mva.FastBDTOptions()
34  basf2_mva.teacher(general_options, fastbdt_options)
35 
36  # Load the model and train it again searching for the best hyperparameters
37  def grid_search(hyperparameters):
38  nTrees, depth = hyperparameters
39  method = basf2_mva_util.Method(general_options.m_identifier)
40  options = basf2_mva.FastBDTOptions()
41  options.m_nTrees = nTrees
42  options.m_nLevels = depth
43  m = method.train_teacher(training_data, general_options.m_treename, specific_options=options)
44  p, t = m.apply_expert(test_data, general_options.m_treename)
46 
47  p = multiprocessing.Pool(None, maxtasksperchild=1)
48  results = p.map(grid_search, itertools.product([10, 50, 100, 500, 1000], [2, 4, 6]))
49  for hyperparameters, auc in results:
50  print("Hyperparameters", hyperparameters, "AUC", auc)
def calculate_auc_efficiency_vs_background_retention(p, t, w=None)