Belle II Software  release-05-01-25
adversary_network.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 # Dennis Weyland 2017
5 # Justin Tan 2017
6 # Thomas Keck 2017
7 
8 # This example shows how to remove bias on one or several spectator variables.
9 # Relevant paper: https://arxiv.org/abs/1611.01046
10 # use basf2_mva_evaluation.py with train.root and test.root at the end to see the impact on the spectator variables.
11 
12 import basf2
13 import basf2_mva
15 import h5py
16 import tensorflow as tf
17 import tensorflow.contrib.keras as keras
18 import keras
19 
20 from keras.layers import Input, Dense, Concatenate, Lambda
21 from keras.models import Model, load_model
22 from keras.optimizers import adam
23 from keras.losses import binary_crossentropy, sparse_categorical_crossentropy
24 from keras.activations import sigmoid, tanh, softmax
25 from keras import backend as K
26 from keras.callbacks import Callback, EarlyStopping
27 from keras.utils import plot_model
28 
29 import numpy as np
30 from basf2_mva_extensions.preprocessing import fast_equal_frequency_binning
31 
32 from sklearn.metrics import roc_auc_score
33 
34 import warnings
35 warnings.filterwarnings('ignore', category=UserWarning)
36 
37 
39  """
40  Class to create batches for training the Adversary Network.
41  Once the steps_per_epoch argument is available for the fit function in keras, this class will become obsolete.
42  """
43 
44  def __init__(self, X, Y, Z):
45  """
46  Init the class
47  :param X: Input Features
48  :param Y: Label Data
49  :param Z: Spectators/Quantity to be uncorrelated to
50  """
51 
52  self.X = X
53 
54  self.Y = Y
55 
56  self.Z = Z
57 
58  self.len = len(Y)
59 
60  self.index_array = np.arange(self.len)
61 
62  self.pointer = 0
63 
64  def next_batch(self, batch_size):
65  """
66  Getting next batch of training data
67  """
68  if self.pointer + batch_size >= self.len:
69  np.random.shuffle(self.index_array)
70  self.pointer = 0
71 
72  batch_index = self.index_array[self.pointer:self.pointer + batch_size]
73  self.pointer += batch_size
74 
75  return self.X[batch_index], self.Y[batch_index], self.Z[batch_index]
76 
77 
78 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
79  """
80  Building 3 keras models:
81  1. Network without adversary, used for apply data.
82  2. Freezed MLP with unfreezed Adverserial Network to train adverserial part of network.
83  3. Unfreezed MLP with freezed adverserial to train MLP part of the network,
84  combined with losses of the adverserial networks.
85  """
86 
87  def adversary_loss(signal):
88  """
89  Loss for adversaries outputs
90  :param signal: If signal or background distribution should be learned.
91  :return: Loss function for the discriminator part of the Network.
92  """
93  back_constant = 0 if signal else 1
94 
95  def adv_loss(y, p):
96  return (y[:, 0] - back_constant) * sparse_categorical_crossentropy(y[:, 1:], p)
97  return adv_loss
98 
99  # Define inputs for input_feature and spectator
100  input = Input(shape=(number_of_features,))
101 
102  # build first model which will produce the desired discriminator
103  layer1 = Dense(units=number_of_features + 1, activation=tanh)(input)
104  layer2 = Dense(units=number_of_features + 1, activation=tanh)(layer1)
105  layer3 = Dense(units=number_of_features + 1, activation=tanh)(layer2)
106  output = Dense(units=1, activation=sigmoid)(layer3)
107 
108  # Model for applying Data. Loss function will not be used for training, if adversary is used.
109  apply_model = Model(input, output)
110  apply_model.compile(optimizer=adam(lr=parameters['learning_rate']), loss=binary_crossentropy, metrics=['accuracy'])
111 
112  state = State(apply_model, use_adv=parameters['lambda'] > 0 and number_of_spectators > 0)
113  state.number_bins = parameters['number_bins']
114 
115  # build second model on top of the first one which will try to predict spectators
116  adversaries, adversary_losses_model = [], []
117  if state.use_adv:
118  for mode in ['signal', 'background']:
119  for i in range(number_of_spectators):
120  adversary1 = Dense(units=2 * parameters['number_bins'], activation=tanh, trainable=False)(output)
121  adversary2 = Dense(units=2 * parameters['number_bins'], activation=tanh, trainable=False)(adversary1)
122  adversaries.append(Dense(units=parameters['number_bins'], activation=softmax, trainable=False)(adversary2))
123 
124  adversary_losses_model.append(adversary_loss(mode == 'signal'))
125 
126  # Model which trains first part of the net
127  model1 = Model(input, [output] + adversaries)
128  model1.compile(optimizer=adam(lr=parameters['learning_rate']),
129  loss=[binary_crossentropy] + adversary_losses_model, metrics=['accuracy'],
130  loss_weights=[1] + [-parameters['lambda']] * len(adversary_losses_model))
131  model1.summary()
132 
133  # Model which train second, adversary part of the net
134  model2 = Model(input, adversaries)
135  # freeze everything except adversary layers
136  for layer in model2.layers:
137  layer.trainable = not layer.trainable
138 
139  model2.compile(optimizer=adam(lr=parameters['learning_rate']), loss=adversary_losses_model,
140  metrics=['accuracy'])
141  model2.summary()
142 
143  state.forward_model, state.adv_model = model1, model2
144  state.K = parameters['adversary_steps']
145 
146  # draw model as a picture
147  plot_model(model1, to_file='model.png', show_shapes=True)
148 
149  return state
150 
151 
152 def begin_fit(state, Xtest, Stest, ytest, wtest):
153  """
154  Save Validation Data for monitoring Training
155  """
156  state.Xtest = Xtest
157  state.Stest = Stest
158  state.ytest = ytest
159 
160  return state
161 
162 
163 def partial_fit(state, X, S, y, w, epoch):
164  """
165  Fit the model.
166  For every training step of MLP. Adverserial Network will be trained K times.
167  """
168 
169  def build_adversary_target(p_y, p_s):
170  """
171  Concat isSignal and spectator bins, because both are target information for the adversary.
172  """
173  return [np.concatenate((p_y, i), axis=1) for i in np.split(p_s, len(p_s[0]), axis=1)] * 2
174 
175  if state.use_adv:
176  # Get bin numbers of S with equal frequency binning
177  preprocessor = fast_equal_frequency_binning()
178  preprocessor.fit(S, number_of_bins=state.number_bins)
179  S = preprocessor.apply(S) * state.number_bins
180  state.Stest = preprocessor.apply(state.Stest) * state.number_bins
181  # Build target for adversary loss function
182  target_array = build_adversary_target(y, S)
183  target_val_array = build_adversary_target(state.ytest, state.Stest)
184  # Build Batch Generator for adversary Callback
185  state.batch_gen = batch_generator(X, y, S)
186 
187  class AUC_Callback(keras.callbacks.Callback):
188  """
189  Callback to print AUC after every epoch.
190  """
191 
192  def on_train_begin(self, logs={}):
193  self.val_aucs = []
194 
195  def on_epoch_end(self, epoch, logs={}):
196  val_y_pred = state.model.predict(state.Xtest).flatten()
197  val_auc = roc_auc_score(state.ytest, val_y_pred)
198  print('\nTest AUC: {}\n'.format(val_auc))
199  self.val_aucs.append(val_auc)
200  return
201 
202  class Adversary(keras.callbacks.Callback):
203  """
204  Callback to train Adversary
205  """
206 
207  def on_batch_end(self, batch, logs={}):
208  v_X, v_y, v_S = state.batch_gen.next_batch(400 * state.K)
209  target_adversary = build_adversary_target(v_y, v_S)
210  state.adv_model.fit(v_X, target_adversary, verbose=0, batch_size=400)
211 
212  if not state.use_adv:
213  state.model.fit(X, y, batch_size=400, epochs=1000, validation_data=(state.Xtest, state.ytest),
214  callbacks=[EarlyStopping(monitor='val_loss', patience=2, mode='min'), AUC_Callback()])
215  else:
216  state.forward_model.fit(X, [y] + target_array, batch_size=400, epochs=1000,
217  callbacks=[EarlyStopping(monitor='val_loss', patience=2, mode='min'), AUC_Callback(), Adversary()],
218  validation_data=(state.Xtest, [state.ytest] + target_val_array))
219  return False
220 
221 
222 if __name__ == "__main__":
223  from basf2 import conditions
224  # NOTE: do not use testing payloads in production! Any results obtained like this WILL NOT BE PUBLISHED
225  conditions.testing_payloads = [
226  'localdb/database.txt'
227  ]
228 
229  variables = ['p', 'pt', 'pz', 'phi',
230  'daughter(0, p)', 'daughter(0, pz)', 'daughter(0, pt)', 'daughter(0, phi)',
231  'daughter(1, p)', 'daughter(1, pz)', 'daughter(1, pt)', 'daughter(1, phi)',
232  'daughter(2, p)', 'daughter(2, pz)', 'daughter(2, pt)', 'daughter(2, phi)',
233  'chiProb', 'dr', 'dz', 'dphi',
234  'daughter(0, dr)', 'daughter(1, dr)', 'daughter(0, dz)', 'daughter(1, dz)',
235  'daughter(0, dphi)', 'daughter(1, dphi)',
236  'daughter(0, chiProb)', 'daughter(1, chiProb)', 'daughter(2, chiProb)',
237  'daughter(0, kaonID)', 'daughter(0, pionID)', 'daughter(1, kaonID)', 'daughter(1, pionID)',
238  'daughterAngle(0, 1)', 'daughterAngle(0, 2)', 'daughterAngle(1, 2)',
239  'daughter(2, daughter(0, E))', 'daughter(2, daughter(1, E))',
240  'daughter(2, daughter(0, clusterTiming))', 'daughter(2, daughter(1, clusterTiming))',
241  'daughter(2, daughter(0, clusterE9E25))', 'daughter(2, daughter(1, clusterE9E25))',
242  'daughter(2, daughter(0, minC2HDist))', 'daughter(2, daughter(1, minC2HDist))',
243  'M']
244 
245  variables2 = ['p', 'pt', 'pz', 'phi',
246  'chiProb', 'dr', 'dz', 'dphi',
247  'daughter(2, chiProb)',
248  'daughter(0, kaonID)', 'daughter(0, pionID)', 'daughter(1, kaonID)', 'daughter(1, pionID)',
249  'daughter(2, daughter(0, E))', 'daughter(2, daughter(1, E))',
250  'daughter(2, daughter(0, clusterTiming))', 'daughter(2, daughter(1, clusterTiming))',
251  'daughter(2, daughter(0, clusterE9E25))', 'daughter(2, daughter(1, clusterE9E25))',
252  'daughter(2, daughter(0, minC2HDist))', 'daughter(2, daughter(1, minC2HDist))']
253 
254  general_options = basf2_mva.GeneralOptions()
255  general_options.m_datafiles = basf2_mva.vector("train.root")
256  general_options.m_treename = "tree"
257  general_options.m_variables = basf2_mva.vector(*variables)
258  general_options.m_spectators = basf2_mva.vector('daughterInvariantMass(0, 1)', 'daughterInvariantMass(0, 2)')
259  general_options.m_target_variable = "isSignal"
260  general_options.m_identifier = "keras"
261 
262  specific_options = basf2_mva.PythonOptions()
263  specific_options.m_framework = "contrib_keras"
264  specific_options.m_steering_file = 'mva/examples/keras/adversary_network.py'
265  specific_options.m_normalize = True
266  specific_options.m_training_fraction = 0.9
267 
268  """
269  Config for Adversary Networks:
270  lambda: Trade off between classifier performance and noncorrelation between classifier output and spectators.
271  Increase to reduce correlations(and also classifier performance)
272  adversary_steps: How many batches the discriminator is trained after one batch of training the classifier.
273  Less steps make the training faster but also unstable. Increase the parameter if something isn't working.
274  number_bins: Number of Bins which are used to quantify the spectators. 10 should be sufficient.
275  """
276  specific_options.m_config = '{"adversary_steps": 5, "learning_rate": 0.001, "lambda": 20.0, "number_bins": 10}'
277  basf2_mva.teacher(general_options, specific_options)
278 
279  general_options.m_identifier = "keras_baseline"
280  specific_options.m_config = '{"adversary_steps": 1, "learning_rate": 0.001, "lambda": 0.0, "number_bins": 10}'
281  basf2_mva.teacher(general_options, specific_options)
282 
283  general_options.m_variables = basf2_mva.vector(*variables2)
284  general_options.m_identifier = "keras_feature_drop"
285  basf2_mva.teacher(general_options, specific_options)
adversary_network.batch_generator.Z
Z
Spectatirs/Quantity to be uncorrelated to.
Definition: adversary_network.py:56
adversary_network.batch_generator.next_batch
def next_batch(self, batch_size)
Definition: adversary_network.py:64
adversary_network.batch_generator.len
len
Length of the data.
Definition: adversary_network.py:58
adversary_network.batch_generator
Definition: adversary_network.py:38
adversary_network.batch_generator.Y
Y
Label data.
Definition: adversary_network.py:54
adversary_network.batch_generator.pointer
pointer
Pointer to the current start of the batch.
Definition: adversary_network.py:62
adversary_network.batch_generator.index_array
index_array
Index array containing indices from 0 to len.
Definition: adversary_network.py:60
adversary_network.batch_generator.X
X
Input Features.
Definition: adversary_network.py:52
basf2_mva_python_interface.contrib_keras
Definition: contrib_keras.py:1
adversary_network.batch_generator.__init__
def __init__(self, X, Y, Z)
Definition: adversary_network.py:44
basf2_mva_python_interface.contrib_keras.State
Definition: contrib_keras.py:15