13 from tensorflow.keras.layers 
import Dense, Input
 
   14 from tensorflow.keras.models 
import Model, load_model
 
   15 from tensorflow.keras.losses 
import binary_crossentropy
 
   16 import tensorflow 
as tf
 
   17 from basf2 
import B2WARNING
 
   22     Tensorflow.keras state 
   26         """ Constructor of the state object """ 
   34         for key, value 
in kwargs.items():
 
   36             setattr(self, key, value)
 
   39 def feature_importance(state):
 
   41     Return a list containing the feature importances 
   46 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
 
   48     Return default tensorflow.keras model 
   51     input = Input(shape=(number_of_features,))
 
   52     net = Dense(units=1)(input)
 
   54     state = 
State(Model(input, net))
 
   56     state.model.compile(optimizer=
"adam", loss=binary_crossentropy, metrics=[
'accuracy'])
 
   65     Load Tensorflow.keras model into state 
   67     with tempfile.TemporaryDirectory() 
as temp_path:
 
   69         temp_path = pathlib.Path(temp_path)
 
   72         for file_index, file_name 
in enumerate(file_names):
 
   73             path = temp_path.joinpath(pathlib.Path(file_name))
 
   74             path.parents[0].mkdir(parents=
True, exist_ok=
True)
 
   76             with open(path, 
'w+b') 
as file:
 
   77                 file.write(bytes(obj[1][file_index]))
 
   79         state = 
State(load_model(pathlib.Path(temp_path) / 
'my_model'))
 
   81         for index, key 
in enumerate(obj[2]):
 
   82             setattr(state, key, obj[3][index])
 
   88     Apply estimator to passed data. 
   93     r = state.model(tf.convert_to_tensor(np.atleast_2d(X), dtype=tf.float32), training=
False).numpy()
 
   96     return np.require(r, dtype=np.float32, requirements=[
'A', 
'W', 
'C', 
'O'])
 
   99 def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
 
  101     Returns just the state object 
  106 def partial_fit(state, X, S, y, w, epoch, batch):
 
  108     Pass received data to tensorflow.keras session 
  111         B2WARNING(
"The keras training interface has been called with specific_options.m_nIterations > 1." 
  112                   " In the default implementation this should not be done as keras handles the number of epochs internally.")
 
  115         B2WARNING(
"The keras training interface has been called with specific_options.m_mini_batch_size > 1." 
  116                   " In the default implementation this should not be done as keras handles the number of batches internally.")
 
  118     state.model.fit(X, y, batch_size=100, epochs=10)
 
  124     Store tensorflow.keras session in a graph 
  127     with tempfile.TemporaryDirectory() 
as temp_path:
 
  129         temp_path = pathlib.Path(temp_path)
 
  130         state.model.save(temp_path.joinpath(
'my_model'))
 
  137         file_names = [f.relative_to(temp_path) 
for f 
in temp_path.rglob(
'*') 
if f.is_file()]
 
  139         for file_name 
in file_names:
 
  140             with open(temp_path.joinpath(file_name), 
'rb') 
as file:
 
  141                 files.append(file.read())
 
  143         collection_keys = state.collection_keys
 
  144         collections_to_store = []
 
  145         for key 
in state.collection_keys:
 
  146             collections_to_store.append(getattr(state, key))
 
  149     return [file_names, files, collection_keys, collections_to_store]
 
collection_keys
list of keys to save
def __init__(self, model=None, **kwargs)