6 from keras.layers.core
import Reshape
7 from keras.layers
import activations
8 from keras.activations
import sigmoid, tanh
9 from keras.engine.topology
import Layer
10 from keras
import backend
as K
16 This is a class which implements Relational Layer into Keras.
17 Relational Layer compares every combination of two feature groups with shared weights.
18 Use this class as every other Layer in Keras.
19 Relevant Paper: https://arxiv.org/abs/1706.01427
20 RN(O) = f_phi(sum_phi(g_theta(o_i,o_j)))
21 For flexibility reason only the part g(o_i,o_j) is modelled
22 f_phi corresponds to a MLP net
23 To sum over all permutations please use GlobalAveragePooling1D from keras.
26 def __init__(self, number_features, hidden_feature_shape=[30, 30, 30, 30], activation=tanh, **kwargs):
46 super(Relations, self).
__init__(**kwargs)
50 Build all weights for Relations Layer
51 :param input_shape: Input shape of tensor
55 assert(len(input_shape) == 3)
65 for i
in range(len(dense_shape[:-1])):
66 weights = self.add_weight(name=
'relation_weights_{}'.format(i),
67 shape=list(dense_shape[i:i + 2]), initializer=
'glorot_uniform', trainable=
True)
68 bias = self.add_weight(name=
'relation_weights_{}'.format(i),
69 shape=(dense_shape[i + 1],), initializer=
'zeros', trainable=
True)
73 super(Relations, self).
build(input_shape)
77 Compute Relational Layer
78 :param inputs: input tensor
79 :return: output tensor
81 input_groups = [inputs[:, i, :]
for i
in range(self.
number_groups)]
83 for index, group1
in enumerate(input_groups[:-1]):
84 for group2
in input_groups[index + 1:]:
85 net = K.dot(K.concatenate([group1, group2]), self.
variables[0][0])
86 net = K.bias_add(net, self.
variables[0][1])
89 net = K.dot(net, variables[0])
90 net = K.bias_add(net, variables[1])
91 outputs.append(sigmoid(net))
93 flat_result = K.concatenate(outputs)
102 assert(len(input_shape) == 3)
110 Config required for saving parameters in keras model.
115 'activation': activations.serialize(self.
activation)
117 base_config = super(Relations, self).
get_config()
118 return dict(list(base_config.items()) + list(config.items()))
123 This is a class which implements Relational Layer into Keras.
124 See Class Relations for details.
125 EnhanceRelations use an additional input for passing event information to every comparison:
126 RN(O) = f_phi(sum_phi(g_theta(o_i,o_j,q)))
127 q is fed in as second one dimensional input.
130 def __init__(self, number_features, hidden_feature_shape=[30, 30, 30, 30], activation=tanh, **kwargs):
152 super(EnhancedRelations, self).
__init__(**kwargs)
156 Build all weights for Relations Layer
157 :param input_shape: Input shape of tensor
161 assert(len(input_shape) == 2)
163 assert(len(input_shape[0]) == 3)
165 assert(len(input_shape[1]) == 2)
177 for i
in range(len(dense_shape[:-1])):
178 weights = self.add_weight(name=
'relation_weights_{}'.format(i),
179 shape=list(dense_shape[i:i + 2]), initializer=
'glorot_uniform', trainable=
True)
180 bias = self.add_weight(name=
'relation_weights_{}'.format(i),
181 shape=(dense_shape[i + 1],), initializer=
'zeros', trainable=
True)
185 super(EnhancedRelations, self).
build(input_shape)
189 Compute Relational Layer
190 :param inputs: input tensor
191 :return: output tensor
193 input_groups = [inputs[0][:, i, :]
for i
in range(self.
number_groups)]
194 questions = inputs[1]
196 for index, group1
in enumerate(input_groups[:-1]):
197 for group2
in input_groups[index + 1:]:
198 net = K.dot(K.concatenate([group1, group2, questions]), self.
variables[0][0])
199 net = K.bias_add(net, self.
variables[0][1])
202 net = K.dot(net, variables[0])
203 net = K.bias_add(net, variables[1])
204 outputs.append(sigmoid(net))
206 flat_result = K.concatenate(outputs)
212 :return: Output shape
215 assert(len(input_shape) == 2)
217 assert(len(input_shape[0]) == 3)
219 assert(len(input_shape[1]) == 2)
227 Config required for saving parameters in keras model.
232 'activation': activations.serialize(self.
activation)
234 base_config = super(EnhancedRelations, self).
get_config()
235 return dict(list(base_config.items()) + list(config.items()))