11from torch_geometric.nn
import MetaLayer
12from .geometric_layers
import NodeLayer, EdgeLayer, GlobalLayer
17 Actual implementation of the model,
19 `MetaLayer <https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.MetaLayer.html>`_
23 `Relational inductive biases, deep learning,
and graph networks <https://arxiv.org/abs/1806.01261>`_
25 The network
is composed of:
27 1. A first MetaLayer to increase the number of nodes
and edges features;
28 2. A number of intermediate MetaLayers (tunable
in config file);
29 3. A last MetaLayer to decrease the number of node
and edge features to the desired output dimension.
31 .. figure:: figs/graFEI.png
35 Each MetaLayer
is in turn composed of `EdgeLayer`, `NodeLayer`
and `GlobalLayer` sub-blocks.
38 nfeat_in_dim (int): Node features dimension (number of input node features).
39 efeat_in_dim (int): Edge features dimension (number of input edge features).
40 gfeat_in_dim (int): Global features dimension (number of input
global features).
41 edge_classes (int): Edge features output dimension (i.e. number of different edge labels
in the LCAS matrix).
42 x_classes (int): Node features output dimension (i.e. number of different mass hypotheses).
43 hidden_layer_dim (int): Intermediate features dimension (same
for node, edge
and global).
44 num_hid_layers (int): Number of hidden layers
in every MetaLayer.
45 num_ML (int): Number of intermediate MetaLayers.
46 dropout (float): Dropout rate :math:`r \\
in [0,1]`.
47 global_layer (bool): Whether to use
global layer.
49 :
return: Node, edge
and global features after model evaluation.
50 :rtype: tuple(`Tensor <https://pytorch.org/docs/stable/tensors.html
112 hidden_layer_dim
if global_layer
else 0,
121 hidden_layer_dim
if global_layer
else 0,
139 for _
in range(num_ML)
148 hidden_layer_dim
if global_layer
else 0,
158 hidden_layer_dim
if global_layer
else 0,
181 Called internally by PyTorch to propagate the input through the network.
183 x, u, edge_index, edge_attr, torch_batch = (
192 x=x, edge_index=edge_index, edge_attr=edge_attr, u=u, batch=torch_batch
199 edge_skip = edge_attr
202 x, edge_attr, u = ML(
203 x=x, edge_index=edge_index, edge_attr=edge_attr, u=u, batch=torch_batch
208 edge_attr += edge_skip
211 del x_skip, edge_skip, u_skip
213 x, edge_attr, u = self.
last_ML(
214 x=x, edge_index=edge_index, edge_attr=edge_attr, u=u, batch=torch_batch
218 edge_index_t = edge_index[[1, 0]]
220 for i
in range(edge_attr.shape[1]):
222 edge_matrix = torch.sparse_coo_tensor(
223 edge_index, edge_attr[:, i]
226 edge_matrix_t = torch.sparse_coo_tensor(
227 edge_index_t, edge_attr[:, i]
232 ((edge_matrix + edge_matrix_t) / 2.0).coalesce()
235 return x, edge_attr, u
def __init__(self, nfeat_in_dim, efeat_in_dim, gfeat_in_dim, edge_classes=6, x_classes=7, hidden_layer_dim=128, num_hid_layers=1, num_ML=1, dropout=0.0, global_layer=True, **kwargs)
ML_list
Intermediate MetaLayers.