56 def forward(self, graph, feat, feat_glob=None):
57 """
58 Forward pass of the GAT module.
59
60 Arguments:
61 graph (torch.Tensor): DGLGraph representing the decay tree.
62 feat (torch.Tensor): Node feataures attached to the graph.
63 feat_glob (torch.Tensor): Global features from previous layers.
64 `None` for initialized as the global average or attention pooling of the whole graph.
65
66 Returns:
67 torch.Tensor: updated node features.
68 torch.Tensor: updated global features.
69 """
70 h = F.leaky_relu(self.gat(graph, feat)).flatten(1)
71 hg = feat_glob
72 if not self.use_gap:
73 with graph.local_scope():
74 graph.ndata['h'] = h
75 hmean = dgl.mean_nodes(graph, 'h')
76 else:
77 hmean = self.gap(graph, h)
78 if hg is None:
79 hg = hmean
80 else:
81
82 hg = torch.cat((hg, hmean), axis=1)
83
84 hg = F.leaky_relu(self.fc(hg))
85 return h, hg
86
87