Belle II Software light-2406-ragdoll
keras.py
1
8
9import pathlib
10import tempfile
11import numpy as np
12
13from tensorflow.keras.layers import Dense, Input
14from tensorflow.keras.models import Model, load_model
15from tensorflow.keras.losses import binary_crossentropy
16import tensorflow as tf
17from basf2 import B2WARNING
18
19
20class State:
21 """
22 Tensorflow.keras state
23 """
24
25 def __init__(self, model=None, **kwargs):
26 """ Constructor of the state object """
27
28 self.model = model
29
30
32
33 # other possible things to save into a tensorflow collection
34 for key, value in kwargs.items():
35 self.collection_keys.append(key)
36 setattr(self, key, value)
37
38
39def feature_importance(state):
40 """
41 Return a list containing the feature importances
42 """
43 return []
44
45
46def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
47 """
48 Return default tensorflow.keras model
49 """
50
51 input = Input(shape=(number_of_features,))
52 net = Dense(units=1)(input)
53
54 state = State(Model(input, net))
55
56 state.model.compile(optimizer="adam", loss=binary_crossentropy, metrics=['accuracy'])
57
58 state.model.summary()
59
60 return state
61
62
63def load(obj):
64 """
65 Load Tensorflow.keras model into state
66 """
67 with tempfile.TemporaryDirectory() as temp_path:
68
69 temp_path = pathlib.Path(temp_path)
70
71 file_names = obj[0]
72 for file_index, file_name in enumerate(file_names):
73 path = temp_path.joinpath(pathlib.Path(file_name))
74 path.parents[0].mkdir(parents=True, exist_ok=True)
75
76 with open(path, 'w+b') as file:
77 file.write(bytes(obj[1][file_index]))
78
79 state = State(load_model(pathlib.Path(temp_path) / 'my_model'))
80
81 for index, key in enumerate(obj[2]):
82 setattr(state, key, obj[3][index])
83 return state
84
85
86def apply(state, X):
87 """
88 Apply estimator to passed data.
89 """
90 # convert array input to tensor to avoid creating a new graph for each input
91 # calling the model directly is faster than using the predict method in most of our applications
92 # as we do a loop over events.
93 r = state.model(tf.convert_to_tensor(np.atleast_2d(X), dtype=tf.float32), training=False).numpy()
94 if r.shape[1] == 1:
95 r = r[:, 0] # cannot use squeeze because we might have output of shape [1,X classes]
96 return np.require(r, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
97
98
99def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
100 """
101 Returns just the state object
102 """
103 return state
104
105
106def partial_fit(state, X, S, y, w, epoch, batch):
107 """
108 Pass received data to tensorflow.keras session
109 """
110 if epoch > 0:
111 B2WARNING("The keras training interface has been called with specific_options.m_nIterations > 1."
112 " In the default implementation this should not be done as keras handles the number of epochs internally.")
113
114 if batch > 0:
115 B2WARNING("The keras training interface has been called with specific_options.m_mini_batch_size > 1."
116 " In the default implementation this should not be done as keras handles the number of batches internally.")
117
118 state.model.fit(X, y, batch_size=100, epochs=10)
119 return False
120
121
122def end_fit(state):
123 """
124 Store tensorflow.keras session in a graph
125 """
126
127 with tempfile.TemporaryDirectory() as temp_path:
128
129 temp_path = pathlib.Path(temp_path)
130 state.model.save(temp_path.joinpath('my_model'))
131
132 # this creates:
133 # path/my_model/saved_model.pb
134 # path/my_model/keras_metadata.pb (sometimes)
135 # path/my_model/variables/*
136 # path/my_model/assets/*
137 file_names = [f.relative_to(temp_path) for f in temp_path.rglob('*') if f.is_file()]
138 files = []
139 for file_name in file_names:
140 with open(temp_path.joinpath(file_name), 'rb') as file:
141 files.append(file.read())
142
143 collection_keys = state.collection_keys
144 collections_to_store = []
145 for key in state.collection_keys:
146 collections_to_store.append(getattr(state, key))
147
148 del state
149 return [file_names, files, collection_keys, collections_to_store]
collection_keys
list of keys to save
Definition: keras.py:31
def __init__(self, model=None, **kwargs)
Definition: keras.py:25