Belle II Software release-09-00-00
adversary_network.py
1#!/usr/bin/env python3
2
3
10
11# This example shows how to remove bias on one or several spectator variables.
12# Relevant paper: https://arxiv.org/abs/1611.01046
13# use basf2_mva_evaluation.py with train.root and test.root at the end to see the impact on the spectator variables.
14
15import basf2_mva
17
18from keras.layers import Dense, Input
19from keras.models import Model
20from keras.optimizers import Adam
21from keras.losses import binary_crossentropy, sparse_categorical_crossentropy
22from keras.activations import sigmoid, tanh, softmax
23from keras.callbacks import EarlyStopping, Callback
24
25import numpy as np
26from basf2_mva_extensions.preprocessing import fast_equal_frequency_binning
27
28from sklearn.metrics import roc_auc_score
29
30import warnings
31warnings.filterwarnings('ignore', category=UserWarning)
32
33
35 """
36 Class to create batches for training the Adversary Network.
37 Once the steps_per_epoch argument is available for the fit function in keras, this class will become obsolete.
38 """
39
40 def __init__(self, X, Y, Z):
41 """
42 Init the class
43 :param X: Input Features
44 :param Y: Label Data
45 :param Z: Spectators/Quantity to be uncorrelated to
46 """
47
48 self.X = X
49
50 self.Y = Y
51
52 self.Z = Z
53
54 self.len = len(Y)
55
56 self.index_array = np.arange(self.len)
57
58 self.pointer = 0
59
60 def next_batch(self, batch_size):
61 """
62 Getting next batch of training data
63 """
64 if self.pointer + batch_size >= self.len:
65 np.random.shuffle(self.index_array)
66 self.pointer = 0
67
68 batch_index = self.index_array[self.pointer:self.pointer + batch_size]
69 self.pointer += batch_size
70
71 return self.X[batch_index], self.Y[batch_index], self.Z[batch_index]
72
73
74def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
75 """
76 Building 3 keras models:
77 1. Network without adversary, used for apply data.
78 2. Frozen MLP with unfrozen Adverserial Network to train adverserial part of network.
79 3. Unfrozen MLP with frozen adverserial to train MLP part of the network,
80 combined with losses of the adverserial networks.
81 """
82
83 def adversary_loss(signal):
84 """
85 Loss for adversaries outputs
86 :param signal: If signal or background distribution should be learned.
87 :return: Loss function for the discriminator part of the Network.
88 """
89 back_constant = 0 if signal else 1
90
91 def adv_loss(y, p):
92 return (y[:, 0] - back_constant) * sparse_categorical_crossentropy(y[:, 1:], p)
93 return adv_loss
94
95 # Define inputs for input_feature and spectator
96 input = Input(shape=(number_of_features,))
97
98 # build first model which will produce the desired discriminator
99 layer1 = Dense(units=number_of_features + 1, activation=tanh)(input)
100 layer2 = Dense(units=number_of_features + 1, activation=tanh)(layer1)
101 layer3 = Dense(units=number_of_features + 1, activation=tanh)(layer2)
102 output = Dense(units=1, activation=sigmoid)(layer3)
103
104 # Model for applying Data. Loss function will not be used for training, if adversary is used.
105 apply_model = Model(input, output)
106 apply_model.compile(optimizer=Adam(learning_rate=parameters['learning_rate']), loss=binary_crossentropy, metrics=['accuracy'])
107
108 state = State(apply_model, use_adv=parameters['lambda'] > 0 and number_of_spectators > 0)
109 state.number_bins = parameters['number_bins']
110
111 # build second model on top of the first one which will try to predict spectators
112 adversaries, adversary_losses_model = [], []
113 if state.use_adv:
114 for mode in ['signal', 'background']:
115 for i in range(number_of_spectators):
116 adversary1 = Dense(units=2 * parameters['number_bins'], activation=tanh, trainable=False)(output)
117 adversary2 = Dense(units=2 * parameters['number_bins'], activation=tanh, trainable=False)(adversary1)
118 adversaries.append(Dense(units=parameters['number_bins'], activation=softmax, trainable=False)(adversary2))
119
120 adversary_losses_model.append(adversary_loss(mode == 'signal'))
121
122 # Model which trains first part of the net
123 model1 = Model(input, [output] + adversaries)
124 model1.compile(optimizer=Adam(learning_rate=parameters['learning_rate']),
125 loss=[binary_crossentropy] + adversary_losses_model, metrics=['accuracy'] * (len(adversaries) + 1),
126 loss_weights=[1] + [-parameters['lambda']] * len(adversary_losses_model))
127 model1.summary()
128
129 # Model which train second, adversary part of the net
130 model2 = Model(input, adversaries)
131 # freeze everything except adversary layers
132 for layer in model2.layers:
133 layer.trainable = not layer.trainable
134
135 model2.compile(optimizer=Adam(learning_rate=parameters['learning_rate']), loss=adversary_losses_model,
136 metrics=['accuracy'] * len(adversaries))
137 model2.summary()
138
139 state.forward_model, state.adv_model = model1, model2
140 state.K = parameters['adversary_steps']
141
142 return state
143
144
145def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
146 """
147 Save Validation Data for monitoring Training
148 """
149 state.Xtest = Xtest
150 state.Stest = Stest
151 state.ytest = ytest
152
153 return state
154
155
156def partial_fit(state, X, S, y, w, epoch, batch):
157 """
158 Fit the model.
159 For every training step of MLP. Adverserial Network will be trained K times.
160 """
161
162 def build_adversary_target(p_y, p_s):
163 """
164 Concat isSignal and spectator bins, because both are target information for the adversary.
165 """
166 return [np.concatenate((p_y, i), axis=1) for i in np.split(p_s, len(p_s[0]), axis=1)] * 2
167
168 if state.use_adv:
169 # Get bin numbers of S with equal frequency binning
170 preprocessor = fast_equal_frequency_binning()
171 preprocessor.fit(S, number_of_bins=state.number_bins)
172 S = preprocessor.apply(S) * state.number_bins
173 state.Stest = preprocessor.apply(state.Stest) * state.number_bins
174 # Build target for adversary loss function
175 target_array = build_adversary_target(y, S)
176 target_val_array = build_adversary_target(state.ytest, state.Stest)
177 # Build Batch Generator for adversary Callback
178 state.batch_gen = batch_generator(X, y, S)
179
180 class AUC_Callback(Callback):
181 """
182 Callback to print AUC after every epoch.
183 """
184
185 def on_train_begin(self, logs=None):
186 self.val_aucs = []
187
188 def on_epoch_end(self, epoch, logs=None):
189 val_y_pred = state.model.predict(state.Xtest).flatten()
190 val_auc = roc_auc_score(state.ytest, val_y_pred)
191 print(f'\nTest AUC: {val_auc}\n')
192 self.val_aucs.append(val_auc)
193 return
194
195 class Adversary(Callback):
196 """
197 Callback to train Adversary
198 """
199
200 def on_batch_end(self, batch, logs=None):
201 v_X, v_y, v_S = state.batch_gen.next_batch(400 * state.K)
202 target_adversary = build_adversary_target(v_y, v_S)
203 state.adv_model.fit(v_X, target_adversary, verbose=0, batch_size=400)
204
205 if not state.use_adv:
206 state.model.fit(X, y, batch_size=400, epochs=1000, validation_data=(state.Xtest, state.ytest),
207 callbacks=[EarlyStopping(monitor='val_loss', patience=2, mode='min'), AUC_Callback()])
208 else:
209 state.forward_model.fit(X, [y] + target_array, batch_size=400, epochs=1000,
210 callbacks=[EarlyStopping(monitor='val_loss', patience=2, mode='min'), AUC_Callback(), Adversary()],
211 validation_data=(state.Xtest, [state.ytest] + target_val_array))
212 return False
213
214
215if __name__ == "__main__":
216 from basf2 import conditions, find_file
217 # NOTE: do not use testing payloads in production! Any results obtained like this WILL NOT BE PUBLISHED
218 conditions.testing_payloads = [
219 'localdb/database.txt'
220 ]
221
222 variables = ['p', 'pt', 'pz', 'phi',
223 'daughter(0, p)', 'daughter(0, pz)', 'daughter(0, pt)', 'daughter(0, phi)',
224 'daughter(1, p)', 'daughter(1, pz)', 'daughter(1, pt)', 'daughter(1, phi)',
225 'daughter(2, p)', 'daughter(2, pz)', 'daughter(2, pt)', 'daughter(2, phi)',
226 'chiProb', 'dr', 'dz', 'dphi',
227 'daughter(0, dr)', 'daughter(1, dr)', 'daughter(0, dz)', 'daughter(1, dz)',
228 'daughter(0, dphi)', 'daughter(1, dphi)',
229 'daughter(0, chiProb)', 'daughter(1, chiProb)', 'daughter(2, chiProb)',
230 'daughter(0, kaonID)', 'daughter(0, pionID)', 'daughter(1, kaonID)', 'daughter(1, pionID)',
231 'daughterAngle(0, 1)', 'daughterAngle(0, 2)', 'daughterAngle(1, 2)',
232 'daughter(2, daughter(0, E))', 'daughter(2, daughter(1, E))',
233 'daughter(2, daughter(0, clusterTiming))', 'daughter(2, daughter(1, clusterTiming))',
234 'daughter(2, daughter(0, clusterE9E25))', 'daughter(2, daughter(1, clusterE9E25))',
235 'daughter(2, daughter(0, minC2TDist))', 'daughter(2, daughter(1, minC2TDist))',
236 'M']
237
238 variables2 = ['p', 'pt', 'pz', 'phi',
239 'chiProb', 'dr', 'dz', 'dphi',
240 'daughter(2, chiProb)',
241 'daughter(0, kaonID)', 'daughter(0, pionID)', 'daughter(1, kaonID)', 'daughter(1, pionID)',
242 'daughter(2, daughter(0, E))', 'daughter(2, daughter(1, E))',
243 'daughter(2, daughter(0, clusterTiming))', 'daughter(2, daughter(1, clusterTiming))',
244 'daughter(2, daughter(0, clusterE9E25))', 'daughter(2, daughter(1, clusterE9E25))',
245 'daughter(2, daughter(0, minC2TDist))', 'daughter(2, daughter(1, minC2TDist))']
246
247 train_file = find_file("mva/train_D0toKpipi.root", "examples")
248 training_data = basf2_mva.vector(train_file)
249
250 general_options = basf2_mva.GeneralOptions()
251 general_options.m_datafiles = training_data
252 general_options.m_treename = "tree"
253 general_options.m_variables = basf2_mva.vector(*variables)
254 general_options.m_spectators = basf2_mva.vector('daughterInvM(0, 1)', 'daughterInvM(0, 2)')
255 general_options.m_target_variable = "isSignal"
256 general_options.m_identifier = "keras_adversary"
257
258 specific_options = basf2_mva.PythonOptions()
259 specific_options.m_framework = "keras"
260 specific_options.m_steering_file = 'mva/examples/keras/adversary_network.py'
261 specific_options.m_normalize = True
262 specific_options.m_training_fraction = 0.9
263
264 """
265 Config for Adversary Networks:
266 lambda: Trade off between classifier performance and noncorrelation between classifier output and spectators.
267 Increase to reduce correlations(and also classifier performance)
268 adversary_steps: How many batches the discriminator is trained after one batch of training the classifier.
269 Less steps make the training faster but also unstable. Increase the parameter if something isn't working.
270 number_bins: Number of Bins which are used to quantify the spectators. 10 should be sufficient.
271 """
272 specific_options.m_config = '{"adversary_steps": 5, "learning_rate": 0.001, "lambda": 20.0, "number_bins": 10}'
273 basf2_mva.teacher(general_options, specific_options)
274
275 general_options.m_identifier = "keras_baseline"
276 specific_options.m_config = '{"adversary_steps": 1, "learning_rate": 0.001, "lambda": 0.0, "number_bins": 10}'
277 basf2_mva.teacher(general_options, specific_options)
278
279 general_options.m_variables = basf2_mva.vector(*variables2)
280 general_options.m_identifier = "keras_feature_drop"
281 basf2_mva.teacher(general_options, specific_options)
pointer
Pointer to the current start of the batch.
Z
Spectatirs/Quantity to be uncorrelated to.
index_array
Index array containing indices from 0 to len.