22 import tensorflow
as tf
23 import tensorflow.contrib.keras
as keras
25 from keras.layers
import Input, Dense, Concatenate, Dropout, Lambda, GlobalAveragePooling1D, Reshape
26 from keras.models
import Model, load_model
27 from keras.optimizers
import adam
28 from keras.losses
import binary_crossentropy, sparse_categorical_crossentropy
29 from keras.activations
import sigmoid, tanh, softmax
30 from keras.callbacks
import Callback, EarlyStopping
31 from sklearn.metrics
import roc_auc_score
36 from basf2_mva_extensions.keras_relational
import Relations, EnhancedRelations
37 from basf2_mva_extensions.preprocessing
import fast_equal_frequency_binning
40 warnings.filterwarnings(
'ignore', category=UserWarning)
43 def slice(input, begin, end):
45 Simple function for slicing feature in tensors.
47 return input[:, begin:end]
52 Class to create batches for training the Adversary Network.
53 See mva/examples/keras/adversary_network.py for details.
59 :param X: Input Features
61 :param Z: Spectaters/Qunatity to be uncorrelated to
79 Getting next batch of training data
88 return self.
X[batch_index], self.
Y[batch_index], self.
Z[batch_index]
91 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
93 Build the keras model for training.
95 def adversary_loss(signal):
97 Loss for adversaries outputs
98 :param signal: If signal or background distribution should be learned.
99 :return: Loss function for the discriminator part of the Network.
101 back_constant = 0
if signal
else 1
104 return (y[:, 0] - back_constant) * sparse_categorical_crossentropy(y[:, 1:], p)
107 param = {
'use_relation_layers':
False,
'lambda': 0,
'number_bins': 10,
'adversary_steps': 5}
109 if isinstance(parameters, dict):
110 param.update(parameters)
119 input = Input((number_of_features,))
124 if param[
'use_relation_layers']:
125 low_level_input = Lambda(slice, arguments={
'begin': 0,
'end': 560})(input)
126 high_level_input = Lambda(slice, arguments={
'begin': 560,
'end': 590})(input)
127 relations_tracks = Lambda(slice, arguments={
'begin': 0,
'end': 340})(low_level_input)
128 relations_tracks = Reshape((20, 17))(relations_tracks)
129 relations_clusters = Lambda(slice, arguments={
'begin': 340,
'end': 560})(low_level_input)
130 relations_clusters = Reshape((20, 11))(relations_clusters)
132 relations1 = EnhancedRelations(number_features=20, hidden_feature_shape=[
133 80, 80, 80])([relations_tracks, high_level_input])
134 relations2 = EnhancedRelations(number_features=20, hidden_feature_shape=[
135 80, 80, 80])([relations_clusters, high_level_input])
137 relations_output1 = GlobalAveragePooling1D()(relations1)
138 relations_output2 = GlobalAveragePooling1D()(relations2)
140 net = Concatenate()([relations_output1, relations_output2])
142 net = Dense(units=100, activation=tanh)(net)
143 net = Dropout(0.5)(net)
144 net = Dense(units=100, activation=tanh)(net)
145 net = Dropout(0.5)(net)
148 net = Dense(units=50, activation=tanh)(input)
149 net = Dense(units=50, activation=tanh)(net)
150 net = Dense(units=50, activation=tanh)(net)
152 output = Dense(units=1, activation=sigmoid)(net)
155 apply_model = Model(input, output)
156 apply_model.compile(optimizer=adam(), loss=binary_crossentropy, metrics=[
'accuracy'])
158 state =
State(apply_model, use_adv=param[
'lambda'] > 0
and number_of_spectators > 0, preprocessor_state=
None,
159 custom_objects={
'EnhancedRelations': EnhancedRelations})
164 adversaries, adversary_losses_model = [], []
165 for mode
in [
'signal',
'background']:
166 for i
in range(number_of_spectators):
167 adversary1 = Dense(units=2 * param[
'number_bins'], activation=tanh, trainable=
False)(output)
168 adversary2 = Dense(units=2 * param[
'number_bins'], activation=tanh, trainable=
False)(adversary1)
169 adversaries.append(Dense(units=param[
'number_bins'], activation=softmax, trainable=
False)(adversary2))
171 adversary_losses_model.append(adversary_loss(mode ==
'signal'))
174 model1 = Model(input, [output] + adversaries)
175 model1.compile(optimizer=adam(),
176 loss=[binary_crossentropy] + adversary_losses_model, metrics=[
'accuracy'],
177 loss_weights=[1] + [-parameters[
'lambda']] *
len(adversary_losses_model))
181 model2 = Model(input, adversaries)
183 for layer
in model2.layers:
184 layer.trainable =
not layer.trainable
186 model2.compile(optimizer=adam(), loss=adversary_losses_model,
187 metrics=[
'accuracy'])
190 state.forward_model, state.adv_model = model1, model2
191 state.K = parameters[
'adversary_steps']
192 state.number_bins = param[
'number_bins']
197 def begin_fit(state, Xtest, Stest, ytest, wtest):
199 Save Validation Data for monitoring Training
208 def partial_fit(state, X, S, y, w, epoch):
211 For every training step of MLP. Adverserial Network (if used) will be trained K times.
215 preprocessor = fast_equal_frequency_binning()
216 preprocessor.fit(X, number_of_bins=500)
217 X = preprocessor.apply(X)
218 state.Xtest = preprocessor.apply(state.Xtest)
220 state.preprocessor_state = preprocessor.export_state()
222 def build_adversary_target(p_y, p_s):
224 Concat isSignal and spectator bins, because both are target information for the adversary.
226 return [np.concatenate((p_y, i), axis=1)
for i
in np.split(p_s,
len(p_s[0]), axis=1)] * 2
230 S_preprocessor = fast_equal_frequency_binning()
231 S_preprocessor.fit(S, number_of_bins=state.number_bins)
232 S = S_preprocessor.apply(S) * state.number_bins
233 state.Stest = S_preprocessor.apply(state.Stest) * state.number_bins
235 target_array = build_adversary_target(y, S)
236 target_val_array = build_adversary_target(state.ytest, state.Stest)
240 class AUC_Callback(keras.callbacks.Callback):
242 Callback to print AUC after every epoch.
245 def on_train_begin(self, logs={}):
248 def on_epoch_end(self, epoch, logs={}):
249 val_y_pred = state.model.predict(state.Xtest).flatten()
250 val_auc = roc_auc_score(state.ytest, val_y_pred)
251 print(
'\nTest AUC: {}\n'.format(val_auc))
252 self.val_aucs.append(val_auc)
255 class Adversary(keras.callbacks.Callback):
257 Callback to train Adversary
260 def on_batch_end(self, batch, logs={}):
261 v_X, v_y, v_S = state.batch_gen.next_batch(500 * state.K)
262 target_adversary = build_adversary_target(v_y, v_S)
263 state.adv_model.fit(v_X, target_adversary, verbose=0, batch_size=500)
265 if not state.use_adv:
266 state.model.fit(X, y, batch_size=500, epochs=100000, validation_data=(state.Xtest, state.ytest),
267 callbacks=[EarlyStopping(monitor=
'val_loss', patience=10, mode=
'min'), AUC_Callback()])
269 state.forward_model.fit(X, [y] + target_array, batch_size=500, epochs=100000,
270 callbacks=[EarlyStopping(monitor=
'val_loss', patience=10, mode=
'min'), AUC_Callback(), Adversary()],
271 validation_data=(state.Xtest, [state.ytest] + target_val_array))
277 Apply estimator to passed data.
278 Has to be overwritten, because also the expert has to apply preprocessing.
281 preprocessor = fast_equal_frequency_binning(state.preprocessor_state)
283 X = preprocessor.apply(X)
285 r = state.model.predict(X).flatten()
286 return np.require(r, dtype=np.float32, requirements=[
'A',
'W',
'C',
'O'])