Belle II Software development
keras.py
1
8
9import pathlib
10import tempfile
11import numpy as np
12
13from keras.layers import Dense, Input
14from keras.models import Model, load_model
15from keras.losses import binary_crossentropy
16import tensorflow as tf
17from basf2 import B2WARNING
18
19
20class State:
21 """
22 keras state
23 """
24
25 def __init__(self, model=None, **kwargs):
26 """ Constructor of the state object """
27
28 self.model = model
29
30
32
33 # other possible things to save into a tensorflow collection
34 for key, value in kwargs.items():
35 self.collection_keys.append(key)
36 setattr(self, key, value)
37
38
39def feature_importance(state):
40 """
41 Return a list containing the feature importances
42 """
43 return []
44
45
46def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
47 """
48 Return dummy default keras model
49 """
50
51 input = Input(shape=(number_of_features,))
52 net = Dense(units=1)(input)
53
54 state = State(Model(input, net))
55
56 state.model.compile(optimizer="adam", loss=binary_crossentropy, metrics=['accuracy'])
57
58 state.model.summary()
59
60 return state
61
62
63def load(obj):
64 """
65 Load keras model into state
66 """
67 with tempfile.TemporaryDirectory() as temp_path:
68
69 temp_path = pathlib.Path(temp_path)
70
71 filename = obj[0]
72 path = temp_path.joinpath(pathlib.Path(filename))
73
74 with open(path, 'w+b') as file:
75 file.write(bytes(obj[1]))
76
77 state = State(load_model(path))
78
79 for index, key in enumerate(obj[2]):
80 setattr(state, key, obj[3][index])
81 return state
82
83
84def apply(state, X):
85 """
86 Apply estimator to passed data.
87 """
88 # convert array input to tensor to avoid creating a new graph for each input
89 # calling the model directly is faster than using the predict method in most of our applications
90 # as we do a loop over events.
91 r = state.model(tf.convert_to_tensor(np.atleast_2d(X), dtype=tf.float32), training=False).numpy()
92 if r.shape[1] == 1:
93 r = r[:, 0] # cannot use squeeze because we might have output of shape [1,X classes]
94 return np.require(r, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
95
96
97def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
98 """
99 Returns just the state object
100 """
101 return state
102
103
104def partial_fit(state, X, S, y, w, epoch, batch):
105 """
106 Pass received data to keras model and fit it
107 """
108 if epoch > 0:
109 B2WARNING("The keras training interface has been called with specific_options.m_nIterations > 1."
110 " In the default implementation this should not be done as keras handles the number of epochs internally.")
111
112 if batch > 0:
113 B2WARNING("The keras training interface has been called with specific_options.m_mini_batch_size > 1."
114 " In the default implementation this should not be done as keras handles the number of batches internally.")
115
116 state.model.fit(X, y, batch_size=100, epochs=10)
117 return False
118
119
120def end_fit(state):
121 """
122 Store trained keras model
123 """
124
125 with tempfile.TemporaryDirectory() as temp_path:
126
127 temp_path = pathlib.Path(temp_path)
128 filename = 'my_model.keras'
129 filepath = temp_path.joinpath(filename)
130 state.model.save(filepath)
131
132 with open(filepath, 'rb') as file:
133 filecontent = file.read()
134
135 collection_keys = state.collection_keys
136 collections_to_store = []
137 for key in state.collection_keys:
138 collections_to_store.append(getattr(state, key))
139
140 del state
141 return [filename, filecontent, collection_keys, collections_to_store]
collection_keys
list of keys to save
Definition: keras.py:31
def __init__(self, model=None, **kwargs)
Definition: keras.py:25