10 import tensorflow
as tf
11 import tensorflow.contrib.keras
as keras
13 from keras.models
import load_model
16 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
18 Just load keras model into state
20 return State(load_model(parameters[
'file_path']))
23 def partial_fit(state, X, S, y, w, epoch):
25 This should be empty, because our model is already fitted.
30 if __name__ ==
"__main__":
33 from root_pandas
import to_root
42 from keras.layers
import Input, Dense
43 from keras.models
import Model
44 from keras.optimizers
import adam
45 from keras.losses
import binary_crossentropy
46 from keras.activations
import sigmoid, tanh
47 from basf2
import conditions
49 conditions.testing_payloads = [
50 'localdb/database.txt'
57 variables = [
'x' + str(i)
for i
in range(10)]
59 data = np.random.normal(size=[1000, 11])
60 data[:, -1] = np.int32(data[:, -1] > 0.5)
63 input = Input(shape=(10,))
65 net = Dense(units=100, activation=tanh)(input)
66 output = Dense(units=1, activation=sigmoid)(net)
68 model = Model(input, output)
69 model.compile(optimizer=adam(lr=0.01), loss=binary_crossentropy, metrics=[
'accuracy'])
71 model.fit(data[:, :-1], data[:, -1], batch_size=10, epochs=10)
74 with tempfile.TemporaryDirectory()
as path:
77 for i, name
in enumerate(variables):
78 dic.update({name: data[:, i]})
79 dic.update({
'isSignal': data[:, -1]})
81 df = pandas.DataFrame(dic, dtype=np.float32)
82 to_root(df, os.path.join(path,
'data.root'), tree_key=
'tree')
85 model.save(os.path.join(path,
'weights.h5'))
89 general_options = basf2_mva.GeneralOptions()
90 general_options.m_datafiles = basf2_mva.vector(os.path.join(path,
'data.root'))
91 general_options.m_treename =
"tree"
92 general_options.m_variables = basf2_mva.vector(*variables)
93 general_options.m_target_variable =
"isSignal"
95 specific_options = basf2_mva.PythonOptions()
96 specific_options.m_framework =
"contrib_keras"
97 specific_options.m_steering_file =
'mva/examples/keras/keras_to_weightfile.py'
99 general_options.m_identifier =
'converted_keras'
100 specific_options.m_config = json.dumps({
'file_path': os.path.join(path,
'weights.h5')})
101 basf2_mva.teacher(general_options, specific_options)
105 p, t = method.apply_expert(general_options.m_datafiles, general_options.m_treename)