Belle II Software  release-08-01-10
B2A714-DeepContinuumSuppression_MVAModel.py
1 #!/usr/bin/env python3
2 
3 
10 
11 
29 
30 from tensorflow.keras.layers import Input, Dense, Concatenate, Dropout, Lambda, GlobalAveragePooling1D, Reshape
31 from tensorflow.keras.models import Model
32 from tensorflow.keras.optimizers import Adam
33 from tensorflow.keras.losses import binary_crossentropy, sparse_categorical_crossentropy
34 from tensorflow.keras.activations import sigmoid, tanh, softmax
35 from tensorflow.keras.callbacks import EarlyStopping, Callback
36 import numpy as np
37 
38 from basf2_mva_python_interface.keras import State
39 from basf2_mva_extensions.keras_relational import EnhancedRelations
40 from basf2_mva_extensions.preprocessing import fast_equal_frequency_binning
41 import basf2_mva_util
42 
43 import warnings
44 warnings.filterwarnings('ignore', category=UserWarning)
45 
46 
47 def slice(input, begin, end):
48  """
49  Simple function for slicing feature in tensors.
50  """
51  return input[:, begin:end]
52 
53 
55  """
56  Class to create batches for training the Adversary Network.
57  See mva/examples/keras/adversary_network.py for details.
58  """
59 
60  def __init__(self, X, Y, Z):
61  """
62  Init the class
63  :param X: Input Features
64  :param Y: Label Data
65  :param Z: Spectaters/Qunatity to be uncorrelated to
66  """
67 
68  self.XX = X
69 
70  self.YY = Y
71 
72  self.ZZ = Z
73 
74  self.lenlen = len(Y)
75 
76  self.index_arrayindex_array = np.arange(self.lenlen)
77  np.random.shuffle(self.index_arrayindex_array)
78 
79  self.pointerpointer = 0
80 
81  def next_batch(self, batch_size):
82  """
83  Getting next batch of training data
84  """
85  if self.pointerpointer + batch_size >= self.lenlen:
86  np.random.shuffle(self.index_arrayindex_array)
87  self.pointerpointer = 0
88 
89  batch_index = self.index_arrayindex_array[self.pointerpointer:self.pointerpointer + batch_size]
90  self.pointerpointer += batch_size
91 
92  return self.XX[batch_index], self.YY[batch_index], self.ZZ[batch_index]
93 
94 
95 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
96  """
97  Build the keras model for training.
98  """
99  def adversary_loss(signal):
100  """
101  Loss for adversaries outputs
102  :param signal: If signal or background distribution should be learned.
103  :return: Loss function for the discriminator part of the Network.
104  """
105  back_constant = 0 if signal else 1
106 
107  def adv_loss(y, p):
108  return (y[:, 0] - back_constant) * sparse_categorical_crossentropy(y[:, 1:], p)
109  return adv_loss
110 
111  param = {'use_relation_layers': False, 'lambda': 0, 'number_bins': 10, 'adversary_steps': 5}
112 
113  if isinstance(parameters, dict):
114  param.update(parameters)
115 
116  # Restrain training to only one GPU if your machine has multiple GPUs
117  # insert tensorflow as tf
118  # gpus = tf.config.list_physical_devices('GPU')
119  # if gpus:
120  # for gpu in gpus:
121  # tf.config.experimental.set_memory_growth(gpu, True)
122 
123  # Build classifier
124  input = Input((number_of_features,))
125 
126  # The slicing in relation layers is only accurate if your are using the variables from
127  # choose_input_features(True, True, 1).
128  # For an example of Relation Layers see: mva/examples/keras/relational_network.py
129  if param['use_relation_layers']:
130  low_level_input = Lambda(slice, arguments={'begin': 0, 'end': 560})(input)
131  high_level_input = Lambda(slice, arguments={'begin': 560, 'end': 590})(input)
132  relations_tracks = Lambda(slice, arguments={'begin': 0, 'end': 340})(low_level_input)
133  relations_tracks = Reshape((20, 17))(relations_tracks)
134  relations_clusters = Lambda(slice, arguments={'begin': 340, 'end': 560})(low_level_input)
135  relations_clusters = Reshape((20, 11))(relations_clusters)
136 
137  relations1 = EnhancedRelations(number_features=20, hidden_feature_shape=[
138  80, 80, 80])([relations_tracks, high_level_input])
139  relations2 = EnhancedRelations(number_features=20, hidden_feature_shape=[
140  80, 80, 80])([relations_clusters, high_level_input])
141 
142  relations_output1 = GlobalAveragePooling1D()(relations1)
143  relations_output2 = GlobalAveragePooling1D()(relations2)
144 
145  net = Concatenate()([relations_output1, relations_output2])
146 
147  net = Dense(units=100, activation=tanh)(net)
148  net = Dropout(0.5)(net)
149  net = Dense(units=100, activation=tanh)(net)
150  net = Dropout(0.5)(net)
151 
152  else:
153  net = Dense(units=50, activation=tanh)(input)
154  net = Dense(units=50, activation=tanh)(net)
155  net = Dense(units=50, activation=tanh)(net)
156 
157  output = Dense(units=1, activation=sigmoid)(net)
158 
159  # Model for applying Data. Loss function will not be used for training, if adversary is used.
160  apply_model = Model(input, output)
161  apply_model.compile(optimizer=Adam(), loss=binary_crossentropy, metrics=['accuracy'])
162 
163  state = State(apply_model, use_adv=param['lambda'] > 0 and number_of_spectators > 0, preprocessor_state=None)
164 
165  # The following is only relevant when using Adversaries
166  # See mva/examples/keras/adversary_network.py for details
167  if state.use_adv:
168  adversaries, adversary_losses_model = [], []
169  for mode in ['signal', 'background']:
170  for i in range(number_of_spectators):
171  adversary1 = Dense(units=2 * param['number_bins'], activation=tanh, trainable=False)(output)
172  adversary2 = Dense(units=2 * param['number_bins'], activation=tanh, trainable=False)(adversary1)
173  adversaries.append(Dense(units=param['number_bins'], activation=softmax, trainable=False)(adversary2))
174 
175  adversary_losses_model.append(adversary_loss(mode == 'signal'))
176 
177  # Model which trains first part of the net
178  model1 = Model(input, [output] + adversaries)
179  model1.compile(optimizer=Adam(),
180  loss=[binary_crossentropy] + adversary_losses_model, metrics=['accuracy'],
181  loss_weights=[1] + [-parameters['lambda']] * len(adversary_losses_model))
182  model1.summary()
183 
184  # Model which train second, adversary part of the net
185  model2 = Model(input, adversaries)
186  # freeze everything except adversary layers
187  for layer in model2.layers:
188  layer.trainable = not layer.trainable
189 
190  model2.compile(optimizer=Adam(), loss=adversary_losses_model,
191  metrics=['accuracy'])
192  model2.summary()
193 
194  state.forward_model, state.adv_model = model1, model2
195  state.K = parameters['adversary_steps']
196  state.number_bins = param['number_bins']
197 
198  return state
199 
200 
201 def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
202  """
203  Save Validation Data for monitoring Training
204  """
205  state.Xtest = Xtest
206  state.Stest = Stest
207  state.ytest = ytest
208 
209  return state
210 
211 
212 def partial_fit(state, X, S, y, w, epoch, batch):
213  """
214  Fit the model.
215  For every training step of MLP. Adverserial Network (if used) will be trained K times.
216  """
217  # Apply equal frequency binning for input data
218  # See mva/examples/keras/preprocessing.py for details
219  preprocessor = fast_equal_frequency_binning()
220  preprocessor.fit(X, number_of_bins=500)
221  X = preprocessor.apply(X)
222  state.Xtest = preprocessor.apply(state.Xtest)
223  # save preprocessor state in the State class
224  state.preprocessor_state = preprocessor.export_state()
225 
226  def build_adversary_target(p_y, p_s):
227  """
228  Concat isSignal and spectator bins, because both are target information for the adversary.
229  """
230  return [np.concatenate((p_y, i), axis=1) for i in np.split(p_s, len(p_s[0]), axis=1)] * 2
231 
232  if state.use_adv:
233  # Get bin numbers of S with equal frequency binning
234  S_preprocessor = fast_equal_frequency_binning()
235  S_preprocessor.fit(S, number_of_bins=state.number_bins)
236  S = S_preprocessor.apply(S) * state.number_bins
237  state.Stest = S_preprocessor.apply(state.Stest) * state.number_bins
238  # Build target for adversary loss function
239  target_array = build_adversary_target(y, S)
240  target_val_array = build_adversary_target(state.ytest, state.Stest)
241  # Build Batch Generator for adversary Callback
242  state.batch_gen = batch_generator(X, y, S)
243 
244  class AUC_Callback(Callback):
245  """
246  Callback to print AUC after every epoch.
247  """
248 
249  def on_train_begin(self, logs=None):
250  self.val_aucs = []
251 
252  def on_epoch_end(self, epoch, logs=None):
253  val_y_pred = state.model.predict(state.Xtest).flatten()
254  val_auc = basf2_mva_util.calculate_auc_efficiency_vs_background_retention(val_y_pred, state.ytest)
255  print(f'\nTest AUC: {val_auc}\n')
256  self.val_aucs.append(val_auc)
257  return
258 
259  class Adversary(Callback):
260  """
261  Callback to train Adversary
262  """
263 
264  def on_batch_end(self, batch, logs=None):
265  v_X, v_y, v_S = state.batch_gen.next_batch(500 * state.K)
266  target_adversary = build_adversary_target(v_y, v_S)
267  state.adv_model.fit(v_X, target_adversary, verbose=0, batch_size=500)
268 
269  if not state.use_adv:
270  state.model.fit(X, y, batch_size=500, epochs=100000, validation_data=(state.Xtest, state.ytest),
271  callbacks=[EarlyStopping(monitor='val_loss', patience=10, mode='min'), AUC_Callback()])
272  else:
273  state.forward_model.fit(X, [y] + target_array, batch_size=500, epochs=100000,
274  callbacks=[EarlyStopping(monitor='val_loss', patience=10, mode='min'), AUC_Callback(), Adversary()],
275  validation_data=(state.Xtest, [state.ytest] + target_val_array))
276  return False
277 
278 
279 def apply(state, X):
280  """
281  Apply estimator to passed data.
282  Has to be overwritten, because also the expert has to apply preprocessing.
283  """
284  # The preprocessor state is automatically loaded in the load function
285  preprocessor = fast_equal_frequency_binning(state.preprocessor_state)
286  # Apply preprocessor
287  X = preprocessor.apply(X)
288 
289  r = state.model.predict(X).flatten()
290  return np.require(r, dtype=np.float32, requirements=['A', 'W', 'C', 'O'])
def calculate_auc_efficiency_vs_background_retention(p, t, w=None)
std::vector< Atom > slice(std::vector< Atom > vec, int s, int e)
Slice the vector to contain only elements with indexes s .. e (included)
Definition: Splitter.h:85