Belle II Software development
bayesian_optimization.py
1#!/usr/bin/env python3
2
3
10
11# A simple example to use bayesian optimization for the hyperparameters of a FastBDT.
12# The package used in this example is https://github.com/scikit-optimize
13# and can be installed with
14# pip3 install scikit-optimize
15
16from basf2 import find_file
17import basf2_mva
18import basf2_mva_util
19import skopt
20import matplotlib.pyplot as plt
21
22
23def f(x):
24 """Returns the figure of merit for the optimization.
25 The functions trains the classifier with the given hyperparameters on the training sample and
26 calculates the AUC on the independent test sample.
27 """
28 g_options = general_options
29 g_options.m_identifier = "test.xml"
30 options = basf2_mva.FastBDTOptions()
31 options.m_nTrees = int(x[0])
32 options.m_nLevels = int(x[1])
33 basf2_mva.teacher(g_options, options)
34 m = basf2_mva_util.Method(g_options.m_identifier)
35 p, t = m.apply_expert(test_data, general_options.m_treename)
37
38
39if __name__ == "__main__":
40
41 train_file = find_file("mva/train_D0toKpipi.root", "examples")
42 test_file = find_file("mva/test_D0toKpipi.root", "examples")
43
44 training_data = basf2_mva.vector(train_file)
45 test_data = basf2_mva.vector(test_file)
46
47 general_options = basf2_mva.GeneralOptions()
48 general_options.m_datafiles = training_data
49 general_options.m_treename = "tree"
50 general_options.m_variables = basf2_mva.vector('p', 'pz', 'daughter(0, kaonID)', 'chiProb', 'M')
51 general_options.m_target_variable = "isSignal"
52
53 # Start optimization
54 res = skopt.gp_minimize(f, # the function to minimize
55 [(10, 1000), (2, 6)], # the bounds on each dimension of x
56 x0=[10, 2], # initial guess
57 n_calls=20) # number of evaluations of f
58
59 # Give some results
60 print(res)
61 skopt.plots.plot_convergence(res)
62 plt.savefig('convergence.png')
63 skopt.plots.plot_evaluations(res)
64 plt.savefig('evaluations.png')
65 skopt.plots.plot_objective(res)
66 plt.savefig('objective.png')
67
68 # Store result of optimization
69 skopt.dump(res, 'opt-result.pkl')
def calculate_auc_efficiency_vs_background_retention(p, t, w=None)