18 from tensorflow.keras.layers
import Dense, Input
19 from tensorflow.keras.models
import Model
20 from tensorflow.keras.optimizers
import Adam
21 from tensorflow.keras.losses
import binary_crossentropy, sparse_categorical_crossentropy
22 from tensorflow.keras.activations
import sigmoid, tanh, softmax
23 from tensorflow.keras.callbacks
import EarlyStopping, Callback
24 from tensorflow.keras.utils
import plot_model
27 from basf2_mva_extensions.preprocessing
import fast_equal_frequency_binning
29 from sklearn.metrics
import roc_auc_score
32 warnings.filterwarnings(
'ignore', category=UserWarning)
37 Class to create batches for training the Adversary Network.
38 Once the steps_per_epoch argument is available for the fit function in keras, this class will become obsolete.
44 :param X: Input Features
46 :param Z: Spectators/Quantity to be uncorrelated to
63 Getting next batch of training data
65 if self.
pointerpointer + batch_size >= self.
lenlen:
70 self.
pointerpointer += batch_size
72 return self.
XX[batch_index], self.
YY[batch_index], self.
ZZ[batch_index]
75 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
77 Building 3 keras models:
78 1. Network without adversary, used for apply data.
79 2. Freezed MLP with unfreezed Adverserial Network to train adverserial part of network.
80 3. Unfreezed MLP with freezed adverserial to train MLP part of the network,
81 combined with losses of the adverserial networks.
84 def adversary_loss(signal):
86 Loss for adversaries outputs
87 :param signal: If signal or background distribution should be learned.
88 :return: Loss function for the discriminator part of the Network.
90 back_constant = 0
if signal
else 1
93 return (y[:, 0] - back_constant) * sparse_categorical_crossentropy(y[:, 1:], p)
97 input = Input(shape=(number_of_features,))
100 layer1 = Dense(units=number_of_features + 1, activation=tanh)(input)
101 layer2 = Dense(units=number_of_features + 1, activation=tanh)(layer1)
102 layer3 = Dense(units=number_of_features + 1, activation=tanh)(layer2)
103 output = Dense(units=1, activation=sigmoid)(layer3)
106 apply_model = Model(input, output)
107 apply_model.compile(optimizer=Adam(lr=parameters[
'learning_rate']), loss=binary_crossentropy, metrics=[
'accuracy'])
109 state =
State(apply_model, use_adv=parameters[
'lambda'] > 0
and number_of_spectators > 0)
110 state.number_bins = parameters[
'number_bins']
113 adversaries, adversary_losses_model = [], []
115 for mode
in [
'signal',
'background']:
116 for i
in range(number_of_spectators):
117 adversary1 = Dense(units=2 * parameters[
'number_bins'], activation=tanh, trainable=
False)(output)
118 adversary2 = Dense(units=2 * parameters[
'number_bins'], activation=tanh, trainable=
False)(adversary1)
119 adversaries.append(Dense(units=parameters[
'number_bins'], activation=softmax, trainable=
False)(adversary2))
121 adversary_losses_model.append(adversary_loss(mode ==
'signal'))
124 model1 = Model(input, [output] + adversaries)
125 model1.compile(optimizer=Adam(lr=parameters[
'learning_rate']),
126 loss=[binary_crossentropy] + adversary_losses_model, metrics=[
'accuracy'],
127 loss_weights=[1] + [-parameters[
'lambda']] *
len(adversary_losses_model))
131 model2 = Model(input, adversaries)
133 for layer
in model2.layers:
134 layer.trainable =
not layer.trainable
136 model2.compile(optimizer=Adam(lr=parameters[
'learning_rate']), loss=adversary_losses_model,
137 metrics=[
'accuracy'])
140 state.forward_model, state.adv_model = model1, model2
141 state.K = parameters[
'adversary_steps']
144 plot_model(model1, to_file=
'model.png', show_shapes=
True)
149 def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
151 Save Validation Data for monitoring Training
160 def partial_fit(state, X, S, y, w, epoch, batch):
163 For every training step of MLP. Adverserial Network will be trained K times.
166 def build_adversary_target(p_y, p_s):
168 Concat isSignal and spectator bins, because both are target information for the adversary.
170 return [np.concatenate((p_y, i), axis=1)
for i
in np.split(p_s,
len(p_s[0]), axis=1)] * 2
174 preprocessor = fast_equal_frequency_binning()
175 preprocessor.fit(S, number_of_bins=state.number_bins)
176 S = preprocessor.apply(S) * state.number_bins
177 state.Stest = preprocessor.apply(state.Stest) * state.number_bins
179 target_array = build_adversary_target(y, S)
180 target_val_array = build_adversary_target(state.ytest, state.Stest)
184 class AUC_Callback(Callback):
186 Callback to print AUC after every epoch.
189 def on_train_begin(self, logs=None):
192 def on_epoch_end(self, epoch, logs=None):
193 val_y_pred = state.model.predict(state.Xtest).flatten()
194 val_auc = roc_auc_score(state.ytest, val_y_pred)
195 print(f
'\nTest AUC: {val_auc}\n')
196 self.val_aucs.append(val_auc)
199 class Adversary(Callback):
201 Callback to train Adversary
204 def on_batch_end(self, batch, logs=None):
205 v_X, v_y, v_S = state.batch_gen.next_batch(400 * state.K)
206 target_adversary = build_adversary_target(v_y, v_S)
207 state.adv_model.fit(v_X, target_adversary, verbose=0, batch_size=400)
209 if not state.use_adv:
210 state.model.fit(X, y, batch_size=400, epochs=1000, validation_data=(state.Xtest, state.ytest),
211 callbacks=[EarlyStopping(monitor=
'val_loss', patience=2, mode=
'min'), AUC_Callback()])
213 state.forward_model.fit(X, [y] + target_array, batch_size=400, epochs=1000,
214 callbacks=[EarlyStopping(monitor=
'val_loss', patience=2, mode=
'min'), AUC_Callback(), Adversary()],
215 validation_data=(state.Xtest, [state.ytest] + target_val_array))
219 if __name__ ==
"__main__":
220 from basf2
import conditions, find_file
222 conditions.testing_payloads = [
223 'localdb/database.txt'
226 variables = [
'p',
'pt',
'pz',
'phi',
227 'daughter(0, p)',
'daughter(0, pz)',
'daughter(0, pt)',
'daughter(0, phi)',
228 'daughter(1, p)',
'daughter(1, pz)',
'daughter(1, pt)',
'daughter(1, phi)',
229 'daughter(2, p)',
'daughter(2, pz)',
'daughter(2, pt)',
'daughter(2, phi)',
230 'chiProb',
'dr',
'dz',
'dphi',
231 'daughter(0, dr)',
'daughter(1, dr)',
'daughter(0, dz)',
'daughter(1, dz)',
232 'daughter(0, dphi)',
'daughter(1, dphi)',
233 'daughter(0, chiProb)',
'daughter(1, chiProb)',
'daughter(2, chiProb)',
234 'daughter(0, kaonID)',
'daughter(0, pionID)',
'daughter(1, kaonID)',
'daughter(1, pionID)',
235 'daughterAngle(0, 1)',
'daughterAngle(0, 2)',
'daughterAngle(1, 2)',
236 'daughter(2, daughter(0, E))',
'daughter(2, daughter(1, E))',
237 'daughter(2, daughter(0, clusterTiming))',
'daughter(2, daughter(1, clusterTiming))',
238 'daughter(2, daughter(0, clusterE9E25))',
'daughter(2, daughter(1, clusterE9E25))',
239 'daughter(2, daughter(0, minC2TDist))',
'daughter(2, daughter(1, minC2TDist))',
242 variables2 = [
'p',
'pt',
'pz',
'phi',
243 'chiProb',
'dr',
'dz',
'dphi',
244 'daughter(2, chiProb)',
245 'daughter(0, kaonID)',
'daughter(0, pionID)',
'daughter(1, kaonID)',
'daughter(1, pionID)',
246 'daughter(2, daughter(0, E))',
'daughter(2, daughter(1, E))',
247 'daughter(2, daughter(0, clusterTiming))',
'daughter(2, daughter(1, clusterTiming))',
248 'daughter(2, daughter(0, clusterE9E25))',
'daughter(2, daughter(1, clusterE9E25))',
249 'daughter(2, daughter(0, minC2TDist))',
'daughter(2, daughter(1, minC2TDist))']
251 train_file = find_file(
"mva/train_D0toKpipi.root",
"examples")
252 training_data = basf2_mva.vector(train_file)
254 general_options = basf2_mva.GeneralOptions()
255 general_options.m_datafiles = training_data
256 general_options.m_treename =
"tree"
257 general_options.m_variables = basf2_mva.vector(*variables)
258 general_options.m_spectators = basf2_mva.vector(
'daughterInvM(0, 1)',
'daughterInvM(0, 2)')
259 general_options.m_target_variable =
"isSignal"
260 general_options.m_identifier =
"keras_adversary"
262 specific_options = basf2_mva.PythonOptions()
263 specific_options.m_framework =
"keras"
264 specific_options.m_steering_file =
'mva/examples/keras/adversary_network.py'
265 specific_options.m_normalize =
True
266 specific_options.m_training_fraction = 0.9
269 Config for Adversary Networks:
270 lambda: Trade off between classifier performance and noncorrelation between classifier output and spectators.
271 Increase to reduce correlations(and also classifier performance)
272 adversary_steps: How many batches the discriminator is trained after one batch of training the classifier.
273 Less steps make the training faster but also unstable. Increase the parameter if something isn't working.
274 number_bins: Number of Bins which are used to quantify the spectators. 10 should be sufficient.
276 specific_options.m_config =
'{"adversary_steps": 5, "learning_rate": 0.001, "lambda": 20.0, "number_bins": 10}'
277 basf2_mva.teacher(general_options, specific_options)
279 general_options.m_identifier =
"keras_baseline"
280 specific_options.m_config =
'{"adversary_steps": 1, "learning_rate": 0.001, "lambda": 0.0, "number_bins": 10}'
281 basf2_mva.teacher(general_options, specific_options)
283 general_options.m_variables = basf2_mva.vector(*variables2)
284 general_options.m_identifier =
"keras_feature_drop"
285 basf2_mva.teacher(general_options, specific_options)
pointer
Pointer to the current start of the batch.
def __init__(self, X, Y, Z)
Z
Spectatirs/Quantity to be uncorrelated to.
index_array
Index array containing indices from 0 to len.
def next_batch(self, batch_size)