13 from tensorflow.keras.layers
import Dense, Input
14 from tensorflow.keras.models
import Model, load_model
15 from tensorflow.keras.losses
import binary_crossentropy
16 import tensorflow
as tf
17 from basf2
import B2WARNING
22 Tensorflow.keras state
26 """ Constructor of the state object """
34 for key, value
in kwargs.items():
36 setattr(self, key, value)
39 def feature_importance(state):
41 Return a list containing the feature importances
46 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
48 Return default tensorflow.keras model
51 input = Input(shape=(number_of_features,))
52 net = Dense(units=1)(input)
54 state =
State(Model(input, net))
56 state.model.compile(optimizer=
"adam", loss=binary_crossentropy, metrics=[
'accuracy'])
65 Load Tensorflow.keras model into state
67 with tempfile.TemporaryDirectory()
as temp_path:
69 temp_path = pathlib.Path(temp_path)
72 for file_index, file_name
in enumerate(file_names):
73 path = temp_path.joinpath(pathlib.Path(file_name))
74 path.parents[0].mkdir(parents=
True, exist_ok=
True)
76 with open(path,
'w+b')
as file:
77 file.write(bytes(obj[1][file_index]))
79 state =
State(load_model(pathlib.Path(temp_path) /
'my_model'))
81 for index, key
in enumerate(obj[2]):
82 setattr(state, key, obj[3][index])
88 Apply estimator to passed data.
93 r = state.model(tf.convert_to_tensor(np.atleast_2d(X), dtype=tf.float32), training=
False).numpy()
96 return np.require(r, dtype=np.float32, requirements=[
'A',
'W',
'C',
'O'])
99 def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
101 Returns just the state object
106 def partial_fit(state, X, S, y, w, epoch, batch):
108 Pass received data to tensorflow.keras session
111 B2WARNING(
"The keras training interface has been called with specific_options.m_nIterations > 1."
112 " In the default implementation this should not be done as keras handles the number of epochs internally.")
115 B2WARNING(
"The keras training interface has been called with specific_options.m_mini_batch_size > 1."
116 " In the default implementation this should not be done as keras handles the number of batches internally.")
118 state.model.fit(X, y, batch_size=100, epochs=10)
124 Store tensorflow.keras session in a graph
127 with tempfile.TemporaryDirectory()
as temp_path:
129 temp_path = pathlib.Path(temp_path)
130 state.model.save(temp_path.joinpath(
'my_model'))
137 file_names = [f.relative_to(temp_path)
for f
in temp_path.rglob(
'*')
if f.is_file()]
139 for file_name
in file_names:
140 with open(temp_path.joinpath(file_name),
'rb')
as file:
141 files.append(file.read())
143 collection_keys = state.collection_keys
144 collections_to_store = []
145 for key
in state.collection_keys:
146 collections_to_store.append(getattr(state, key))
149 return [file_names, files, collection_keys, collections_to_store]
collection_keys
list of keys to save
def __init__(self, model=None, **kwargs)