Belle II Software development
gatgap.py
1
8import torch
9import torch.nn.functional as F
10import dgl
11import dgl.nn.pytorch as dglnn
12from dgl.nn.pytorch.glob import GlobalAttentionPooling
13
14from smartBKG import TOKENIZE_DICT
15
16NUM_PDG = len(TOKENIZE_DICT)
17
18
19class GATModule(torch.nn.Module):
20 """
21 Apply a GAT layer to node features, flatten outputs of attention heads
22 and update global features.
23 """
24
26 self,
27 in_feats,
28 units,
29 num_heads,
30 in_feats_glob,
31 use_gap=False,
32 ):
33 """
34 Initialise the class.
35
36 :param in_feats: Number of features for each node.
37 :param units: Number of output units for the GAT layer.
38 :param num_heads: Number of attention heads in the GAT layer.
39 :param in_feats_glob: Current dimension of global features. Initialized as 0.
40 :param use_gap: Whether to use Global Attention Pooling (GAP) for the production of global features.
41 """
42 super().__init__()
43
44 self.gat = dglnn.GATConv(in_feats, units, num_heads)
45 out_feats = units * num_heads
46
47 self.fc = torch.nn.Linear(in_feats_glob + out_feats, units)
48
49 self.use_gap = use_gap
50 if self.use_gap:
51
52 self.gap_gate = torch.nn.Linear(out_feats, 1)
53
54 self.gap = GlobalAttentionPooling(self.gap_gate)
55
56 def forward(self, graph, feat, feat_glob=None):
57 """
58 Forward pass of the GAT module.
59
60 Arguments:
61 graph (torch.Tensor): DGLGraph representing the decay tree.
62 feat (torch.Tensor): Node feataures attached to the graph.
63 feat_glob (torch.Tensor): Global features from previous layers.
64 `None` for initialized as the global average or attention pooling of the whole graph.
65
66 Returns:
67 torch.Tensor: updated node features.
68 torch.Tensor: updated global features.
69 """
70 h = F.leaky_relu(self.gat(graph, feat)).flatten(1)
71 hg = feat_glob
72 if not self.use_gap:
73 with graph.local_scope():
74 graph.ndata['h'] = h
75 hmean = dgl.mean_nodes(graph, 'h')
76 else:
77 hmean = self.gap(graph, h)
78 if hg is None:
79 hg = hmean
80 else:
81 # Concatenate previous global features with new aggregation
82 hg = torch.cat((hg, hmean), axis=1)
83 # Update global features
84 hg = F.leaky_relu(self.fc(hg))
85 return h, hg
86
87
88class GATGAPModel(torch.nn.Module):
89 """
90 Input:
91 dgl graph built from decay event
92
93 Arguments:
94 units(int): Number of units for the output dimension of GAT Convolutional layers
95 as well as the dimension of global features
96 num_features(int): Number of features attached to each node or particle as NN input
97 num_pdg(int): Number of all possible PDG IDs
98 emb_size(int): Dimension of embedded PDG space
99 attention_heads(int): Number of attention heads for GAT Convolutional layers
100 n_layers(int): Number of GAT Convolutional layers
101 use_gap(bool): Whether to use Global Attention Pooling (GAP) for the production of global features
102
103 Returns:
104 logits(float): Indicating the probability of an event being able to pass the
105 corresponding skim, need `sigmoid` to be used as a prediction
106 """
107
109 self,
110 units=128,
111 num_features=8,
112 num_pdg=NUM_PDG,
113 emb_size=8,
114 attention_heads=4,
115 n_layers=5,
116 use_gap=False
117 ):
118 """
119 Initialise the class.
120
121 :param units: Number of units for the output dimension of GAT Convolutional layers
122 as well as the dimension of global features.
123 :param num_features: Number of features attached to each node or particle as NN input.
124 :param num_pdg: Number of all possible PDG IDs.
125 :param emb_size: Dimension of embedded PDG space.
126 :param attention_heads: Number of attention heads for GAT Convolutional layers.
127 :param n_layers: Number of GAT Convolutional layers.
128 :param use_gap: Whether to use Global Attention Pooling (GAP) for the production of global features.
129 """
130 super().__init__()
131
132 self.pdg_embedding = torch.nn.Embedding(num_pdg + 1, emb_size)
133 in_feats = num_features + emb_size
134
135 self.gat_layers = torch.nn.ModuleList()
136 in_feats_glob = 0
137 for i in range(n_layers):
138 self.gat_layers.append(
139 GATModule(
140 in_feats=in_feats,
141 units=units,
142 num_heads=attention_heads,
143 in_feats_glob=in_feats_glob,
144 use_gap=use_gap
145 )
146 )
147 in_feats = units * attention_heads
148 in_feats_glob = units
149
150
151 self.fc_output = torch.nn.Linear(units, 1)
152
153 def forward(self, graph):
154 """
155 Forward pass of the GATGAPModel.
156
157 Arguments:
158 graph (torch.Tensor): DGLGraph representing the decay tree.
159
160 Returns:
161 torch.Tensor: the final prediction with size 1.
162 """
163 h_pdg = graph.ndata["x_pdg"]
164 h_feat = graph.ndata["x_feature"]
165 h_pdg = self.pdg_embedding(h_pdg.long())
166 h = torch.cat((h_pdg, h_feat), axis=1)
167 hg = None
168 for layer in self.gat_layers:
169 h, hg = layer(graph, h, hg)
170 return self.fc_output(hg)
def __init__(self, units=128, num_features=8, num_pdg=NUM_PDG, emb_size=8, attention_heads=4, n_layers=5, use_gap=False)
Definition: gatgap.py:117
def forward(self, graph)
Definition: gatgap.py:153
pdg_embedding
Embedding layer for PDG IDs.
Definition: gatgap.py:132
gat_layers
List of GAT modules to update node features.
Definition: gatgap.py:135
fc_output
Output layer for final prediction.
Definition: gatgap.py:151
def __init__(self, in_feats, units, num_heads, in_feats_glob, use_gap=False)
Definition: gatgap.py:32
use_gap
Whether to use Global Attention Pooling (GAP) for the production of global features.
Definition: gatgap.py:49
fc
Fully connected layer for feature aggregation to update global features.
Definition: gatgap.py:47
gat
GAT layer to update node features.
Definition: gatgap.py:44
gap_gate
Linear gate to produce global features.
Definition: gatgap.py:52
gap
Global Attention Pooling layer to produce global features.
Definition: gatgap.py:54
def forward(self, graph, feat, feat_glob=None)
Definition: gatgap.py:56