Belle II Software  release-08-01-10
adversary_network.py
1 #!/usr/bin/env python3
2 
3 
10 
11 # This example shows how to remove bias on one or several spectator variables.
12 # Relevant paper: https://arxiv.org/abs/1611.01046
13 # use basf2_mva_evaluation.py with train.root and test.root at the end to see the impact on the spectator variables.
14 
15 import basf2_mva
16 from basf2_mva_python_interface.keras import State
17 
18 from tensorflow.keras.layers import Dense, Input
19 from tensorflow.keras.models import Model
20 from tensorflow.keras.optimizers import Adam
21 from tensorflow.keras.losses import binary_crossentropy, sparse_categorical_crossentropy
22 from tensorflow.keras.activations import sigmoid, tanh, softmax
23 from tensorflow.keras.callbacks import EarlyStopping, Callback
24 from tensorflow.keras.utils import plot_model
25 
26 import numpy as np
27 from basf2_mva_extensions.preprocessing import fast_equal_frequency_binning
28 
29 from sklearn.metrics import roc_auc_score
30 
31 import warnings
32 warnings.filterwarnings('ignore', category=UserWarning)
33 
34 
36  """
37  Class to create batches for training the Adversary Network.
38  Once the steps_per_epoch argument is available for the fit function in keras, this class will become obsolete.
39  """
40 
41  def __init__(self, X, Y, Z):
42  """
43  Init the class
44  :param X: Input Features
45  :param Y: Label Data
46  :param Z: Spectators/Quantity to be uncorrelated to
47  """
48 
49  self.XX = X
50 
51  self.YY = Y
52 
53  self.ZZ = Z
54 
55  self.lenlen = len(Y)
56 
57  self.index_arrayindex_array = np.arange(self.lenlen)
58 
59  self.pointerpointer = 0
60 
61  def next_batch(self, batch_size):
62  """
63  Getting next batch of training data
64  """
65  if self.pointerpointer + batch_size >= self.lenlen:
66  np.random.shuffle(self.index_arrayindex_array)
67  self.pointerpointer = 0
68 
69  batch_index = self.index_arrayindex_array[self.pointerpointer:self.pointerpointer + batch_size]
70  self.pointerpointer += batch_size
71 
72  return self.XX[batch_index], self.YY[batch_index], self.ZZ[batch_index]
73 
74 
75 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
76  """
77  Building 3 keras models:
78  1. Network without adversary, used for apply data.
79  2. Freezed MLP with unfreezed Adverserial Network to train adverserial part of network.
80  3. Unfreezed MLP with freezed adverserial to train MLP part of the network,
81  combined with losses of the adverserial networks.
82  """
83 
84  def adversary_loss(signal):
85  """
86  Loss for adversaries outputs
87  :param signal: If signal or background distribution should be learned.
88  :return: Loss function for the discriminator part of the Network.
89  """
90  back_constant = 0 if signal else 1
91 
92  def adv_loss(y, p):
93  return (y[:, 0] - back_constant) * sparse_categorical_crossentropy(y[:, 1:], p)
94  return adv_loss
95 
96  # Define inputs for input_feature and spectator
97  input = Input(shape=(number_of_features,))
98 
99  # build first model which will produce the desired discriminator
100  layer1 = Dense(units=number_of_features + 1, activation=tanh)(input)
101  layer2 = Dense(units=number_of_features + 1, activation=tanh)(layer1)
102  layer3 = Dense(units=number_of_features + 1, activation=tanh)(layer2)
103  output = Dense(units=1, activation=sigmoid)(layer3)
104 
105  # Model for applying Data. Loss function will not be used for training, if adversary is used.
106  apply_model = Model(input, output)
107  apply_model.compile(optimizer=Adam(lr=parameters['learning_rate']), loss=binary_crossentropy, metrics=['accuracy'])
108 
109  state = State(apply_model, use_adv=parameters['lambda'] > 0 and number_of_spectators > 0)
110  state.number_bins = parameters['number_bins']
111 
112  # build second model on top of the first one which will try to predict spectators
113  adversaries, adversary_losses_model = [], []
114  if state.use_adv:
115  for mode in ['signal', 'background']:
116  for i in range(number_of_spectators):
117  adversary1 = Dense(units=2 * parameters['number_bins'], activation=tanh, trainable=False)(output)
118  adversary2 = Dense(units=2 * parameters['number_bins'], activation=tanh, trainable=False)(adversary1)
119  adversaries.append(Dense(units=parameters['number_bins'], activation=softmax, trainable=False)(adversary2))
120 
121  adversary_losses_model.append(adversary_loss(mode == 'signal'))
122 
123  # Model which trains first part of the net
124  model1 = Model(input, [output] + adversaries)
125  model1.compile(optimizer=Adam(lr=parameters['learning_rate']),
126  loss=[binary_crossentropy] + adversary_losses_model, metrics=['accuracy'],
127  loss_weights=[1] + [-parameters['lambda']] * len(adversary_losses_model))
128  model1.summary()
129 
130  # Model which train second, adversary part of the net
131  model2 = Model(input, adversaries)
132  # freeze everything except adversary layers
133  for layer in model2.layers:
134  layer.trainable = not layer.trainable
135 
136  model2.compile(optimizer=Adam(lr=parameters['learning_rate']), loss=adversary_losses_model,
137  metrics=['accuracy'])
138  model2.summary()
139 
140  state.forward_model, state.adv_model = model1, model2
141  state.K = parameters['adversary_steps']
142 
143  # draw model as a picture
144  plot_model(model1, to_file='model.png', show_shapes=True)
145 
146  return state
147 
148 
149 def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
150  """
151  Save Validation Data for monitoring Training
152  """
153  state.Xtest = Xtest
154  state.Stest = Stest
155  state.ytest = ytest
156 
157  return state
158 
159 
160 def partial_fit(state, X, S, y, w, epoch, batch):
161  """
162  Fit the model.
163  For every training step of MLP. Adverserial Network will be trained K times.
164  """
165 
166  def build_adversary_target(p_y, p_s):
167  """
168  Concat isSignal and spectator bins, because both are target information for the adversary.
169  """
170  return [np.concatenate((p_y, i), axis=1) for i in np.split(p_s, len(p_s[0]), axis=1)] * 2
171 
172  if state.use_adv:
173  # Get bin numbers of S with equal frequency binning
174  preprocessor = fast_equal_frequency_binning()
175  preprocessor.fit(S, number_of_bins=state.number_bins)
176  S = preprocessor.apply(S) * state.number_bins
177  state.Stest = preprocessor.apply(state.Stest) * state.number_bins
178  # Build target for adversary loss function
179  target_array = build_adversary_target(y, S)
180  target_val_array = build_adversary_target(state.ytest, state.Stest)
181  # Build Batch Generator for adversary Callback
182  state.batch_gen = batch_generator(X, y, S)
183 
184  class AUC_Callback(Callback):
185  """
186  Callback to print AUC after every epoch.
187  """
188 
189  def on_train_begin(self, logs=None):
190  self.val_aucs = []
191 
192  def on_epoch_end(self, epoch, logs=None):
193  val_y_pred = state.model.predict(state.Xtest).flatten()
194  val_auc = roc_auc_score(state.ytest, val_y_pred)
195  print(f'\nTest AUC: {val_auc}\n')
196  self.val_aucs.append(val_auc)
197  return
198 
199  class Adversary(Callback):
200  """
201  Callback to train Adversary
202  """
203 
204  def on_batch_end(self, batch, logs=None):
205  v_X, v_y, v_S = state.batch_gen.next_batch(400 * state.K)
206  target_adversary = build_adversary_target(v_y, v_S)
207  state.adv_model.fit(v_X, target_adversary, verbose=0, batch_size=400)
208 
209  if not state.use_adv:
210  state.model.fit(X, y, batch_size=400, epochs=1000, validation_data=(state.Xtest, state.ytest),
211  callbacks=[EarlyStopping(monitor='val_loss', patience=2, mode='min'), AUC_Callback()])
212  else:
213  state.forward_model.fit(X, [y] + target_array, batch_size=400, epochs=1000,
214  callbacks=[EarlyStopping(monitor='val_loss', patience=2, mode='min'), AUC_Callback(), Adversary()],
215  validation_data=(state.Xtest, [state.ytest] + target_val_array))
216  return False
217 
218 
219 if __name__ == "__main__":
220  from basf2 import conditions, find_file
221  # NOTE: do not use testing payloads in production! Any results obtained like this WILL NOT BE PUBLISHED
222  conditions.testing_payloads = [
223  'localdb/database.txt'
224  ]
225 
226  variables = ['p', 'pt', 'pz', 'phi',
227  'daughter(0, p)', 'daughter(0, pz)', 'daughter(0, pt)', 'daughter(0, phi)',
228  'daughter(1, p)', 'daughter(1, pz)', 'daughter(1, pt)', 'daughter(1, phi)',
229  'daughter(2, p)', 'daughter(2, pz)', 'daughter(2, pt)', 'daughter(2, phi)',
230  'chiProb', 'dr', 'dz', 'dphi',
231  'daughter(0, dr)', 'daughter(1, dr)', 'daughter(0, dz)', 'daughter(1, dz)',
232  'daughter(0, dphi)', 'daughter(1, dphi)',
233  'daughter(0, chiProb)', 'daughter(1, chiProb)', 'daughter(2, chiProb)',
234  'daughter(0, kaonID)', 'daughter(0, pionID)', 'daughter(1, kaonID)', 'daughter(1, pionID)',
235  'daughterAngle(0, 1)', 'daughterAngle(0, 2)', 'daughterAngle(1, 2)',
236  'daughter(2, daughter(0, E))', 'daughter(2, daughter(1, E))',
237  'daughter(2, daughter(0, clusterTiming))', 'daughter(2, daughter(1, clusterTiming))',
238  'daughter(2, daughter(0, clusterE9E25))', 'daughter(2, daughter(1, clusterE9E25))',
239  'daughter(2, daughter(0, minC2TDist))', 'daughter(2, daughter(1, minC2TDist))',
240  'M']
241 
242  variables2 = ['p', 'pt', 'pz', 'phi',
243  'chiProb', 'dr', 'dz', 'dphi',
244  'daughter(2, chiProb)',
245  'daughter(0, kaonID)', 'daughter(0, pionID)', 'daughter(1, kaonID)', 'daughter(1, pionID)',
246  'daughter(2, daughter(0, E))', 'daughter(2, daughter(1, E))',
247  'daughter(2, daughter(0, clusterTiming))', 'daughter(2, daughter(1, clusterTiming))',
248  'daughter(2, daughter(0, clusterE9E25))', 'daughter(2, daughter(1, clusterE9E25))',
249  'daughter(2, daughter(0, minC2TDist))', 'daughter(2, daughter(1, minC2TDist))']
250 
251  train_file = find_file("mva/train_D0toKpipi.root", "examples")
252  training_data = basf2_mva.vector(train_file)
253 
254  general_options = basf2_mva.GeneralOptions()
255  general_options.m_datafiles = training_data
256  general_options.m_treename = "tree"
257  general_options.m_variables = basf2_mva.vector(*variables)
258  general_options.m_spectators = basf2_mva.vector('daughterInvM(0, 1)', 'daughterInvM(0, 2)')
259  general_options.m_target_variable = "isSignal"
260  general_options.m_identifier = "keras_adversary"
261 
262  specific_options = basf2_mva.PythonOptions()
263  specific_options.m_framework = "keras"
264  specific_options.m_steering_file = 'mva/examples/keras/adversary_network.py'
265  specific_options.m_normalize = True
266  specific_options.m_training_fraction = 0.9
267 
268  """
269  Config for Adversary Networks:
270  lambda: Trade off between classifier performance and noncorrelation between classifier output and spectators.
271  Increase to reduce correlations(and also classifier performance)
272  adversary_steps: How many batches the discriminator is trained after one batch of training the classifier.
273  Less steps make the training faster but also unstable. Increase the parameter if something isn't working.
274  number_bins: Number of Bins which are used to quantify the spectators. 10 should be sufficient.
275  """
276  specific_options.m_config = '{"adversary_steps": 5, "learning_rate": 0.001, "lambda": 20.0, "number_bins": 10}'
277  basf2_mva.teacher(general_options, specific_options)
278 
279  general_options.m_identifier = "keras_baseline"
280  specific_options.m_config = '{"adversary_steps": 1, "learning_rate": 0.001, "lambda": 0.0, "number_bins": 10}'
281  basf2_mva.teacher(general_options, specific_options)
282 
283  general_options.m_variables = basf2_mva.vector(*variables2)
284  general_options.m_identifier = "keras_feature_drop"
285  basf2_mva.teacher(general_options, specific_options)
pointer
Pointer to the current start of the batch.
Z
Spectatirs/Quantity to be uncorrelated to.
index_array
Index array containing indices from 0 to len.