Belle II Software  light-2403-persian
geometric_layers.py
1 
8 
9 
10 import torch
11 import torch.nn as nn
12 import torch.nn.functional as F
13 from torch_scatter import scatter
14 
15 
16 def _init_weights(layer, normalize):
17  """
18  Initializes the weights and biases.
19  """
20  for m in layer.modules():
21  if isinstance(m, nn.Linear):
22  nn.init.xavier_normal_(m.weight.data)
23  if not normalize:
24  m.bias.data.fill_(0.1)
25  elif isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.LayerNorm):
26  m.weight.data.fill_(1)
27  m.bias.data.zero_()
28 
29 
30 class EdgeLayer(nn.Module):
31  """
32  Updates edge features in MetaLayer:
33 
34  .. math::
35  e_{ij}^{'} = \\phi^{e}(e_{ij}, v_{i}, v_{j}, u),
36 
37  where :math:`\\phi^{e}` is a neural network of the form
38 
39  .. figure:: figs/MLP_structure.png
40  :width: 42em
41  :align: center
42 
43  Args:
44  nfeat_in_dim (int): Node features input dimension (number of node features in input).
45  efeat_in_dim (int): Edge features input dimension (number of edge features in input).
46  gfeat_in_dim (int): Gloabl features input dimension (number of global features in input).
47  efeat_hid_dim (int): Edge features dimension in hidden layers.
48  efeat_out_dim (int): Edge features output dimension.
49  num_hid_layers (int): Number of hidden layers.
50  dropout (float): Dropout rate :math:`r \\in [0,1]`.
51  normalize (str): Type of normalization (batch/layer).
52 
53 
54  :return: Updated edge features tensor.
55  :rtype: `Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_
56  """
57 
58  def __init__(
59  self,
60  nfeat_in_dim,
61  efeat_in_dim,
62  gfeat_in_dim,
63  efeat_hid_dim,
64  efeat_out_dim,
65  num_hid_layers,
66  dropout,
67  normalize=True,
68  ):
69  """
70  Initialization.
71  """
72  super(EdgeLayer, self).__init__()
73 
74 
75  self.nonlin_functionnonlin_function = F.elu
76 
77  self.num_hid_layersnum_hid_layers = num_hid_layers
78 
79  self.dropout_probdropout_prob = dropout
80 
81  self.normalizenormalize = normalize
82 
83 
84  self.lin_inlin_in = nn.Linear(
85  efeat_in_dim + 2 * nfeat_in_dim + gfeat_in_dim, efeat_hid_dim
86  )
87 
88  self.lins_hidlins_hid = nn.ModuleList(
89  [
90  nn.Linear(efeat_hid_dim, efeat_hid_dim)
91  for _ in range(self.num_hid_layersnum_hid_layers)
92  ]
93  )
94 
95  self.lin_outlin_out = nn.Linear(efeat_hid_dim, efeat_out_dim, bias=not normalize)
96 
97  if normalize:
98 
99  self.normnorm = nn.BatchNorm1d(efeat_out_dim)
100 
101  _init_weights(self, normalize)
102 
103  def forward(self, src, dest, edge_attr, u, batch):
104  """
105  Called internally by PyTorch to propagate the input through the network.
106  - src, dest: [E, F_x], where E is the number of edges.
107  - edge_attr: [E, F_e]
108  - u: [B, F_u], where B is the number of graphs.
109  - batch: [E] with max entry B - 1.
110  """
111  out = (
112  torch.cat([edge_attr, src, dest, u[batch]], dim=1)
113  if u.shape != torch.Size([0])
114  else torch.cat([edge_attr, src, dest], dim=1)
115  )
116 
117  out = self.nonlin_functionnonlin_function(self.lin_inlin_in(out))
118  out = F.dropout(out, self.dropout_probdropout_prob, training=self.training)
119 
120  out_skip = out
121 
122  for lin_hid in self.lins_hidlins_hid:
123  out = self.nonlin_functionnonlin_function(lin_hid(out))
124  out = F.dropout(out, self.dropout_probdropout_prob, training=self.training)
125 
126  if self.num_hid_layersnum_hid_layers > 1:
127  out += out_skip
128 
129  if self.normalizenormalize:
130  out = self.nonlin_functionnonlin_function(self.normnorm(self.lin_outlin_out(out)))
131  else:
132  out = self.nonlin_functionnonlin_function(self.lin_outlin_out(out))
133 
134  return out
135 
136 
137 class NodeLayer(nn.Module):
138  """
139  Updates node features in MetaLayer:
140 
141  .. math::
142  v_{i}^{'} = \\phi^{v}(v_{i}, \\rho^{e \\to v}(v_{i}), u)
143 
144  with
145 
146  .. math::
147  \\rho^{e \\to v}(v_{i}) = \\frac{\\sum_{j=1,\\ j \\neq i}^{N} (e_{ji} + e _{ij})}{2 \\cdot (N-1)},
148 
149  where :math:`\\phi^{v}` is a neural network of the form
150 
151  .. figure:: figs/MLP_structure.png
152  :width: 42em
153  :align: center
154 
155  Args:
156  nfeat_in_dim (int): Node features input dimension (number of node features in input).
157  efeat_in_dim (int): Edge features input dimension (number of edge features in input).
158  gfeat_in_dim (int): Gloabl features input dimension (number of global features in input).
159  nfeat_hid_dim (int): Node features dimension in hidden layers.
160  nfeat_out_dim (int): Node features output dimension.
161  num_hid_layers (int): Number of hidden layers.
162  dropout (float): Dropout rate :math:`r \\in [0,1]`.
163  normalize (str): Type of normalization (batch/layer).
164 
165  :return: Updated node features tensor.
166  :rtype: `Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_
167  """
168 
169  def __init__(
170  self,
171  nfeat_in_dim,
172  efeat_in_dim,
173  gfeat_in_dim,
174  nfeat_hid_dim,
175  nfeat_out_dim,
176  num_hid_layers,
177  dropout,
178  normalize=True,
179  ):
180  """
181  Initialization.
182  """
183  super(NodeLayer, self).__init__()
184 
185 
186  self.nonlin_functionnonlin_function = F.elu
187 
188  self.num_hid_layersnum_hid_layers = num_hid_layers
189 
190  self.dropout_probdropout_prob = dropout
191 
192  self.normalizenormalize = normalize
193 
194 
195  self.lin_inlin_in = nn.Linear(
196  gfeat_in_dim + nfeat_in_dim + efeat_in_dim, nfeat_hid_dim
197  )
198 
199  self.lins_hidlins_hid = nn.ModuleList(
200  [
201  nn.Linear(nfeat_hid_dim, nfeat_hid_dim)
202  for _ in range(self.num_hid_layersnum_hid_layers)
203  ]
204  )
205 
206  self.lin_outlin_out = nn.Linear(nfeat_hid_dim, nfeat_out_dim, bias=not normalize)
207 
208  if normalize:
209 
210  self.normnorm = nn.BatchNorm1d(nfeat_out_dim)
211 
212  _init_weights(self, normalize)
213 
214  def forward(self, x, edge_index, edge_attr, u, batch):
215  """
216  Called internally by PyTorch to propagate the input through the network.
217  - x: [N, F_x], where N is the number of nodes.
218  - edge_index: [2, E] with max entry N - 1.
219  - edge_attr: [E, F_e]
220  - u: [B, F_u]
221  - batch: [N] with max entry B - 1.
222 
223  Edge labels are averaged (dim_size = N: number of nodes in the graph)
224  """
225  out = scatter(
226  edge_attr, edge_index[1], dim=0, dim_size=batch.size(0), reduce="mean"
227  )
228  out = (
229  torch.cat([x, out, u[batch]], dim=1)
230  if u.shape != torch.Size([0])
231  else torch.cat([x, out], dim=1)
232  )
233 
234  out = self.nonlin_functionnonlin_function(self.lin_inlin_in(out))
235  out = F.dropout(out, self.dropout_probdropout_prob, training=self.training)
236 
237  out_skip = out
238 
239  for lin_hid in self.lins_hidlins_hid:
240  out = self.nonlin_functionnonlin_function(lin_hid(out))
241  out = F.dropout(out, self.dropout_probdropout_prob, training=self.training)
242 
243  if self.num_hid_layersnum_hid_layers > 1:
244  out += out_skip
245 
246  if self.normalizenormalize:
247  out = self.nonlin_functionnonlin_function(self.normnorm(self.lin_outlin_out(out)))
248  else:
249  out = self.nonlin_functionnonlin_function(self.lin_outlin_out(out))
250 
251  return out
252 
253 
254 class GlobalLayer(nn.Module):
255  """
256  Updates node features in MetaLayer:
257 
258  .. math::
259  u_{i}^{'} = \\phi^{u}(\\rho^{e \\to u}(e), \\rho^{v \\to u}(v), u)
260 
261  with
262 
263  .. math::
264  \\rho^{e \\to u}(e) = \\frac{\\sum_{i,j=1,\\ i \\neq j}^{N} e_{ij}}{N \\cdot (N-1)},\\\\
265  \\rho^{v \\to u}(e) = \\frac{\\sum_{i=1}^{N} v_{i}}{N},
266 
267  where :math:`\\phi^{u}` is a neural network of the form
268 
269  .. figure:: figs/MLP_structure.png
270  :width: 42em
271  :align: center
272 
273  Args:
274  nfeat_in_dim (int): Node features input dimension (number of node features in input).
275  efeat_in_dim (int): Edge features input dimension (number of edge features in input).
276  gfeat_in_dim (int): Gloabl features input dimension (number of global features in input).
277  nfeat_hid_dim (int): Global features dimension in hidden layers.
278  nfeat_out_dim (int): Global features output dimension.
279  num_hid_layers (int): Number of hidden layers.
280  dropout (float): Dropout rate :math:`r \\in [0,1]`.
281  normalize (str): Type of normalization (batch/layer).
282 
283  :return: Updated global features tensor.
284  :rtype: `Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_
285  """
286 
287  def __init__(
288  self,
289  nfeat_in_dim,
290  efeat_in_dim,
291  gfeat_in_dim,
292  gfeat_hid_dim,
293  gfeat_out_dim,
294  num_hid_layers,
295  dropout,
296  normalize=True,
297  ):
298  """
299  Initialization.
300  """
301  super(GlobalLayer, self).__init__()
302 
303 
304  self.nonlin_functionnonlin_function = F.elu
305 
306  self.num_hid_layersnum_hid_layers = num_hid_layers
307 
308  self.dropout_probdropout_prob = dropout
309 
310  self.normalizenormalize = normalize
311 
312 
313  self.lin_inlin_in = nn.Linear(
314  nfeat_in_dim + efeat_in_dim + gfeat_in_dim, gfeat_hid_dim
315  )
316 
317  self.lins_hidlins_hid = nn.ModuleList(
318  [
319  nn.Linear(gfeat_hid_dim, gfeat_hid_dim)
320  for _ in range(self.num_hid_layersnum_hid_layers)
321  ]
322  )
323 
324  self.lin_outlin_out = nn.Linear(gfeat_hid_dim, gfeat_out_dim, bias=not normalize)
325 
326  if normalize:
327 
328  self.normnorm = nn.BatchNorm1d(gfeat_out_dim)
329 
330  _init_weights(self, normalize)
331 
332  def forward(self, x, edge_index, edge_attr, u, batch):
333  """
334  Called internally by Pytorch to propagate the input through the network.
335  - x: [N, F_x], where N is the number of nodes.
336  - edge_index: [2, E] with max entry N - 1.
337  - edge_attr: [E, F_e]
338  - u: [B, F_u]
339  - batch: [N] with max entry B - 1.
340 
341  Nodes are averaged over graph
342  """
343  node_mean = scatter(
344  x, batch, dim=0, reduce="mean"
345  )
346  # Edges are averaged over nodes
347  edge_mean = scatter(
348  edge_attr, edge_index[1], dim=0, reduce="mean"
349  )
350  # Edges are averaged over graph
351  edge_mean = scatter(
352  edge_mean, batch, dim=0, reduce="mean"
353  )
354  out = (
355  torch.cat([u, node_mean, edge_mean], dim=1)
356  if u.shape != torch.Size([0])
357  else torch.cat([node_mean, edge_mean], dim=1)
358  )
359 
360  out = self.nonlin_functionnonlin_function(self.lin_inlin_in(out))
361  out = F.dropout(out, self.dropout_probdropout_prob, training=self.training)
362 
363  out_skip = out
364 
365  for lin_hid in self.lins_hidlins_hid:
366  out = self.nonlin_functionnonlin_function(lin_hid(out))
367  out = F.dropout(out, self.dropout_probdropout_prob, training=self.training)
368 
369  if self.num_hid_layersnum_hid_layers > 1:
370  out += out_skip
371 
372  if self.normalizenormalize:
373  out = self.nonlin_functionnonlin_function(self.normnorm(self.lin_outlin_out(out)))
374  else:
375  out = self.lin_outlin_out(out)
376 
377  return out
norm
Batch normalization.
lin_in
Linear input layer.
num_hid_layers
Number of hidden layers.
dropout_prob
Dropout probability.
nonlin_function
Non-linear activation.
lin_out
Output linear layer.
def forward(self, src, dest, edge_attr, u, batch)
def __init__(self, nfeat_in_dim, efeat_in_dim, gfeat_in_dim, efeat_hid_dim, efeat_out_dim, num_hid_layers, dropout, normalize=True)
lins_hid
Intermediate linear layers.
num_hid_layers
Number of hidden layers.
def __init__(self, nfeat_in_dim, efeat_in_dim, gfeat_in_dim, gfeat_hid_dim, gfeat_out_dim, num_hid_layers, dropout, normalize=True)
dropout_prob
Dropout probability.
nonlin_function
Non-linear activation.
lins_hid
Intermediate linear layers.
def forward(self, x, edge_index, edge_attr, u, batch)
norm
Batch normalization.
def __init__(self, nfeat_in_dim, efeat_in_dim, gfeat_in_dim, nfeat_hid_dim, nfeat_out_dim, num_hid_layers, dropout, normalize=True)
num_hid_layers
Number of hidden layers.
dropout_prob
Dropout probability.
nonlin_function
Non-linear activation.
lin_out
Output linear layer.
lins_hid
Intermediate linear layers.
def forward(self, x, edge_index, edge_attr, u, batch)