Belle II Software development
simple_deep.py
1#!/usr/bin/env python3
2
3
10
11# This example shows the implementation of a simple MLP in keras.
12
13import basf2_mva
14import basf2_mva_util
15import time
16
18
19
20from keras.layers import Input, Dense, Dropout, BatchNormalization
21from keras.models import Model
22from keras.optimizers import Adam
23from keras.losses import binary_crossentropy
24from keras.activations import sigmoid, tanh
25from keras.callbacks import Callback
26
27
28old_time = time.time()
29
30
31def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
32 """
33 Build feed forward keras model
34 """
35 input = Input(shape=(number_of_features,))
36
37 net = Dense(units=number_of_features, activation=tanh)(input)
38 for i in range(7):
39 net = Dense(units=number_of_features, activation=tanh)(net)
40 net = BatchNormalization()(net)
41 for i in range(7):
42 net = Dense(units=number_of_features, activation=tanh)(net)
43 net = Dropout(rate=0.4)(net)
44
45 output = Dense(units=1, activation=sigmoid)(net)
46
47 state = State(Model(input, output))
48
49 state.model.compile(optimizer=Adam(learning_rate=0.01), loss=binary_crossentropy, metrics=['accuracy'])
50
51 state.model.summary()
52
53 return state
54
55
56def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
57 """
58 Returns just the state object
59 """
60 state.Xtest = Xtest
61 state.ytest = ytest
62
63 return state
64
65
66def partial_fit(state, X, S, y, w, epoch, batch):
67 """
68 Pass received data to tensorflow
69 """
70 class TestCallback(Callback):
71
72 def on_epoch_end(self, epoch, logs=None):
73 loss, acc = state.model.evaluate(state.Xtest, state.ytest, verbose=0, batch_size=1000)
74 loss2, acc2 = state.model.evaluate(X[:10000], y[:10000], verbose=0, batch_size=1000)
75 print(f'\nTesting loss: {loss}, acc: {acc}')
76 print(f'Training loss: {loss2}, acc: {acc2}')
77
78 state.model.fit(X, y, batch_size=500, epochs=10, callbacks=[TestCallback()])
79 return False
80
81
82if __name__ == "__main__":
83 from basf2 import conditions, find_file
84 # NOTE: do not use testing payloads in production! Any results obtained like this WILL NOT BE PUBLISHED
85 conditions.testing_payloads = [
86 'localdb/database.txt'
87 ]
88 train_file = find_file("mva/train_D0toKpipi.root", "examples")
89 test_file = find_file("mva/test_D0toKpipi.root", "examples")
90
91 training_data = basf2_mva.vector(train_file)
92 testing_data = basf2_mva.vector(test_file)
93
94 general_options = basf2_mva.GeneralOptions()
95 general_options.m_datafiles = training_data
96 general_options.m_identifier = "deep_keras"
97 general_options.m_treename = "tree"
98 variables = ['M', 'p', 'pt', 'pz',
99 'daughter(0, p)', 'daughter(0, pz)', 'daughter(0, pt)',
100 'daughter(1, p)', 'daughter(1, pz)', 'daughter(1, pt)',
101 'daughter(2, p)', 'daughter(2, pz)', 'daughter(2, pt)',
102 'chiProb', 'dr', 'dz',
103 'daughter(0, dr)', 'daughter(1, dr)',
104 'daughter(0, dz)', 'daughter(1, dz)',
105 'daughter(0, chiProb)', 'daughter(1, chiProb)', 'daughter(2, chiProb)',
106 'daughter(0, kaonID)', 'daughter(0, pionID)',
107 'daughterInvM(0, 1)', 'daughterInvM(0, 2)', 'daughterInvM(1, 2)']
108 general_options.m_variables = basf2_mva.vector(*variables)
109 general_options.m_target_variable = "isSignal"
110
111 specific_options = basf2_mva.PythonOptions()
112 specific_options.m_framework = "keras"
113 specific_options.m_steering_file = 'mva/examples/keras/simple_deep.py'
114 specific_options.m_normalize = True
115 specific_options.m_training_fraction = 0.9
116
117 training_start = time.time()
118 basf2_mva.teacher(general_options, specific_options)
119 training_stop = time.time()
120 training_time = training_stop - training_start
121
122 method = basf2_mva_util.Method(general_options.m_identifier)
123 inference_start = time.time()
124 p, t = method.apply_expert(testing_data, general_options.m_treename)
125 inference_stop = time.time()
126 inference_time = inference_stop - inference_start
128 print("Tensorflow.keras", training_time, inference_time, auc)
def calculate_auc_efficiency_vs_background_retention(p, t, w=None)