Belle II Software  release-08-01-10
keras.py
1 
8 
9 import pathlib
10 import tempfile
11 import numpy as np
12 
13 from tensorflow.keras.layers import Dense, Input
14 from tensorflow.keras.models import Model, load_model
15 from tensorflow.keras.losses import binary_crossentropy
16 import tensorflow as tf
17 from basf2 import B2WARNING
18 
19 
20 class State:
21  """
22  Tensorflow.keras state
23  """
24 
25  def __init__(self, model=None, **kwargs):
26  """ Constructor of the state object """
27 
28  self.modelmodel = model
29 
30 
31  self.collection_keyscollection_keys = []
32 
33  # other possible things to save into a tensorflow collection
34  for key, value in kwargs.items():
35  self.collection_keyscollection_keys.append(key)
36  setattr(self, key, value)
37 
38 
39 def feature_importance(state):
40  """
41  Return a list containing the feature importances
42  """
43  return []
44 
45 
46 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
47  """
48  Return default tensorflow.keras model
49  """
50 
51  input = Input(shape=(number_of_features,))
52  net = Dense(units=1)(input)
53 
54  state = State(Model(input, net))
55 
56  state.model.compile(optimizer="adam", loss=binary_crossentropy, metrics=['accuracy'])
57 
58  state.model.summary()
59 
60  return state
61 
62 
63 def load(obj):
64  """
65  Load Tensorflow.keras model into state
66  """
67  with tempfile.TemporaryDirectory() as temp_path:
68 
69  temp_path = pathlib.Path(temp_path)
70 
71  file_names = obj[0]
72  for file_index, file_name in enumerate(file_names):
73  path = temp_path.joinpath(pathlib.Path(file_name))
74  path.parents[0].mkdir(parents=True, exist_ok=True)
75 
76  with open(path, 'w+b') as file:
77  file.write(bytes(obj[1][file_index]))
78 
79  state = State(load_model(pathlib.Path(temp_path) / 'my_model'))
80 
81  for index, key in enumerate(obj[2]):
82  setattr(state, key, obj[3][index])
83  return state
84 
85 
86 def apply(state, X):
87  """
88  Apply estimator to passed data.
89  """
90  # convert array input to tensor to avoid creating a new graph for each input
91  # calling the model directly is faster than using the predict method in most of our applications
92  # as we do a loop over events.
93  r = state.model(tf.convert_to_tensor(np.atleast_2d(X), dtype=tf.float32), training=False).numpy()
94  if r.shape[1] == 1:
95  r = r[:, 0] # cannot use squeeze because we might have output of shape [1,X classes]
96  return np.require(r, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
97 
98 
99 def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
100  """
101  Returns just the state object
102  """
103  return state
104 
105 
106 def partial_fit(state, X, S, y, w, epoch, batch):
107  """
108  Pass received data to tensorflow.keras session
109  """
110  if epoch > 0:
111  B2WARNING("The keras training interface has been called with specific_options.m_nIterations > 1."
112  " In the default implementation this should not be done as keras handles the number of epochs internally.")
113 
114  if batch > 0:
115  B2WARNING("The keras training interface has been called with specific_options.m_mini_batch_size > 1."
116  " In the default implementation this should not be done as keras handles the number of batches internally.")
117 
118  state.model.fit(X, y, batch_size=100, epochs=10)
119  return False
120 
121 
122 def end_fit(state):
123  """
124  Store tensorflow.keras session in a graph
125  """
126 
127  with tempfile.TemporaryDirectory() as temp_path:
128 
129  temp_path = pathlib.Path(temp_path)
130  state.model.save(temp_path.joinpath('my_model'))
131 
132  # this creates:
133  # path/my_model/saved_model.pb
134  # path/my_model/keras_metadata.pb (sometimes)
135  # path/my_model/variables/*
136  # path/my_model/assets/*
137  file_names = [f.relative_to(temp_path) for f in temp_path.rglob('*') if f.is_file()]
138  files = []
139  for file_name in file_names:
140  with open(temp_path.joinpath(file_name), 'rb') as file:
141  files.append(file.read())
142 
143  collection_keys = state.collection_keys
144  collections_to_store = []
145  for key in state.collection_keys:
146  collections_to_store.append(getattr(state, key))
147 
148  del state
149  return [file_names, files, collection_keys, collections_to_store]
collection_keys
list of keys to save
Definition: keras.py:31
def __init__(self, model=None, **kwargs)
Definition: keras.py:25