20from keras.layers 
import Input, Dense, Dropout, BatchNormalization
 
   21from keras.models 
import Model
 
   22from keras.optimizers 
import Adam
 
   23from keras.losses 
import binary_crossentropy
 
   24from keras.activations 
import sigmoid, tanh
 
   25from keras.callbacks 
import Callback
 
   31def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
 
   33    Build feed forward keras model 
   35    input = Input(shape=(number_of_features,))
 
   37    net = Dense(units=number_of_features, activation=tanh)(input)
 
   39        net = Dense(units=number_of_features, activation=tanh)(net)
 
   40        net = BatchNormalization()(net)
 
   42        net = Dense(units=number_of_features, activation=tanh)(net)
 
   43        net = Dropout(rate=0.4)(net)
 
   45    output = Dense(units=1, activation=sigmoid)(net)
 
   47    state = State(Model(input, output))
 
   49    state.model.compile(optimizer=Adam(learning_rate=0.01), loss=binary_crossentropy, metrics=[
'accuracy'])
 
   56def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
 
   58    Returns just the state object 
   66def partial_fit(state, X, S, y, w, epoch, batch):
 
   68    Pass received data to tensorflow 
   70    class TestCallback(Callback):
 
   72        def on_epoch_end(self, epoch, logs=None):
 
   73            loss, acc = state.model.evaluate(state.Xtest, state.ytest, verbose=0, batch_size=1000)
 
   74            loss2, acc2 = state.model.evaluate(X[:10000], y[:10000], verbose=0, batch_size=1000)
 
   75            print(f
'\nTesting loss: {loss}, acc: {acc}')
 
   76            print(f
'Training loss: {loss2}, acc: {acc2}')
 
   78    state.model.fit(X, y, batch_size=500, epochs=10, callbacks=[TestCallback()])
 
   82if __name__ == 
"__main__":
 
   83    from basf2 
import conditions, find_file
 
   85    conditions.testing_payloads = [
 
   86        'localdb/database.txt' 
   88    train_file = find_file(
"mva/train_D0toKpipi.root", 
"examples")
 
   89    test_file = find_file(
"mva/test_D0toKpipi.root", 
"examples")
 
   91    training_data = basf2_mva.vector(train_file)
 
   92    testing_data = basf2_mva.vector(test_file)
 
   94    general_options = basf2_mva.GeneralOptions()
 
   95    general_options.m_datafiles = training_data
 
   96    general_options.m_identifier = 
"deep_keras" 
   97    general_options.m_treename = 
"tree" 
   98    variables = [
'M', 
'p', 
'pt', 
'pz',
 
   99                 'daughter(0, p)', 
'daughter(0, pz)', 
'daughter(0, pt)',
 
  100                 'daughter(1, p)', 
'daughter(1, pz)', 
'daughter(1, pt)',
 
  101                 'daughter(2, p)', 
'daughter(2, pz)', 
'daughter(2, pt)',
 
  102                 'chiProb', 
'dr', 
'dz',
 
  103                 'daughter(0, dr)', 
'daughter(1, dr)',
 
  104                 'daughter(0, dz)', 
'daughter(1, dz)',
 
  105                 'daughter(0, chiProb)', 
'daughter(1, chiProb)', 
'daughter(2, chiProb)',
 
  106                 'daughter(0, kaonID)', 
'daughter(0, pionID)',
 
  107                 'daughterInvM(0, 1)', 
'daughterInvM(0, 2)', 
'daughterInvM(1, 2)']
 
  108    general_options.m_variables = basf2_mva.vector(*variables)
 
  109    general_options.m_target_variable = 
"isSignal" 
  111    specific_options = basf2_mva.PythonOptions()
 
  112    specific_options.m_framework = 
"keras" 
  113    specific_options.m_steering_file = 
'mva/examples/keras/simple_deep.py' 
  114    specific_options.m_normalize = 
True 
  115    specific_options.m_training_fraction = 0.9
 
  117    training_start = time.time()
 
  118    basf2_mva.teacher(general_options, specific_options)
 
  119    training_stop = time.time()
 
  120    training_time = training_stop - training_start
 
  123    inference_start = time.time()
 
  124    p, t = method.apply_expert(testing_data, general_options.m_treename)
 
  125    inference_stop = time.time()
 
  126    inference_time = inference_stop - inference_start
 
  128    print(
"Tensorflow.keras", training_time, inference_time, auc)
 
calculate_auc_efficiency_vs_background_retention(p, t, w=None)