13from basf2
import B2WARNING
15from keras.layers
import Dense, Input
16from keras.models
import Model, load_model
17from keras.losses
import binary_crossentropy
18import tensorflow
as tf
19tf.config.threading.set_intra_op_parallelism_threads(1)
20tf.config.threading.set_inter_op_parallelism_threads(1)
29 """ Constructor of the state object """
37 for key, value
in kwargs.items():
39 setattr(self, key, value)
42def feature_importance(state):
44 Return a list containing the feature importances
49def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
51 Return dummy default keras model
54 input = Input(shape=(number_of_features,))
55 net = Dense(units=1)(input)
57 state = State(Model(input, net))
59 state.model.compile(optimizer="adam", loss=binary_crossentropy, metrics=[
'accuracy'])
68 Load keras model into state
70 with tempfile.TemporaryDirectory()
as temp_path:
72 temp_path = pathlib.Path(temp_path)
75 path = temp_path.joinpath(pathlib.Path(filename))
77 with open(path,
'w+b')
as file:
78 file.write(bytes(obj[1]))
80 state =
State(load_model(path))
82 for index, key
in enumerate(obj[2]):
83 setattr(state, key, obj[3][index])
89 Apply estimator to passed data.
94 r = state.model(tf.convert_to_tensor(np.atleast_2d(X), dtype=tf.float32), training=
False).numpy()
97 return np.require(r, dtype=np.float32, requirements=[
'A',
'W',
'C',
'O'])
100def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
102 Returns just the state object
107def partial_fit(state, X, S, y, w, epoch, batch):
109 Pass received data to keras model and fit it
112 B2WARNING(
"The keras training interface has been called with specific_options.m_nIterations > 1."
113 " In the default implementation this should not be done as keras handles the number of epochs internally.")
116 B2WARNING(
"The keras training interface has been called with specific_options.m_mini_batch_size > 1."
117 " In the default implementation this should not be done as keras handles the number of batches internally.")
119 state.model.fit(X, y, batch_size=100, epochs=10)
125 Store trained keras model
128 with tempfile.TemporaryDirectory()
as temp_path:
130 temp_path = pathlib.Path(temp_path)
131 filename =
'my_model.keras'
132 filepath = temp_path.joinpath(filename)
133 state.model.save(filepath)
135 with open(filepath,
'rb')
as file:
136 filecontent = file.read()
138 collection_keys = state.collection_keys
139 collections_to_store = []
140 for key
in state.collection_keys:
141 collections_to_store.append(getattr(state, key))
144 return [filename, filecontent, collection_keys, collections_to_store]
collection_keys
list of keys to save
def __init__(self, model=None, **kwargs)