Belle II Software  release-05-02-19
contrib_keras.py
1 import os
2 import tempfile
3 import h5py
4 import numpy as np
5 
6 import tensorflow as tf
7 import tensorflow.contrib.keras as keras
8 
9 from keras.layers import Input, Dense, Concatenate
10 from keras.models import Model, load_model
11 from keras.optimizers import adam
12 from keras.losses import binary_crossentropy
13 
14 
15 class State(object):
16  """
17  Tensorflow state
18  """
19 
20  def __init__(self, model=None, custom_objects=None, **kwargs):
21  """ Constructor of the state object """
22 
23  self.model = model
24 
25  self.custom_objects = custom_objects
26 
27  self.collection_keys = []
28 
29  # other possible things to save into a tensorflow collection
30  for key, value in kwargs.items():
31  self.collection_keys.append(key)
32  setattr(self, key, value)
33 
34 
35 def feature_importance(state):
36  """
37  Return a list containing the feature importances
38  """
39  return []
40 
41 
42 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
43  """
44  Return default tensorflow model
45  """
46  input = Input(shape=(number_of_features,))
47  net = Dense(units=1)(input)
48 
49  state = State(Model(input, net))
50 
51  state.model.compile(optimizer=adam(), loss=binary_crossentropy, metrics=['accuracy'])
52 
53  state.model.summary()
54 
55  return state
56 
57 
58 def load(obj):
59  """
60  Load Tensorflow estimator into state
61  """
62  with tempfile.TemporaryDirectory() as path:
63  with open(os.path.join(path, 'weights.h5'), 'w+b') as file:
64  file.write(bytes(obj[0]))
65  state = State(load_model(os.path.join(path, 'weights.h5'), custom_objects=obj[1]))
66 
67  for index, key in enumerate(obj[2]):
68  setattr(state, key, obj[index + 3])
69 
70  return state
71 
72 
73 def apply(state, X):
74  """
75  Apply estimator to passed data.
76  """
77  r = state.model.predict(X).flatten()
78  return np.require(r, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
79 
80 
81 def begin_fit(state, Xtest, Stest, ytest, wtest):
82  """
83  Returns just the state object
84  """
85  return state
86 
87 
88 def partial_fit(state, X, S, y, w, epoch):
89  """
90  Pass received data to tensorflow session
91  """
92  state.model.fit(X, y, batch_size=100, epochs=10)
93  return False
94 
95 
96 def end_fit(state):
97  """
98  Store tensorflow session in a graph
99  """
100 
101  with tempfile.TemporaryDirectory() as path:
102  state.model.save(os.path.join(path, 'weights.h5'))
103  with open(os.path.join(path, 'weights.h5'), 'rb') as file:
104  data = file.read()
105 
106  obj_to_save = [data, state.custom_objects, state.collection_keys]
107  for key in state.collection_keys:
108  obj_to_save.append(getattr(state, key))
109  del state
110  return obj_to_save
basf2_mva_python_interface.contrib_keras.State.collection_keys
collection_keys
list of keys to save
Definition: contrib_keras.py:27
basf2_mva_python_interface.contrib_keras.State.__init__
def __init__(self, model=None, custom_objects=None, **kwargs)
Definition: contrib_keras.py:20
basf2_mva_python_interface.contrib_keras.State.custom_objects
custom_objects
used by keras to load custom objects like custom layers
Definition: contrib_keras.py:25
basf2_mva_python_interface.contrib_keras.State
Definition: contrib_keras.py:15
basf2_mva_python_interface.contrib_keras.State.model
model
keras model
Definition: contrib_keras.py:23