Belle II Software light-2406-ragdoll
adversary_network.py
1#!/usr/bin/env python3
2
3
10
11# This example shows how to remove bias on one or several spectator variables.
12# Relevant paper: https://arxiv.org/abs/1611.01046
13# use basf2_mva_evaluation.py with train.root and test.root at the end to see the impact on the spectator variables.
14
15import basf2_mva
17
18from tensorflow.keras.layers import Dense, Input
19from tensorflow.keras.models import Model
20from tensorflow.keras.optimizers import Adam
21from tensorflow.keras.losses import binary_crossentropy, sparse_categorical_crossentropy
22from tensorflow.keras.activations import sigmoid, tanh, softmax
23from tensorflow.keras.callbacks import EarlyStopping, Callback
24from tensorflow.keras.utils import plot_model
25
26import numpy as np
27from basf2_mva_extensions.preprocessing import fast_equal_frequency_binning
28
29from sklearn.metrics import roc_auc_score
30
31import warnings
32warnings.filterwarnings('ignore', category=UserWarning)
33
34
36 """
37 Class to create batches for training the Adversary Network.
38 Once the steps_per_epoch argument is available for the fit function in keras, this class will become obsolete.
39 """
40
41 def __init__(self, X, Y, Z):
42 """
43 Init the class
44 :param X: Input Features
45 :param Y: Label Data
46 :param Z: Spectators/Quantity to be uncorrelated to
47 """
48
49 self.X = X
50
51 self.Y = Y
52
53 self.Z = Z
54
55 self.len = len(Y)
56
57 self.index_array = np.arange(self.len)
58
59 self.pointer = 0
60
61 def next_batch(self, batch_size):
62 """
63 Getting next batch of training data
64 """
65 if self.pointer + batch_size >= self.len:
66 np.random.shuffle(self.index_array)
67 self.pointer = 0
68
69 batch_index = self.index_array[self.pointer:self.pointer + batch_size]
70 self.pointer += batch_size
71
72 return self.X[batch_index], self.Y[batch_index], self.Z[batch_index]
73
74
75def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
76 """
77 Building 3 keras models:
78 1. Network without adversary, used for apply data.
79 2. Freezed MLP with unfreezed Adverserial Network to train adverserial part of network.
80 3. Unfreezed MLP with freezed adverserial to train MLP part of the network,
81 combined with losses of the adverserial networks.
82 """
83
84 def adversary_loss(signal):
85 """
86 Loss for adversaries outputs
87 :param signal: If signal or background distribution should be learned.
88 :return: Loss function for the discriminator part of the Network.
89 """
90 back_constant = 0 if signal else 1
91
92 def adv_loss(y, p):
93 return (y[:, 0] - back_constant) * sparse_categorical_crossentropy(y[:, 1:], p)
94 return adv_loss
95
96 # Define inputs for input_feature and spectator
97 input = Input(shape=(number_of_features,))
98
99 # build first model which will produce the desired discriminator
100 layer1 = Dense(units=number_of_features + 1, activation=tanh)(input)
101 layer2 = Dense(units=number_of_features + 1, activation=tanh)(layer1)
102 layer3 = Dense(units=number_of_features + 1, activation=tanh)(layer2)
103 output = Dense(units=1, activation=sigmoid)(layer3)
104
105 # Model for applying Data. Loss function will not be used for training, if adversary is used.
106 apply_model = Model(input, output)
107 apply_model.compile(optimizer=Adam(lr=parameters['learning_rate']), loss=binary_crossentropy, metrics=['accuracy'])
108
109 state = State(apply_model, use_adv=parameters['lambda'] > 0 and number_of_spectators > 0)
110 state.number_bins = parameters['number_bins']
111
112 # build second model on top of the first one which will try to predict spectators
113 adversaries, adversary_losses_model = [], []
114 if state.use_adv:
115 for mode in ['signal', 'background']:
116 for i in range(number_of_spectators):
117 adversary1 = Dense(units=2 * parameters['number_bins'], activation=tanh, trainable=False)(output)
118 adversary2 = Dense(units=2 * parameters['number_bins'], activation=tanh, trainable=False)(adversary1)
119 adversaries.append(Dense(units=parameters['number_bins'], activation=softmax, trainable=False)(adversary2))
120
121 adversary_losses_model.append(adversary_loss(mode == 'signal'))
122
123 # Model which trains first part of the net
124 model1 = Model(input, [output] + adversaries)
125 model1.compile(optimizer=Adam(lr=parameters['learning_rate']),
126 loss=[binary_crossentropy] + adversary_losses_model, metrics=['accuracy'],
127 loss_weights=[1] + [-parameters['lambda']] * len(adversary_losses_model))
128 model1.summary()
129
130 # Model which train second, adversary part of the net
131 model2 = Model(input, adversaries)
132 # freeze everything except adversary layers
133 for layer in model2.layers:
134 layer.trainable = not layer.trainable
135
136 model2.compile(optimizer=Adam(lr=parameters['learning_rate']), loss=adversary_losses_model,
137 metrics=['accuracy'])
138 model2.summary()
139
140 state.forward_model, state.adv_model = model1, model2
141 state.K = parameters['adversary_steps']
142
143 # draw model as a picture
144 plot_model(model1, to_file='model.png', show_shapes=True)
145
146 return state
147
148
149def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
150 """
151 Save Validation Data for monitoring Training
152 """
153 state.Xtest = Xtest
154 state.Stest = Stest
155 state.ytest = ytest
156
157 return state
158
159
160def partial_fit(state, X, S, y, w, epoch, batch):
161 """
162 Fit the model.
163 For every training step of MLP. Adverserial Network will be trained K times.
164 """
165
166 def build_adversary_target(p_y, p_s):
167 """
168 Concat isSignal and spectator bins, because both are target information for the adversary.
169 """
170 return [np.concatenate((p_y, i), axis=1) for i in np.split(p_s, len(p_s[0]), axis=1)] * 2
171
172 if state.use_adv:
173 # Get bin numbers of S with equal frequency binning
174 preprocessor = fast_equal_frequency_binning()
175 preprocessor.fit(S, number_of_bins=state.number_bins)
176 S = preprocessor.apply(S) * state.number_bins
177 state.Stest = preprocessor.apply(state.Stest) * state.number_bins
178 # Build target for adversary loss function
179 target_array = build_adversary_target(y, S)
180 target_val_array = build_adversary_target(state.ytest, state.Stest)
181 # Build Batch Generator for adversary Callback
182 state.batch_gen = batch_generator(X, y, S)
183
184 class AUC_Callback(Callback):
185 """
186 Callback to print AUC after every epoch.
187 """
188
189 def on_train_begin(self, logs=None):
190 self.val_aucs = []
191
192 def on_epoch_end(self, epoch, logs=None):
193 val_y_pred = state.model.predict(state.Xtest).flatten()
194 val_auc = roc_auc_score(state.ytest, val_y_pred)
195 print(f'\nTest AUC: {val_auc}\n')
196 self.val_aucs.append(val_auc)
197 return
198
199 class Adversary(Callback):
200 """
201 Callback to train Adversary
202 """
203
204 def on_batch_end(self, batch, logs=None):
205 v_X, v_y, v_S = state.batch_gen.next_batch(400 * state.K)
206 target_adversary = build_adversary_target(v_y, v_S)
207 state.adv_model.fit(v_X, target_adversary, verbose=0, batch_size=400)
208
209 if not state.use_adv:
210 state.model.fit(X, y, batch_size=400, epochs=1000, validation_data=(state.Xtest, state.ytest),
211 callbacks=[EarlyStopping(monitor='val_loss', patience=2, mode='min'), AUC_Callback()])
212 else:
213 state.forward_model.fit(X, [y] + target_array, batch_size=400, epochs=1000,
214 callbacks=[EarlyStopping(monitor='val_loss', patience=2, mode='min'), AUC_Callback(), Adversary()],
215 validation_data=(state.Xtest, [state.ytest] + target_val_array))
216 return False
217
218
219if __name__ == "__main__":
220 from basf2 import conditions, find_file
221 # NOTE: do not use testing payloads in production! Any results obtained like this WILL NOT BE PUBLISHED
222 conditions.testing_payloads = [
223 'localdb/database.txt'
224 ]
225
226 variables = ['p', 'pt', 'pz', 'phi',
227 'daughter(0, p)', 'daughter(0, pz)', 'daughter(0, pt)', 'daughter(0, phi)',
228 'daughter(1, p)', 'daughter(1, pz)', 'daughter(1, pt)', 'daughter(1, phi)',
229 'daughter(2, p)', 'daughter(2, pz)', 'daughter(2, pt)', 'daughter(2, phi)',
230 'chiProb', 'dr', 'dz', 'dphi',
231 'daughter(0, dr)', 'daughter(1, dr)', 'daughter(0, dz)', 'daughter(1, dz)',
232 'daughter(0, dphi)', 'daughter(1, dphi)',
233 'daughter(0, chiProb)', 'daughter(1, chiProb)', 'daughter(2, chiProb)',
234 'daughter(0, kaonID)', 'daughter(0, pionID)', 'daughter(1, kaonID)', 'daughter(1, pionID)',
235 'daughterAngle(0, 1)', 'daughterAngle(0, 2)', 'daughterAngle(1, 2)',
236 'daughter(2, daughter(0, E))', 'daughter(2, daughter(1, E))',
237 'daughter(2, daughter(0, clusterTiming))', 'daughter(2, daughter(1, clusterTiming))',
238 'daughter(2, daughter(0, clusterE9E25))', 'daughter(2, daughter(1, clusterE9E25))',
239 'daughter(2, daughter(0, minC2TDist))', 'daughter(2, daughter(1, minC2TDist))',
240 'M']
241
242 variables2 = ['p', 'pt', 'pz', 'phi',
243 'chiProb', 'dr', 'dz', 'dphi',
244 'daughter(2, chiProb)',
245 'daughter(0, kaonID)', 'daughter(0, pionID)', 'daughter(1, kaonID)', 'daughter(1, pionID)',
246 'daughter(2, daughter(0, E))', 'daughter(2, daughter(1, E))',
247 'daughter(2, daughter(0, clusterTiming))', 'daughter(2, daughter(1, clusterTiming))',
248 'daughter(2, daughter(0, clusterE9E25))', 'daughter(2, daughter(1, clusterE9E25))',
249 'daughter(2, daughter(0, minC2TDist))', 'daughter(2, daughter(1, minC2TDist))']
250
251 train_file = find_file("mva/train_D0toKpipi.root", "examples")
252 training_data = basf2_mva.vector(train_file)
253
254 general_options = basf2_mva.GeneralOptions()
255 general_options.m_datafiles = training_data
256 general_options.m_treename = "tree"
257 general_options.m_variables = basf2_mva.vector(*variables)
258 general_options.m_spectators = basf2_mva.vector('daughterInvM(0, 1)', 'daughterInvM(0, 2)')
259 general_options.m_target_variable = "isSignal"
260 general_options.m_identifier = "keras_adversary"
261
262 specific_options = basf2_mva.PythonOptions()
263 specific_options.m_framework = "keras"
264 specific_options.m_steering_file = 'mva/examples/keras/adversary_network.py'
265 specific_options.m_normalize = True
266 specific_options.m_training_fraction = 0.9
267
268 """
269 Config for Adversary Networks:
270 lambda: Trade off between classifier performance and noncorrelation between classifier output and spectators.
271 Increase to reduce correlations(and also classifier performance)
272 adversary_steps: How many batches the discriminator is trained after one batch of training the classifier.
273 Less steps make the training faster but also unstable. Increase the parameter if something isn't working.
274 number_bins: Number of Bins which are used to quantify the spectators. 10 should be sufficient.
275 """
276 specific_options.m_config = '{"adversary_steps": 5, "learning_rate": 0.001, "lambda": 20.0, "number_bins": 10}'
277 basf2_mva.teacher(general_options, specific_options)
278
279 general_options.m_identifier = "keras_baseline"
280 specific_options.m_config = '{"adversary_steps": 1, "learning_rate": 0.001, "lambda": 0.0, "number_bins": 10}'
281 basf2_mva.teacher(general_options, specific_options)
282
283 general_options.m_variables = basf2_mva.vector(*variables2)
284 general_options.m_identifier = "keras_feature_drop"
285 basf2_mva.teacher(general_options, specific_options)
pointer
Pointer to the current start of the batch.
Z
Spectatirs/Quantity to be uncorrelated to.
index_array
Index array containing indices from 0 to len.