Belle II Software light-2406-ragdoll
Relations Class Reference
Inheritance diagram for Relations:
Collaboration diagram for Relations:

Public Member Functions

def __init__ (self, number_features, hidden_feature_shape=[30, 30, 30, 30], activation=tanh, **kwargs)
 
def build (self, input_shape)
 
def call (self, inputs)
 
def compute_output_shape (self, input_shape)
 
def get_config (self)
 

Public Attributes

 number_features
 Number of features.
 
 number_groups
 Number of groups in input.
 
 hidden_feature_shape
 shape of hidden layers used for extracting relations
 
 activation
 activation used for hidden layer in shared weights.
 
 group_len
 how many neurons has one comparable object
 
 weightvariables
 saves weights for call
 
 combinations
 number of relation combinations
 

Detailed Description

This is a class which implements Relational Layer into Keras.
Relational Layer compares every combination of two feature groups with shared weights.
Use this class as every other Layer in Keras.
Relevant Paper: https://arxiv.org/abs/1706.01427
RN(O) = f_phi(sum_phi(g_theta(o_i,o_j)))
For flexibility reason only the part g(o_i,o_j) is modelled
f_phi corresponds to a MLP net
To sum over all permutations please use GlobalAveragePooling1D from keras.

Definition at line 18 of file keras_relational.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  number_features,
  hidden_feature_shape = [30, 30, 30, 30],
  activation = tanh,
**  kwargs 
)
Init class.

Definition at line 30 of file keras_relational.py.

30 def __init__(self, number_features, hidden_feature_shape=[30, 30, 30, 30], activation=tanh, **kwargs):
31 """
32 Init class.
33 """
34
35
36 self.number_features = number_features
37
38 self.number_groups = 0
39
40 self.hidden_feature_shape = hidden_feature_shape
41
42 self.activation = activations.get(activation)
43
44 self.group_len = 0
45
46 self.weightvariables = []
47
48 self.combinations = 0
49
50 super().__init__(**kwargs)
51

Member Function Documentation

◆ build()

def build (   self,
  input_shape 
)
Build all weights for Relations Layer
:param input_shape: Input shape of tensor
:return:  Nothing

Definition at line 52 of file keras_relational.py.

52 def build(self, input_shape):
53 """
54 Build all weights for Relations Layer
55 :param input_shape: Input shape of tensor
56 :return: Nothing
57 """
58 # only accept 2D layers
59 assert(len(input_shape) == 3)
60
61 self.number_groups = input_shape[1]
62
63 self.group_len = input_shape[2]
64
65 self.combinations = np.int32(np.math.factorial(self.number_groups) / (2 * np.math.factorial(self.number_groups - 2)))
66
67 dense_shape = [2 * self.group_len] + self.hidden_feature_shape + [self.number_features]
68
69 for i in range(len(dense_shape[:-1])):
70 weights = self.add_weight(name=f'relation_weights_{i}',
71 shape=list(dense_shape[i:i + 2]), initializer='glorot_uniform', trainable=True)
72 bias = self.add_weight(name=f'relation_weights_{i}',
73 shape=(dense_shape[i + 1],), initializer='zeros', trainable=True)
74
75 self.weightvariables.append([weights, bias])
76
77 super().build(input_shape)
78

◆ call()

def call (   self,
  inputs 
)
Compute Relational Layer
:param inputs: input tensor
:return: output tensor

Definition at line 79 of file keras_relational.py.

79 def call(self, inputs):
80 """
81 Compute Relational Layer
82 :param inputs: input tensor
83 :return: output tensor
84 """
85 input_groups = [inputs[:, i, :] for i in range(self.number_groups)]
86 outputs = []
87 for index, group1 in enumerate(input_groups[:-1]):
88 for group2 in input_groups[index + 1:]:
89 net = K.dot(K.concatenate([group1, group2]), self.weightvariables[0][0])
90 net = K.bias_add(net, self.weightvariables[0][1])
91 for variables in self.weightvariables[1:]:
92 net = self.activation(net)
93 net = K.dot(net, variables[0])
94 net = K.bias_add(net, variables[1])
95 outputs.append(sigmoid(net))
96
97 flat_result = K.concatenate(outputs)
98 return Reshape((self.combinations, self.number_features,))(flat_result)
99

◆ compute_output_shape()

def compute_output_shape (   self,
  input_shape 
)
Compute Output shape
:return: Output shape

Definition at line 100 of file keras_relational.py.

100 def compute_output_shape(self, input_shape):
101 """
102 Compute Output shape
103 :return: Output shape
104 """
105 # only 2D layers
106 assert(len(input_shape) == 3)
107
108 self.combinations = np.int32(np.math.factorial(self.number_groups) / (2 * np.math.factorial(self.number_groups - 2)))
109
110 return (input_shape[0], self.combinations, self.number_features)
111

◆ get_config()

def get_config (   self)
Config required for saving parameters in keras model.

Definition at line 112 of file keras_relational.py.

112 def get_config(self):
113 """
114 Config required for saving parameters in keras model.
115 """
116 config = {
117 'number_features': self.number_features,
118 'hidden_feature_shape': self.hidden_feature_shape,
119 'activation': activations.serialize(self.activation)
120 }
121 base_config = super().get_config()
122 return dict(list(base_config.items()) + list(config.items()))
123
124

Member Data Documentation

◆ activation

activation

activation used for hidden layer in shared weights.

For output sigmoid will always be used.

Definition at line 42 of file keras_relational.py.

◆ combinations

combinations

number of relation combinations

Definition at line 48 of file keras_relational.py.

◆ group_len

group_len

how many neurons has one comparable object

Definition at line 44 of file keras_relational.py.

◆ hidden_feature_shape

hidden_feature_shape

shape of hidden layers used for extracting relations

Definition at line 40 of file keras_relational.py.

◆ number_features

number_features

Number of features.

Number of different shared weights used for comparison for each relation.

Definition at line 36 of file keras_relational.py.

◆ number_groups

number_groups

Number of groups in input.

Definition at line 38 of file keras_relational.py.

◆ weightvariables

weightvariables

saves weights for call

Definition at line 46 of file keras_relational.py.


The documentation for this class was generated from the following file: