Belle II Software development
keras.py
1
8
9import pathlib
10import tempfile
11import numpy as np
12
13from basf2 import B2WARNING
14
15from keras.layers import Dense, Input
16from keras.models import Model, load_model
17from keras.losses import binary_crossentropy
18import tensorflow as tf
19tf.config.threading.set_intra_op_parallelism_threads(1)
20tf.config.threading.set_inter_op_parallelism_threads(1)
21
22
23class State:
24 """
25 keras state
26 """
27
28 def __init__(self, model=None, **kwargs):
29 """ Constructor of the state object """
30
31 self.model = model
32
33
35
36 # other possible things to save into a tensorflow collection
37 for key, value in kwargs.items():
38 self.collection_keys.append(key)
39 setattr(self, key, value)
40
41
42def feature_importance(state):
43 """
44 Return a list containing the feature importances
45 """
46 return []
47
48
49def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
50 """
51 Return dummy default keras model
52 """
53
54 input = Input(shape=(number_of_features,))
55 net = Dense(units=1)(input)
56
57 state = State(Model(input, net))
58
59 state.model.compile(optimizer="adam", loss=binary_crossentropy, metrics=['accuracy'])
60
61 state.model.summary()
62
63 return state
64
65
66def load(obj):
67 """
68 Load keras model into state
69 """
70 with tempfile.TemporaryDirectory() as temp_path:
71
72 temp_path = pathlib.Path(temp_path)
73
74 filename = obj[0]
75 path = temp_path.joinpath(pathlib.Path(filename))
76
77 with open(path, 'w+b') as file:
78 file.write(bytes(obj[1]))
79
80 state = State(load_model(path))
81
82 for index, key in enumerate(obj[2]):
83 setattr(state, key, obj[3][index])
84 return state
85
86
87def apply(state, X):
88 """
89 Apply estimator to passed data.
90 """
91 # convert array input to tensor to avoid creating a new graph for each input
92 # calling the model directly is faster than using the predict method in most of our applications
93 # as we do a loop over events.
94 r = state.model(tf.convert_to_tensor(np.atleast_2d(X), dtype=tf.float32), training=False).numpy()
95 if r.shape[1] == 1:
96 r = r[:, 0] # cannot use squeeze because we might have output of shape [1,X classes]
97 return np.require(r, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
98
99
100def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
101 """
102 Returns just the state object
103 """
104 return state
105
106
107def partial_fit(state, X, S, y, w, epoch, batch):
108 """
109 Pass received data to keras model and fit it
110 """
111 if epoch > 0:
112 B2WARNING("The keras training interface has been called with specific_options.m_nIterations > 1."
113 " In the default implementation this should not be done as keras handles the number of epochs internally.")
114
115 if batch > 0:
116 B2WARNING("The keras training interface has been called with specific_options.m_mini_batch_size > 1."
117 " In the default implementation this should not be done as keras handles the number of batches internally.")
118
119 state.model.fit(X, y, batch_size=100, epochs=10)
120 return False
121
122
123def end_fit(state):
124 """
125 Store trained keras model
126 """
127
128 with tempfile.TemporaryDirectory() as temp_path:
129
130 temp_path = pathlib.Path(temp_path)
131 filename = 'my_model.keras'
132 filepath = temp_path.joinpath(filename)
133 state.model.save(filepath)
134
135 with open(filepath, 'rb') as file:
136 filecontent = file.read()
137
138 collection_keys = state.collection_keys
139 collections_to_store = []
140 for key in state.collection_keys:
141 collections_to_store.append(getattr(state, key))
142
143 del state
144 return [filename, filecontent, collection_keys, collections_to_store]
collection_keys
list of keys to save
Definition: keras.py:34
def __init__(self, model=None, **kwargs)
Definition: keras.py:28