Source code for grafei.model.geometric_network

##########################################################################
# basf2 (Belle II Analysis Software Framework)                           #
# Author: The Belle II Collaboration                                     #
#                                                                        #
# See git log for contributors and copyright holders.                    #
# This file is licensed under LGPL-3.0, see LICENSE.md.                  #
##########################################################################


import torch
from torch_geometric.nn import MetaLayer
from .geometric_layers import NodeLayer, EdgeLayer, GlobalLayer


[docs] class GraFEIModel(torch.nn.Module): """ Actual implementation of the model, based on the `MetaLayer <https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.MetaLayer.html>`_ class. .. seealso:: `Relational inductive biases, deep learning, and graph networks <https://arxiv.org/abs/1806.01261>`_ The network is composed of: 1. A first MetaLayer to increase the number of nodes and edges features; 2. A number of intermediate MetaLayers (tunable in config file); 3. A last MetaLayer to decrease the number of node and edge features to the desired output dimension. .. figure:: figs/graFEI.png :width: 42em :align: center Each MetaLayer is in turn composed of `EdgeLayer`, `NodeLayer` and `GlobalLayer` sub-blocks. Args: nfeat_in_dim (int): Node features dimension (number of input node features). efeat_in_dim (int): Edge features dimension (number of input edge features). gfeat_in_dim (int): Global features dimension (number of input global features). edge_classes (int): Edge features output dimension (i.e. number of different edge labels in the LCAS matrix). x_classes (int): Node features output dimension (i.e. number of different mass hypotheses). hidden_layer_dim (int): Intermediate features dimension (same for node, edge and global). num_hid_layers (int): Number of hidden layers in every MetaLayer. num_ML (int): Number of intermediate MetaLayers. droput (float): Dropout rate :math:`r \\in [0,1]`. global_layer (bool): Whether to use global layer. :return: Node, edge and global features after model evaluation. :rtype: tuple(`Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_) """ def __init__( self, nfeat_in_dim, efeat_in_dim, gfeat_in_dim, edge_classes=6, x_classes=7, hidden_layer_dim=128, num_hid_layers=1, num_ML=1, dropout=0.0, global_layer=True, **kwargs ): """ Initialization. """ super(GraFEIModel, self).__init__() #: First MetaLayer self.first_ML = MetaLayer( EdgeLayer( nfeat_in_dim, efeat_in_dim, gfeat_in_dim, hidden_layer_dim, hidden_layer_dim, num_hid_layers, dropout, ), NodeLayer( nfeat_in_dim, hidden_layer_dim, gfeat_in_dim, hidden_layer_dim, hidden_layer_dim, num_hid_layers, dropout, ), GlobalLayer( hidden_layer_dim, hidden_layer_dim, gfeat_in_dim, hidden_layer_dim, hidden_layer_dim, num_hid_layers, dropout, ) if global_layer else None, ) #: Intermediate MetaLayers self.ML_list = torch.nn.ModuleList( [ MetaLayer( EdgeLayer( hidden_layer_dim, hidden_layer_dim, hidden_layer_dim if global_layer else 0, hidden_layer_dim, hidden_layer_dim, num_hid_layers, dropout, ), NodeLayer( hidden_layer_dim, hidden_layer_dim, hidden_layer_dim if global_layer else 0, hidden_layer_dim, hidden_layer_dim, num_hid_layers, dropout, ), GlobalLayer( hidden_layer_dim, hidden_layer_dim, hidden_layer_dim, hidden_layer_dim, hidden_layer_dim, num_hid_layers, dropout, ) if global_layer else None, ) for _ in range(num_ML) ] ) #: Output MetaLayer self.last_ML = MetaLayer( EdgeLayer( hidden_layer_dim, hidden_layer_dim, hidden_layer_dim if global_layer else 0, hidden_layer_dim, edge_classes, num_hid_layers, dropout, normalize=False, # Do not normalize output layer ), NodeLayer( hidden_layer_dim, edge_classes, hidden_layer_dim if global_layer else 0, hidden_layer_dim, x_classes, num_hid_layers, dropout, normalize=False, # Do not normalize output layer ), GlobalLayer( x_classes, edge_classes, hidden_layer_dim, hidden_layer_dim, 1, num_hid_layers, dropout, normalize=False, # Do not normalize output layer ) if global_layer else None, ) def forward(self, batch): """ Called internally by PyTorch to propagate the input through the network. """ x, u, edge_index, edge_attr, torch_batch = ( batch.x, batch.u, batch.edge_index, batch.edge_attr, batch.batch, ) x, edge_attr, u = self.first_ML( x=x, edge_index=edge_index, edge_attr=edge_attr, u=u, batch=torch_batch ) del batch for ML in self.ML_list: x_skip = x edge_skip = edge_attr u_skip = u x, edge_attr, u = ML( x=x, edge_index=edge_index, edge_attr=edge_attr, u=u, batch=torch_batch ) # Skip connections are added x += x_skip edge_attr += edge_skip u += u_skip del x_skip, edge_skip, u_skip x, edge_attr, u = self.last_ML( x=x, edge_index=edge_index, edge_attr=edge_attr, u=u, batch=torch_batch ) # Edge labels are symmetrized edge_index_t = edge_index[[1, 0]] # edge_index transposed for i in range(edge_attr.shape[1]): # edge_attr converted to sparse tensor... edge_matrix = torch.sparse_coo_tensor( edge_index, edge_attr[:, i] ) # ... and its transposed edge_matrix_t = torch.sparse_coo_tensor( edge_index_t, edge_attr[:, i] ) # Symmetrization happens here edge_attr[:, i] = ( ((edge_matrix + edge_matrix_t) / 2.0).coalesce() ).values() return x, edge_attr, u