14 from tensorflow.keras.layers
import Dense, Input
15 from tensorflow.keras.models
import Model, load_model
16 from tensorflow.keras.losses
import binary_crossentropy
21 Tensorflow.keras state
25 """ Constructor of the state object """
33 for key, value
in kwargs.items():
35 setattr(self, key, value)
38 def feature_importance(state):
40 Return a list containing the feature importances
45 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
47 Return default tensorflow.keras model
50 input = Input(shape=(number_of_features,))
51 net = Dense(units=1)(input)
53 state =
State(Model(input, net))
55 state.model.compile(optimizer=
"adam", loss=binary_crossentropy, metrics=[
'accuracy'])
64 Load Tensorflow.keras model into state
66 with tempfile.TemporaryDirectory()
as temp_path:
68 temp_path = pathlib.Path(temp_path)
71 for file_index, file_name
in enumerate(file_names):
72 path = temp_path.joinpath(pathlib.Path(file_name))
73 path.parents[0].mkdir(parents=
True, exist_ok=
True)
75 with open(path,
'w+b')
as file:
76 file.write(bytes(obj[1][file_index]))
78 state =
State(load_model(pathlib.Path(temp_path) /
'my_model'))
80 for index, key
in enumerate(obj[2]):
81 setattr(state, key, obj[3][index])
88 Apply estimator to passed data.
90 r = state.model.predict(X).flatten()
91 return np.require(r, dtype=np.float32, requirements=[
'A',
'W',
'C',
'O'])
94 def begin_fit(state, Xtest, Stest, ytest, wtest):
96 Returns just the state object
101 def partial_fit(state, X, S, y, w, epoch):
103 Pass received data to tensorflow.keras session
105 state.model.fit(X, y, batch_size=100, epochs=10)
111 Store tensorflow.keras session in a graph
114 with tempfile.TemporaryDirectory()
as temp_path:
116 temp_path = pathlib.Path(temp_path)
117 state.model.save(temp_path.joinpath(
'my_model'))
124 file_names = [f.relative_to(temp_path)
for f
in temp_path.rglob(
'*')
if f.is_file()]
126 for file_name
in file_names:
127 with open(temp_path.joinpath(file_name),
'rb')
as file:
128 files.append(file.read())
130 collection_keys = state.collection_keys
131 collections_to_store = []
132 for key
in state.collection_keys:
133 collections_to_store.append(getattr(state, key))
136 return [file_names, files, collection_keys, collections_to_store]
collection_keys
list of keys to save
def __init__(self, model=None, **kwargs)