17from basf2_mva_extensions.preprocessing 
import fast_equal_frequency_binning
 
   20from keras.layers 
import Dense, Input
 
   21from keras.models 
import Model
 
   22from keras.optimizers 
import Adam
 
   23from keras.losses 
import binary_crossentropy, sparse_categorical_crossentropy
 
   24from keras.activations 
import sigmoid, tanh, softmax
 
   25from keras.callbacks 
import EarlyStopping, Callback
 
   27from sklearn.metrics 
import roc_auc_score
 
   32warnings.filterwarnings(
'ignore', category=UserWarning)
 
   35def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
 
   37    Building 3 keras models: 
   38    1. Network without adversary, used for apply data. 
   39    2. Frozen MLP with unfrozen Adverserial Network to train adverserial part of network. 
   40    3. Unfrozen MLP with frozen adverserial to train MLP part of the network, 
   41       combined with losses of the adverserial networks. 
   44    def adversary_loss(signal):
 
   46        Loss for adversaries outputs 
   47        :param signal: If signal or background distribution should be learned. 
   48        :return: Loss function for the discriminator part of the Network. 
   50        back_constant = 0 
if signal == 
'signal' else 1
 
   51        sign = 1 
if signal == 
'signal' else -1
 
   54            return sign * (y[:, 0] - back_constant) * sparse_categorical_crossentropy(y[:, 1:], p)
 
   58    input_layer = Input(shape=(number_of_features,))
 
   61    layer1 = Dense(units=number_of_features + 1, activation=tanh)(input_layer)
 
   62    layer2 = Dense(units=number_of_features + 1, activation=tanh)(layer1)
 
   63    layer3 = Dense(units=number_of_features + 1, activation=tanh)(layer2)
 
   64    output = Dense(units=1, activation=sigmoid)(layer3)
 
   67    apply_model = Model(input_layer, output)
 
   68    apply_model.compile(optimizer=Adam(learning_rate=parameters[
'learning_rate']), loss=binary_crossentropy, metrics=[
'accuracy'])
 
   70    state = State(apply_model, use_adv=parameters[
'lambda'] > 0 
and number_of_spectators > 0)
 
   71    state.number_bins = parameters[
'number_bins']
 
   74    adversaries, adversary_losses_model = [], []
 
   76        for mode 
in [
'signal', 
'background']:
 
   77            for i 
in range(number_of_spectators):
 
   78                adversary1 = Dense(units=2 * parameters[
'number_bins'], activation=tanh, trainable=
False)(output)
 
   79                adversary2 = Dense(units=2 * parameters[
'number_bins'], activation=tanh, trainable=
False)(adversary1)
 
   81                adversaries.append(Dense(units=parameters[
'number_bins'], activation=softmax, trainable=
False)(adversary2))
 
   82                adversary_losses_model.append(adversary_loss(mode == 
'signal'))
 
   85        forward_model = Model(input_layer, [output] + adversaries)
 
   86        forward_model.compile(optimizer=Adam(learning_rate=parameters[
'learning_rate']),
 
   87                              loss=[binary_crossentropy] + adversary_losses_model, metrics=[
'accuracy'] * (len(adversaries) + 1),
 
   88                              loss_weights=[1.0] + [float(-parameters[
'lambda'])] * len(adversary_losses_model))
 
   89        forward_model.summary()
 
   92        adv_model = Model(input_layer, adversaries)
 
   94        for layer 
in adv_model.layers:
 
   95            layer.train_in_adversarial_mode = 
not layer.trainable
 
   97        adv_model.compile(optimizer=Adam(learning_rate=parameters[
'learning_rate']), loss=adversary_losses_model,
 
   98                          metrics=[
'accuracy'] * len(adversaries))
 
  101        state.forward_model = forward_model
 
  102        state.adv_model = adv_model
 
  103        state.K = parameters[
'adversary_steps']
 
  107def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
 
  109    Save Validation Data for monitoring Training 
  118def partial_fit(state, X, S, y, w, epoch, batch):
 
  121    For every training step of MLP. Adverserial Network will be trained K times. 
  124    def build_adversary_target(p_y, p_s):
 
  126        Concat isSignal and spectator bins, because both are target information for the adversary. 
  128        return [np.concatenate((p_y, i), axis=1) 
for i 
in np.split(p_s, len(p_s[0]), axis=1)] * 2
 
  132        preprocessor = fast_equal_frequency_binning()
 
  133        preprocessor.fit(S, number_of_bins=state.number_bins)
 
  134        S = preprocessor.apply(S) * state.number_bins
 
  135        state.Stest = preprocessor.apply(state.Stest) * state.number_bins
 
  137        target_array = build_adversary_target(y, S)
 
  138        target_val_array = build_adversary_target(state.ytest, state.Stest)
 
  140    class AUC_Callback(Callback):
 
  142        Callback to print AUC after every epoch. 
  145        def on_train_begin(self, logs=None):
 
  148        def on_epoch_end(self, epoch, logs=None):
 
  149            val_y_pred = state.model.predict(state.Xtest).flatten()
 
  150            val_auc = roc_auc_score(state.ytest, val_y_pred)
 
  151            print(f
'\nTest AUC: {val_auc}\n')
 
  152            self.val_aucs.append(val_auc)
 
  155    if not state.use_adv:
 
  156        state.model.fit(X, y, batch_size=400, epochs=1000, validation_data=(state.Xtest, state.ytest),
 
  157                        callbacks=[EarlyStopping(monitor=
'val_loss', patience=20, mode=
'min'), AUC_Callback()])
 
  159        class Adversary(Callback):
 
  161            Callback to train Adversary 
  164            def on_batch_end(self, batch, logs=None):
 
  166                for layer 
in state.adv_model.layers:
 
  167                    layer.trainable = layer.train_in_adversarial_mode
 
  169                state.adv_model.fit(X, target_array, verbose=0, batch_size=400, steps_per_epoch=state.K, epochs=1)
 
  172                for layer 
in state.adv_model.layers:
 
  173                    layer.trainable = 
not layer.train_in_adversarial_mode
 
  175        state.forward_model.fit(X, [y] + target_array, batch_size=400, epochs=1000,
 
  181            AUC_Callback(), Adversary()],
 
  182            validation_data=(state.Xtest, [state.ytest] + target_val_array))
 
  186if __name__ == 
"__main__":
 
  187    from basf2 
import conditions, find_file
 
  189    conditions.testing_payloads = [
 
  190        'localdb/database.txt' 
  193    variables = [
'p', 
'pt', 
'pz', 
'phi',
 
  194                 'daughter(0, p)', 
'daughter(0, pz)', 
'daughter(0, pt)', 
'daughter(0, phi)',
 
  195                 'daughter(1, p)', 
'daughter(1, pz)', 
'daughter(1, pt)', 
'daughter(1, phi)',
 
  196                 'daughter(2, p)', 
'daughter(2, pz)', 
'daughter(2, pt)', 
'daughter(2, phi)',
 
  197                 'chiProb', 
'dr', 
'dz', 
'dphi',
 
  198                 'daughter(0, dr)', 
'daughter(1, dr)', 
'daughter(0, dz)', 
'daughter(1, dz)',
 
  199                 'daughter(0, dphi)', 
'daughter(1, dphi)',
 
  200                 'daughter(0, chiProb)', 
'daughter(1, chiProb)', 
'daughter(2, chiProb)',
 
  201                 'daughter(0, kaonID)', 
'daughter(0, pionID)', 
'daughter(1, kaonID)', 
'daughter(1, pionID)',
 
  202                 'daughterAngle(0, 1)', 
'daughterAngle(0, 2)', 
'daughterAngle(1, 2)',
 
  203                 'daughter(2, daughter(0, E))', 
'daughter(2, daughter(1, E))',
 
  204                 'daughter(2, daughter(0, clusterTiming))', 
'daughter(2, daughter(1, clusterTiming))',
 
  205                 'daughter(2, daughter(0, clusterE9E25))', 
'daughter(2, daughter(1, clusterE9E25))',
 
  206                 'daughter(2, daughter(0, minC2TDist))', 
'daughter(2, daughter(1, minC2TDist))',
 
  209    variables2 = [
'p', 
'pt', 
'pz', 
'phi',
 
  210                  'chiProb', 
'dr', 
'dz', 
'dphi',
 
  211                  'daughter(2, chiProb)',
 
  212                  'daughter(0, kaonID)', 
'daughter(0, pionID)', 
'daughter(1, kaonID)', 
'daughter(1, pionID)',
 
  213                  'daughter(2, daughter(0, E))', 
'daughter(2, daughter(1, E))',
 
  214                  'daughter(2, daughter(0, clusterTiming))', 
'daughter(2, daughter(1, clusterTiming))',
 
  215                  'daughter(2, daughter(0, clusterE9E25))', 
'daughter(2, daughter(1, clusterE9E25))',
 
  216                  'daughter(2, daughter(0, minC2TDist))', 
'daughter(2, daughter(1, minC2TDist))']
 
  218    train_file = find_file(
"mva/train_D0toKpipi.root", 
"examples")
 
  219    test_file = find_file(
"mva/test_D0toKpipi.root", 
"examples")
 
  221    training_data = basf2_mva.vector(train_file)
 
  222    testing_data = basf2_mva.vector(test_file)
 
  224    general_options = basf2_mva.GeneralOptions()
 
  225    general_options.m_datafiles = training_data
 
  226    general_options.m_treename = 
"tree" 
  227    general_options.m_variables = basf2_mva.vector(*variables)
 
  228    general_options.m_spectators = basf2_mva.vector(
'daughterInvM(0, 1)', 
'daughterInvM(0, 2)')
 
  229    general_options.m_target_variable = 
"isSignal" 
  230    general_options.m_identifier = 
"keras_adversary" 
  232    specific_options = basf2_mva.PythonOptions()
 
  233    specific_options.m_framework = 
"keras" 
  234    specific_options.m_steering_file = 
'mva/examples/keras/adversary_network.py' 
  235    specific_options.m_normalize = 
True 
  236    specific_options.m_training_fraction = 0.8
 
  241    Config for Adversary Networks: 
  242    lambda: Trade off between classifier performance and noncorrelation between classifier output and spectators. 
  243            Increase to reduce correlations(and also classifier performance) 
  244    adversary_steps: How many batches the discriminator is trained after one batch of training the classifier. 
  245            Less steps make the training faster but also unstable. Increase the parameter if something isn't working. 
  246    number_bins: Number of Bins which are used to quantify the spectators. 10 should be sufficient. 
  249    specific_options.m_config = 
'{"adversary_steps": 5, "learning_rate": 0.01, "lambda": 5.0, "number_bins": 10}' 
  250    basf2_mva.teacher(general_options, specific_options)
 
  253    inference_start = time.time()
 
  254    p, t = method.apply_expert(testing_data, general_options.m_treename)
 
  255    inference_time = time.time() - inference_start
 
  257    print(
"Adversary:", inference_time, auc)
 
  259    general_options.m_identifier = 
"keras_baseline" 
  260    specific_options.m_config = 
'{"adversary_steps": 1, "learning_rate": 0.001, "lambda": 0.0, "number_bins": 10}' 
  261    basf2_mva.teacher(general_options, specific_options)
 
  264    inference_start = time.time()
 
  265    p, t = method.apply_expert(testing_data, general_options.m_treename)
 
  266    inference_time = time.time() - inference_start
 
  268    print(
"Baseline:", inference_time, auc)
 
  271    general_options.m_variables = basf2_mva.vector(*variables2)
 
  272    general_options.m_identifier = 
"keras_feature_drop" 
  273    basf2_mva.teacher(general_options, specific_options)
 
  276    inference_start = time.time()
 
  277    p, t = method.apply_expert(testing_data, general_options.m_treename)
 
  278    inference_time = time.time() - inference_start
 
  280    print(
"Drop features:", inference_time, auc)
 
calculate_auc_efficiency_vs_background_retention(p, t, w=None)