14 import tensorflow
as tf
15 import tensorflow.contrib.keras
as keras
17 from keras.layers
import Input, Dense, Concatenate, Flatten, Dropout, GlobalAveragePooling1D
18 from keras.layers.merge
import Average
19 from keras.layers.core
import Reshape
20 from keras.layers
import activations
21 from keras.models
import Model, load_model
22 from keras.optimizers
import adam
23 from keras.losses
import binary_crossentropy, mean_squared_error
24 from keras.activations
import sigmoid, tanh
25 from keras.engine.topology
import Layer
26 from keras
import backend
as K
27 from keras.callbacks
import Callback, EarlyStopping
30 from basf2_mva_extensions.keras_relational
import Relations
33 def get_model(number_of_features, number_of_spectators, number_of_events, training_fraction, parameters):
38 input = Input(shape=(number_of_features,))
41 if parameters[
'use_relations']:
42 net = Reshape((number_of_features // 6, 6))(net)
43 net = Relations(number_features=parameters[
'number_features'])(net)
45 net = GlobalAveragePooling1D()(net)
48 net = Dense(units=2 * number_of_features, activation=tanh)(net)
50 output = Dense(units=1, activation=sigmoid)(net)
52 state = State(Model(input, output), custom_objects={
'Relations': Relations})
54 state.model.compile(optimizer=adam(lr=0.001), loss=binary_crossentropy, metrics=[
'accuracy'])
60 def begin_fit(state, Xtest, Stest, ytest, wtest):
62 Returns just the state object
70 def partial_fit(state, X, S, y, w, epoch):
74 class TestCallback(Callback):
77 Class has to inherit from abstract Callback class
80 def on_epoch_end(self, epoch, logs={}):
82 Print summary at the end of epoch.
83 For other possibilities look at the abstract Callback class.
85 loss, acc = state.model.evaluate(state.Xtest, state.ytest, verbose=0, batch_size=1000)
86 loss2, acc2 = state.model.evaluate(X[:10000], y[:10000], verbose=0, batch_size=1000)
87 print(
'\nTesting loss: {}, acc: {}'.format(loss, acc))
88 print(
'Training loss: {}, acc: {}'.format(loss2, acc2))
90 state.model.fit(X, y, batch_size=100, epochs=100, validation_data=(state.Xtest, state.ytest),
91 callbacks=[TestCallback(), EarlyStopping(monitor=
'val_loss')])
95 if __name__ ==
"__main__":
98 from root_pandas
import to_root
105 import basf2_mva_util
106 from basf2
import conditions
108 conditions.testing_payloads = [
109 'localdb/database.txt'
122 number_total_lines = 5
124 for i
in range(number_total_lines):
125 variables += [
'px_' + str(i),
'py_' + str(i),
'pz_' + str(i),
'dx_' + str(i),
'dy_' + str(i),
128 number_of_events = 1000000
130 def build_signal_event():
131 """Building two lines which are hitting each other"""
132 p_vec1, p_vec2 = np.random.normal(size=3), np.random.normal(size=3)
133 v_cross = np.random.normal(size=3)
134 epsilon1, epsilon2 = (np.random.rand() * 2 - 1) / 10, (np.random.rand() * 2 - 1) / 10
135 v_vec1 = v_cross + (p_vec1 * epsilon1)
136 v_vec2 = v_cross + (p_vec2 * epsilon2)
137 return np.concatenate([p_vec1, v_vec1]), np.concatenate([p_vec2, v_vec2])
140 with tempfile.TemporaryDirectory()
as path:
141 for filename
in [
'train.root',
'test.root']:
142 print(
'Building ' + filename)
144 data = np.random.normal(size=[number_of_events, number_total_lines * 6])
145 target = np.zeros([number_of_events], dtype=bool)
149 for index, sample
in enumerate(data):
150 if np.random.rand() > 0.5:
152 i1, i2 = int(np.random.rand() * number_total_lines), int(np.random.rand() * (number_total_lines - 1))
153 i2 = (i1 + i2) % number_total_lines
154 track1, track2 = build_signal_event()
155 data[index, i1 * 6:(i1 + 1) * 6] = track1
156 data[index, i2 * 6:(i2 + 1) * 6] = track2
160 for i, name
in enumerate(variables):
161 dic.update({name: data[:, i]})
162 dic.update({
'isSignal': target})
164 df = pandas.DataFrame(dic, dtype=np.float32)
165 to_root(df, os.path.join(path, filename), key=
'variables')
170 general_options = basf2_mva.GeneralOptions()
171 general_options.m_datafiles = basf2_mva.vector(os.path.join(path,
'train.root'))
172 general_options.m_treename =
"variables"
173 general_options.m_variables = basf2_mva.vector(*variables)
174 general_options.m_target_variable =
"isSignal"
176 specific_options = basf2_mva.PythonOptions()
177 specific_options.m_framework =
"contrib_keras"
178 specific_options.m_steering_file =
'mva/examples/keras/relational_network.py'
179 specific_options.m_training_fraction = 0.999
182 print(
'Train relational net ')
183 general_options.m_identifier = os.path.join(path,
'relation.xml')
184 specific_options.m_config = json.dumps({
'use_relations':
True,
185 'number_features': 3})
186 basf2_mva.teacher(general_options, specific_options)
189 print(
'Train feed forward net')
190 general_options.m_identifier = os.path.join(path,
'feed_forward.xml')
191 specific_options.m_config = json.dumps({
'use_relations':
False})
192 basf2_mva.teacher(general_options, specific_options)
198 test_data = basf2_mva.vector(os.path.join(path,
'test.root'))
200 print(
'Apply relational net')
201 p1, t1 = method1.apply_expert(test_data, general_options.m_treename)
202 print(
'Apply feed forward net')
203 p2, t2 = method2.apply_expert(test_data, general_options.m_treename)