Belle II Software  release-06-02-00
contrib_keras.py
1 
8 
9 import os
10 import tempfile
11 import numpy as np
12 
13 
14 from keras.layers import Dense, Input
15 from keras.models import Model, load_model
16 from keras.losses import binary_crossentropy
17 
18 
19 class State(object):
20  """
21  Tensorflow state
22  """
23 
24  def __init__(self, model=None, custom_objects=None, **kwargs):
25  """ Constructor of the state object """
26 
27  self.modelmodel = model
28 
29  self.custom_objectscustom_objects = custom_objects
30 
31  self.collection_keyscollection_keys = []
32 
33  # other possible things to save into a tensorflow collection
34  for key, value in kwargs.items():
35  self.collection_keyscollection_keys.append(key)
36  setattr(self, key, value)
37 
38 
39 def feature_importance(state):
40  """
41  Return a list containing the feature importances
42  """
43  return []
44 
45 
46 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
47  """
48  Return default tensorflow model
49  """
50  input = Input(shape=(number_of_features,))
51  net = Dense(units=1)(input)
52 
53  state = State(Model(input, net))
54 
55  state.model.compile(optimizer="adam", loss=binary_crossentropy, metrics=['accuracy'])
56 
57  state.model.summary()
58 
59  return state
60 
61 
62 def load(obj):
63  """
64  Load Tensorflow estimator into state
65  """
66  with tempfile.TemporaryDirectory() as path:
67  with open(os.path.join(path, 'weights.h5'), 'w+b') as file:
68  file.write(bytes(obj[0]))
69  state = State(load_model(os.path.join(path, 'weights.h5'), custom_objects=obj[1]))
70 
71  for index, key in enumerate(obj[2]):
72  setattr(state, key, obj[index + 3])
73 
74  return state
75 
76 
77 def apply(state, X):
78  """
79  Apply estimator to passed data.
80  """
81  r = state.model.predict(X).flatten()
82  return np.require(r, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
83 
84 
85 def begin_fit(state, Xtest, Stest, ytest, wtest):
86  """
87  Returns just the state object
88  """
89  return state
90 
91 
92 def partial_fit(state, X, S, y, w, epoch):
93  """
94  Pass received data to tensorflow session
95  """
96  state.model.fit(X, y, batch_size=100, epochs=10)
97  return False
98 
99 
100 def end_fit(state):
101  """
102  Store tensorflow session in a graph
103  """
104 
105  with tempfile.TemporaryDirectory() as path:
106  state.model.save(os.path.join(path, 'weights.h5'))
107  with open(os.path.join(path, 'weights.h5'), 'rb') as file:
108  data = file.read()
109 
110  obj_to_save = [data, state.custom_objects, state.collection_keys]
111  for key in state.collection_keys:
112  obj_to_save.append(getattr(state, key))
113  del state
114  return obj_to_save
custom_objects
used by keras to load custom objects like custom layers
def __init__(self, model=None, custom_objects=None, **kwargs)