Belle II Software  release-06-00-14
B2A714-DeepContinuumSuppression_MVAModel.py
1 #!/usr/bin/env python3
2 
3 
10 
11 
29 
30 import tensorflow.contrib.keras as keras
31 
32 from keras.layers import Input, Dense, Concatenate, Dropout, Lambda, GlobalAveragePooling1D, Reshape
33 from keras.models import Model
34 from keras.optimizers import adam
35 from keras.losses import binary_crossentropy, sparse_categorical_crossentropy
36 from keras.activations import sigmoid, tanh, softmax
37 from keras.callbacks import EarlyStopping
38 from sklearn.metrics import roc_auc_score
39 import numpy as np
40 
42 from basf2_mva_extensions.keras_relational import EnhancedRelations
43 from basf2_mva_extensions.preprocessing import fast_equal_frequency_binning
44 
45 import warnings
46 warnings.filterwarnings('ignore', category=UserWarning)
47 
48 
49 def slice(input, begin, end):
50  """
51  Simple function for slicing feature in tensors.
52  """
53  return input[:, begin:end]
54 
55 
57  """
58  Class to create batches for training the Adversary Network.
59  See mva/examples/keras/adversary_network.py for details.
60  """
61 
62  def __init__(self, X, Y, Z):
63  """
64  Init the class
65  :param X: Input Features
66  :param Y: Label Data
67  :param Z: Spectaters/Qunatity to be uncorrelated to
68  """
69 
70  self.XX = X
71 
72  self.YY = Y
73 
74  self.ZZ = Z
75 
76  self.lenlen = len(Y)
77 
78  self.index_arrayindex_array = np.arange(self.lenlen)
79  np.random.shuffle(self.index_arrayindex_array)
80 
81  self.pointerpointer = 0
82 
83  def next_batch(self, batch_size):
84  """
85  Getting next batch of training data
86  """
87  if self.pointerpointer + batch_size >= self.lenlen:
88  np.random.shuffle(self.index_arrayindex_array)
89  self.pointerpointer = 0
90 
91  batch_index = self.index_arrayindex_array[self.pointerpointer:self.pointerpointer + batch_size]
92  self.pointerpointer += batch_size
93 
94  return self.XX[batch_index], self.YY[batch_index], self.ZZ[batch_index]
95 
96 
97 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
98  """
99  Build the keras model for training.
100  """
101  def adversary_loss(signal):
102  """
103  Loss for adversaries outputs
104  :param signal: If signal or background distribution should be learned.
105  :return: Loss function for the discriminator part of the Network.
106  """
107  back_constant = 0 if signal else 1
108 
109  def adv_loss(y, p):
110  return (y[:, 0] - back_constant) * sparse_categorical_crossentropy(y[:, 1:], p)
111  return adv_loss
112 
113  param = {'use_relation_layers': False, 'lambda': 0, 'number_bins': 10, 'adversary_steps': 5}
114 
115  if isinstance(parameters, dict):
116  param.update(parameters)
117 
118  # Restrain training to only one GPU if your machine has multiple GPUs
119  # os.environ["CUDA_VISIBLE_DEVICES"] = '0'
120  # Uncomment if you are using GPU and don't want to occupy all GPU resources.
121  # gpu_options = tf.GPUOptions(allow_growth=True)
122  # s = tf.InteractiveSession(config=tf.ConfigProto(gpu_options=gpu_options))
123 
124  # Build classifier
125  input = Input((number_of_features,))
126 
127  # The slicing in relation layers is only accurate if your are using the variables from
128  # choose_input_features(True, True, 1).
129  # For an example of Relation Layers see: mva/examples/keras/relational_network.py
130  if param['use_relation_layers']:
131  low_level_input = Lambda(slice, arguments={'begin': 0, 'end': 560})(input)
132  high_level_input = Lambda(slice, arguments={'begin': 560, 'end': 590})(input)
133  relations_tracks = Lambda(slice, arguments={'begin': 0, 'end': 340})(low_level_input)
134  relations_tracks = Reshape((20, 17))(relations_tracks)
135  relations_clusters = Lambda(slice, arguments={'begin': 340, 'end': 560})(low_level_input)
136  relations_clusters = Reshape((20, 11))(relations_clusters)
137 
138  relations1 = EnhancedRelations(number_features=20, hidden_feature_shape=[
139  80, 80, 80])([relations_tracks, high_level_input])
140  relations2 = EnhancedRelations(number_features=20, hidden_feature_shape=[
141  80, 80, 80])([relations_clusters, high_level_input])
142 
143  relations_output1 = GlobalAveragePooling1D()(relations1)
144  relations_output2 = GlobalAveragePooling1D()(relations2)
145 
146  net = Concatenate()([relations_output1, relations_output2])
147 
148  net = Dense(units=100, activation=tanh)(net)
149  net = Dropout(0.5)(net)
150  net = Dense(units=100, activation=tanh)(net)
151  net = Dropout(0.5)(net)
152 
153  else:
154  net = Dense(units=50, activation=tanh)(input)
155  net = Dense(units=50, activation=tanh)(net)
156  net = Dense(units=50, activation=tanh)(net)
157 
158  output = Dense(units=1, activation=sigmoid)(net)
159 
160  # Model for applying Data. Loss function will not be used for training, if adversary is used.
161  apply_model = Model(input, output)
162  apply_model.compile(optimizer=adam(), loss=binary_crossentropy, metrics=['accuracy'])
163 
164  state = State(apply_model, use_adv=param['lambda'] > 0 and number_of_spectators > 0, preprocessor_state=None,
165  custom_objects={'EnhancedRelations': EnhancedRelations})
166 
167  # The following is only relevant when using Adversaries
168  # See mva/examples/keras/adversary_network.py for details
169  if state.use_adv:
170  adversaries, adversary_losses_model = [], []
171  for mode in ['signal', 'background']:
172  for i in range(number_of_spectators):
173  adversary1 = Dense(units=2 * param['number_bins'], activation=tanh, trainable=False)(output)
174  adversary2 = Dense(units=2 * param['number_bins'], activation=tanh, trainable=False)(adversary1)
175  adversaries.append(Dense(units=param['number_bins'], activation=softmax, trainable=False)(adversary2))
176 
177  adversary_losses_model.append(adversary_loss(mode == 'signal'))
178 
179  # Model which trains first part of the net
180  model1 = Model(input, [output] + adversaries)
181  model1.compile(optimizer=adam(),
182  loss=[binary_crossentropy] + adversary_losses_model, metrics=['accuracy'],
183  loss_weights=[1] + [-parameters['lambda']] * len(adversary_losses_model))
184  model1.summary()
185 
186  # Model which train second, adversary part of the net
187  model2 = Model(input, adversaries)
188  # freeze everything except adversary layers
189  for layer in model2.layers:
190  layer.trainable = not layer.trainable
191 
192  model2.compile(optimizer=adam(), loss=adversary_losses_model,
193  metrics=['accuracy'])
194  model2.summary()
195 
196  state.forward_model, state.adv_model = model1, model2
197  state.K = parameters['adversary_steps']
198  state.number_bins = param['number_bins']
199 
200  return state
201 
202 
203 def begin_fit(state, Xtest, Stest, ytest, wtest):
204  """
205  Save Validation Data for monitoring Training
206  """
207  state.Xtest = Xtest
208  state.Stest = Stest
209  state.ytest = ytest
210 
211  return state
212 
213 
214 def partial_fit(state, X, S, y, w, epoch):
215  """
216  Fit the model.
217  For every training step of MLP. Adverserial Network (if used) will be trained K times.
218  """
219  # Apply equal frequency binning for input data
220  # See mva/examples/keras/preprocessing.py for details
221  preprocessor = fast_equal_frequency_binning()
222  preprocessor.fit(X, number_of_bins=500)
223  X = preprocessor.apply(X)
224  state.Xtest = preprocessor.apply(state.Xtest)
225  # save preprocessor state in the State class
226  state.preprocessor_state = preprocessor.export_state()
227 
228  def build_adversary_target(p_y, p_s):
229  """
230  Concat isSignal and spectator bins, because both are target information for the adversary.
231  """
232  return [np.concatenate((p_y, i), axis=1) for i in np.split(p_s, len(p_s[0]), axis=1)] * 2
233 
234  if state.use_adv:
235  # Get bin numbers of S with equal frequency binning
236  S_preprocessor = fast_equal_frequency_binning()
237  S_preprocessor.fit(S, number_of_bins=state.number_bins)
238  S = S_preprocessor.apply(S) * state.number_bins
239  state.Stest = S_preprocessor.apply(state.Stest) * state.number_bins
240  # Build target for adversary loss function
241  target_array = build_adversary_target(y, S)
242  target_val_array = build_adversary_target(state.ytest, state.Stest)
243  # Build Batch Generator for adversary Callback
244  state.batch_gen = batch_generator(X, y, S)
245 
246  class AUC_Callback(keras.callbacks.Callback):
247  """
248  Callback to print AUC after every epoch.
249  """
250 
251  def on_train_begin(self, logs=None):
252  self.val_aucs = []
253 
254  def on_epoch_end(self, epoch, logs=None):
255  val_y_pred = state.model.predict(state.Xtest).flatten()
256  val_auc = roc_auc_score(state.ytest, val_y_pred)
257  print(f'\nTest AUC: {val_auc}\n')
258  self.val_aucs.append(val_auc)
259  return
260 
261  class Adversary(keras.callbacks.Callback):
262  """
263  Callback to train Adversary
264  """
265 
266  def on_batch_end(self, batch, logs=None):
267  v_X, v_y, v_S = state.batch_gen.next_batch(500 * state.K)
268  target_adversary = build_adversary_target(v_y, v_S)
269  state.adv_model.fit(v_X, target_adversary, verbose=0, batch_size=500)
270 
271  if not state.use_adv:
272  state.model.fit(X, y, batch_size=500, epochs=100000, validation_data=(state.Xtest, state.ytest),
273  callbacks=[EarlyStopping(monitor='val_loss', patience=10, mode='min'), AUC_Callback()])
274  else:
275  state.forward_model.fit(X, [y] + target_array, batch_size=500, epochs=100000,
276  callbacks=[EarlyStopping(monitor='val_loss', patience=10, mode='min'), AUC_Callback(), Adversary()],
277  validation_data=(state.Xtest, [state.ytest] + target_val_array))
278  return False
279 
280 
281 def apply(state, X):
282  """
283  Apply estimator to passed data.
284  Has to be overwritten, because also the expert has to apply preprocessing.
285  """
286  # The preprocessor state is automatically loaded in the load function
287  preprocessor = fast_equal_frequency_binning(state.preprocessor_state)
288  # Apply preprocessor
289  X = preprocessor.apply(X)
290 
291  r = state.model.predict(X).flatten()
292  return np.require(r, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
std::vector< Atom > slice(std::vector< Atom > vec, int s, int e)
Slice the vector to contain only elements with indexes s .. e (included)
Definition: Splitter.h:85