Belle II Software development
geometric_layers.py
1
8
9
10import torch
11import torch.nn as nn
12import torch.nn.functional as F
13from torch_scatter import scatter
14
15
16def _init_weights(layer, normalize):
17 """
18 Initializes the weights and biases.
19 """
20 for m in layer.modules():
21 if isinstance(m, nn.Linear):
22 nn.init.xavier_normal_(m.weight.data)
23 if not normalize:
24 m.bias.data.fill_(0.1)
25 elif isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.LayerNorm):
26 m.weight.data.fill_(1)
27 m.bias.data.zero_()
28
29
30class EdgeLayer(nn.Module):
31 """
32 Updates edge features in MetaLayer:
33
34 .. math::
35 e_{ij}^{'} = \\phi^{e}(e_{ij}, v_{i}, v_{j}, u),
36
37 where :math:`\\phi^{e}` is a neural network of the form
38
39 .. figure:: figs/MLP_structure.png
40 :width: 42em
41 :align: center
42
43 Args:
44 nfeat_in_dim (int): Node features input dimension (number of node features in input).
45 efeat_in_dim (int): Edge features input dimension (number of edge features in input).
46 gfeat_in_dim (int): Global features input dimension (number of global features in input).
47 efeat_hid_dim (int): Edge features dimension in hidden layers.
48 efeat_out_dim (int): Edge features output dimension.
49 num_hid_layers (int): Number of hidden layers.
50 dropout (float): Dropout rate :math:`r \\in [0,1]`.
51 normalize (str): Type of normalization (batch/layer).
52
53
54 :return: Updated edge features tensor.
55 :rtype: `Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_
56 """
57
59 self,
60 nfeat_in_dim,
61 efeat_in_dim,
62 gfeat_in_dim,
63 efeat_hid_dim,
64 efeat_out_dim,
65 num_hid_layers,
66 dropout,
67 normalize=True,
68 ):
69 """
70 Initialization.
71 """
72 super(EdgeLayer, self).__init__()
73
74
75 self.nonlin_function = F.elu
76
77 self.num_hid_layers = num_hid_layers
78
79 self.dropout_prob = dropout
80
81 self.normalize = normalize
82
83
84 self.lin_in = nn.Linear(
85 efeat_in_dim + 2 * nfeat_in_dim + gfeat_in_dim, efeat_hid_dim
86 )
87
88 self.lins_hid = nn.ModuleList(
89 [
90 nn.Linear(efeat_hid_dim, efeat_hid_dim)
91 for _ in range(self.num_hid_layers)
92 ]
93 )
94
95 self.lin_out = nn.Linear(efeat_hid_dim, efeat_out_dim, bias=not normalize)
96
97 if normalize:
98
99 self.norm = nn.BatchNorm1d(efeat_out_dim)
100
101 _init_weights(self, normalize)
102
103 def forward(self, src, dest, edge_attr, u, batch):
104 """
105 Called internally by PyTorch to propagate the input through the network.
106 - src, dest: [E, F_x], where E is the number of edges.
107 - edge_attr: [E, F_e]
108 - u: [B, F_u], where B is the number of graphs.
109 - batch: [E] with max entry B - 1.
110 """
111 out = (
112 torch.cat([edge_attr, src, dest, u[batch]], dim=1)
113 if u.shape != torch.Size([0])
114 else torch.cat([edge_attr, src, dest], dim=1)
115 )
116
117 out = self.nonlin_function(self.lin_in(out))
118 out = F.dropout(out, self.dropout_prob, training=self.training)
119
120 out_skip = out
121
122 for lin_hid in self.lins_hid:
123 out = self.nonlin_function(lin_hid(out))
124 out = F.dropout(out, self.dropout_prob, training=self.training)
125
126 if self.num_hid_layers > 1:
127 out += out_skip
128
129 if self.normalize:
130 out = self.nonlin_function(self.norm(self.lin_out(out)))
131 else:
132 out = self.nonlin_function(self.lin_out(out))
133
134 return out
135
136
137class NodeLayer(nn.Module):
138 """
139 Updates node features in MetaLayer:
140
141 .. math::
142 v_{i}^{'} = \\phi^{v}(v_{i}, \\rho^{e \\to v}(v_{i}), u)
143
144 with
145
146 .. math::
147 \\rho^{e \\to v}(v_{i}) = \\frac{\\sum_{j=1,\\ j \\neq i}^{N} (e_{ji} + e _{ij})}{2 \\cdot (N-1)},
148
149 where :math:`\\phi^{v}` is a neural network of the form
150
151 .. figure:: figs/MLP_structure.png
152 :width: 42em
153 :align: center
154
155 Args:
156 nfeat_in_dim (int): Node features input dimension (number of node features in input).
157 efeat_in_dim (int): Edge features input dimension (number of edge features in input).
158 gfeat_in_dim (int): Global features input dimension (number of global features in input).
159 nfeat_hid_dim (int): Node features dimension in hidden layers.
160 nfeat_out_dim (int): Node features output dimension.
161 num_hid_layers (int): Number of hidden layers.
162 dropout (float): Dropout rate :math:`r \\in [0,1]`.
163 normalize (str): Type of normalization (batch/layer).
164
165 :return: Updated node features tensor.
166 :rtype: `Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_
167 """
168
170 self,
171 nfeat_in_dim,
172 efeat_in_dim,
173 gfeat_in_dim,
174 nfeat_hid_dim,
175 nfeat_out_dim,
176 num_hid_layers,
177 dropout,
178 normalize=True,
179 ):
180 """
181 Initialization.
182 """
183 super(NodeLayer, self).__init__()
184
185
186 self.nonlin_function = F.elu
187
188 self.num_hid_layers = num_hid_layers
189
190 self.dropout_prob = dropout
191
192 self.normalize = normalize
193
194
195 self.lin_in = nn.Linear(
196 gfeat_in_dim + nfeat_in_dim + efeat_in_dim, nfeat_hid_dim
197 )
198
199 self.lins_hid = nn.ModuleList(
200 [
201 nn.Linear(nfeat_hid_dim, nfeat_hid_dim)
202 for _ in range(self.num_hid_layers)
203 ]
204 )
205
206 self.lin_out = nn.Linear(nfeat_hid_dim, nfeat_out_dim, bias=not normalize)
207
208 if normalize:
209
210 self.norm = nn.BatchNorm1d(nfeat_out_dim)
211
212 _init_weights(self, normalize)
213
214 def forward(self, x, edge_index, edge_attr, u, batch):
215 """
216 Called internally by PyTorch to propagate the input through the network.
217 - x: [N, F_x], where N is the number of nodes.
218 - edge_index: [2, E] with max entry N - 1.
219 - edge_attr: [E, F_e]
220 - u: [B, F_u]
221 - batch: [N] with max entry B - 1.
222
223 Edge labels are averaged (dim_size = N: number of nodes in the graph)
224 """
225 out = scatter(
226 edge_attr, edge_index[1], dim=0, dim_size=batch.size(0), reduce="mean"
227 )
228 out = (
229 torch.cat([x, out, u[batch]], dim=1)
230 if u.shape != torch.Size([0])
231 else torch.cat([x, out], dim=1)
232 )
233
234 out = self.nonlin_function(self.lin_in(out))
235 out = F.dropout(out, self.dropout_prob, training=self.training)
236
237 out_skip = out
238
239 for lin_hid in self.lins_hid:
240 out = self.nonlin_function(lin_hid(out))
241 out = F.dropout(out, self.dropout_prob, training=self.training)
242
243 if self.num_hid_layers > 1:
244 out += out_skip
245
246 if self.normalize:
247 out = self.nonlin_function(self.norm(self.lin_out(out)))
248 else:
249 out = self.nonlin_function(self.lin_out(out))
250
251 return out
252
253
254class GlobalLayer(nn.Module):
255 """
256 Updates node features in MetaLayer:
257
258 .. math::
259 u_{i}^{'} = \\phi^{u}(\\rho^{e \\to u}(e), \\rho^{v \\to u}(v), u)
260
261 with
262
263 .. math::
264 \\rho^{e \\to u}(e) = \\frac{\\sum_{i,j=1,\\ i \\neq j}^{N} e_{ij}}{N \\cdot (N-1)},\\\\
265 \\rho^{v \\to u}(e) = \\frac{\\sum_{i=1}^{N} v_{i}}{N},
266
267 where :math:`\\phi^{u}` is a neural network of the form
268
269 .. figure:: figs/MLP_structure.png
270 :width: 42em
271 :align: center
272
273 Args:
274 nfeat_in_dim (int): Node features input dimension (number of node features in input).
275 efeat_in_dim (int): Edge features input dimension (number of edge features in input).
276 gfeat_in_dim (int): Global features input dimension (number of global features in input).
277 nfeat_hid_dim (int): Global features dimension in hidden layers.
278 nfeat_out_dim (int): Global features output dimension.
279 num_hid_layers (int): Number of hidden layers.
280 dropout (float): Dropout rate :math:`r \\in [0,1]`.
281 normalize (str): Type of normalization (batch/layer).
282
283 :return: Updated global features tensor.
284 :rtype: `Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_
285 """
286
288 self,
289 nfeat_in_dim,
290 efeat_in_dim,
291 gfeat_in_dim,
292 gfeat_hid_dim,
293 gfeat_out_dim,
294 num_hid_layers,
295 dropout,
296 normalize=True,
297 ):
298 """
299 Initialization.
300 """
301 super(GlobalLayer, self).__init__()
302
303
304 self.nonlin_function = F.elu
305
306 self.num_hid_layers = num_hid_layers
307
308 self.dropout_prob = dropout
309
310 self.normalize = normalize
311
312
313 self.lin_in = nn.Linear(
314 nfeat_in_dim + efeat_in_dim + gfeat_in_dim, gfeat_hid_dim
315 )
316
317 self.lins_hid = nn.ModuleList(
318 [
319 nn.Linear(gfeat_hid_dim, gfeat_hid_dim)
320 for _ in range(self.num_hid_layers)
321 ]
322 )
323
324 self.lin_out = nn.Linear(gfeat_hid_dim, gfeat_out_dim, bias=not normalize)
325
326 if normalize:
327
328 self.norm = nn.BatchNorm1d(gfeat_out_dim)
329
330 _init_weights(self, normalize)
331
332 def forward(self, x, edge_index, edge_attr, u, batch):
333 """
334 Called internally by Pytorch to propagate the input through the network.
335 - x: [N, F_x], where N is the number of nodes.
336 - edge_index: [2, E] with max entry N - 1.
337 - edge_attr: [E, F_e]
338 - u: [B, F_u]
339 - batch: [N] with max entry B - 1.
340
341 Nodes are averaged over graph
342 """
343 node_mean = scatter(
344 x, batch, dim=0, reduce="mean"
345 )
346 # Edges are averaged over nodes
347 edge_mean = scatter(
348 edge_attr, edge_index[1], dim=0, reduce="mean"
349 )
350 # Edges are averaged over graph
351 edge_mean = scatter(
352 edge_mean, batch, dim=0, reduce="mean"
353 )
354 out = (
355 torch.cat([u, node_mean, edge_mean], dim=1)
356 if u.shape != torch.Size([0])
357 else torch.cat([node_mean, edge_mean], dim=1)
358 )
359
360 out = self.nonlin_function(self.lin_in(out))
361 out = F.dropout(out, self.dropout_prob, training=self.training)
362
363 out_skip = out
364
365 for lin_hid in self.lins_hid:
366 out = self.nonlin_function(lin_hid(out))
367 out = F.dropout(out, self.dropout_prob, training=self.training)
368
369 if self.num_hid_layers > 1:
370 out += out_skip
371
372 if self.normalize:
373 out = self.nonlin_function(self.norm(self.lin_out(out)))
374 else:
375 out = self.lin_out(out)
376
377 return out
norm
Batch normalization.
lin_in
Linear input layer.
num_hid_layers
Number of hidden layers.
dropout_prob
Dropout probability.
nonlin_function
Non-linear activation.
lin_out
Output linear layer.
def forward(self, src, dest, edge_attr, u, batch)
def __init__(self, nfeat_in_dim, efeat_in_dim, gfeat_in_dim, efeat_hid_dim, efeat_out_dim, num_hid_layers, dropout, normalize=True)
lins_hid
Intermediate linear layers.
num_hid_layers
Number of hidden layers.
def __init__(self, nfeat_in_dim, efeat_in_dim, gfeat_in_dim, gfeat_hid_dim, gfeat_out_dim, num_hid_layers, dropout, normalize=True)
dropout_prob
Dropout probability.
nonlin_function
Non-linear activation.
lins_hid
Intermediate linear layers.
def forward(self, x, edge_index, edge_attr, u, batch)
norm
Batch normalization.
def __init__(self, nfeat_in_dim, efeat_in_dim, gfeat_in_dim, nfeat_hid_dim, nfeat_out_dim, num_hid_layers, dropout, normalize=True)
num_hid_layers
Number of hidden layers.
dropout_prob
Dropout probability.
nonlin_function
Non-linear activation.
lin_out
Output linear layer.
lins_hid
Intermediate linear layers.
def forward(self, x, edge_index, edge_attr, u, batch)