Belle II Software development
adversary_network.py
1#!/usr/bin/env python3
2
3
10
11# This example shows how to remove bias on one or several spectator variables.
12# Relevant paper: https://arxiv.org/abs/1611.01046
13# use basf2_mva_evaluation.py with train.root and test.root at the end to see the impact on the spectator variables.
14
15import basf2_mva
17from basf2_mva_extensions.preprocessing import fast_equal_frequency_binning
18import basf2_mva_util
19
20from keras.layers import Dense, Input
21from keras.models import Model
22from keras.optimizers import Adam
23from keras.losses import binary_crossentropy, sparse_categorical_crossentropy
24from keras.activations import sigmoid, tanh, softmax
25from keras.callbacks import EarlyStopping, Callback
26
27from sklearn.metrics import roc_auc_score
28
29import numpy as np
30import time
31import warnings
32warnings.filterwarnings('ignore', category=UserWarning)
33
34
35def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
36 """
37 Building 3 keras models:
38 1. Network without adversary, used for apply data.
39 2. Frozen MLP with unfrozen Adverserial Network to train adverserial part of network.
40 3. Unfrozen MLP with frozen adverserial to train MLP part of the network,
41 combined with losses of the adverserial networks.
42 """
43
44 def adversary_loss(signal):
45 """
46 Loss for adversaries outputs
47 :param signal: If signal or background distribution should be learned.
48 :return: Loss function for the discriminator part of the Network.
49 """
50 back_constant = 0 if signal == 'signal' else 1
51 sign = 1 if signal == 'signal' else -1
52
53 def adv_loss(y, p):
54 return sign * (y[:, 0] - back_constant) * sparse_categorical_crossentropy(y[:, 1:], p)
55 return adv_loss
56
57 # Define inputs for input_feature and spectator
58 input_layer = Input(shape=(number_of_features,))
59
60 # build first model which will produce the desired discriminator
61 layer1 = Dense(units=number_of_features + 1, activation=tanh)(input_layer)
62 layer2 = Dense(units=number_of_features + 1, activation=tanh)(layer1)
63 layer3 = Dense(units=number_of_features + 1, activation=tanh)(layer2)
64 output = Dense(units=1, activation=sigmoid)(layer3)
65
66 # Model for applying Data. Loss function will not be used for training, if adversary is used.
67 apply_model = Model(input_layer, output)
68 apply_model.compile(optimizer=Adam(learning_rate=parameters['learning_rate']), loss=binary_crossentropy, metrics=['accuracy'])
69
70 state = State(apply_model, use_adv=parameters['lambda'] > 0 and number_of_spectators > 0)
71 state.number_bins = parameters['number_bins']
72
73 # build second model on top of the first one which will try to predict spectators
74 adversaries, adversary_losses_model = [], []
75 if state.use_adv:
76 for mode in ['signal', 'background']:
77 for i in range(number_of_spectators):
78 adversary1 = Dense(units=2 * parameters['number_bins'], activation=tanh, trainable=False)(output)
79 adversary2 = Dense(units=2 * parameters['number_bins'], activation=tanh, trainable=False)(adversary1)
80
81 adversaries.append(Dense(units=parameters['number_bins'], activation=softmax, trainable=False)(adversary2))
82 adversary_losses_model.append(adversary_loss(mode == 'signal'))
83
84 # Model which trains first part of the net
85 forward_model = Model(input_layer, [output] + adversaries)
86 forward_model.compile(optimizer=Adam(learning_rate=parameters['learning_rate']),
87 loss=[binary_crossentropy] + adversary_losses_model, metrics=['accuracy'] * (len(adversaries) + 1),
88 loss_weights=[1.0] + [float(-parameters['lambda'])] * len(adversary_losses_model))
89 forward_model.summary()
90
91 # Model which train second, adversary part of the net
92 adv_model = Model(input_layer, adversaries)
93 # mark which layers we want to be trainable in the adversarial network
94 for layer in adv_model.layers:
95 layer.train_in_adversarial_mode = not layer.trainable
96
97 adv_model.compile(optimizer=Adam(learning_rate=parameters['learning_rate']), loss=adversary_losses_model,
98 metrics=['accuracy'] * len(adversaries))
99 adv_model.summary()
100
101 state.forward_model = forward_model
102 state.adv_model = adv_model
103 state.K = parameters['adversary_steps']
104 return state
105
106
107def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
108 """
109 Save Validation Data for monitoring Training
110 """
111 state.Xtest = Xtest
112 state.Stest = Stest
113 state.ytest = ytest
114
115 return state
116
117
118def partial_fit(state, X, S, y, w, epoch, batch):
119 """
120 Fit the model.
121 For every training step of MLP. Adverserial Network will be trained K times.
122 """
123
124 def build_adversary_target(p_y, p_s):
125 """
126 Concat isSignal and spectator bins, because both are target information for the adversary.
127 """
128 return [np.concatenate((p_y, i), axis=1) for i in np.split(p_s, len(p_s[0]), axis=1)] * 2
129
130 if state.use_adv:
131 # Get bin numbers of S with equal frequency binning
132 preprocessor = fast_equal_frequency_binning()
133 preprocessor.fit(S, number_of_bins=state.number_bins)
134 S = preprocessor.apply(S) * state.number_bins
135 state.Stest = preprocessor.apply(state.Stest) * state.number_bins
136 # Build target for adversary loss function
137 target_array = build_adversary_target(y, S)
138 target_val_array = build_adversary_target(state.ytest, state.Stest)
139
140 class AUC_Callback(Callback):
141 """
142 Callback to print AUC after every epoch.
143 """
144
145 def on_train_begin(self, logs=None):
146 self.val_aucs = []
147
148 def on_epoch_end(self, epoch, logs=None):
149 val_y_pred = state.model.predict(state.Xtest).flatten()
150 val_auc = roc_auc_score(state.ytest, val_y_pred)
151 print(f'\nTest AUC: {val_auc}\n')
152 self.val_aucs.append(val_auc)
153 return
154
155 if not state.use_adv:
156 state.model.fit(X, y, batch_size=400, epochs=1000, validation_data=(state.Xtest, state.ytest),
157 callbacks=[EarlyStopping(monitor='val_loss', patience=20, mode='min'), AUC_Callback()])
158 else:
159 class Adversary(Callback):
160 """
161 Callback to train Adversary
162 """
163
164 def on_batch_end(self, batch, logs=None):
165 # freeze the layers of the forward network and unfreeze the adversarial layers
166 for layer in state.adv_model.layers:
167 layer.trainable = layer.train_in_adversarial_mode
168
169 state.adv_model.fit(X, target_array, verbose=0, batch_size=400, steps_per_epoch=state.K, epochs=1)
170
171 # unfreeze the layers of the forward network and freeze the adversarial layers
172 for layer in state.adv_model.layers:
173 layer.trainable = not layer.train_in_adversarial_mode
174
175 state.forward_model.fit(X, [y] + target_array, batch_size=400, epochs=1000,
176 callbacks=[
177 EarlyStopping(
178 monitor='val_loss',
179 patience=20,
180 mode='min'),
181 AUC_Callback(), Adversary()],
182 validation_data=(state.Xtest, [state.ytest] + target_val_array))
183 return False
184
185
186if __name__ == "__main__":
187 from basf2 import conditions, find_file
188 # NOTE: do not use testing payloads in production! Any results obtained like this WILL NOT BE PUBLISHED
189 conditions.testing_payloads = [
190 'localdb/database.txt'
191 ]
192
193 variables = ['p', 'pt', 'pz', 'phi',
194 'daughter(0, p)', 'daughter(0, pz)', 'daughter(0, pt)', 'daughter(0, phi)',
195 'daughter(1, p)', 'daughter(1, pz)', 'daughter(1, pt)', 'daughter(1, phi)',
196 'daughter(2, p)', 'daughter(2, pz)', 'daughter(2, pt)', 'daughter(2, phi)',
197 'chiProb', 'dr', 'dz', 'dphi',
198 'daughter(0, dr)', 'daughter(1, dr)', 'daughter(0, dz)', 'daughter(1, dz)',
199 'daughter(0, dphi)', 'daughter(1, dphi)',
200 'daughter(0, chiProb)', 'daughter(1, chiProb)', 'daughter(2, chiProb)',
201 'daughter(0, kaonID)', 'daughter(0, pionID)', 'daughter(1, kaonID)', 'daughter(1, pionID)',
202 'daughterAngle(0, 1)', 'daughterAngle(0, 2)', 'daughterAngle(1, 2)',
203 'daughter(2, daughter(0, E))', 'daughter(2, daughter(1, E))',
204 'daughter(2, daughter(0, clusterTiming))', 'daughter(2, daughter(1, clusterTiming))',
205 'daughter(2, daughter(0, clusterE9E25))', 'daughter(2, daughter(1, clusterE9E25))',
206 'daughter(2, daughter(0, minC2TDist))', 'daughter(2, daughter(1, minC2TDist))',
207 'M']
208
209 variables2 = ['p', 'pt', 'pz', 'phi',
210 'chiProb', 'dr', 'dz', 'dphi',
211 'daughter(2, chiProb)',
212 'daughter(0, kaonID)', 'daughter(0, pionID)', 'daughter(1, kaonID)', 'daughter(1, pionID)',
213 'daughter(2, daughter(0, E))', 'daughter(2, daughter(1, E))',
214 'daughter(2, daughter(0, clusterTiming))', 'daughter(2, daughter(1, clusterTiming))',
215 'daughter(2, daughter(0, clusterE9E25))', 'daughter(2, daughter(1, clusterE9E25))',
216 'daughter(2, daughter(0, minC2TDist))', 'daughter(2, daughter(1, minC2TDist))']
217
218 train_file = find_file("mva/train_D0toKpipi.root", "examples")
219 test_file = find_file("mva/test_D0toKpipi.root", "examples")
220
221 training_data = basf2_mva.vector(train_file)
222 testing_data = basf2_mva.vector(test_file)
223
224 general_options = basf2_mva.GeneralOptions()
225 general_options.m_datafiles = training_data
226 general_options.m_treename = "tree"
227 general_options.m_variables = basf2_mva.vector(*variables)
228 general_options.m_spectators = basf2_mva.vector('daughterInvM(0, 1)', 'daughterInvM(0, 2)')
229 general_options.m_target_variable = "isSignal"
230 general_options.m_identifier = "keras_adversary"
231
232 specific_options = basf2_mva.PythonOptions()
233 specific_options.m_framework = "keras"
234 specific_options.m_steering_file = 'mva/examples/keras/adversary_network.py'
235 specific_options.m_normalize = True
236 specific_options.m_training_fraction = 0.8
237
238 print(train_file)
239
240 """
241 Config for Adversary Networks:
242 lambda: Trade off between classifier performance and noncorrelation between classifier output and spectators.
243 Increase to reduce correlations(and also classifier performance)
244 adversary_steps: How many batches the discriminator is trained after one batch of training the classifier.
245 Less steps make the training faster but also unstable. Increase the parameter if something isn't working. number_bins: Number of Bins which are used to quantify the spectators. 10 should be sufficient.
246 """
247
248 specific_options.m_config = '{"adversary_steps": 5, "learning_rate": 0.01, "lambda": 5.0, "number_bins": 10}'
249 basf2_mva.teacher(general_options, specific_options)
250
251 method = basf2_mva_util.Method(general_options.m_identifier)
252 inference_start = time.time()
253 p, t = method.apply_expert(testing_data, general_options.m_treename)
254 inference_time = time.time() - inference_start
256 print("Adversary:", inference_time, auc)
257
258 general_options.m_identifier = "keras_baseline"
259 specific_options.m_config = '{"adversary_steps": 1, "learning_rate": 0.001, "lambda": 0.0, "number_bins": 10}'
260 basf2_mva.teacher(general_options, specific_options)
261
262 method = basf2_mva_util.Method(general_options.m_identifier)
263 inference_start = time.time()
264 p, t = method.apply_expert(testing_data, general_options.m_treename)
265 inference_time = time.time() - inference_start
267 print("Baseline:", inference_time, auc)
268
269 # Now lets drop some of the features most correlated to the spectator variables.
270 general_options.m_variables = basf2_mva.vector(*variables2)
271 general_options.m_identifier = "keras_feature_drop"
272 basf2_mva.teacher(general_options, specific_options)
273
274 method = basf2_mva_util.Method(general_options.m_identifier)
275 inference_start = time.time()
276 p, t = method.apply_expert(testing_data, general_options.m_treename)
277 inference_time = time.time() - inference_start
279 print("Drop features:", inference_time, auc)
280
281 # Uncomment the following lines to run basf2_mva_evaluation.py and compare all three methods
282 # import os
283 # os.system(
284 # f'basf2_mva_evaluate.py -id keras_adversary keras_baseline keras_feature_drop '\
285 # f'-train {train_file} -data {test_file} -c -out adversarial_output.pdf -l localdb/database.txt')
286
def calculate_auc_efficiency_vs_background_retention(p, t, w=None)