9import torch.nn.functional
as F
11import dgl.nn.pytorch
as dglnn
12from dgl.nn.pytorch.glob
import GlobalAttentionPooling
14from smartBKG
import TOKENIZE_DICT
16NUM_PDG = len(TOKENIZE_DICT)
21 Apply a GAT layer to node features, flatten outputs of attention heads
22 and update
global features.
36 :param in_feats: Number of features
for each node.
37 :param units: Number of output units
for the GAT layer.
38 :param num_heads: Number of attention heads
in the GAT layer.
39 :param in_feats_glob: Current dimension of
global features. Initialized
as 0.
40 :param use_gap: Whether to use Global Attention Pooling (GAP)
for the production of
global features.
44 self.gat = dglnn.GATConv(in_feats, units, num_heads)
45 out_feats = units * num_heads
47 self.fc = torch.nn.Linear(in_feats_glob + out_feats, units)
56 def forward(self, graph, feat, feat_glob=None):
58 Forward pass of the GAT module.
61 graph (torch.Tensor): DGLGraph representing the decay tree.
62 feat (torch.Tensor): Node feataures attached to the graph.
63 feat_glob (torch.Tensor): Global features
from previous layers.
64 `
None`
for initialized
as the
global average
or attention pooling of the whole graph.
67 torch.Tensor: updated node features.
68 torch.Tensor: updated
global features.
70 h = F.leaky_relu(self.gat(graph, feat)).flatten(1)
73 with graph.local_scope():
75 hmean = dgl.mean_nodes(graph,
'h')
77 hmean = self.
gap(graph, h)
82 hg = torch.cat((hg, hmean), axis=1)
84 hg = F.leaky_relu(self.
fc(hg))
91 dgl graph built from decay event
94 units(int): Number of units
for the output dimension of GAT Convolutional layers
95 as well
as the dimension of
global features
96 num_features(int): Number of features attached to each node
or particle
as NN input
97 num_pdg(int): Number of all possible PDG IDs
98 emb_size(int): Dimension of embedded PDG space
99 attention_heads(int): Number of attention heads
for GAT Convolutional layers
100 n_layers(int): Number of GAT Convolutional layers
101 use_gap(bool): Whether to use Global Attention Pooling (GAP)
for the production of
global features
104 logits(float): Indicating the probability of an event being able to
pass the
105 corresponding skim, need `sigmoid` to be used
as a prediction
119 Initialise the class.
121 :param units: Number of units
for the output dimension of GAT Convolutional layers
122 as well
as the dimension of
global features.
123 :param num_features: Number of features attached to each node
or particle
as NN input.
124 :param num_pdg: Number of all possible PDG IDs.
125 :param emb_size: Dimension of embedded PDG space.
126 :param attention_heads: Number of attention heads
for GAT Convolutional layers.
127 :param n_layers: Number of GAT Convolutional layers.
128 :param use_gap: Whether to use Global Attention Pooling (GAP)
for the production of
global features.
133 in_feats = num_features + emb_size
137 for i
in range(n_layers):
142 num_heads=attention_heads,
143 in_feats_glob=in_feats_glob,
147 in_feats = units * attention_heads
148 in_feats_glob = units
155 Forward pass of the GATGAPModel.
158 graph (torch.Tensor): DGLGraph representing the decay tree.
161 torch.Tensor: the final prediction
with size 1.
163 h_pdg = graph.ndata["x_pdg"]
164 h_feat = graph.ndata[
"x_feature"]
166 h = torch.cat((h_pdg, h_feat), axis=1)
169 h, hg = layer(graph, h, hg)
def __init__(self, units=128, num_features=8, num_pdg=NUM_PDG, emb_size=8, attention_heads=4, n_layers=5, use_gap=False)
pdg_embedding
Embedding layer for PDG IDs.
gat_layers
List of GAT modules to update node features.
fc_output
Output layer for final prediction.
def __init__(self, in_feats, units, num_heads, in_feats_glob, use_gap=False)
use_gap
Whether to use Global Attention Pooling (GAP) for the production of global features.
fc
Fully connected layer for feature aggregation to update global features.
gat
GAT layer to update node features.
gap_gate
Linear gate to produce global features.
gap
Global Attention Pooling layer to produce global features.
def forward(self, graph, feat, feat_glob=None)