15from tensorflow.keras.models
import load_model
18def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
20 Just load keras model into state
22 return State(load_model(parameters[
'file_path']))
25def partial_fit(state, X, S, y, w, epoch, batch):
27 This should be empty, because our model is already fitted.
32if __name__ ==
"__main__":
43 from tensorflow.keras.layers
import Input, Dense
44 from tensorflow.keras.models
import Model
45 from tensorflow.keras.optimizers
import Adam
46 from tensorflow.keras.losses
import binary_crossentropy
47 from tensorflow.keras.activations
import sigmoid, tanh
48 from basf2
import conditions
50 conditions.testing_payloads = [
51 'localdb/database.txt'
58 variables = [
'x' + str(i)
for i
in range(10)]
60 data = np.random.normal(size=[1000, 11])
61 data[:, -1] = np.int32(data[:, -1] > 0.5)
64 input = Input(shape=(10,))
66 net = Dense(units=100, activation=tanh)(input)
67 output = Dense(units=1, activation=sigmoid)(net)
69 model = Model(input, output)
70 model.compile(optimizer=Adam(lr=0.01), loss=binary_crossentropy, metrics=[
'accuracy'])
72 model.fit(data[:, :-1], data[:, -1], batch_size=10, epochs=10)
75 with tempfile.TemporaryDirectory()
as path:
78 for i, name
in enumerate(variables):
79 dic.update({name: data[:, i]})
80 dic.update({
'isSignal': data[:, -1]})
82 df = pandas.DataFrame(dic, dtype=np.float64)
83 with uproot.recreate(os.path.join(path,
'data.root'))
as outfile:
87 model.save(os.path.join(path,
'example_existing_model'))
91 general_options = basf2_mva.GeneralOptions()
92 general_options.m_datafiles = basf2_mva.vector(os.path.join(path,
'data.root'))
93 general_options.m_treename =
"tree"
94 general_options.m_variables = basf2_mva.vector(*variables)
95 general_options.m_target_variable =
"isSignal"
97 specific_options = basf2_mva.PythonOptions()
98 specific_options.m_framework =
"keras"
99 specific_options.m_steering_file =
'mva/examples/keras/import_existing_keras_model.py'
101 general_options.m_identifier =
'converted_keras'
102 specific_options.m_config = json.dumps({
'file_path': os.path.join(path,
'example_existing_model')})
103 basf2_mva.teacher(general_options, specific_options)
107 p, t = method.apply_expert(general_options.m_datafiles, general_options.m_treename)
def calculate_auc_efficiency_vs_background_retention(p, t, w=None)