17from basf2_mva_extensions.preprocessing
import fast_equal_frequency_binning
20from keras.layers
import Dense, Input
21from keras.models
import Model
22from keras.optimizers
import Adam
23from keras.losses
import binary_crossentropy, sparse_categorical_crossentropy
24from keras.activations
import sigmoid, tanh, softmax
25from keras.callbacks
import EarlyStopping, Callback
27from sklearn.metrics
import roc_auc_score
32warnings.filterwarnings(
'ignore', category=UserWarning)
35def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
37 Building 3 keras models:
38 1. Network without adversary, used for apply data.
39 2. Frozen MLP with unfrozen Adverserial Network to train adverserial part of network.
40 3. Unfrozen MLP with frozen adverserial to train MLP part of the network,
41 combined with losses of the adverserial networks.
44 def adversary_loss(signal):
46 Loss for adversaries outputs
47 :param signal: If signal or background distribution should be learned.
48 :return: Loss function for the discriminator part of the Network.
50 back_constant = 0
if signal ==
'signal' else 1
51 sign = 1
if signal ==
'signal' else -1
54 return sign * (y[:, 0] - back_constant) * sparse_categorical_crossentropy(y[:, 1:], p)
58 input_layer = Input(shape=(number_of_features,))
61 layer1 = Dense(units=number_of_features + 1, activation=tanh)(input_layer)
62 layer2 = Dense(units=number_of_features + 1, activation=tanh)(layer1)
63 layer3 = Dense(units=number_of_features + 1, activation=tanh)(layer2)
64 output = Dense(units=1, activation=sigmoid)(layer3)
67 apply_model = Model(input_layer, output)
68 apply_model.compile(optimizer=Adam(learning_rate=parameters[
'learning_rate']), loss=binary_crossentropy, metrics=[
'accuracy'])
70 state = State(apply_model, use_adv=parameters[
'lambda'] > 0
and number_of_spectators > 0)
71 state.number_bins = parameters[
'number_bins']
74 adversaries, adversary_losses_model = [], []
76 for mode
in [
'signal',
'background']:
77 for i
in range(number_of_spectators):
78 adversary1 = Dense(units=2 * parameters[
'number_bins'], activation=tanh, trainable=
False)(output)
79 adversary2 = Dense(units=2 * parameters[
'number_bins'], activation=tanh, trainable=
False)(adversary1)
81 adversaries.append(Dense(units=parameters[
'number_bins'], activation=softmax, trainable=
False)(adversary2))
82 adversary_losses_model.append(adversary_loss(mode ==
'signal'))
85 forward_model = Model(input_layer, [output] + adversaries)
86 forward_model.compile(optimizer=Adam(learning_rate=parameters[
'learning_rate']),
87 loss=[binary_crossentropy] + adversary_losses_model, metrics=[
'accuracy'] * (len(adversaries) + 1),
88 loss_weights=[1.0] + [float(-parameters[
'lambda'])] * len(adversary_losses_model))
89 forward_model.summary()
92 adv_model = Model(input_layer, adversaries)
94 for layer
in adv_model.layers:
95 layer.train_in_adversarial_mode =
not layer.trainable
97 adv_model.compile(optimizer=Adam(learning_rate=parameters[
'learning_rate']), loss=adversary_losses_model,
98 metrics=[
'accuracy'] * len(adversaries))
101 state.forward_model = forward_model
102 state.adv_model = adv_model
103 state.K = parameters[
'adversary_steps']
107def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
109 Save Validation Data for monitoring Training
118def partial_fit(state, X, S, y, w, epoch, batch):
121 For every training step of MLP. Adverserial Network will be trained K times.
124 def build_adversary_target(p_y, p_s):
126 Concat isSignal and spectator bins, because both are target information for the adversary.
128 return [np.concatenate((p_y, i), axis=1)
for i
in np.split(p_s, len(p_s[0]), axis=1)] * 2
132 preprocessor = fast_equal_frequency_binning()
133 preprocessor.fit(S, number_of_bins=state.number_bins)
134 S = preprocessor.apply(S) * state.number_bins
135 state.Stest = preprocessor.apply(state.Stest) * state.number_bins
137 target_array = build_adversary_target(y, S)
138 target_val_array = build_adversary_target(state.ytest, state.Stest)
140 class AUC_Callback(Callback):
142 Callback to print AUC after every epoch.
145 def on_train_begin(self, logs=None):
148 def on_epoch_end(self, epoch, logs=None):
149 val_y_pred = state.model.predict(state.Xtest).flatten()
150 val_auc = roc_auc_score(state.ytest, val_y_pred)
151 print(f
'\nTest AUC: {val_auc}\n')
152 self.val_aucs.append(val_auc)
155 if not state.use_adv:
156 state.model.fit(X, y, batch_size=400, epochs=1000, validation_data=(state.Xtest, state.ytest),
157 callbacks=[EarlyStopping(monitor=
'val_loss', patience=20, mode=
'min'), AUC_Callback()])
159 class Adversary(Callback):
161 Callback to train Adversary
164 def on_batch_end(self, batch, logs=None):
166 for layer
in state.adv_model.layers:
167 layer.trainable = layer.train_in_adversarial_mode
169 state.adv_model.fit(X, target_array, verbose=0, batch_size=400, steps_per_epoch=state.K, epochs=1)
172 for layer
in state.adv_model.layers:
173 layer.trainable =
not layer.train_in_adversarial_mode
175 state.forward_model.fit(X, [y] + target_array, batch_size=400, epochs=1000,
181 AUC_Callback(), Adversary()],
182 validation_data=(state.Xtest, [state.ytest] + target_val_array))
186if __name__ ==
"__main__":
187 from basf2
import conditions, find_file
189 conditions.testing_payloads = [
190 'localdb/database.txt'
193 variables = [
'p',
'pt',
'pz',
'phi',
194 'daughter(0, p)',
'daughter(0, pz)',
'daughter(0, pt)',
'daughter(0, phi)',
195 'daughter(1, p)',
'daughter(1, pz)',
'daughter(1, pt)',
'daughter(1, phi)',
196 'daughter(2, p)',
'daughter(2, pz)',
'daughter(2, pt)',
'daughter(2, phi)',
197 'chiProb',
'dr',
'dz',
'dphi',
198 'daughter(0, dr)',
'daughter(1, dr)',
'daughter(0, dz)',
'daughter(1, dz)',
199 'daughter(0, dphi)',
'daughter(1, dphi)',
200 'daughter(0, chiProb)',
'daughter(1, chiProb)',
'daughter(2, chiProb)',
201 'daughter(0, kaonID)',
'daughter(0, pionID)',
'daughter(1, kaonID)',
'daughter(1, pionID)',
202 'daughterAngle(0, 1)',
'daughterAngle(0, 2)',
'daughterAngle(1, 2)',
203 'daughter(2, daughter(0, E))',
'daughter(2, daughter(1, E))',
204 'daughter(2, daughter(0, clusterTiming))',
'daughter(2, daughter(1, clusterTiming))',
205 'daughter(2, daughter(0, clusterE9E25))',
'daughter(2, daughter(1, clusterE9E25))',
206 'daughter(2, daughter(0, minC2TDist))',
'daughter(2, daughter(1, minC2TDist))',
209 variables2 = [
'p',
'pt',
'pz',
'phi',
210 'chiProb',
'dr',
'dz',
'dphi',
211 'daughter(2, chiProb)',
212 'daughter(0, kaonID)',
'daughter(0, pionID)',
'daughter(1, kaonID)',
'daughter(1, pionID)',
213 'daughter(2, daughter(0, E))',
'daughter(2, daughter(1, E))',
214 'daughter(2, daughter(0, clusterTiming))',
'daughter(2, daughter(1, clusterTiming))',
215 'daughter(2, daughter(0, clusterE9E25))',
'daughter(2, daughter(1, clusterE9E25))',
216 'daughter(2, daughter(0, minC2TDist))',
'daughter(2, daughter(1, minC2TDist))']
218 train_file = find_file(
"mva/train_D0toKpipi.root",
"examples")
219 test_file = find_file(
"mva/test_D0toKpipi.root",
"examples")
221 training_data = basf2_mva.vector(train_file)
222 testing_data = basf2_mva.vector(test_file)
224 general_options = basf2_mva.GeneralOptions()
225 general_options.m_datafiles = training_data
226 general_options.m_treename =
"tree"
227 general_options.m_variables = basf2_mva.vector(*variables)
228 general_options.m_spectators = basf2_mva.vector(
'daughterInvM(0, 1)',
'daughterInvM(0, 2)')
229 general_options.m_target_variable =
"isSignal"
230 general_options.m_identifier =
"keras_adversary"
232 specific_options = basf2_mva.PythonOptions()
233 specific_options.m_framework =
"keras"
234 specific_options.m_steering_file =
'mva/examples/keras/adversary_network.py'
235 specific_options.m_normalize =
True
236 specific_options.m_training_fraction = 0.8
241 Config for Adversary Networks:
242 lambda: Trade off between classifier performance and noncorrelation between classifier output and spectators.
243 Increase to reduce correlations(and also classifier performance)
244 adversary_steps: How many batches the discriminator is trained after one batch of training the classifier.
245 Less steps make the training faster but also unstable. Increase the parameter if something isn't working.
246 number_bins: Number of Bins which are used to quantify the spectators. 10 should be sufficient.
249 specific_options.m_config =
'{"adversary_steps": 5, "learning_rate": 0.01, "lambda": 5.0, "number_bins": 10}'
250 basf2_mva.teacher(general_options, specific_options)
253 inference_start = time.time()
254 p, t = method.apply_expert(testing_data, general_options.m_treename)
255 inference_time = time.time() - inference_start
257 print(
"Adversary:", inference_time, auc)
259 general_options.m_identifier =
"keras_baseline"
260 specific_options.m_config =
'{"adversary_steps": 1, "learning_rate": 0.001, "lambda": 0.0, "number_bins": 10}'
261 basf2_mva.teacher(general_options, specific_options)
264 inference_start = time.time()
265 p, t = method.apply_expert(testing_data, general_options.m_treename)
266 inference_time = time.time() - inference_start
268 print(
"Baseline:", inference_time, auc)
271 general_options.m_variables = basf2_mva.vector(*variables2)
272 general_options.m_identifier =
"keras_feature_drop"
273 basf2_mva.teacher(general_options, specific_options)
276 inference_start = time.time()
277 p, t = method.apply_expert(testing_data, general_options.m_treename)
278 inference_time = time.time() - inference_start
280 print(
"Drop features:", inference_time, auc)
calculate_auc_efficiency_vs_background_retention(p, t, w=None)