Belle II Software  light-2205-abys
keras.py
1 
8 
9 import pathlib
10 import tempfile
11 import numpy as np
12 
13 
14 from tensorflow.keras.layers import Dense, Input
15 from tensorflow.keras.models import Model, load_model
16 from tensorflow.keras.losses import binary_crossentropy
17 
18 
19 class State(object):
20  """
21  Tensorflow.keras state
22  """
23 
24  def __init__(self, model=None, **kwargs):
25  """ Constructor of the state object """
26 
27  self.modelmodel = model
28 
29 
30  self.collection_keyscollection_keys = []
31 
32  # other possible things to save into a tensorflow collection
33  for key, value in kwargs.items():
34  self.collection_keyscollection_keys.append(key)
35  setattr(self, key, value)
36 
37 
38 def feature_importance(state):
39  """
40  Return a list containing the feature importances
41  """
42  return []
43 
44 
45 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
46  """
47  Return default tensorflow.keras model
48  """
49 
50  input = Input(shape=(number_of_features,))
51  net = Dense(units=1)(input)
52 
53  state = State(Model(input, net))
54 
55  state.model.compile(optimizer="adam", loss=binary_crossentropy, metrics=['accuracy'])
56 
57  state.model.summary()
58 
59  return state
60 
61 
62 def load(obj):
63  """
64  Load Tensorflow.keras model into state
65  """
66  with tempfile.TemporaryDirectory() as temp_path:
67 
68  temp_path = pathlib.Path(temp_path)
69 
70  file_names = obj[0]
71  for file_index, file_name in enumerate(file_names):
72  path = temp_path.joinpath(pathlib.Path(file_name))
73  path.parents[0].mkdir(parents=True, exist_ok=True)
74 
75  with open(path, 'w+b') as file:
76  file.write(bytes(obj[1][file_index]))
77 
78  state = State(load_model(pathlib.Path(temp_path) / 'my_model'))
79 
80  for index, key in enumerate(obj[2]):
81  setattr(state, key, obj[3][index])
82 
83  return state
84 
85 
86 def apply(state, X):
87  """
88  Apply estimator to passed data.
89  """
90  r = state.model.predict(X).flatten()
91  return np.require(r, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
92 
93 
94 def begin_fit(state, Xtest, Stest, ytest, wtest):
95  """
96  Returns just the state object
97  """
98  return state
99 
100 
101 def partial_fit(state, X, S, y, w, epoch):
102  """
103  Pass received data to tensorflow.keras session
104  """
105  state.model.fit(X, y, batch_size=100, epochs=10)
106  return False
107 
108 
109 def end_fit(state):
110  """
111  Store tensorflow.keras session in a graph
112  """
113 
114  with tempfile.TemporaryDirectory() as temp_path:
115 
116  temp_path = pathlib.Path(temp_path)
117  state.model.save(temp_path.joinpath('my_model'))
118 
119  # this creates:
120  # path/my_model/saved_model.pb
121  # path/my_model/keras_metadata.pb (sometimes)
122  # path/my_model/variables/*
123  # path/my_model/assets/*
124  file_names = [f.relative_to(temp_path) for f in temp_path.rglob('*') if f.is_file()]
125  files = []
126  for file_name in file_names:
127  with open(temp_path.joinpath(file_name), 'rb') as file:
128  files.append(file.read())
129 
130  collection_keys = state.collection_keys
131  collections_to_store = []
132  for key in state.collection_keys:
133  collections_to_store.append(getattr(state, key))
134 
135  del state
136  return [file_names, files, collection_keys, collections_to_store]
collection_keys
list of keys to save
Definition: keras.py:30
def __init__(self, model=None, **kwargs)
Definition: keras.py:24