18from keras.layers
import Dense, Input
19from keras.models
import Model
20from keras.optimizers
import Adam
21from keras.losses
import binary_crossentropy, sparse_categorical_crossentropy
22from keras.activations
import sigmoid, tanh, softmax
23from keras.callbacks
import EarlyStopping, Callback
26from basf2_mva_extensions.preprocessing
import fast_equal_frequency_binning
28from sklearn.metrics
import roc_auc_score
31warnings.filterwarnings(
'ignore', category=UserWarning)
36 Class to create batches for training the Adversary Network.
37 Once the steps_per_epoch argument
is available
for the fit function
in keras, this
class will become obsolete.
43 :param X: Input Features
45 :param Z: Spectators/Quantity to be uncorrelated to
62 Getting next batch of training data
71 return self.
X[batch_index], self.
Y[batch_index], self.
Z[batch_index]
74def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
76 Building 3 keras models:
77 1. Network without adversary, used for apply data.
78 2. Frozen MLP
with unfrozen Adverserial Network to train adverserial part of network.
79 3. Unfrozen MLP
with frozen adverserial to train MLP part of the network,
80 combined
with losses of the adverserial networks.
83 def adversary_loss(signal):
85 Loss for adversaries outputs
86 :param signal: If signal
or background distribution should be learned.
87 :
return: Loss function
for the discriminator part of the Network.
89 back_constant = 0 if signal
else 1
92 return (y[:, 0] - back_constant) * sparse_categorical_crossentropy(y[:, 1:], p)
96 input = Input(shape=(number_of_features,))
99 layer1 = Dense(units=number_of_features + 1, activation=tanh)(input)
100 layer2 = Dense(units=number_of_features + 1, activation=tanh)(layer1)
101 layer3 = Dense(units=number_of_features + 1, activation=tanh)(layer2)
102 output = Dense(units=1, activation=sigmoid)(layer3)
105 apply_model = Model(input, output)
106 apply_model.compile(optimizer=Adam(learning_rate=parameters[
'learning_rate']), loss=binary_crossentropy, metrics=[
'accuracy'])
108 state =
State(apply_model, use_adv=parameters[
'lambda'] > 0
and number_of_spectators > 0)
109 state.number_bins = parameters[
'number_bins']
112 adversaries, adversary_losses_model = [], []
114 for mode
in [
'signal',
'background']:
115 for i
in range(number_of_spectators):
116 adversary1 = Dense(units=2 * parameters[
'number_bins'], activation=tanh, trainable=
False)(output)
117 adversary2 = Dense(units=2 * parameters[
'number_bins'], activation=tanh, trainable=
False)(adversary1)
118 adversaries.append(Dense(units=parameters[
'number_bins'], activation=softmax, trainable=
False)(adversary2))
120 adversary_losses_model.append(adversary_loss(mode ==
'signal'))
123 model1 = Model(input, [output] + adversaries)
124 model1.compile(optimizer=Adam(learning_rate=parameters[
'learning_rate']),
125 loss=[binary_crossentropy] + adversary_losses_model, metrics=[
'accuracy'] * (
len(adversaries) + 1),
126 loss_weights=[1] + [-parameters[
'lambda']] *
len(adversary_losses_model))
130 model2 = Model(input, adversaries)
132 for layer
in model2.layers:
133 layer.trainable =
not layer.trainable
135 model2.compile(optimizer=Adam(learning_rate=parameters[
'learning_rate']), loss=adversary_losses_model,
136 metrics=[
'accuracy'] *
len(adversaries))
139 state.forward_model, state.adv_model = model1, model2
140 state.K = parameters[
'adversary_steps']
145def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
147 Save Validation Data for monitoring Training
156def partial_fit(state, X, S, y, w, epoch, batch):
159 For every training step of MLP. Adverserial Network will be trained K times.
162 def build_adversary_target(p_y, p_s):
164 Concat isSignal and spectator bins, because both are target information
for the adversary.
166 return [np.concatenate((p_y, i), axis=1)
for i
in np.split(p_s,
len(p_s[0]), axis=1)] * 2
170 preprocessor = fast_equal_frequency_binning()
171 preprocessor.fit(S, number_of_bins=state.number_bins)
172 S = preprocessor.apply(S) * state.number_bins
173 state.Stest = preprocessor.apply(state.Stest) * state.number_bins
175 target_array = build_adversary_target(y, S)
176 target_val_array = build_adversary_target(state.ytest, state.Stest)
180 class AUC_Callback(Callback):
182 Callback to print AUC after every epoch.
185 def on_train_begin(self, logs=None):
188 def on_epoch_end(self, epoch, logs=None):
189 val_y_pred = state.model.predict(state.Xtest).flatten()
190 val_auc = roc_auc_score(state.ytest, val_y_pred)
191 print(f
'\nTest AUC: {val_auc}\n')
192 self.val_aucs.append(val_auc)
195 class Adversary(Callback):
197 Callback to train Adversary
200 def on_batch_end(self, batch, logs=None):
201 v_X, v_y, v_S = state.batch_gen.next_batch(400 * state.K)
202 target_adversary = build_adversary_target(v_y, v_S)
203 state.adv_model.fit(v_X, target_adversary, verbose=0, batch_size=400)
205 if not state.use_adv:
206 state.model.fit(X, y, batch_size=400, epochs=1000, validation_data=(state.Xtest, state.ytest),
207 callbacks=[EarlyStopping(monitor=
'val_loss', patience=2, mode=
'min'), AUC_Callback()])
209 state.forward_model.fit(X, [y] + target_array, batch_size=400, epochs=1000,
210 callbacks=[EarlyStopping(monitor=
'val_loss', patience=2, mode=
'min'), AUC_Callback(), Adversary()],
211 validation_data=(state.Xtest, [state.ytest] + target_val_array))
215if __name__ ==
"__main__":
216 from basf2
import conditions, find_file
218 conditions.testing_payloads = [
219 'localdb/database.txt'
222 variables = [
'p',
'pt',
'pz',
'phi',
223 'daughter(0, p)',
'daughter(0, pz)',
'daughter(0, pt)',
'daughter(0, phi)',
224 'daughter(1, p)',
'daughter(1, pz)',
'daughter(1, pt)',
'daughter(1, phi)',
225 'daughter(2, p)',
'daughter(2, pz)',
'daughter(2, pt)',
'daughter(2, phi)',
226 'chiProb',
'dr',
'dz',
'dphi',
227 'daughter(0, dr)',
'daughter(1, dr)',
'daughter(0, dz)',
'daughter(1, dz)',
228 'daughter(0, dphi)',
'daughter(1, dphi)',
229 'daughter(0, chiProb)',
'daughter(1, chiProb)',
'daughter(2, chiProb)',
230 'daughter(0, kaonID)',
'daughter(0, pionID)',
'daughter(1, kaonID)',
'daughter(1, pionID)',
231 'daughterAngle(0, 1)',
'daughterAngle(0, 2)',
'daughterAngle(1, 2)',
232 'daughter(2, daughter(0, E))',
'daughter(2, daughter(1, E))',
233 'daughter(2, daughter(0, clusterTiming))',
'daughter(2, daughter(1, clusterTiming))',
234 'daughter(2, daughter(0, clusterE9E25))',
'daughter(2, daughter(1, clusterE9E25))',
235 'daughter(2, daughter(0, minC2TDist))',
'daughter(2, daughter(1, minC2TDist))',
238 variables2 = [
'p',
'pt',
'pz',
'phi',
239 'chiProb',
'dr',
'dz',
'dphi',
240 'daughter(2, chiProb)',
241 'daughter(0, kaonID)',
'daughter(0, pionID)',
'daughter(1, kaonID)',
'daughter(1, pionID)',
242 'daughter(2, daughter(0, E))',
'daughter(2, daughter(1, E))',
243 'daughter(2, daughter(0, clusterTiming))',
'daughter(2, daughter(1, clusterTiming))',
244 'daughter(2, daughter(0, clusterE9E25))',
'daughter(2, daughter(1, clusterE9E25))',
245 'daughter(2, daughter(0, minC2TDist))',
'daughter(2, daughter(1, minC2TDist))']
247 train_file = find_file(
"mva/train_D0toKpipi.root",
"examples")
248 training_data = basf2_mva.vector(train_file)
250 general_options = basf2_mva.GeneralOptions()
251 general_options.m_datafiles = training_data
252 general_options.m_treename =
"tree"
253 general_options.m_variables = basf2_mva.vector(*variables)
254 general_options.m_spectators = basf2_mva.vector(
'daughterInvM(0, 1)',
'daughterInvM(0, 2)')
255 general_options.m_target_variable =
"isSignal"
256 general_options.m_identifier =
"keras_adversary"
258 specific_options = basf2_mva.PythonOptions()
259 specific_options.m_framework =
"keras"
260 specific_options.m_steering_file =
'mva/examples/keras/adversary_network.py'
261 specific_options.m_normalize =
True
262 specific_options.m_training_fraction = 0.9
265 Config for Adversary Networks:
266 lambda: Trade off between classifier performance
and noncorrelation between classifier output
and spectators.
267 Increase to reduce correlations(
and also classifier performance)
268 adversary_steps: How many batches the discriminator
is trained after one batch of training the classifier.
269 Less steps make the training faster but also unstable. Increase the parameter
if something isn
't working.
270 number_bins: Number of Bins which are used to quantify the spectators. 10 should be sufficient.
272 specific_options.m_config = '{"adversary_steps": 5, "learning_rate": 0.001, "lambda": 20.0, "number_bins": 10}'
273 basf2_mva.teacher(general_options, specific_options)
275 general_options.m_identifier =
"keras_baseline"
276 specific_options.m_config =
'{"adversary_steps": 1, "learning_rate": 0.001, "lambda": 0.0, "number_bins": 10}'
277 basf2_mva.teacher(general_options, specific_options)
279 general_options.m_variables = basf2_mva.vector(*variables2)
280 general_options.m_identifier =
"keras_feature_drop"
281 basf2_mva.teacher(general_options, specific_options)
pointer
Pointer to the current start of the batch.
def __init__(self, X, Y, Z)
Z
Spectatirs/Quantity to be uncorrelated to.
index_array
Index array containing indices from 0 to len.
def next_batch(self, batch_size)