18from tensorflow.keras.layers
import Dense, Input
19from tensorflow.keras.models
import Model
20from tensorflow.keras.optimizers
import Adam
21from tensorflow.keras.losses
import binary_crossentropy, sparse_categorical_crossentropy
22from tensorflow.keras.activations
import sigmoid, tanh, softmax
23from tensorflow.keras.callbacks
import EarlyStopping, Callback
24from tensorflow.keras.utils
import plot_model
27from basf2_mva_extensions.preprocessing
import fast_equal_frequency_binning
29from sklearn.metrics
import roc_auc_score
32warnings.filterwarnings(
'ignore', category=UserWarning)
37 Class to create batches for training the Adversary Network.
38 Once the steps_per_epoch argument
is available
for the fit function
in keras, this
class will become obsolete.
44 :param X: Input Features
46 :param Z: Spectators/Quantity to be uncorrelated to
63 Getting next batch of training data
72 return self.
X[batch_index], self.
Y[batch_index], self.
Z[batch_index]
75def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
77 Building 3 keras models:
78 1. Network without adversary, used for apply data.
79 2. Freezed MLP
with unfreezed Adverserial Network to train adverserial part of network.
80 3. Unfreezed MLP
with freezed adverserial to train MLP part of the network,
81 combined
with losses of the adverserial networks.
84 def adversary_loss(signal):
86 Loss for adversaries outputs
87 :param signal: If signal
or background distribution should be learned.
88 :
return: Loss function
for the discriminator part of the Network.
90 back_constant = 0 if signal
else 1
93 return (y[:, 0] - back_constant) * sparse_categorical_crossentropy(y[:, 1:], p)
97 input = Input(shape=(number_of_features,))
100 layer1 = Dense(units=number_of_features + 1, activation=tanh)(input)
101 layer2 = Dense(units=number_of_features + 1, activation=tanh)(layer1)
102 layer3 = Dense(units=number_of_features + 1, activation=tanh)(layer2)
103 output = Dense(units=1, activation=sigmoid)(layer3)
106 apply_model = Model(input, output)
107 apply_model.compile(optimizer=Adam(lr=parameters[
'learning_rate']), loss=binary_crossentropy, metrics=[
'accuracy'])
109 state =
State(apply_model, use_adv=parameters[
'lambda'] > 0
and number_of_spectators > 0)
110 state.number_bins = parameters[
'number_bins']
113 adversaries, adversary_losses_model = [], []
115 for mode
in [
'signal',
'background']:
116 for i
in range(number_of_spectators):
117 adversary1 = Dense(units=2 * parameters[
'number_bins'], activation=tanh, trainable=
False)(output)
118 adversary2 = Dense(units=2 * parameters[
'number_bins'], activation=tanh, trainable=
False)(adversary1)
119 adversaries.append(Dense(units=parameters[
'number_bins'], activation=softmax, trainable=
False)(adversary2))
121 adversary_losses_model.append(adversary_loss(mode ==
'signal'))
124 model1 = Model(input, [output] + adversaries)
125 model1.compile(optimizer=Adam(lr=parameters[
'learning_rate']),
126 loss=[binary_crossentropy] + adversary_losses_model, metrics=[
'accuracy'],
127 loss_weights=[1] + [-parameters[
'lambda']] *
len(adversary_losses_model))
131 model2 = Model(input, adversaries)
133 for layer
in model2.layers:
134 layer.trainable =
not layer.trainable
136 model2.compile(optimizer=Adam(lr=parameters[
'learning_rate']), loss=adversary_losses_model,
137 metrics=[
'accuracy'])
140 state.forward_model, state.adv_model = model1, model2
141 state.K = parameters[
'adversary_steps']
144 plot_model(model1, to_file=
'model.png', show_shapes=
True)
149def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
151 Save Validation Data for monitoring Training
160def partial_fit(state, X, S, y, w, epoch, batch):
163 For every training step of MLP. Adverserial Network will be trained K times.
166 def build_adversary_target(p_y, p_s):
168 Concat isSignal and spectator bins, because both are target information
for the adversary.
170 return [np.concatenate((p_y, i), axis=1)
for i
in np.split(p_s,
len(p_s[0]), axis=1)] * 2
174 preprocessor = fast_equal_frequency_binning()
175 preprocessor.fit(S, number_of_bins=state.number_bins)
176 S = preprocessor.apply(S) * state.number_bins
177 state.Stest = preprocessor.apply(state.Stest) * state.number_bins
179 target_array = build_adversary_target(y, S)
180 target_val_array = build_adversary_target(state.ytest, state.Stest)
184 class AUC_Callback(Callback):
186 Callback to print AUC after every epoch.
189 def on_train_begin(self, logs=None):
192 def on_epoch_end(self, epoch, logs=None):
193 val_y_pred = state.model.predict(state.Xtest).flatten()
194 val_auc = roc_auc_score(state.ytest, val_y_pred)
195 print(f
'\nTest AUC: {val_auc}\n')
196 self.val_aucs.append(val_auc)
199 class Adversary(Callback):
201 Callback to train Adversary
204 def on_batch_end(self, batch, logs=None):
205 v_X, v_y, v_S = state.batch_gen.next_batch(400 * state.K)
206 target_adversary = build_adversary_target(v_y, v_S)
207 state.adv_model.fit(v_X, target_adversary, verbose=0, batch_size=400)
209 if not state.use_adv:
210 state.model.fit(X, y, batch_size=400, epochs=1000, validation_data=(state.Xtest, state.ytest),
211 callbacks=[EarlyStopping(monitor=
'val_loss', patience=2, mode=
'min'), AUC_Callback()])
213 state.forward_model.fit(X, [y] + target_array, batch_size=400, epochs=1000,
214 callbacks=[EarlyStopping(monitor=
'val_loss', patience=2, mode=
'min'), AUC_Callback(), Adversary()],
215 validation_data=(state.Xtest, [state.ytest] + target_val_array))
219if __name__ ==
"__main__":
220 from basf2
import conditions, find_file
222 conditions.testing_payloads = [
223 'localdb/database.txt'
226 variables = [
'p',
'pt',
'pz',
'phi',
227 'daughter(0, p)',
'daughter(0, pz)',
'daughter(0, pt)',
'daughter(0, phi)',
228 'daughter(1, p)',
'daughter(1, pz)',
'daughter(1, pt)',
'daughter(1, phi)',
229 'daughter(2, p)',
'daughter(2, pz)',
'daughter(2, pt)',
'daughter(2, phi)',
230 'chiProb',
'dr',
'dz',
'dphi',
231 'daughter(0, dr)',
'daughter(1, dr)',
'daughter(0, dz)',
'daughter(1, dz)',
232 'daughter(0, dphi)',
'daughter(1, dphi)',
233 'daughter(0, chiProb)',
'daughter(1, chiProb)',
'daughter(2, chiProb)',
234 'daughter(0, kaonID)',
'daughter(0, pionID)',
'daughter(1, kaonID)',
'daughter(1, pionID)',
235 'daughterAngle(0, 1)',
'daughterAngle(0, 2)',
'daughterAngle(1, 2)',
236 'daughter(2, daughter(0, E))',
'daughter(2, daughter(1, E))',
237 'daughter(2, daughter(0, clusterTiming))',
'daughter(2, daughter(1, clusterTiming))',
238 'daughter(2, daughter(0, clusterE9E25))',
'daughter(2, daughter(1, clusterE9E25))',
239 'daughter(2, daughter(0, minC2TDist))',
'daughter(2, daughter(1, minC2TDist))',
242 variables2 = [
'p',
'pt',
'pz',
'phi',
243 'chiProb',
'dr',
'dz',
'dphi',
244 'daughter(2, chiProb)',
245 'daughter(0, kaonID)',
'daughter(0, pionID)',
'daughter(1, kaonID)',
'daughter(1, pionID)',
246 'daughter(2, daughter(0, E))',
'daughter(2, daughter(1, E))',
247 'daughter(2, daughter(0, clusterTiming))',
'daughter(2, daughter(1, clusterTiming))',
248 'daughter(2, daughter(0, clusterE9E25))',
'daughter(2, daughter(1, clusterE9E25))',
249 'daughter(2, daughter(0, minC2TDist))',
'daughter(2, daughter(1, minC2TDist))']
251 train_file = find_file(
"mva/train_D0toKpipi.root",
"examples")
252 training_data = basf2_mva.vector(train_file)
254 general_options = basf2_mva.GeneralOptions()
255 general_options.m_datafiles = training_data
256 general_options.m_treename =
"tree"
257 general_options.m_variables = basf2_mva.vector(*variables)
258 general_options.m_spectators = basf2_mva.vector(
'daughterInvM(0, 1)',
'daughterInvM(0, 2)')
259 general_options.m_target_variable =
"isSignal"
260 general_options.m_identifier =
"keras_adversary"
262 specific_options = basf2_mva.PythonOptions()
263 specific_options.m_framework =
"keras"
264 specific_options.m_steering_file =
'mva/examples/keras/adversary_network.py'
265 specific_options.m_normalize =
True
266 specific_options.m_training_fraction = 0.9
269 Config for Adversary Networks:
270 lambda: Trade off between classifier performance
and noncorrelation between classifier output
and spectators.
271 Increase to reduce correlations(
and also classifier performance)
272 adversary_steps: How many batches the discriminator
is trained after one batch of training the classifier.
273 Less steps make the training faster but also unstable. Increase the parameter
if something isn
't working.
274 number_bins: Number of Bins which are used to quantify the spectators. 10 should be sufficient.
276 specific_options.m_config = '{"adversary_steps": 5, "learning_rate": 0.001, "lambda": 20.0, "number_bins": 10}'
277 basf2_mva.teacher(general_options, specific_options)
279 general_options.m_identifier =
"keras_baseline"
280 specific_options.m_config =
'{"adversary_steps": 1, "learning_rate": 0.001, "lambda": 0.0, "number_bins": 10}'
281 basf2_mva.teacher(general_options, specific_options)
283 general_options.m_variables = basf2_mva.vector(*variables2)
284 general_options.m_identifier =
"keras_feature_drop"
285 basf2_mva.teacher(general_options, specific_options)
pointer
Pointer to the current start of the batch.
def __init__(self, X, Y, Z)
Z
Spectatirs/Quantity to be uncorrelated to.
index_array
Index array containing indices from 0 to len.
def next_batch(self, batch_size)