12import torch.nn.functional
as F
13from torch_scatter
import scatter
16def _init_weights(layer, normalize):
18 Initializes the weights and biases.
20 for m
in layer.modules():
21 if isinstance(m, nn.Linear):
22 nn.init.xavier_normal_(m.weight.data)
24 m.bias.data.fill_(0.1)
25 elif isinstance(m, nn.BatchNorm1d)
or isinstance(m, nn.LayerNorm):
26 m.weight.data.fill_(1)
32 Updates edge features in MetaLayer:
35 e_{ij}^{
'} = \\phi^{e}(e_{ij}, v_{i}, v_{j}, u),
37 where :math:`\\phi^{e}` is a neural network of the form
39 .. figure:: figs/MLP_structure.png
44 nfeat_in_dim (int): Node features input dimension (number of node features
in input).
45 efeat_in_dim (int): Edge features input dimension (number of edge features
in input).
46 gfeat_in_dim (int): Global features input dimension (number of
global features
in input).
47 efeat_hid_dim (int): Edge features dimension
in hidden layers.
48 efeat_out_dim (int): Edge features output dimension.
49 num_hid_layers (int): Number of hidden layers.
50 dropout (float): Dropout rate :math:`r \\
in [0,1]`.
51 normalize (str): Type of normalization (batch/layer).
54 :
return: Updated edge features tensor.
55 :rtype: `Tensor <https://pytorch.org/docs/stable/tensors.html
85 efeat_in_dim + 2 * nfeat_in_dim + gfeat_in_dim, efeat_hid_dim
90 nn.Linear(efeat_hid_dim, efeat_hid_dim)
95 self.
lin_out = nn.Linear(efeat_hid_dim, efeat_out_dim, bias=
not normalize)
99 self.
norm = nn.BatchNorm1d(efeat_out_dim)
101 _init_weights(self, normalize)
103 def forward(self, src, dest, edge_attr, u, batch):
105 Called internally by PyTorch to propagate the input through the network.
106 - src, dest: [E, F_x], where E is the number of edges.
107 - edge_attr: [E, F_e]
108 - u: [B, F_u], where B
is the number of graphs.
109 - batch: [E]
with max entry B - 1.
112 torch.cat([edge_attr, src, dest, u[batch]], dim=1)
113 if u.shape != torch.Size([0])
114 else torch.cat([edge_attr, src, dest], dim=1)
118 out = F.dropout(out, self.
dropout_prob, training=self.training)
124 out = F.dropout(out, self.
dropout_prob, training=self.training)
139 Updates node features in MetaLayer:
142 v_{i}^{
'} = \\phi^{v}(v_{i}, \\rho^{e \\to v}(v_{i}), u)
147 \\rho^{e \\to v}(v_{i}) = \\frac{\\sum_{j=1,\\ j \\neq i}^{N} (e_{ji} + e _{ij})}{2 \\cdot (N-1)},
149 where :math:`\\phi^{v}`
is a neural network of the form
151 .. figure:: figs/MLP_structure.png
156 nfeat_in_dim (int): Node features input dimension (number of node features
in input).
157 efeat_in_dim (int): Edge features input dimension (number of edge features
in input).
158 gfeat_in_dim (int): Global features input dimension (number of
global features
in input).
159 nfeat_hid_dim (int): Node features dimension
in hidden layers.
160 nfeat_out_dim (int): Node features output dimension.
161 num_hid_layers (int): Number of hidden layers.
162 dropout (float): Dropout rate :math:`r \\
in [0,1]`.
163 normalize (str): Type of normalization (batch/layer).
165 :
return: Updated node features tensor.
166 :rtype: `Tensor <https://pytorch.org/docs/stable/tensors.html
196 gfeat_in_dim + nfeat_in_dim + efeat_in_dim, nfeat_hid_dim
201 nn.Linear(nfeat_hid_dim, nfeat_hid_dim)
206 self.
lin_out = nn.Linear(nfeat_hid_dim, nfeat_out_dim, bias=
not normalize)
210 self.
norm = nn.BatchNorm1d(nfeat_out_dim)
212 _init_weights(self, normalize)
214 def forward(self, x, edge_index, edge_attr, u, batch):
216 Called internally by PyTorch to propagate the input through the network.
217 - x: [N, F_x], where N is the number of nodes.
218 - edge_index: [2, E]
with max entry N - 1.
219 - edge_attr: [E, F_e]
221 - batch: [N]
with max entry B - 1.
223 Edge labels are averaged (dim_size = N: number of nodes
in the graph)
226 edge_attr, edge_index[1], dim=0, dim_size=batch.size(0), reduce="mean"
229 torch.cat([x, out, u[batch]], dim=1)
230 if u.shape != torch.Size([0])
231 else torch.cat([x, out], dim=1)
235 out = F.dropout(out, self.
dropout_prob, training=self.training)
241 out = F.dropout(out, self.
dropout_prob, training=self.training)
256 Updates node features in MetaLayer:
259 u_{i}^{
'} = \\phi^{u}(\\rho^{e \\to u}(e), \\rho^{v \\to u}(v), u)
264 \\rho^{e \\to u}(e) = \\frac{\\sum_{i,j=1,\\ i \\neq j}^{N} e_{ij}}{N \\cdot (N-1)},\\\\
265 \\rho^{v \\to u}(e) = \\frac{\\sum_{i=1}^{N} v_{i}}{N},
267 where :math:`\\phi^{u}`
is a neural network of the form
269 .. figure:: figs/MLP_structure.png
274 nfeat_in_dim (int): Node features input dimension (number of node features
in input).
275 efeat_in_dim (int): Edge features input dimension (number of edge features
in input).
276 gfeat_in_dim (int): Global features input dimension (number of
global features
in input).
277 nfeat_hid_dim (int): Global features dimension
in hidden layers.
278 nfeat_out_dim (int): Global features output dimension.
279 num_hid_layers (int): Number of hidden layers.
280 dropout (float): Dropout rate :math:`r \\
in [0,1]`.
281 normalize (str): Type of normalization (batch/layer).
283 :
return: Updated
global features tensor.
284 :rtype: `Tensor <https://pytorch.org/docs/stable/tensors.html
314 nfeat_in_dim + efeat_in_dim + gfeat_in_dim, gfeat_hid_dim
319 nn.Linear(gfeat_hid_dim, gfeat_hid_dim)
324 self.
lin_out = nn.Linear(gfeat_hid_dim, gfeat_out_dim, bias=
not normalize)
328 self.
norm = nn.BatchNorm1d(gfeat_out_dim)
330 _init_weights(self, normalize)
332 def forward(self, x, edge_index, edge_attr, u, batch):
334 Called internally by Pytorch to propagate the input through the network.
335 - x: [N, F_x], where N is the number of nodes.
336 - edge_index: [2, E]
with max entry N - 1.
337 - edge_attr: [E, F_e]
339 - batch: [N]
with max entry B - 1.
341 Nodes are averaged over graph
344 x, batch, dim=0, reduce="mean"
348 edge_attr, edge_index[1], dim=0, reduce=
"mean"
352 edge_mean, batch, dim=0, reduce=
"mean"
355 torch.cat([u, node_mean, edge_mean], dim=1)
356 if u.shape != torch.Size([0])
357 else torch.cat([node_mean, edge_mean], dim=1)
361 out = F.dropout(out, self.
dropout_prob, training=self.training)
367 out = F.dropout(out, self.
dropout_prob, training=self.training)
lin_in
Linear input layer.
num_hid_layers
Number of hidden layers.
dropout_prob
Dropout probability.
nonlin_function
Non-linear activation.
lin_out
Output linear layer.
def forward(self, src, dest, edge_attr, u, batch)
def __init__(self, nfeat_in_dim, efeat_in_dim, gfeat_in_dim, efeat_hid_dim, efeat_out_dim, num_hid_layers, dropout, normalize=True)
lins_hid
Intermediate linear layers.
lin_in
Input linear layer.
num_hid_layers
Number of hidden layers.
def __init__(self, nfeat_in_dim, efeat_in_dim, gfeat_in_dim, gfeat_hid_dim, gfeat_out_dim, num_hid_layers, dropout, normalize=True)
dropout_prob
Dropout probability.
nonlin_function
Non-linear activation.
lin_out
Output linear layer.
lins_hid
Intermediate linear layers.
def forward(self, x, edge_index, edge_attr, u, batch)
def __init__(self, nfeat_in_dim, efeat_in_dim, gfeat_in_dim, nfeat_hid_dim, nfeat_out_dim, num_hid_layers, dropout, normalize=True)
lin_in
Input linear layer.
num_hid_layers
Number of hidden layers.
dropout_prob
Dropout probability.
nonlin_function
Non-linear activation.
lin_out
Output linear layer.
lins_hid
Intermediate linear layers.
def forward(self, x, edge_index, edge_attr, u, batch)