16 import tensorflow
as tf
17 import tensorflow.contrib.keras
as keras
20 from keras.layers
import Input, Dense, Concatenate, Lambda
21 from keras.models
import Model, load_model
22 from keras.optimizers
import adam
23 from keras.losses
import binary_crossentropy, sparse_categorical_crossentropy
24 from keras.activations
import sigmoid, tanh, softmax
25 from keras
import backend
as K
26 from keras.callbacks
import Callback, EarlyStopping
27 from keras.utils
import plot_model
30 from basf2_mva_extensions.preprocessing
import fast_equal_frequency_binning
32 from sklearn.metrics
import roc_auc_score
35 warnings.filterwarnings(
'ignore', category=UserWarning)
40 Class to create batches for training the Adversary Network.
41 Once the steps_per_epoch argument is available for the fit function in keras, this class will become obsolete.
47 :param X: Input Features
49 :param Z: Spectators/Quantity to be uncorrelated to
66 Getting next batch of training data
75 return self.
X[batch_index], self.
Y[batch_index], self.
Z[batch_index]
78 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
80 Building 3 keras models:
81 1. Network without adversary, used for apply data.
82 2. Freezed MLP with unfreezed Adverserial Network to train adverserial part of network.
83 3. Unfreezed MLP with freezed adverserial to train MLP part of the network,
84 combined with losses of the adverserial networks.
87 def adversary_loss(signal):
89 Loss for adversaries outputs
90 :param signal: If signal or background distribution should be learned.
91 :return: Loss function for the discriminator part of the Network.
93 back_constant = 0
if signal
else 1
96 return (y[:, 0] - back_constant) * sparse_categorical_crossentropy(y[:, 1:], p)
100 input = Input(shape=(number_of_features,))
103 layer1 = Dense(units=number_of_features + 1, activation=tanh)(input)
104 layer2 = Dense(units=number_of_features + 1, activation=tanh)(layer1)
105 layer3 = Dense(units=number_of_features + 1, activation=tanh)(layer2)
106 output = Dense(units=1, activation=sigmoid)(layer3)
109 apply_model = Model(input, output)
110 apply_model.compile(optimizer=adam(lr=parameters[
'learning_rate']), loss=binary_crossentropy, metrics=[
'accuracy'])
112 state =
State(apply_model, use_adv=parameters[
'lambda'] > 0
and number_of_spectators > 0)
113 state.number_bins = parameters[
'number_bins']
116 adversaries, adversary_losses_model = [], []
118 for mode
in [
'signal',
'background']:
119 for i
in range(number_of_spectators):
120 adversary1 = Dense(units=2 * parameters[
'number_bins'], activation=tanh, trainable=
False)(output)
121 adversary2 = Dense(units=2 * parameters[
'number_bins'], activation=tanh, trainable=
False)(adversary1)
122 adversaries.append(Dense(units=parameters[
'number_bins'], activation=softmax, trainable=
False)(adversary2))
124 adversary_losses_model.append(adversary_loss(mode ==
'signal'))
127 model1 = Model(input, [output] + adversaries)
128 model1.compile(optimizer=adam(lr=parameters[
'learning_rate']),
129 loss=[binary_crossentropy] + adversary_losses_model, metrics=[
'accuracy'],
130 loss_weights=[1] + [-parameters[
'lambda']] *
len(adversary_losses_model))
134 model2 = Model(input, adversaries)
136 for layer
in model2.layers:
137 layer.trainable =
not layer.trainable
139 model2.compile(optimizer=adam(lr=parameters[
'learning_rate']), loss=adversary_losses_model,
140 metrics=[
'accuracy'])
143 state.forward_model, state.adv_model = model1, model2
144 state.K = parameters[
'adversary_steps']
147 plot_model(model1, to_file=
'model.png', show_shapes=
True)
152 def begin_fit(state, Xtest, Stest, ytest, wtest):
154 Save Validation Data for monitoring Training
163 def partial_fit(state, X, S, y, w, epoch):
166 For every training step of MLP. Adverserial Network will be trained K times.
169 def build_adversary_target(p_y, p_s):
171 Concat isSignal and spectator bins, because both are target information for the adversary.
173 return [np.concatenate((p_y, i), axis=1)
for i
in np.split(p_s,
len(p_s[0]), axis=1)] * 2
177 preprocessor = fast_equal_frequency_binning()
178 preprocessor.fit(S, number_of_bins=state.number_bins)
179 S = preprocessor.apply(S) * state.number_bins
180 state.Stest = preprocessor.apply(state.Stest) * state.number_bins
182 target_array = build_adversary_target(y, S)
183 target_val_array = build_adversary_target(state.ytest, state.Stest)
187 class AUC_Callback(keras.callbacks.Callback):
189 Callback to print AUC after every epoch.
192 def on_train_begin(self, logs={}):
195 def on_epoch_end(self, epoch, logs={}):
196 val_y_pred = state.model.predict(state.Xtest).flatten()
197 val_auc = roc_auc_score(state.ytest, val_y_pred)
198 print(
'\nTest AUC: {}\n'.format(val_auc))
199 self.val_aucs.append(val_auc)
202 class Adversary(keras.callbacks.Callback):
204 Callback to train Adversary
207 def on_batch_end(self, batch, logs={}):
208 v_X, v_y, v_S = state.batch_gen.next_batch(400 * state.K)
209 target_adversary = build_adversary_target(v_y, v_S)
210 state.adv_model.fit(v_X, target_adversary, verbose=0, batch_size=400)
212 if not state.use_adv:
213 state.model.fit(X, y, batch_size=400, epochs=1000, validation_data=(state.Xtest, state.ytest),
214 callbacks=[EarlyStopping(monitor=
'val_loss', patience=2, mode=
'min'), AUC_Callback()])
216 state.forward_model.fit(X, [y] + target_array, batch_size=400, epochs=1000,
217 callbacks=[EarlyStopping(monitor=
'val_loss', patience=2, mode=
'min'), AUC_Callback(), Adversary()],
218 validation_data=(state.Xtest, [state.ytest] + target_val_array))
222 if __name__ ==
"__main__":
223 from basf2
import conditions
225 conditions.testing_payloads = [
226 'localdb/database.txt'
229 variables = [
'p',
'pt',
'pz',
'phi',
230 'daughter(0, p)',
'daughter(0, pz)',
'daughter(0, pt)',
'daughter(0, phi)',
231 'daughter(1, p)',
'daughter(1, pz)',
'daughter(1, pt)',
'daughter(1, phi)',
232 'daughter(2, p)',
'daughter(2, pz)',
'daughter(2, pt)',
'daughter(2, phi)',
233 'chiProb',
'dr',
'dz',
'dphi',
234 'daughter(0, dr)',
'daughter(1, dr)',
'daughter(0, dz)',
'daughter(1, dz)',
235 'daughter(0, dphi)',
'daughter(1, dphi)',
236 'daughter(0, chiProb)',
'daughter(1, chiProb)',
'daughter(2, chiProb)',
237 'daughter(0, kaonID)',
'daughter(0, pionID)',
'daughter(1, kaonID)',
'daughter(1, pionID)',
238 'daughterAngle(0, 1)',
'daughterAngle(0, 2)',
'daughterAngle(1, 2)',
239 'daughter(2, daughter(0, E))',
'daughter(2, daughter(1, E))',
240 'daughter(2, daughter(0, clusterTiming))',
'daughter(2, daughter(1, clusterTiming))',
241 'daughter(2, daughter(0, clusterE9E25))',
'daughter(2, daughter(1, clusterE9E25))',
242 'daughter(2, daughter(0, minC2HDist))',
'daughter(2, daughter(1, minC2HDist))',
245 variables2 = [
'p',
'pt',
'pz',
'phi',
246 'chiProb',
'dr',
'dz',
'dphi',
247 'daughter(2, chiProb)',
248 'daughter(0, kaonID)',
'daughter(0, pionID)',
'daughter(1, kaonID)',
'daughter(1, pionID)',
249 'daughter(2, daughter(0, E))',
'daughter(2, daughter(1, E))',
250 'daughter(2, daughter(0, clusterTiming))',
'daughter(2, daughter(1, clusterTiming))',
251 'daughter(2, daughter(0, clusterE9E25))',
'daughter(2, daughter(1, clusterE9E25))',
252 'daughter(2, daughter(0, minC2HDist))',
'daughter(2, daughter(1, minC2HDist))']
254 general_options = basf2_mva.GeneralOptions()
255 general_options.m_datafiles = basf2_mva.vector(
"train.root")
256 general_options.m_treename =
"tree"
257 general_options.m_variables = basf2_mva.vector(*variables)
258 general_options.m_spectators = basf2_mva.vector(
'daughterInvariantMass(0, 1)',
'daughterInvariantMass(0, 2)')
259 general_options.m_target_variable =
"isSignal"
260 general_options.m_identifier =
"keras"
262 specific_options = basf2_mva.PythonOptions()
263 specific_options.m_framework =
"contrib_keras"
264 specific_options.m_steering_file =
'mva/examples/keras/adversary_network.py'
265 specific_options.m_normalize =
True
266 specific_options.m_training_fraction = 0.9
269 Config for Adversary Networks:
270 lambda: Trade off between classifier performance and noncorrelation between classifier output and spectators.
271 Increase to reduce correlations(and also classifier performance)
272 adversary_steps: How many batches the discriminator is trained after one batch of training the classifier.
273 Less steps make the training faster but also unstable. Increase the parameter if something isn't working.
274 number_bins: Number of Bins which are used to quantify the spectators. 10 should be sufficient.
276 specific_options.m_config =
'{"adversary_steps": 5, "learning_rate": 0.001, "lambda": 20.0, "number_bins": 10}'
277 basf2_mva.teacher(general_options, specific_options)
279 general_options.m_identifier =
"keras_baseline"
280 specific_options.m_config =
'{"adversary_steps": 1, "learning_rate": 0.001, "lambda": 0.0, "number_bins": 10}'
281 basf2_mva.teacher(general_options, specific_options)
283 general_options.m_variables = basf2_mva.vector(*variables2)
284 general_options.m_identifier =
"keras_feature_drop"
285 basf2_mva.teacher(general_options, specific_options)