Belle II Software  release-08-01-10
relational_network.py
1 #!/usr/bin/env python3
2 
3 
10 
11 # This example serves as a basic example of implementing Relational networks into basf2 with tensorflow.
12 # As a toy example it will try to tell if 2 out of multiple lines are hitting each other in three dimensional space.
13 # Relevant Paper: https://arxiv.org/abs/1706.01427
14 # If you want to try out relational networks to your problem, feel free to import the two classes in your code.
15 
16 from basf2_mva_python_interface.keras import State
17 
18 
19 from tensorflow.keras.layers import Dense, GlobalAveragePooling1D, Input, Reshape
20 from tensorflow.keras.models import Model
21 from tensorflow.keras.optimizers import Adam
22 from tensorflow.keras.losses import binary_crossentropy
23 from tensorflow.keras.activations import sigmoid, tanh
24 from tensorflow.keras.callbacks import Callback, EarlyStopping
25 import numpy as np
26 
27 from basf2_mva_extensions.keras_relational import Relations
28 
29 
30 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
31  """
32  Building keras model
33  """
34 
35  input = Input(shape=(number_of_features,))
36  net = input
37 
38  if parameters['use_relations']:
39  net = Reshape((number_of_features // 6, 6))(net)
40  net = Relations(number_features=parameters['number_features'])(net)
41  # average over all permutations
42  net = GlobalAveragePooling1D()(net)
43  else:
44  for i in range(6):
45  net = Dense(units=2 * number_of_features, activation=tanh)(net)
46 
47  output = Dense(units=1, activation=sigmoid)(net)
48 
49  state = State(Model(input, output), custom_objects={'Relations': Relations})
50 
51  state.model.compile(optimizer=Adam(lr=0.001), loss=binary_crossentropy, metrics=['accuracy'])
52  state.model.summary()
53 
54  return state
55 
56 
57 def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
58  """
59  Returns just the state object
60  """
61  state.Xtest = Xtest
62  state.ytest = ytest
63 
64  return state
65 
66 
67 def partial_fit(state, X, S, y, w, epoch, batch):
68  """
69  Do the fit
70  """
71  class TestCallback(Callback):
72  """
73  Print small summary.
74  Class has to inherit from abstract Callback class
75  """
76 
77  def on_epoch_end(self, epoch, logs=None):
78  """
79  Print summary at the end of epoch.
80  For other possibilities look at the abstract Callback class.
81  """
82  loss, acc = state.model.evaluate(state.Xtest, state.ytest, verbose=0, batch_size=1000)
83  loss2, acc2 = state.model.evaluate(X[:10000], y[:10000], verbose=0, batch_size=1000)
84  print(f'\nTesting loss: {loss}, acc: {acc}')
85  print(f'Training loss: {loss2}, acc: {acc2}')
86 
87  state.model.fit(X, y, batch_size=100, epochs=100, validation_data=(state.Xtest, state.ytest),
88  callbacks=[TestCallback(), EarlyStopping(monitor='val_loss')])
89  return False
90 
91 
92 if __name__ == "__main__":
93  import os
94  import pandas
95  import uproot
96  import tempfile
97  import json
98 
99  import basf2_mva
100  import basf2_mva_util
101  from basf2 import conditions
102  # NOTE: do not use testing payloads in production! Any results obtained like this WILL NOT BE PUBLISHED
103  conditions.testing_payloads = [
104  'localdb/database.txt'
105  ]
106 
107  # ##############Building Data samples ###########################
108  # This is a dataset for testing relational nets.
109  # It consists of number_total_lines lines in 3 dimensional space.
110  # Each line has 6 variables.
111  # In apprx. half of the cases, two lines are hitting each other.
112  # This is considered a signal event.
113  # Training results differs from the number of total lines.
114 
115  variables = []
116  # try using 10 and 20 lines and see what happens
117  number_total_lines = 5
118  # Names for the training data set
119  for i in range(number_total_lines):
120  variables += ['px_' + str(i), 'py_' + str(i), 'pz_' + str(i), 'dx_' + str(i), 'dy_' + str(i),
121  'dz_' + str(i)]
122  # Number of events in training and test root file.
123  number_of_events = 1000000
124 
125  def build_signal_event():
126  """Building two lines which are hitting each other"""
127  p_vec1, p_vec2 = np.random.normal(size=3), np.random.normal(size=3)
128  v_cross = np.random.normal(size=3)
129  epsilon1, epsilon2 = (np.random.rand() * 2 - 1) / 10, (np.random.rand() * 2 - 1) / 10
130  v_vec1 = v_cross + (p_vec1 * epsilon1)
131  v_vec2 = v_cross + (p_vec2 * epsilon2)
132  return np.concatenate([p_vec1, v_vec1]), np.concatenate([p_vec2, v_vec2])
133 
134  # This path will delete itself with all data in it after end of program.
135  with tempfile.TemporaryDirectory() as path:
136  for filename in ['train.root', 'test.root']:
137  print('Building ' + filename)
138  # Use random numbers to build all training and spectator variables.
139  data = np.random.normal(size=[number_of_events, number_total_lines * 6])
140  target = np.zeros([number_of_events], dtype=np.bool)
141 
142  # Overwrite for half of the variables some lines so that they are hitting each other.
143  # Write them also at the end for the spectators.
144  for index, sample in enumerate(data):
145  if np.random.rand() > 0.5:
146  target[index] = True
147  i1, i2 = int(np.random.rand() * number_total_lines), int(np.random.rand() * (number_total_lines - 1))
148  i2 = (i1 + i2) % number_total_lines
149  track1, track2 = build_signal_event()
150  data[index, i1 * 6:(i1 + 1) * 6] = track1
151  data[index, i2 * 6:(i2 + 1) * 6] = track2
152 
153  # Saving all variables in root files
154  dic = {}
155  for i, name in enumerate(variables):
156  dic.update({name: data[:, i]})
157  dic.update({'isSignal': target})
158 
159  df = pandas.DataFrame(dic)
160  with uproot.recreate(os.path.join(path, filename)) as outfile:
161  outfile['variables'] = df
162 
163  # ##########################Do Training#################################
164  # Do a comparison of different Nets for this task.
165 
166  general_options = basf2_mva.GeneralOptions()
167  general_options.m_datafiles = basf2_mva.vector(os.path.join(path, 'train.root'))
168  general_options.m_treename = "variables"
169  general_options.m_variables = basf2_mva.vector(*variables)
170  general_options.m_target_variable = "isSignal"
171 
172  specific_options = basf2_mva.PythonOptions()
173  specific_options.m_framework = "keras"
174  specific_options.m_steering_file = 'mva/examples/keras/relational_network.py'
175  specific_options.m_training_fraction = 0.999
176 
177  # Train relational Net
178  print('Train relational net ')
179  general_options.m_identifier = os.path.join(path, 'relation.xml')
180  specific_options.m_config = json.dumps({'use_relations': True,
181  'number_features': 3})
182  basf2_mva.teacher(general_options, specific_options)
183 
184  # Train normal feed forward Net:
185  print('Train feed forward net')
186  general_options.m_identifier = os.path.join(path, 'feed_forward.xml')
187  specific_options.m_config = json.dumps({'use_relations': False})
188  basf2_mva.teacher(general_options, specific_options)
189 
190  # ########################Compare Results####################################
191  method1 = basf2_mva_util.Method(os.path.join(path, 'relation.xml'))
192  method2 = basf2_mva_util.Method(os.path.join(path, 'feed_forward.xml'))
193 
194  test_data = basf2_mva.vector(os.path.join(path, 'test.root'))
195 
196  print('Apply relational net')
197  p1, t1 = method1.apply_expert(test_data, general_options.m_treename)
198  print('Apply feed forward net')
199  p2, t2 = method2.apply_expert(test_data, general_options.m_treename)
200 
201  print('Relational Net AUC: ', basf2_mva_util.calculate_auc_efficiency_vs_background_retention(p1, t1))
202  print('Feed Forward Net AUC: ', basf2_mva_util.calculate_auc_efficiency_vs_background_retention(p2, t2))
def calculate_auc_efficiency_vs_background_retention(p, t, w=None)