Belle II Software development
grid_search.py
1#!/usr/bin/env python3
2
3
10
11import basf2
12import basf2_mva
13import basf2_mva_util
14import multiprocessing
15import itertools
16
17if __name__ == "__main__":
18
19 train_file = basf2.find_file("mva/train_D0toKpipi.root", "examples")
20 test_file = basf2.find_file("mva/test_D0toKpipi.root", "examples")
21
22 training_data = basf2_mva.vector(train_file)
23 test_data = basf2_mva.vector(test_file)
24
25 # Train model with default parameters
26 general_options = basf2_mva.GeneralOptions()
27 general_options.m_datafiles = training_data
28 general_options.m_treename = "tree"
29 general_options.m_identifier = "test.xml"
30 general_options.m_variables = basf2_mva.vector('p', 'pz', 'daughter(0, kaonID)', 'chiProb', 'M')
31 general_options.m_target_variable = "isSignal"
32
33 fastbdt_options = basf2_mva.FastBDTOptions()
34 basf2_mva.teacher(general_options, fastbdt_options)
35
36 # Load the model and train it again searching for the best hyperparameters
37 def grid_search(hyperparameters):
38 nTrees, depth = hyperparameters
39 method = basf2_mva_util.Method(general_options.m_identifier)
40 options = basf2_mva.FastBDTOptions()
41 options.m_nTrees = nTrees
42 options.m_nLevels = depth
43 m = method.train_teacher(training_data, general_options.m_treename, specific_options=options)
44 p, t = m.apply_expert(test_data, general_options.m_treename)
46
47 p = multiprocessing.Pool(None, maxtasksperchild=1)
48 results = p.map(grid_search, itertools.product([10, 50, 100, 500, 1000], [2, 4, 6]))
49 for hyperparameters, auc in results:
50 print("Hyperparameters", hyperparameters, "AUC", auc)
def calculate_auc_efficiency_vs_background_retention(p, t, w=None)