13from keras.layers
import Dense, Input
14from keras.models
import Model, load_model
15from keras.losses
import binary_crossentropy
16import tensorflow
as tf
17from basf2
import B2WARNING
26 """ Constructor of the state object """
34 for key, value
in kwargs.items():
36 setattr(self, key, value)
39def feature_importance(state):
41 Return a list containing the feature importances
46def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
48 Return dummy default keras model
51 input = Input(shape=(number_of_features,))
52 net = Dense(units=1)(input)
54 state = State(Model(input, net))
56 state.model.compile(optimizer="adam", loss=binary_crossentropy, metrics=[
'accuracy'])
65 Load keras model into state
67 with tempfile.TemporaryDirectory()
as temp_path:
69 temp_path = pathlib.Path(temp_path)
72 path = temp_path.joinpath(pathlib.Path(filename))
74 with open(path,
'w+b')
as file:
75 file.write(bytes(obj[1]))
77 state =
State(load_model(path))
79 for index, key
in enumerate(obj[2]):
80 setattr(state, key, obj[3][index])
86 Apply estimator to passed data.
91 r = state.model(tf.convert_to_tensor(np.atleast_2d(X), dtype=tf.float32), training=
False).numpy()
94 return np.require(r, dtype=np.float32, requirements=[
'A',
'W',
'C',
'O'])
97def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
99 Returns just the state object
104def partial_fit(state, X, S, y, w, epoch, batch):
106 Pass received data to keras model and fit it
109 B2WARNING(
"The keras training interface has been called with specific_options.m_nIterations > 1."
110 " In the default implementation this should not be done as keras handles the number of epochs internally.")
113 B2WARNING(
"The keras training interface has been called with specific_options.m_mini_batch_size > 1."
114 " In the default implementation this should not be done as keras handles the number of batches internally.")
116 state.model.fit(X, y, batch_size=100, epochs=10)
122 Store trained keras model
125 with tempfile.TemporaryDirectory()
as temp_path:
127 temp_path = pathlib.Path(temp_path)
128 filename =
'my_model.keras'
129 filepath = temp_path.joinpath(filename)
130 state.model.save(filepath)
132 with open(filepath,
'rb')
as file:
133 filecontent = file.read()
135 collection_keys = state.collection_keys
136 collections_to_store = []
137 for key
in state.collection_keys:
138 collections_to_store.append(getattr(state, key))
141 return [filename, filecontent, collection_keys, collections_to_store]
collection_keys
list of keys to save
def __init__(self, model=None, **kwargs)