Belle II Software  release-06-02-00
relational_network.py
1 #!/usr/bin/env python3
2 
3 
10 
11 # This example serves as a basic example of implementing Relational networks into basf2 with tensorflow.
12 # As a toy example it will try to tell if 2 out of multiple lines are hitting each other in three dimensional space.
13 # Relevant Paper: https://arxiv.org/abs/1706.01427
14 # If you want to try out relational networks to your problem, feel free to import the two classes in your code.
15 
17 
18 
19 from keras.layers import Dense, GlobalAveragePooling1D, Input
20 from keras.layers.core import Reshape
21 from keras.models import Model
22 from keras.optimizers import Adam
23 from keras.losses import binary_crossentropy
24 from keras.activations import sigmoid, tanh
25 from keras.callbacks import Callback, EarlyStopping
26 import numpy as np
27 
28 from basf2_mva_extensions.keras_relational import Relations
29 
30 
31 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
32  """
33  Building keras model
34  """
35 
36  input = Input(shape=(number_of_features,))
37  net = input
38 
39  if parameters['use_relations']:
40  net = Reshape((number_of_features // 6, 6))(net)
41  net = Relations(number_features=parameters['number_features'])(net)
42  # average over all permutations
43  net = GlobalAveragePooling1D()(net)
44  else:
45  for i in range(6):
46  net = Dense(units=2 * number_of_features, activation=tanh)(net)
47 
48  output = Dense(units=1, activation=sigmoid)(net)
49 
50  state = State(Model(input, output), custom_objects={'Relations': Relations})
51 
52  state.model.compile(optimizer=Adam(lr=0.001), loss=binary_crossentropy, metrics=['accuracy'])
53  state.model.summary()
54 
55  return state
56 
57 
58 def begin_fit(state, Xtest, Stest, ytest, wtest):
59  """
60  Returns just the state object
61  """
62  state.Xtest = Xtest
63  state.ytest = ytest
64 
65  return state
66 
67 
68 def partial_fit(state, X, S, y, w, epoch):
69  """
70  Do the fit
71  """
72  class TestCallback(Callback):
73  """
74  Print small summary.
75  Class has to inherit from abstract Callback class
76  """
77 
78  def on_epoch_end(self, epoch, logs=None):
79  """
80  Print summary at the end of epoch.
81  For other possibilities look at the abstract Callback class.
82  """
83  loss, acc = state.model.evaluate(state.Xtest, state.ytest, verbose=0, batch_size=1000)
84  loss2, acc2 = state.model.evaluate(X[:10000], y[:10000], verbose=0, batch_size=1000)
85  print('\nTesting loss: {}, acc: {}'.format(loss, acc))
86  print('Training loss: {}, acc: {}'.format(loss2, acc2))
87 
88  state.model.fit(X, y, batch_size=100, epochs=100, validation_data=(state.Xtest, state.ytest),
89  callbacks=[TestCallback(), EarlyStopping(monitor='val_loss')])
90  return False
91 
92 
93 if __name__ == "__main__":
94  import os
95  import pandas
96  from root_pandas import to_root
97  import tempfile
98  import json
99 
100  import basf2_mva
101  import basf2_mva_util
102  from basf2 import conditions
103  # NOTE: do not use testing payloads in production! Any results obtained like this WILL NOT BE PUBLISHED
104  conditions.testing_payloads = [
105  'localdb/database.txt'
106  ]
107 
108  # ##############Building Data samples ###########################
109  # This is a dataset for testing relational nets.
110  # It consists of number_total_lines lines in 3 dimensional space.
111  # Each line has 6 variables.
112  # In apprx. half of the cases, two lines are hitting each other.
113  # This is considered a signal event.
114  # Training results differs from the number of total lines.
115 
116  variables = []
117  # try using 10 and 20 lines and see what happens
118  number_total_lines = 5
119  # Names for the training data set
120  for i in range(number_total_lines):
121  variables += ['px_' + str(i), 'py_' + str(i), 'pz_' + str(i), 'dx_' + str(i), 'dy_' + str(i),
122  'dz_' + str(i)]
123  # Number of events in training and test root file.
124  number_of_events = 1000000
125 
126  def build_signal_event():
127  """Building two lines which are hitting each other"""
128  p_vec1, p_vec2 = np.random.normal(size=3), np.random.normal(size=3)
129  v_cross = np.random.normal(size=3)
130  epsilon1, epsilon2 = (np.random.rand() * 2 - 1) / 10, (np.random.rand() * 2 - 1) / 10
131  v_vec1 = v_cross + (p_vec1 * epsilon1)
132  v_vec2 = v_cross + (p_vec2 * epsilon2)
133  return np.concatenate([p_vec1, v_vec1]), np.concatenate([p_vec2, v_vec2])
134 
135  # This path will delete itself with all data in it after end of program.
136  with tempfile.TemporaryDirectory() as path:
137  for filename in ['train.root', 'test.root']:
138  print('Building ' + filename)
139  # Use random numbers to build all training and spectator variables.
140  data = np.random.normal(size=[number_of_events, number_total_lines * 6])
141  target = np.zeros([number_of_events], dtype=bool)
142 
143  # Overwrite for half of the variables some lines so that they are hitting each other.
144  # Write them also at the end for the spectators.
145  for index, sample in enumerate(data):
146  if np.random.rand() > 0.5:
147  target[index] = True
148  i1, i2 = int(np.random.rand() * number_total_lines), int(np.random.rand() * (number_total_lines - 1))
149  i2 = (i1 + i2) % number_total_lines
150  track1, track2 = build_signal_event()
151  data[index, i1 * 6:(i1 + 1) * 6] = track1
152  data[index, i2 * 6:(i2 + 1) * 6] = track2
153 
154  # Saving all variables in root files
155  dic = {}
156  for i, name in enumerate(variables):
157  dic.update({name: data[:, i]})
158  dic.update({'isSignal': target})
159 
160  df = pandas.DataFrame(dic, dtype=np.float32)
161  to_root(df, os.path.join(path, filename), key='variables')
162 
163  # ##########################Do Training#################################
164  # Do a comparison of different Nets for this task.
165 
166  general_options = basf2_mva.GeneralOptions()
167  general_options.m_datafiles = basf2_mva.vector(os.path.join(path, 'train.root'))
168  general_options.m_treename = "variables"
169  general_options.m_variables = basf2_mva.vector(*variables)
170  general_options.m_target_variable = "isSignal"
171 
172  specific_options = basf2_mva.PythonOptions()
173  specific_options.m_framework = "contrib_keras"
174  specific_options.m_steering_file = 'mva/examples/keras/relational_network.py'
175  specific_options.m_training_fraction = 0.999
176 
177  # Train relational Net
178  print('Train relational net ')
179  general_options.m_identifier = os.path.join(path, 'relation.xml')
180  specific_options.m_config = json.dumps({'use_relations': True,
181  'number_features': 3})
182  basf2_mva.teacher(general_options, specific_options)
183 
184  # Train normal feed forward Net:
185  print('Train feed forward net')
186  general_options.m_identifier = os.path.join(path, 'feed_forward.xml')
187  specific_options.m_config = json.dumps({'use_relations': False})
188  basf2_mva.teacher(general_options, specific_options)
189 
190  # ########################Compare Results####################################
191  method1 = basf2_mva_util.Method(os.path.join(path, 'relation.xml'))
192  method2 = basf2_mva_util.Method(os.path.join(path, 'feed_forward.xml'))
193 
194  test_data = basf2_mva.vector(os.path.join(path, 'test.root'))
195 
196  print('Apply relational net')
197  p1, t1 = method1.apply_expert(test_data, general_options.m_treename)
198  print('Apply feed forward net')
199  p2, t2 = method2.apply_expert(test_data, general_options.m_treename)
200 
201  print('Relational Net AUC: ', basf2_mva_util.calculate_roc_auc(p1, t1))
202  print('Feed Forward Net AUC: ', basf2_mva_util.calculate_roc_auc(p2, t2))
def calculate_roc_auc(p, t)