Belle II Software light-2406-ragdoll
relational_network.py
1#!/usr/bin/env python3
2
3
10
11# This example serves as a basic example of implementing Relational networks into basf2 with tensorflow.
12# As a toy example it will try to tell if 2 out of multiple lines are hitting each other in three dimensional space.
13# Relevant Paper: https://arxiv.org/abs/1706.01427
14# If you want to try out relational networks to your problem, feel free to import the two classes in your code.
15
17
18
19from tensorflow.keras.layers import Dense, GlobalAveragePooling1D, Input, Reshape
20from tensorflow.keras.models import Model
21from tensorflow.keras.optimizers import Adam
22from tensorflow.keras.losses import binary_crossentropy
23from tensorflow.keras.activations import sigmoid, tanh
24from tensorflow.keras.callbacks import Callback, EarlyStopping
25import numpy as np
26
27from basf2_mva_extensions.keras_relational import Relations
28
29
30def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
31 """
32 Building keras model
33 """
34
35 input = Input(shape=(number_of_features,))
36 net = input
37
38 if parameters['use_relations']:
39 net = Reshape((number_of_features // 6, 6))(net)
40 net = Relations(number_features=parameters['number_features'])(net)
41 # average over all permutations
42 net = GlobalAveragePooling1D()(net)
43 else:
44 for i in range(6):
45 net = Dense(units=2 * number_of_features, activation=tanh)(net)
46
47 output = Dense(units=1, activation=sigmoid)(net)
48
49 state = State(Model(input, output), custom_objects={'Relations': Relations})
50
51 state.model.compile(optimizer=Adam(lr=0.001), loss=binary_crossentropy, metrics=['accuracy'])
52 state.model.summary()
53
54 return state
55
56
57def begin_fit(state, Xtest, Stest, ytest, wtest, nBatches):
58 """
59 Returns just the state object
60 """
61 state.Xtest = Xtest
62 state.ytest = ytest
63
64 return state
65
66
67def partial_fit(state, X, S, y, w, epoch, batch):
68 """
69 Do the fit
70 """
71 class TestCallback(Callback):
72 """
73 Print small summary.
74 Class has to inherit from abstract Callback class
75 """
76
77 def on_epoch_end(self, epoch, logs=None):
78 """
79 Print summary at the end of epoch.
80 For other possibilities look at the abstract Callback class.
81 """
82 loss, acc = state.model.evaluate(state.Xtest, state.ytest, verbose=0, batch_size=1000)
83 loss2, acc2 = state.model.evaluate(X[:10000], y[:10000], verbose=0, batch_size=1000)
84 print(f'\nTesting loss: {loss}, acc: {acc}')
85 print(f'Training loss: {loss2}, acc: {acc2}')
86
87 state.model.fit(X, y, batch_size=100, epochs=100, validation_data=(state.Xtest, state.ytest),
88 callbacks=[TestCallback(), EarlyStopping(monitor='val_loss')])
89 return False
90
91
92if __name__ == "__main__":
93 import os
94 import pandas
95 import uproot
96 import tempfile
97 import json
98
99 import basf2_mva
100 import basf2_mva_util
101 from basf2 import conditions
102 # NOTE: do not use testing payloads in production! Any results obtained like this WILL NOT BE PUBLISHED
103 conditions.testing_payloads = [
104 'localdb/database.txt'
105 ]
106
107 # ##############Building Data samples ###########################
108 # This is a dataset for testing relational nets.
109 # It consists of number_total_lines lines in 3 dimensional space.
110 # Each line has 6 variables.
111 # In apprx. half of the cases, two lines are hitting each other.
112 # This is considered a signal event.
113 # Training results differs from the number of total lines.
114
115 variables = []
116 # try using 10 and 20 lines and see what happens
117 number_total_lines = 5
118 # Names for the training data set
119 for i in range(number_total_lines):
120 variables += ['px_' + str(i), 'py_' + str(i), 'pz_' + str(i), 'dx_' + str(i), 'dy_' + str(i),
121 'dz_' + str(i)]
122 # Number of events in training and test root file.
123 number_of_events = 1000000
124
125 def build_signal_event():
126 """Building two lines which are hitting each other"""
127 p_vec1, p_vec2 = np.random.normal(size=3), np.random.normal(size=3)
128 v_cross = np.random.normal(size=3)
129 epsilon1, epsilon2 = (np.random.rand() * 2 - 1) / 10, (np.random.rand() * 2 - 1) / 10
130 v_vec1 = v_cross + (p_vec1 * epsilon1)
131 v_vec2 = v_cross + (p_vec2 * epsilon2)
132 return np.concatenate([p_vec1, v_vec1]), np.concatenate([p_vec2, v_vec2])
133
134 # This path will delete itself with all data in it after end of program.
135 with tempfile.TemporaryDirectory() as path:
136 for filename in ['train.root', 'test.root']:
137 print('Building ' + filename)
138 # Use random numbers to build all training and spectator variables.
139 data = np.random.normal(size=[number_of_events, number_total_lines * 6])
140 target = np.zeros([number_of_events], dtype=np.bool)
141
142 # Overwrite for half of the variables some lines so that they are hitting each other.
143 # Write them also at the end for the spectators.
144 for index, sample in enumerate(data):
145 if np.random.rand() > 0.5:
146 target[index] = True
147 i1, i2 = int(np.random.rand() * number_total_lines), int(np.random.rand() * (number_total_lines - 1))
148 i2 = (i1 + i2) % number_total_lines
149 track1, track2 = build_signal_event()
150 data[index, i1 * 6:(i1 + 1) * 6] = track1
151 data[index, i2 * 6:(i2 + 1) * 6] = track2
152
153 # Saving all variables in root files
154 dic = {}
155 for i, name in enumerate(variables):
156 dic.update({name: data[:, i]})
157 dic.update({'isSignal': target})
158
159 df = pandas.DataFrame(dic)
160 with uproot.recreate(os.path.join(path, filename)) as outfile:
161 outfile['variables'] = df
162
163 # ##########################Do Training#################################
164 # Do a comparison of different Nets for this task.
165
166 general_options = basf2_mva.GeneralOptions()
167 general_options.m_datafiles = basf2_mva.vector(os.path.join(path, 'train.root'))
168 general_options.m_treename = "variables"
169 general_options.m_variables = basf2_mva.vector(*variables)
170 general_options.m_target_variable = "isSignal"
171
172 specific_options = basf2_mva.PythonOptions()
173 specific_options.m_framework = "keras"
174 specific_options.m_steering_file = 'mva/examples/keras/relational_network.py'
175 specific_options.m_training_fraction = 0.999
176
177 # Train relational Net
178 print('Train relational net ')
179 general_options.m_identifier = os.path.join(path, 'relation.xml')
180 specific_options.m_config = json.dumps({'use_relations': True,
181 'number_features': 3})
182 basf2_mva.teacher(general_options, specific_options)
183
184 # Train normal feed forward Net:
185 print('Train feed forward net')
186 general_options.m_identifier = os.path.join(path, 'feed_forward.xml')
187 specific_options.m_config = json.dumps({'use_relations': False})
188 basf2_mva.teacher(general_options, specific_options)
189
190 # ########################Compare Results####################################
191 method1 = basf2_mva_util.Method(os.path.join(path, 'relation.xml'))
192 method2 = basf2_mva_util.Method(os.path.join(path, 'feed_forward.xml'))
193
194 test_data = basf2_mva.vector(os.path.join(path, 'test.root'))
195
196 print('Apply relational net')
197 p1, t1 = method1.apply_expert(test_data, general_options.m_treename)
198 print('Apply feed forward net')
199 p2, t2 = method2.apply_expert(test_data, general_options.m_treename)
200
201 print('Relational Net AUC: ', basf2_mva_util.calculate_auc_efficiency_vs_background_retention(p1, t1))
202 print('Feed Forward Net AUC: ', basf2_mva_util.calculate_auc_efficiency_vs_background_retention(p2, t2))
def calculate_auc_efficiency_vs_background_retention(p, t, w=None)