Belle II Software  release-05-01-25
B2A714-DeepContinuumSuppression_MVAModel.py
1 #!/usr/bin/env python3
2 
3 
21 
22 import tensorflow as tf
23 import tensorflow.contrib.keras as keras
24 
25 from keras.layers import Input, Dense, Concatenate, Dropout, Lambda, GlobalAveragePooling1D, Reshape
26 from keras.models import Model, load_model
27 from keras.optimizers import adam
28 from keras.losses import binary_crossentropy, sparse_categorical_crossentropy
29 from keras.activations import sigmoid, tanh, softmax
30 from keras.callbacks import Callback, EarlyStopping
31 from sklearn.metrics import roc_auc_score
32 import numpy as np
33 import os
34 
36 from basf2_mva_extensions.keras_relational import Relations, EnhancedRelations
37 from basf2_mva_extensions.preprocessing import fast_equal_frequency_binning
38 
39 import warnings
40 warnings.filterwarnings('ignore', category=UserWarning)
41 
42 
43 def slice(input, begin, end):
44  """
45  Simple function for slicing feature in tensors.
46  """
47  return input[:, begin:end]
48 
49 
51  """
52  Class to create batches for training the Adversary Network.
53  See mva/examples/keras/adversary_network.py for details.
54  """
55 
56  def __init__(self, X, Y, Z):
57  """
58  Init the class
59  :param X: Input Features
60  :param Y: Label Data
61  :param Z: Spectaters/Qunatity to be uncorrelated to
62  """
63 
64  self.X = X
65 
66  self.Y = Y
67 
68  self.Z = Z
69 
70  self.len = len(Y)
71 
72  self.index_array = np.arange(self.len)
73  np.random.shuffle(self.index_array)
74 
75  self.pointer = 0
76 
77  def next_batch(self, batch_size):
78  """
79  Getting next batch of training data
80  """
81  if self.pointer + batch_size >= self.len:
82  np.random.shuffle(self.index_array)
83  self.pointer = 0
84 
85  batch_index = self.index_array[self.pointer:self.pointer + batch_size]
86  self.pointer += batch_size
87 
88  return self.X[batch_index], self.Y[batch_index], self.Z[batch_index]
89 
90 
91 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
92  """
93  Build the keras model for training.
94  """
95  def adversary_loss(signal):
96  """
97  Loss for adversaries outputs
98  :param signal: If signal or background distribution should be learned.
99  :return: Loss function for the discriminator part of the Network.
100  """
101  back_constant = 0 if signal else 1
102 
103  def adv_loss(y, p):
104  return (y[:, 0] - back_constant) * sparse_categorical_crossentropy(y[:, 1:], p)
105  return adv_loss
106 
107  param = {'use_relation_layers': False, 'lambda': 0, 'number_bins': 10, 'adversary_steps': 5}
108 
109  if isinstance(parameters, dict):
110  param.update(parameters)
111 
112  # Restrain training to only one GPU if your machine has multiple GPUs
113  # os.environ["CUDA_VISIBLE_DEVICES"] = '0'
114  # Uncomment if you are using GPU and don't want to occupy all GPU resources.
115  # gpu_options = tf.GPUOptions(allow_growth=True)
116  # s = tf.InteractiveSession(config=tf.ConfigProto(gpu_options=gpu_options))
117 
118  # Build classifier
119  input = Input((number_of_features,))
120 
121  # The slicing in relation layers is only accurate if your are using the variables from
122  # choose_input_features(True, True, 1).
123  # For an example of Relation Layers see: mva/examples/keras/relational_network.py
124  if param['use_relation_layers']:
125  low_level_input = Lambda(slice, arguments={'begin': 0, 'end': 560})(input)
126  high_level_input = Lambda(slice, arguments={'begin': 560, 'end': 590})(input)
127  relations_tracks = Lambda(slice, arguments={'begin': 0, 'end': 340})(low_level_input)
128  relations_tracks = Reshape((20, 17))(relations_tracks)
129  relations_clusters = Lambda(slice, arguments={'begin': 340, 'end': 560})(low_level_input)
130  relations_clusters = Reshape((20, 11))(relations_clusters)
131 
132  relations1 = EnhancedRelations(number_features=20, hidden_feature_shape=[
133  80, 80, 80])([relations_tracks, high_level_input])
134  relations2 = EnhancedRelations(number_features=20, hidden_feature_shape=[
135  80, 80, 80])([relations_clusters, high_level_input])
136 
137  relations_output1 = GlobalAveragePooling1D()(relations1)
138  relations_output2 = GlobalAveragePooling1D()(relations2)
139 
140  net = Concatenate()([relations_output1, relations_output2])
141 
142  net = Dense(units=100, activation=tanh)(net)
143  net = Dropout(0.5)(net)
144  net = Dense(units=100, activation=tanh)(net)
145  net = Dropout(0.5)(net)
146 
147  else:
148  net = Dense(units=50, activation=tanh)(input)
149  net = Dense(units=50, activation=tanh)(net)
150  net = Dense(units=50, activation=tanh)(net)
151 
152  output = Dense(units=1, activation=sigmoid)(net)
153 
154  # Model for applying Data. Loss function will not be used for training, if adversary is used.
155  apply_model = Model(input, output)
156  apply_model.compile(optimizer=adam(), loss=binary_crossentropy, metrics=['accuracy'])
157 
158  state = State(apply_model, use_adv=param['lambda'] > 0 and number_of_spectators > 0, preprocessor_state=None,
159  custom_objects={'EnhancedRelations': EnhancedRelations})
160 
161  # The following is only relevant when using Adversaries
162  # See mva/examples/keras/adversary_network.py for details
163  if state.use_adv:
164  adversaries, adversary_losses_model = [], []
165  for mode in ['signal', 'background']:
166  for i in range(number_of_spectators):
167  adversary1 = Dense(units=2 * param['number_bins'], activation=tanh, trainable=False)(output)
168  adversary2 = Dense(units=2 * param['number_bins'], activation=tanh, trainable=False)(adversary1)
169  adversaries.append(Dense(units=param['number_bins'], activation=softmax, trainable=False)(adversary2))
170 
171  adversary_losses_model.append(adversary_loss(mode == 'signal'))
172 
173  # Model which trains first part of the net
174  model1 = Model(input, [output] + adversaries)
175  model1.compile(optimizer=adam(),
176  loss=[binary_crossentropy] + adversary_losses_model, metrics=['accuracy'],
177  loss_weights=[1] + [-parameters['lambda']] * len(adversary_losses_model))
178  model1.summary()
179 
180  # Model which train second, adversary part of the net
181  model2 = Model(input, adversaries)
182  # freeze everything except adversary layers
183  for layer in model2.layers:
184  layer.trainable = not layer.trainable
185 
186  model2.compile(optimizer=adam(), loss=adversary_losses_model,
187  metrics=['accuracy'])
188  model2.summary()
189 
190  state.forward_model, state.adv_model = model1, model2
191  state.K = parameters['adversary_steps']
192  state.number_bins = param['number_bins']
193 
194  return state
195 
196 
197 def begin_fit(state, Xtest, Stest, ytest, wtest):
198  """
199  Save Validation Data for monitoring Training
200  """
201  state.Xtest = Xtest
202  state.Stest = Stest
203  state.ytest = ytest
204 
205  return state
206 
207 
208 def partial_fit(state, X, S, y, w, epoch):
209  """
210  Fit the model.
211  For every training step of MLP. Adverserial Network (if used) will be trained K times.
212  """
213  # Apply equal frequency binning for input data
214  # See mva/examples/keras/preprocessing.py for details
215  preprocessor = fast_equal_frequency_binning()
216  preprocessor.fit(X, number_of_bins=500)
217  X = preprocessor.apply(X)
218  state.Xtest = preprocessor.apply(state.Xtest)
219  # save preprocessor state in the State class
220  state.preprocessor_state = preprocessor.export_state()
221 
222  def build_adversary_target(p_y, p_s):
223  """
224  Concat isSignal and spectator bins, because both are target information for the adversary.
225  """
226  return [np.concatenate((p_y, i), axis=1) for i in np.split(p_s, len(p_s[0]), axis=1)] * 2
227 
228  if state.use_adv:
229  # Get bin numbers of S with equal frequency binning
230  S_preprocessor = fast_equal_frequency_binning()
231  S_preprocessor.fit(S, number_of_bins=state.number_bins)
232  S = S_preprocessor.apply(S) * state.number_bins
233  state.Stest = S_preprocessor.apply(state.Stest) * state.number_bins
234  # Build target for adversary loss function
235  target_array = build_adversary_target(y, S)
236  target_val_array = build_adversary_target(state.ytest, state.Stest)
237  # Build Batch Generator for adversary Callback
238  state.batch_gen = batch_generator(X, y, S)
239 
240  class AUC_Callback(keras.callbacks.Callback):
241  """
242  Callback to print AUC after every epoch.
243  """
244 
245  def on_train_begin(self, logs={}):
246  self.val_aucs = []
247 
248  def on_epoch_end(self, epoch, logs={}):
249  val_y_pred = state.model.predict(state.Xtest).flatten()
250  val_auc = roc_auc_score(state.ytest, val_y_pred)
251  print('\nTest AUC: {}\n'.format(val_auc))
252  self.val_aucs.append(val_auc)
253  return
254 
255  class Adversary(keras.callbacks.Callback):
256  """
257  Callback to train Adversary
258  """
259 
260  def on_batch_end(self, batch, logs={}):
261  v_X, v_y, v_S = state.batch_gen.next_batch(500 * state.K)
262  target_adversary = build_adversary_target(v_y, v_S)
263  state.adv_model.fit(v_X, target_adversary, verbose=0, batch_size=500)
264 
265  if not state.use_adv:
266  state.model.fit(X, y, batch_size=500, epochs=100000, validation_data=(state.Xtest, state.ytest),
267  callbacks=[EarlyStopping(monitor='val_loss', patience=10, mode='min'), AUC_Callback()])
268  else:
269  state.forward_model.fit(X, [y] + target_array, batch_size=500, epochs=100000,
270  callbacks=[EarlyStopping(monitor='val_loss', patience=10, mode='min'), AUC_Callback(), Adversary()],
271  validation_data=(state.Xtest, [state.ytest] + target_val_array))
272  return False
273 
274 
275 def apply(state, X):
276  """
277  Apply estimator to passed data.
278  Has to be overwritten, because also the expert has to apply preprocessing.
279  """
280  # The preprocessor state is automatically loaded in the load function
281  preprocessor = fast_equal_frequency_binning(state.preprocessor_state)
282  # Apply preprocessor
283  X = preprocessor.apply(X)
284 
285  r = state.model.predict(X).flatten()
286  return np.require(r, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
B2A714-DeepContinuumSuppression_MVAModel.batch_generator.__init__
def __init__(self, X, Y, Z)
Definition: B2A714-DeepContinuumSuppression_MVAModel.py:56
B2A714-DeepContinuumSuppression_MVAModel.batch_generator.len
len
Number of events.
Definition: B2A714-DeepContinuumSuppression_MVAModel.py:70
B2A714-DeepContinuumSuppression_MVAModel.batch_generator.Y
Y
Label Data.
Definition: B2A714-DeepContinuumSuppression_MVAModel.py:66
B2A714-DeepContinuumSuppression_MVAModel.batch_generator
Definition: B2A714-DeepContinuumSuppression_MVAModel.py:50
B2A714-DeepContinuumSuppression_MVAModel.batch_generator.Z
Z
Spectators.
Definition: B2A714-DeepContinuumSuppression_MVAModel.py:68
B2A714-DeepContinuumSuppression_MVAModel.batch_generator.next_batch
def next_batch(self, batch_size)
Definition: B2A714-DeepContinuumSuppression_MVAModel.py:77
basf2_mva_python_interface.contrib_keras
Definition: contrib_keras.py:1
B2A714-DeepContinuumSuppression_MVAModel.batch_generator.index_array
index_array
Index array, which will be shuffled.
Definition: B2A714-DeepContinuumSuppression_MVAModel.py:72
Belle2::slice
std::vector< Atom > slice(std::vector< Atom > vec, int s, int e)
Slice the vector to contain only elements with indexes s .. e (included)
Definition: Splitter.h:89
B2A714-DeepContinuumSuppression_MVAModel.batch_generator.pointer
pointer
Pointer for index array.
Definition: B2A714-DeepContinuumSuppression_MVAModel.py:75
basf2_mva_python_interface.contrib_keras.State
Definition: contrib_keras.py:15
B2A714-DeepContinuumSuppression_MVAModel.batch_generator.X
X
Input Features.
Definition: B2A714-DeepContinuumSuppression_MVAModel.py:64