Belle II Software  release-06-01-15
simple_deep.py
1 #!/usr/bin/env python3
2 
3 
10 
11 # This example shows the implementation of a simple MLP in keras.
12 
13 import basf2_mva
14 import basf2_mva_util
15 import time
16 
18 
19 
20 from keras.layers import Input, Dense, Dropout
21 from keras.layers.normalization import BatchNormalization
22 from keras.models import Model
23 from keras.optimizers import Adam
24 from keras.losses import binary_crossentropy
25 from keras.activations import sigmoid, tanh
26 from keras.callbacks import Callback
27 
28 
29 old_time = time.time()
30 
31 
32 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
33  """
34  Build feed forward keras model
35  """
36  input = Input(shape=(number_of_features,))
37 
38  net = Dense(units=number_of_features, activation=tanh)(input)
39  for i in range(7):
40  net = Dense(units=number_of_features, activation=tanh)(net)
41  net = BatchNormalization()(net)
42  for i in range(7):
43  net = Dense(units=number_of_features, activation=tanh)(net)
44  net = Dropout(rate=0.4)(net)
45 
46  output = Dense(units=1, activation=sigmoid)(net)
47 
48  state = State(Model(input, output))
49 
50  state.model.compile(optimizer=Adam(lr=0.01), loss=binary_crossentropy, metrics=['accuracy'])
51 
52  state.model.summary()
53 
54  return state
55 
56 
57 def begin_fit(state, Xtest, Stest, ytest, wtest):
58  """
59  Returns just the state object
60  """
61  state.Xtest = Xtest
62  state.ytest = ytest
63 
64  return state
65 
66 
67 def partial_fit(state, X, S, y, w, epoch):
68  """
69  Pass received data to tensorflow session
70  """
71  class TestCallback(Callback):
72 
73  def on_epoch_end(self, epoch, logs=None):
74  loss, acc = state.model.evaluate(state.Xtest, state.ytest, verbose=0, batch_size=1000)
75  loss2, acc2 = state.model.evaluate(X[:10000], y[:10000], verbose=0, batch_size=1000)
76  print('\nTesting loss: {}, acc: {}'.format(loss, acc))
77  print('Training loss: {}, acc: {}'.format(loss2, acc2))
78 
79  state.model.fit(X, y, batch_size=500, epochs=10, callbacks=[TestCallback()])
80  return False
81 
82 
83 if __name__ == "__main__":
84  from basf2 import conditions
85  # NOTE: do not use testing payloads in production! Any results obtained like this WILL NOT BE PUBLISHED
86  conditions.testing_payloads = [
87  'localdb/database.txt'
88  ]
89 
90  general_options = basf2_mva.GeneralOptions()
91  general_options.m_datafiles = basf2_mva.vector("train.root")
92  general_options.m_identifier = "deep_keras"
93  general_options.m_treename = "tree"
94  variables = ['M', 'p', 'pt', 'pz',
95  'daughter(0, p)', 'daughter(0, pz)', 'daughter(0, pt)',
96  'daughter(1, p)', 'daughter(1, pz)', 'daughter(1, pt)',
97  'daughter(2, p)', 'daughter(2, pz)', 'daughter(2, pt)',
98  'chiProb', 'dr', 'dz',
99  'daughter(0, dr)', 'daughter(1, dr)',
100  'daughter(0, dz)', 'daughter(1, dz)',
101  'daughter(0, chiProb)', 'daughter(1, chiProb)', 'daughter(2, chiProb)',
102  'daughter(0, kaonID)', 'daughter(0, pionID)',
103  'daughterInvariantMass(0, 1)', 'daughterInvariantMass(0, 2)', 'daughterInvariantMass(1, 2)']
104  general_options.m_variables = basf2_mva.vector(*variables)
105  general_options.m_target_variable = "isSignal"
106 
107  specific_options = basf2_mva.PythonOptions()
108  specific_options.m_framework = "contrib_keras"
109  specific_options.m_steering_file = 'mva/examples/keras/simple_deep.py'
110  specific_options.m_normalize = True
111  specific_options.m_training_fraction = 0.9
112 
113  training_start = time.time()
114  basf2_mva.teacher(general_options, specific_options)
115  training_stop = time.time()
116  training_time = training_stop - training_start
117  method = basf2_mva_util.Method(general_options.m_identifier)
118  inference_start = time.time()
119  test_data = ["test.root"] * 10
120  p, t = method.apply_expert(basf2_mva.vector(*test_data), general_options.m_treename)
121  inference_stop = time.time()
122  inference_time = inference_stop - inference_start
124  print("Tensorflow", training_time, inference_time, auc)
def calculate_roc_auc(p, t)