Belle II Software development
GATModule Class Reference
Inheritance diagram for GATModule:

Public Member Functions

def __init__ (self, in_feats, units, num_heads, in_feats_glob, use_gap=False)
 
def forward (self, graph, feat, feat_glob=None)
 

Public Attributes

 gat
 GAT layer to update node features.
 
 fc
 Fully connected layer for feature aggregation to update global features.
 
 use_gap
 Whether to use Global Attention Pooling (GAP) for the production of global features.
 
 gap_gate
 Linear gate to produce global features.
 
 gap
 Global Attention Pooling layer to produce global features.
 

Detailed Description

Apply a GAT layer to node features, flatten outputs of attention heads
and update global features.

Definition at line 19 of file gatgap.py.

Constructor & Destructor Documentation

◆ __init__()

def __init__ (   self,
  in_feats,
  units,
  num_heads,
  in_feats_glob,
  use_gap = False 
)
Initialise the class.

:param in_feats: Number of features for each node.
:param units: Number of output units for the GAT layer.
:param num_heads: Number of attention heads in the GAT layer.
:param in_feats_glob: Current dimension of global features. Initialized as 0.
:param use_gap: Whether to use Global Attention Pooling (GAP) for the production of global features.

Definition at line 25 of file gatgap.py.

32 ):
33 """
34 Initialise the class.
35
36 :param in_feats: Number of features for each node.
37 :param units: Number of output units for the GAT layer.
38 :param num_heads: Number of attention heads in the GAT layer.
39 :param in_feats_glob: Current dimension of global features. Initialized as 0.
40 :param use_gap: Whether to use Global Attention Pooling (GAP) for the production of global features.
41 """
42 super().__init__()
43
44 self.gat = dglnn.GATConv(in_feats, units, num_heads)
45 out_feats = units * num_heads
46
47 self.fc = torch.nn.Linear(in_feats_glob + out_feats, units)
48
49 self.use_gap = use_gap
50 if self.use_gap:
51
52 self.gap_gate = torch.nn.Linear(out_feats, 1)
53
54 self.gap = GlobalAttentionPooling(self.gap_gate)
55

Member Function Documentation

◆ forward()

def forward (   self,
  graph,
  feat,
  feat_glob = None 
)
Forward pass of the GAT module.

Arguments:
    graph (torch.Tensor): DGLGraph representing the decay tree.
    feat (torch.Tensor): Node feataures attached to the graph.
    feat_glob (torch.Tensor): Global features from previous layers.
    `None` for initialized as the global average or attention pooling of the whole graph.

Returns:
    torch.Tensor: updated node features.
    torch.Tensor: updated global features.

Definition at line 56 of file gatgap.py.

56 def forward(self, graph, feat, feat_glob=None):
57 """
58 Forward pass of the GAT module.
59
60 Arguments:
61 graph (torch.Tensor): DGLGraph representing the decay tree.
62 feat (torch.Tensor): Node feataures attached to the graph.
63 feat_glob (torch.Tensor): Global features from previous layers.
64 `None` for initialized as the global average or attention pooling of the whole graph.
65
66 Returns:
67 torch.Tensor: updated node features.
68 torch.Tensor: updated global features.
69 """
70 h = F.leaky_relu(self.gat(graph, feat)).flatten(1)
71 hg = feat_glob
72 if not self.use_gap:
73 with graph.local_scope():
74 graph.ndata['h'] = h
75 hmean = dgl.mean_nodes(graph, 'h')
76 else:
77 hmean = self.gap(graph, h)
78 if hg is None:
79 hg = hmean
80 else:
81 # Concatenate previous global features with new aggregation
82 hg = torch.cat((hg, hmean), axis=1)
83 # Update global features
84 hg = F.leaky_relu(self.fc(hg))
85 return h, hg
86
87

Member Data Documentation

◆ fc

fc

Fully connected layer for feature aggregation to update global features.

Definition at line 47 of file gatgap.py.

◆ gap

gap

Global Attention Pooling layer to produce global features.

Definition at line 54 of file gatgap.py.

◆ gap_gate

gap_gate

Linear gate to produce global features.

Definition at line 52 of file gatgap.py.

◆ gat

gat

GAT layer to update node features.

Definition at line 44 of file gatgap.py.

◆ use_gap

use_gap

Whether to use Global Attention Pooling (GAP) for the production of global features.

Definition at line 49 of file gatgap.py.


The documentation for this class was generated from the following file: