Public Member Functions | |
__init__ (self, in_feats, units, num_heads, in_feats_glob, use_gap=False) | |
forward (self, graph, feat, feat_glob=None) | |
Public Attributes | |
gat = dglnn.GATConv(in_feats, units, num_heads) | |
GAT layer to update node features. | |
fc = torch.nn.Linear(in_feats_glob + out_feats, units) | |
Fully connected layer for feature aggregation to update global features. | |
use_gap = use_gap | |
Whether to use Global Attention Pooling (GAP) for the production of global features. | |
gap_gate = torch.nn.Linear(out_feats, 1) | |
Linear gate to produce global features. | |
gap = GlobalAttentionPooling(self.gap_gate) | |
Global Attention Pooling layer to produce global features. | |
Apply a GAT layer to node features, flatten outputs of attention heads and update global features.
__init__ | ( | self, | |
in_feats, | |||
units, | |||
num_heads, | |||
in_feats_glob, | |||
use_gap = False ) |
Initialise the class. :param in_feats: Number of features for each node. :param units: Number of output units for the GAT layer. :param num_heads: Number of attention heads in the GAT layer. :param in_feats_glob: Current dimension of global features. Initialized as 0. :param use_gap: Whether to use Global Attention Pooling (GAP) for the production of global features.
Definition at line 25 of file gatgap.py.
forward | ( | self, | |
graph, | |||
feat, | |||
feat_glob = None ) |
Forward pass of the GAT module. Arguments: graph (torch.Tensor): DGLGraph representing the decay tree. feat (torch.Tensor): Node feataures attached to the graph. feat_glob (torch.Tensor): Global features from previous layers. `None` for initialized as the global average or attention pooling of the whole graph. Returns: torch.Tensor: updated node features. torch.Tensor: updated global features.
Definition at line 56 of file gatgap.py.
fc = torch.nn.Linear(in_feats_glob + out_feats, units) |
gap = GlobalAttentionPooling(self.gap_gate) |
gap_gate = torch.nn.Linear(out_feats, 1) |
gat = dglnn.GATConv(in_feats, units, num_heads) |
use_gap = use_gap |