Belle II Software  release-08-01-10
gatgap.py
1 
8 import torch
9 import torch.nn.functional as F
10 import dgl
11 import dgl.nn.pytorch as dglnn
12 from dgl.nn.pytorch.glob import GlobalAttentionPooling
13 from smartBKG import TOKENIZE_DICT
14 
15 NUM_PDG = len(TOKENIZE_DICT)
16 
17 
18 class GATModule(torch.nn.Module):
19  """
20  Apply a GAT layer to node features, flatten outputs of attention heads
21  and update global features.
22  """
23 
24  def __init__(
25  self,
26  in_feats,
27  units,
28  num_heads,
29  in_feats_glob,
30  use_gap=False,
31  ):
32  """
33  Initialise the class.
34  :param in_feats: TODO
35  :param units: TODO
36  :param num_heads: TODO
37  :param in_feats_glob: TODO
38  :param use_gap: TODO
39  """
40  super().__init__()
41 
42  self.gatgat = dglnn.GATConv(in_feats, units, num_heads)
43  out_feats = units * num_heads
44 
45  self.fcfc = torch.nn.Linear(in_feats_glob + out_feats, units)
46 
47  self.use_gapuse_gap = use_gap
48  if self.use_gapuse_gap:
49 
50  self.gap_gategap_gate = torch.nn.Linear(out_feats, 1)
51 
52  self.gapgap = GlobalAttentionPooling(self.gap_gategap_gate)
53 
54  def forward(self, graph, feat, feat_glob=None):
55  """
56  TODO
57  """
58  h = F.leaky_relu(self.gatgat(graph, feat)).flatten(1)
59  hg = feat_glob
60  if not self.use_gapuse_gap:
61  with graph.local_scope():
62  graph.ndata['h'] = h
63  hmean = dgl.mean_nodes(graph, 'h')
64  else:
65  hmean = self.gapgap(graph, h)
66  if hg is None:
67  hg = hmean
68  else:
69  # concatenate previous state with new aggregation
70  hg = torch.cat((hg, hmean), axis=1)
71  # update global state
72  hg = F.leaky_relu(self.fcfc(hg))
73  return h, hg
74 
75 
76 class GATGAPModel(torch.nn.Module):
77  """
78  Input:
79  dgl graph built from decay event
80 
81  Arguments:
82  units(int): Number of units for the output dimension of GAT Convolutional layers
83  as well as the dimension of global features
84  num_features(int): Number of features attached to each node or particle as NN input
85  num_pdg(int): Number of all possible PDG IDs
86  emb_size(int): Dimension of embedded PDG space
87  attention_heads(int): Number of attention heads for GAT Convolutional layers
88  n_layers(int): Number of GAT Convolutional layers
89  use_gap(bool): Whether use Global Attention Pooling or Global Average
90 
91  Returns:
92  logits(float): Indicating the probability of an event being able to pass the
93  corresponding skim, need `sigmoid` to be used as a prediction
94  """
95 
96  def __init__(
97  self,
98  units=128,
99  num_features=8,
100  num_pdg=NUM_PDG,
101  emb_size=8,
102  attention_heads=4,
103  n_layers=5,
104  use_gap=False
105  ):
106  """
107  Initialise the class.
108  :param units: TODO
109  :param num_features: TODO
110  :param num_pdg: TODO
111  :param emb_size: TODO
112  :param attention_heads: TODO
113  :param n_layers: TODO
114  :param use_gap: TODO
115  """
116  super().__init__()
117 
118  self.pdg_embeddingpdg_embedding = torch.nn.Embedding(num_pdg + 1, emb_size)
119  in_feats = num_features + emb_size
120 
121  self.gat_layersgat_layers = torch.nn.ModuleList()
122  in_feats_glob = 0
123  for i in range(n_layers):
124  self.gat_layersgat_layers.append(
125  GATModule(
126  in_feats=in_feats,
127  units=units,
128  num_heads=attention_heads,
129  in_feats_glob=in_feats_glob,
130  use_gap=use_gap
131  )
132  )
133  in_feats = units * attention_heads
134  in_feats_glob = units
135 
136 
137  self.fc_outputfc_output = torch.nn.Linear(units, 1)
138 
139  def forward(self, graph):
140  """
141  TODO
142  """
143  h_pdg = graph.ndata["x_pdg"]
144  h_feat = graph.ndata["x_feature"]
145  h_pdg = self.pdg_embeddingpdg_embedding(h_pdg.long())
146  h = torch.cat((h_pdg, h_feat), axis=1)
147  hg = None
148  for layer in self.gat_layersgat_layers:
149  h, hg = layer(graph, h, hg)
150  return self.fc_outputfc_output(hg)
def __init__(self, units=128, num_features=8, num_pdg=NUM_PDG, emb_size=8, attention_heads=4, n_layers=5, use_gap=False)
Definition: gatgap.py:105
def forward(self, graph)
Definition: gatgap.py:139
pdg_embedding
TODO.
Definition: gatgap.py:118
def __init__(self, in_feats, units, num_heads, in_feats_glob, use_gap=False)
Definition: gatgap.py:31
use_gap
TODO.
Definition: gatgap.py:47
gap_gate
TODO.
Definition: gatgap.py:50
def forward(self, graph, feat, feat_glob=None)
Definition: gatgap.py:54