Belle II Software development
GATModule Class Reference
Inheritance diagram for GATModule:

Public Member Functions

 __init__ (self, in_feats, units, num_heads, in_feats_glob, use_gap=False)
 
 forward (self, graph, feat, feat_glob=None)
 

Public Attributes

 gat = dglnn.GATConv(in_feats, units, num_heads)
 GAT layer to update node features.
 
 fc = torch.nn.Linear(in_feats_glob + out_feats, units)
 Fully connected layer for feature aggregation to update global features.
 
 use_gap = use_gap
 Whether to use Global Attention Pooling (GAP) for the production of global features.
 
 gap_gate = torch.nn.Linear(out_feats, 1)
 Linear gate to produce global features.
 
 gap = GlobalAttentionPooling(self.gap_gate)
 Global Attention Pooling layer to produce global features.
 

Detailed Description

Apply a GAT layer to node features, flatten outputs of attention heads and update global features.

Definition at line 19 of file gatgap.py.

Constructor & Destructor Documentation

◆ __init__()

__init__ ( self,
in_feats,
units,
num_heads,
in_feats_glob,
use_gap = False )
Initialise the class. :param in_feats: Number of features for each node. :param units: Number of output units for the GAT layer. :param num_heads: Number of attention heads in the GAT layer. :param in_feats_glob: Current dimension of global features. Initialized as 0. :param use_gap: Whether to use Global Attention Pooling (GAP) for the production of global features.

Definition at line 25 of file gatgap.py.

32 ):
33 """
34 Initialise the class.
35
36 :param in_feats: Number of features for each node.
37 :param units: Number of output units for the GAT layer.
38 :param num_heads: Number of attention heads in the GAT layer.
39 :param in_feats_glob: Current dimension of global features. Initialized as 0.
40 :param use_gap: Whether to use Global Attention Pooling (GAP) for the production of global features.
41 """
42 super().__init__()
43
44 self.gat = dglnn.GATConv(in_feats, units, num_heads)
45 out_feats = units * num_heads
46
47 self.fc = torch.nn.Linear(in_feats_glob + out_feats, units)
48
49 self.use_gap = use_gap
50 if self.use_gap:
51
52 self.gap_gate = torch.nn.Linear(out_feats, 1)
53
54 self.gap = GlobalAttentionPooling(self.gap_gate)
55

Member Function Documentation

◆ forward()

forward ( self,
graph,
feat,
feat_glob = None )
Forward pass of the GAT module. Arguments: graph (torch.Tensor): DGLGraph representing the decay tree. feat (torch.Tensor): Node feataures attached to the graph. feat_glob (torch.Tensor): Global features from previous layers. `None` for initialized as the global average or attention pooling of the whole graph. Returns: torch.Tensor: updated node features. torch.Tensor: updated global features.

Definition at line 56 of file gatgap.py.

56 def forward(self, graph, feat, feat_glob=None):
57 """
58 Forward pass of the GAT module.
59
60 Arguments:
61 graph (torch.Tensor): DGLGraph representing the decay tree.
62 feat (torch.Tensor): Node feataures attached to the graph.
63 feat_glob (torch.Tensor): Global features from previous layers.
64 `None` for initialized as the global average or attention pooling of the whole graph.
65
66 Returns:
67 torch.Tensor: updated node features.
68 torch.Tensor: updated global features.
69 """
70 h = F.leaky_relu(self.gat(graph, feat)).flatten(1)
71 hg = feat_glob
72 if not self.use_gap:
73 with graph.local_scope():
74 graph.ndata['h'] = h
75 hmean = dgl.mean_nodes(graph, 'h')
76 else:
77 hmean = self.gap(graph, h)
78 if hg is None:
79 hg = hmean
80 else:
81 # Concatenate previous global features with new aggregation
82 hg = torch.cat((hg, hmean), axis=1)
83 # Update global features
84 hg = F.leaky_relu(self.fc(hg))
85 return h, hg
86
87

Member Data Documentation

◆ fc

fc = torch.nn.Linear(in_feats_glob + out_feats, units)

Fully connected layer for feature aggregation to update global features.

Definition at line 47 of file gatgap.py.

◆ gap

gap = GlobalAttentionPooling(self.gap_gate)

Global Attention Pooling layer to produce global features.

Definition at line 54 of file gatgap.py.

◆ gap_gate

gap_gate = torch.nn.Linear(out_feats, 1)

Linear gate to produce global features.

Definition at line 52 of file gatgap.py.

◆ gat

gat = dglnn.GATConv(in_feats, units, num_heads)

GAT layer to update node features.

Definition at line 44 of file gatgap.py.

◆ use_gap

use_gap = use_gap

Whether to use Global Attention Pooling (GAP) for the production of global features.

Definition at line 49 of file gatgap.py.


The documentation for this class was generated from the following file: