![]() |
Belle II Software development
|
Public Member Functions | |
| __init__ (self, in_feats, units, num_heads, in_feats_glob, use_gap=False) | |
| forward (self, graph, feat, feat_glob=None) | |
Public Attributes | |
| gat = dglnn.GATConv(in_feats, units, num_heads) | |
| GAT layer to update node features. | |
| fc = torch.nn.Linear(in_feats_glob + out_feats, units) | |
| Fully connected layer for feature aggregation to update global features. | |
| use_gap = use_gap | |
| Whether to use Global Attention Pooling (GAP) for the production of global features. | |
| gap_gate = torch.nn.Linear(out_feats, 1) | |
| Linear gate to produce global features. | |
| gap = GlobalAttentionPooling(self.gap_gate) | |
| Global Attention Pooling layer to produce global features. | |
Apply a GAT layer to node features, flatten outputs of attention heads and update global features.
| __init__ | ( | self, | |
| in_feats, | |||
| units, | |||
| num_heads, | |||
| in_feats_glob, | |||
| use_gap = False ) |
Initialise the class. :param in_feats: Number of features for each node. :param units: Number of output units for the GAT layer. :param num_heads: Number of attention heads in the GAT layer. :param in_feats_glob: Current dimension of global features. Initialized as 0. :param use_gap: Whether to use Global Attention Pooling (GAP) for the production of global features.
Definition at line 25 of file gatgap.py.
| forward | ( | self, | |
| graph, | |||
| feat, | |||
| feat_glob = None ) |
Forward pass of the GAT module.
Arguments:
graph (torch.Tensor): DGLGraph representing the decay tree.
feat (torch.Tensor): Node feataures attached to the graph.
feat_glob (torch.Tensor): Global features from previous layers.
`None` for initialized as the global average or attention pooling of the whole graph.
Returns:
torch.Tensor: updated node features.
torch.Tensor: updated global features.
Definition at line 56 of file gatgap.py.
| fc = torch.nn.Linear(in_feats_glob + out_feats, units) |
| gap = GlobalAttentionPooling(self.gap_gate) |
| gap_gate = torch.nn.Linear(out_feats, 1) |
| gat = dglnn.GATConv(in_feats, units, num_heads) |
| use_gap = use_gap |