16 from keras.models
import load_model
19 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
21 Just load keras model into state
23 return State(load_model(parameters[
'file_path']))
26 def partial_fit(state, X, S, y, w, epoch):
28 This should be empty, because our model is already fitted.
33 if __name__ ==
"__main__":
36 from root_pandas
import to_root
44 from keras.layers
import Input, Dense
45 from keras.models
import Model
46 from keras.optimizers
import Adam
47 from keras.losses
import binary_crossentropy
48 from keras.activations
import sigmoid, tanh
49 from basf2
import conditions
51 conditions.testing_payloads = [
52 'localdb/database.txt'
59 variables = [
'x' + str(i)
for i
in range(10)]
61 data = np.random.normal(size=[1000, 11])
62 data[:, -1] = np.int32(data[:, -1] > 0.5)
65 input = Input(shape=(10,))
67 net = Dense(units=100, activation=tanh)(input)
68 output = Dense(units=1, activation=sigmoid)(net)
70 model = Model(input, output)
71 model.compile(optimizer=Adam(lr=0.01), loss=binary_crossentropy, metrics=[
'accuracy'])
73 model.fit(data[:, :-1], data[:, -1], batch_size=10, epochs=10)
76 with tempfile.TemporaryDirectory()
as path:
79 for i, name
in enumerate(variables):
80 dic.update({name: data[:, i]})
81 dic.update({
'isSignal': data[:, -1]})
83 df = pandas.DataFrame(dic, dtype=np.float32)
84 to_root(df, os.path.join(path,
'data.root'), key=
'tree')
87 model.save(os.path.join(path,
'weights.h5'))
91 general_options = basf2_mva.GeneralOptions()
92 general_options.m_datafiles = basf2_mva.vector(os.path.join(path,
'data.root'))
93 general_options.m_treename =
"tree"
94 general_options.m_variables = basf2_mva.vector(*variables)
95 general_options.m_target_variable =
"isSignal"
97 specific_options = basf2_mva.PythonOptions()
98 specific_options.m_framework =
"contrib_keras"
99 specific_options.m_steering_file =
'mva/examples/keras/keras_to_weightfile.py'
101 general_options.m_identifier =
'converted_keras'
102 specific_options.m_config = json.dumps({
'file_path': os.path.join(path,
'weights.h5')})
103 basf2_mva.teacher(general_options, specific_options)
107 p, t = method.apply_expert(general_options.m_datafiles, general_options.m_treename)
def calculate_roc_auc(p, t)