9 import torch.nn.functional
as F
11 import dgl.nn.pytorch
as dglnn
12 from dgl.nn.pytorch.glob
import GlobalAttentionPooling
13 from smartBKG
import TOKENIZE_DICT
15 NUM_PDG = len(TOKENIZE_DICT)
20 Apply a GAT layer to node features, flatten outputs of attention heads
21 and update global features.
36 :param num_heads: TODO
37 :param in_feats_glob: TODO
42 self.
gatgat = dglnn.GATConv(in_feats, units, num_heads)
43 out_feats = units * num_heads
45 self.
fcfc = torch.nn.Linear(in_feats_glob + out_feats, units)
50 self.
gap_gategap_gate = torch.nn.Linear(out_feats, 1)
54 def forward(self, graph, feat, feat_glob=None):
58 h = F.leaky_relu(self.
gatgat(graph, feat)).flatten(1)
61 with graph.local_scope():
63 hmean = dgl.mean_nodes(graph,
'h')
65 hmean = self.
gapgap(graph, h)
70 hg = torch.cat((hg, hmean), axis=1)
72 hg = F.leaky_relu(self.
fcfc(hg))
79 dgl graph built from decay event
82 units(int): Number of units for the output dimension of GAT Convolutional layers
83 as well as the dimension of global features
84 num_features(int): Number of features attached to each node or particle as NN input
85 num_pdg(int): Number of all possible PDG IDs
86 emb_size(int): Dimension of embedded PDG space
87 attention_heads(int): Number of attention heads for GAT Convolutional layers
88 n_layers(int): Number of GAT Convolutional layers
89 use_gap(bool): Whether use Global Attention Pooling or Global Average
92 logits(float): Indicating the probability of an event being able to pass the
93 corresponding skim, need `sigmoid` to be used as a prediction
107 Initialise the class.
109 :param num_features: TODO
111 :param emb_size: TODO
112 :param attention_heads: TODO
113 :param n_layers: TODO
119 in_feats = num_features + emb_size
123 for i
in range(n_layers):
128 num_heads=attention_heads,
129 in_feats_glob=in_feats_glob,
133 in_feats = units * attention_heads
134 in_feats_glob = units
143 h_pdg = graph.ndata[
"x_pdg"]
144 h_feat = graph.ndata[
"x_feature"]
146 h = torch.cat((h_pdg, h_feat), axis=1)
149 h, hg = layer(graph, h, hg)
def __init__(self, units=128, num_features=8, num_pdg=NUM_PDG, emb_size=8, attention_heads=4, n_layers=5, use_gap=False)
def __init__(self, in_feats, units, num_heads, in_feats_glob, use_gap=False)
def forward(self, graph, feat, feat_glob=None)