Belle II Software  release-05-01-25
relational_network.py
1 #!/usr/bin/env python3
2 # -*- coding: utf-8 -*-
3 
4 # Dennis Weyland 2017
5 
6 # This example serves as a basic example of implementing Relational networks into basf2 with tensorflow.
7 # As a toy example it will try to tell if 2 out of multiple lines are hitting each other in three dimensional space.
8 # Relevant Paper: https://arxiv.org/abs/1706.01427
9 # If you want to try out relational networks to your problem, feel free to import the two classes in your code.
10 
12 import h5py
13 
14 import tensorflow as tf
15 import tensorflow.contrib.keras as keras
16 
17 from keras.layers import Input, Dense, Concatenate, Flatten, Dropout, GlobalAveragePooling1D
18 from keras.layers.merge import Average
19 from keras.layers.core import Reshape
20 from keras.layers import activations
21 from keras.models import Model, load_model
22 from keras.optimizers import adam
23 from keras.losses import binary_crossentropy, mean_squared_error
24 from keras.activations import sigmoid, tanh
25 from keras.engine.topology import Layer
26 from keras import backend as K
27 from keras.callbacks import Callback, EarlyStopping
28 import numpy as np
29 
30 from basf2_mva_extensions.keras_relational import Relations
31 
32 
33 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
34  """
35  Building keras model
36  """
37 
38  input = Input(shape=(number_of_features,))
39  net = input
40 
41  if parameters['use_relations']:
42  net = Reshape((number_of_features // 6, 6))(net)
43  net = Relations(number_features=parameters['number_features'])(net)
44  # average over all permutations
45  net = GlobalAveragePooling1D()(net)
46  else:
47  for i in range(6):
48  net = Dense(units=2 * number_of_features, activation=tanh)(net)
49 
50  output = Dense(units=1, activation=sigmoid)(net)
51 
52  state = State(Model(input, output), custom_objects={'Relations': Relations})
53 
54  state.model.compile(optimizer=adam(lr=0.001), loss=binary_crossentropy, metrics=['accuracy'])
55  state.model.summary()
56 
57  return state
58 
59 
60 def begin_fit(state, Xtest, Stest, ytest, wtest):
61  """
62  Returns just the state object
63  """
64  state.Xtest = Xtest
65  state.ytest = ytest
66 
67  return state
68 
69 
70 def partial_fit(state, X, S, y, w, epoch):
71  """
72  Do the fit
73  """
74  class TestCallback(Callback):
75  """
76  Print small summary.
77  Class has to inherit from abstract Callback class
78  """
79 
80  def on_epoch_end(self, epoch, logs={}):
81  """
82  Print summary at the end of epoch.
83  For other possibilities look at the abstract Callback class.
84  """
85  loss, acc = state.model.evaluate(state.Xtest, state.ytest, verbose=0, batch_size=1000)
86  loss2, acc2 = state.model.evaluate(X[:10000], y[:10000], verbose=0, batch_size=1000)
87  print('\nTesting loss: {}, acc: {}'.format(loss, acc))
88  print('Training loss: {}, acc: {}'.format(loss2, acc2))
89 
90  state.model.fit(X, y, batch_size=100, epochs=100, validation_data=(state.Xtest, state.ytest),
91  callbacks=[TestCallback(), EarlyStopping(monitor='val_loss')])
92  return False
93 
94 
95 if __name__ == "__main__":
96  import os
97  import pandas
98  from root_pandas import to_root
99  import tempfile
100  import json
101  import numpy as np
102 
103  import basf2
104  import basf2_mva
105  import basf2_mva_util
106  from basf2 import conditions
107  # NOTE: do not use testing payloads in production! Any results obtained like this WILL NOT BE PUBLISHED
108  conditions.testing_payloads = [
109  'localdb/database.txt'
110  ]
111 
112  # ##############Building Data samples ###########################
113  # This is a dataset for testing relational nets.
114  # It consists of number_total_lines lines in 3 dimensional space.
115  # Each line has 6 variables.
116  # In apprx. half of the cases, two lines are hitting each other.
117  # This is considered a signal event.
118  # Training results differs from the number of total lines.
119 
120  variables = []
121  # try using 10 and 20 lines and see what happens
122  number_total_lines = 5
123  # Names for the training data set
124  for i in range(number_total_lines):
125  variables += ['px_' + str(i), 'py_' + str(i), 'pz_' + str(i), 'dx_' + str(i), 'dy_' + str(i),
126  'dz_' + str(i)]
127  # Number of events in training and test root file.
128  number_of_events = 1000000
129 
130  def build_signal_event():
131  """Building two lines which are hitting each other"""
132  p_vec1, p_vec2 = np.random.normal(size=3), np.random.normal(size=3)
133  v_cross = np.random.normal(size=3)
134  epsilon1, epsilon2 = (np.random.rand() * 2 - 1) / 10, (np.random.rand() * 2 - 1) / 10
135  v_vec1 = v_cross + (p_vec1 * epsilon1)
136  v_vec2 = v_cross + (p_vec2 * epsilon2)
137  return np.concatenate([p_vec1, v_vec1]), np.concatenate([p_vec2, v_vec2])
138 
139  # This path will delete itself with all data in it after end of program.
140  with tempfile.TemporaryDirectory() as path:
141  for filename in ['train.root', 'test.root']:
142  print('Building ' + filename)
143  # Use random numbers to build all training and spectator variables.
144  data = np.random.normal(size=[number_of_events, number_total_lines * 6])
145  target = np.zeros([number_of_events], dtype=bool)
146 
147  # Overwrite for half of the variables some lines so that they are hitting each other.
148  # Write them also at the end for the spectators.
149  for index, sample in enumerate(data):
150  if np.random.rand() > 0.5:
151  target[index] = True
152  i1, i2 = int(np.random.rand() * number_total_lines), int(np.random.rand() * (number_total_lines - 1))
153  i2 = (i1 + i2) % number_total_lines
154  track1, track2 = build_signal_event()
155  data[index, i1 * 6:(i1 + 1) * 6] = track1
156  data[index, i2 * 6:(i2 + 1) * 6] = track2
157 
158  # Saving all variables in root files
159  dic = {}
160  for i, name in enumerate(variables):
161  dic.update({name: data[:, i]})
162  dic.update({'isSignal': target})
163 
164  df = pandas.DataFrame(dic, dtype=np.float32)
165  to_root(df, os.path.join(path, filename), key='variables')
166 
167  # ##########################Do Training#################################
168  # Do a comparison of different Nets for this task.
169 
170  general_options = basf2_mva.GeneralOptions()
171  general_options.m_datafiles = basf2_mva.vector(os.path.join(path, 'train.root'))
172  general_options.m_treename = "variables"
173  general_options.m_variables = basf2_mva.vector(*variables)
174  general_options.m_target_variable = "isSignal"
175 
176  specific_options = basf2_mva.PythonOptions()
177  specific_options.m_framework = "contrib_keras"
178  specific_options.m_steering_file = 'mva/examples/keras/relational_network.py'
179  specific_options.m_training_fraction = 0.999
180 
181  # Train relational Net
182  print('Train relational net ')
183  general_options.m_identifier = os.path.join(path, 'relation.xml')
184  specific_options.m_config = json.dumps({'use_relations': True,
185  'number_features': 3})
186  basf2_mva.teacher(general_options, specific_options)
187 
188  # Train normal feed forward Net:
189  print('Train feed forward net')
190  general_options.m_identifier = os.path.join(path, 'feed_forward.xml')
191  specific_options.m_config = json.dumps({'use_relations': False})
192  basf2_mva.teacher(general_options, specific_options)
193 
194  # ########################Compare Results####################################
195  method1 = basf2_mva_util.Method(os.path.join(path, 'relation.xml'))
196  method2 = basf2_mva_util.Method(os.path.join(path, 'feed_forward.xml'))
197 
198  test_data = basf2_mva.vector(os.path.join(path, 'test.root'))
199 
200  print('Apply relational net')
201  p1, t1 = method1.apply_expert(test_data, general_options.m_treename)
202  print('Apply feed forward net')
203  p2, t2 = method2.apply_expert(test_data, general_options.m_treename)
204 
205  print('Relational Net AUC: ', basf2_mva_util.calculate_roc_auc(p1, t1))
206  print('Feed Forward Net AUC: ', basf2_mva_util.calculate_roc_auc(p2, t2))
basf2_mva_util.calculate_roc_auc
def calculate_roc_auc(p, t)
Definition: basf2_mva_util.py:39
basf2_mva_util.Method
Definition: basf2_mva_util.py:81
basf2_mva_python_interface.contrib_keras
Definition: contrib_keras.py:1