Belle II Software  release-06-02-00
adversary_network.py
1 #!/usr/bin/env python3
2 
3 
10 
11 # This example shows how to remove bias on one or several spectator variables.
12 # Relevant paper: https://arxiv.org/abs/1611.01046
13 # use basf2_mva_evaluation.py with train.root and test.root at the end to see the impact on the spectator variables.
14 
15 import basf2_mva
17 import tensorflow.keras as keras
18 
19 from keras.layers import Dense, Input
20 from keras.models import Model
21 from keras.optimizers import Adam
22 from keras.losses import binary_crossentropy, sparse_categorical_crossentropy
23 from keras.activations import sigmoid, tanh, softmax
24 from keras.callbacks import EarlyStopping
25 from keras.utils import plot_model
26 
27 import numpy as np
28 from basf2_mva_extensions.preprocessing import fast_equal_frequency_binning
29 
30 from sklearn.metrics import roc_auc_score
31 
32 import warnings
33 warnings.filterwarnings('ignore', category=UserWarning)
34 
35 
37  """
38  Class to create batches for training the Adversary Network.
39  Once the steps_per_epoch argument is available for the fit function in keras, this class will become obsolete.
40  """
41 
42  def __init__(self, X, Y, Z):
43  """
44  Init the class
45  :param X: Input Features
46  :param Y: Label Data
47  :param Z: Spectators/Quantity to be uncorrelated to
48  """
49 
50  self.XX = X
51 
52  self.YY = Y
53 
54  self.ZZ = Z
55 
56  self.lenlen = len(Y)
57 
58  self.index_arrayindex_array = np.arange(self.lenlen)
59 
60  self.pointerpointer = 0
61 
62  def next_batch(self, batch_size):
63  """
64  Getting next batch of training data
65  """
66  if self.pointerpointer + batch_size >= self.lenlen:
67  np.random.shuffle(self.index_arrayindex_array)
68  self.pointerpointer = 0
69 
70  batch_index = self.index_arrayindex_array[self.pointerpointer:self.pointerpointer + batch_size]
71  self.pointerpointer += batch_size
72 
73  return self.XX[batch_index], self.YY[batch_index], self.ZZ[batch_index]
74 
75 
76 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
77  """
78  Building 3 keras models:
79  1. Network without adversary, used for apply data.
80  2. Freezed MLP with unfreezed Adverserial Network to train adverserial part of network.
81  3. Unfreezed MLP with freezed adverserial to train MLP part of the network,
82  combined with losses of the adverserial networks.
83  """
84 
85  def adversary_loss(signal):
86  """
87  Loss for adversaries outputs
88  :param signal: If signal or background distribution should be learned.
89  :return: Loss function for the discriminator part of the Network.
90  """
91  back_constant = 0 if signal else 1
92 
93  def adv_loss(y, p):
94  return (y[:, 0] - back_constant) * sparse_categorical_crossentropy(y[:, 1:], p)
95  return adv_loss
96 
97  # Define inputs for input_feature and spectator
98  input = Input(shape=(number_of_features,))
99 
100  # build first model which will produce the desired discriminator
101  layer1 = Dense(units=number_of_features + 1, activation=tanh)(input)
102  layer2 = Dense(units=number_of_features + 1, activation=tanh)(layer1)
103  layer3 = Dense(units=number_of_features + 1, activation=tanh)(layer2)
104  output = Dense(units=1, activation=sigmoid)(layer3)
105 
106  # Model for applying Data. Loss function will not be used for training, if adversary is used.
107  apply_model = Model(input, output)
108  apply_model.compile(optimizer=Adam(lr=parameters['learning_rate']), loss=binary_crossentropy, metrics=['accuracy'])
109 
110  state = State(apply_model, use_adv=parameters['lambda'] > 0 and number_of_spectators > 0)
111  state.number_bins = parameters['number_bins']
112 
113  # build second model on top of the first one which will try to predict spectators
114  adversaries, adversary_losses_model = [], []
115  if state.use_adv:
116  for mode in ['signal', 'background']:
117  for i in range(number_of_spectators):
118  adversary1 = Dense(units=2 * parameters['number_bins'], activation=tanh, trainable=False)(output)
119  adversary2 = Dense(units=2 * parameters['number_bins'], activation=tanh, trainable=False)(adversary1)
120  adversaries.append(Dense(units=parameters['number_bins'], activation=softmax, trainable=False)(adversary2))
121 
122  adversary_losses_model.append(adversary_loss(mode == 'signal'))
123 
124  # Model which trains first part of the net
125  model1 = Model(input, [output] + adversaries)
126  model1.compile(optimizer=Adam(lr=parameters['learning_rate']),
127  loss=[binary_crossentropy] + adversary_losses_model, metrics=['accuracy'],
128  loss_weights=[1] + [-parameters['lambda']] * len(adversary_losses_model))
129  model1.summary()
130 
131  # Model which train second, adversary part of the net
132  model2 = Model(input, adversaries)
133  # freeze everything except adversary layers
134  for layer in model2.layers:
135  layer.trainable = not layer.trainable
136 
137  model2.compile(optimizer=Adam(lr=parameters['learning_rate']), loss=adversary_losses_model,
138  metrics=['accuracy'])
139  model2.summary()
140 
141  state.forward_model, state.adv_model = model1, model2
142  state.K = parameters['adversary_steps']
143 
144  # draw model as a picture
145  plot_model(model1, to_file='model.png', show_shapes=True)
146 
147  return state
148 
149 
150 def begin_fit(state, Xtest, Stest, ytest, wtest):
151  """
152  Save Validation Data for monitoring Training
153  """
154  state.Xtest = Xtest
155  state.Stest = Stest
156  state.ytest = ytest
157 
158  return state
159 
160 
161 def partial_fit(state, X, S, y, w, epoch):
162  """
163  Fit the model.
164  For every training step of MLP. Adverserial Network will be trained K times.
165  """
166 
167  def build_adversary_target(p_y, p_s):
168  """
169  Concat isSignal and spectator bins, because both are target information for the adversary.
170  """
171  return [np.concatenate((p_y, i), axis=1) for i in np.split(p_s, len(p_s[0]), axis=1)] * 2
172 
173  if state.use_adv:
174  # Get bin numbers of S with equal frequency binning
175  preprocessor = fast_equal_frequency_binning()
176  preprocessor.fit(S, number_of_bins=state.number_bins)
177  S = preprocessor.apply(S) * state.number_bins
178  state.Stest = preprocessor.apply(state.Stest) * state.number_bins
179  # Build target for adversary loss function
180  target_array = build_adversary_target(y, S)
181  target_val_array = build_adversary_target(state.ytest, state.Stest)
182  # Build Batch Generator for adversary Callback
183  state.batch_gen = batch_generator(X, y, S)
184 
185  class AUC_Callback(keras.callbacks.Callback):
186  """
187  Callback to print AUC after every epoch.
188  """
189 
190  def on_train_begin(self, logs=None):
191  self.val_aucs = []
192 
193  def on_epoch_end(self, epoch, logs=None):
194  val_y_pred = state.model.predict(state.Xtest).flatten()
195  val_auc = roc_auc_score(state.ytest, val_y_pred)
196  print('\nTest AUC: {}\n'.format(val_auc))
197  self.val_aucs.append(val_auc)
198  return
199 
200  class Adversary(keras.callbacks.Callback):
201  """
202  Callback to train Adversary
203  """
204 
205  def on_batch_end(self, batch, logs=None):
206  v_X, v_y, v_S = state.batch_gen.next_batch(400 * state.K)
207  target_adversary = build_adversary_target(v_y, v_S)
208  state.adv_model.fit(v_X, target_adversary, verbose=0, batch_size=400)
209 
210  if not state.use_adv:
211  state.model.fit(X, y, batch_size=400, epochs=1000, validation_data=(state.Xtest, state.ytest),
212  callbacks=[EarlyStopping(monitor='val_loss', patience=2, mode='min'), AUC_Callback()])
213  else:
214  state.forward_model.fit(X, [y] + target_array, batch_size=400, epochs=1000,
215  callbacks=[EarlyStopping(monitor='val_loss', patience=2, mode='min'), AUC_Callback(), Adversary()],
216  validation_data=(state.Xtest, [state.ytest] + target_val_array))
217  return False
218 
219 
220 if __name__ == "__main__":
221  from basf2 import conditions
222  # NOTE: do not use testing payloads in production! Any results obtained like this WILL NOT BE PUBLISHED
223  conditions.testing_payloads = [
224  'localdb/database.txt'
225  ]
226 
227  variables = ['p', 'pt', 'pz', 'phi',
228  'daughter(0, p)', 'daughter(0, pz)', 'daughter(0, pt)', 'daughter(0, phi)',
229  'daughter(1, p)', 'daughter(1, pz)', 'daughter(1, pt)', 'daughter(1, phi)',
230  'daughter(2, p)', 'daughter(2, pz)', 'daughter(2, pt)', 'daughter(2, phi)',
231  'chiProb', 'dr', 'dz', 'dphi',
232  'daughter(0, dr)', 'daughter(1, dr)', 'daughter(0, dz)', 'daughter(1, dz)',
233  'daughter(0, dphi)', 'daughter(1, dphi)',
234  'daughter(0, chiProb)', 'daughter(1, chiProb)', 'daughter(2, chiProb)',
235  'daughter(0, kaonID)', 'daughter(0, pionID)', 'daughter(1, kaonID)', 'daughter(1, pionID)',
236  'daughterAngle(0, 1)', 'daughterAngle(0, 2)', 'daughterAngle(1, 2)',
237  'daughter(2, daughter(0, E))', 'daughter(2, daughter(1, E))',
238  'daughter(2, daughter(0, clusterTiming))', 'daughter(2, daughter(1, clusterTiming))',
239  'daughter(2, daughter(0, clusterE9E25))', 'daughter(2, daughter(1, clusterE9E25))',
240  'daughter(2, daughter(0, minC2TDist))', 'daughter(2, daughter(1, minC2TDist))',
241  'M']
242 
243  variables2 = ['p', 'pt', 'pz', 'phi',
244  'chiProb', 'dr', 'dz', 'dphi',
245  'daughter(2, chiProb)',
246  'daughter(0, kaonID)', 'daughter(0, pionID)', 'daughter(1, kaonID)', 'daughter(1, pionID)',
247  'daughter(2, daughter(0, E))', 'daughter(2, daughter(1, E))',
248  'daughter(2, daughter(0, clusterTiming))', 'daughter(2, daughter(1, clusterTiming))',
249  'daughter(2, daughter(0, clusterE9E25))', 'daughter(2, daughter(1, clusterE9E25))',
250  'daughter(2, daughter(0, minC2TDist))', 'daughter(2, daughter(1, minC2TDist))']
251 
252  general_options = basf2_mva.GeneralOptions()
253  general_options.m_datafiles = basf2_mva.vector("train.root")
254  general_options.m_treename = "tree"
255  general_options.m_variables = basf2_mva.vector(*variables)
256  general_options.m_spectators = basf2_mva.vector('daughterInvariantMass(0, 1)', 'daughterInvariantMass(0, 2)')
257  general_options.m_target_variable = "isSignal"
258  general_options.m_identifier = "keras"
259 
260  specific_options = basf2_mva.PythonOptions()
261  specific_options.m_framework = "contrib_keras"
262  specific_options.m_steering_file = 'mva/examples/keras/adversary_network.py'
263  specific_options.m_normalize = True
264  specific_options.m_training_fraction = 0.9
265 
266  """
267  Config for Adversary Networks:
268  lambda: Trade off between classifier performance and noncorrelation between classifier output and spectators.
269  Increase to reduce correlations(and also classifier performance)
270  adversary_steps: How many batches the discriminator is trained after one batch of training the classifier.
271  Less steps make the training faster but also unstable. Increase the parameter if something isn't working.
272  number_bins: Number of Bins which are used to quantify the spectators. 10 should be sufficient.
273  """
274  specific_options.m_config = '{"adversary_steps": 5, "learning_rate": 0.001, "lambda": 20.0, "number_bins": 10}'
275  basf2_mva.teacher(general_options, specific_options)
276 
277  general_options.m_identifier = "keras_baseline"
278  specific_options.m_config = '{"adversary_steps": 1, "learning_rate": 0.001, "lambda": 0.0, "number_bins": 10}'
279  basf2_mva.teacher(general_options, specific_options)
280 
281  general_options.m_variables = basf2_mva.vector(*variables2)
282  general_options.m_identifier = "keras_feature_drop"
283  basf2_mva.teacher(general_options, specific_options)
pointer
Pointer to the current start of the batch.
Z
Spectatirs/Quantity to be uncorrelated to.
index_array
Index array containing indices from 0 to len.