Belle II Software development
geometric_network.py
1
8
9
10import torch
11from torch_geometric.nn import MetaLayer
12from .geometric_layers import NodeLayer, EdgeLayer, GlobalLayer
13
14
15class GraFEIModel(torch.nn.Module):
16 """
17 Actual implementation of the model,
18 based on the
19 `MetaLayer <https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.MetaLayer.html>`_
20 class.
21
22 .. seealso::
23 `Relational inductive biases, deep learning, and graph networks <https://arxiv.org/abs/1806.01261>`_
24
25 The network is composed of:
26
27 1. A first MetaLayer to increase the number of nodes and edges features;
28 2. A number of intermediate MetaLayers (tunable in config file);
29 3. A last MetaLayer to decrease the number of node and edge features to the desired output dimension.
30
31 .. figure:: figs/graFEI.png
32 :width: 42em
33 :align: center
34
35 Each MetaLayer is in turn composed of `EdgeLayer`, `NodeLayer` and `GlobalLayer` sub-blocks.
36
37 Args:
38 nfeat_in_dim (int): Node features dimension (number of input node features).
39 efeat_in_dim (int): Edge features dimension (number of input edge features).
40 gfeat_in_dim (int): Global features dimension (number of input global features).
41 edge_classes (int): Edge features output dimension (i.e. number of different edge labels in the LCAS matrix).
42 x_classes (int): Node features output dimension (i.e. number of different mass hypotheses).
43 hidden_layer_dim (int): Intermediate features dimension (same for node, edge and global).
44 num_hid_layers (int): Number of hidden layers in every MetaLayer.
45 num_ML (int): Number of intermediate MetaLayers.
46 dropout (float): Dropout rate :math:`r \\in [0,1]`.
47 global_layer (bool): Whether to use global layer.
48
49 :return: Node, edge and global features after model evaluation.
50 :rtype: tuple(`Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_)
51 """
52
54 self,
55 nfeat_in_dim,
56 efeat_in_dim,
57 gfeat_in_dim,
58 edge_classes=6,
59 x_classes=7,
60 hidden_layer_dim=128,
61 num_hid_layers=1,
62 num_ML=1,
63 dropout=0.0,
64 global_layer=True,
65 **kwargs
66 ):
67 """
68 Initialization.
69 """
70 super(GraFEIModel, self).__init__()
71
72
73 self.first_ML = MetaLayer(
75 nfeat_in_dim,
76 efeat_in_dim,
77 gfeat_in_dim,
78 hidden_layer_dim,
79 hidden_layer_dim,
80 num_hid_layers,
81 dropout,
82 ),
84 nfeat_in_dim,
85 hidden_layer_dim,
86 gfeat_in_dim,
87 hidden_layer_dim,
88 hidden_layer_dim,
89 num_hid_layers,
90 dropout,
91 ),
93 hidden_layer_dim,
94 hidden_layer_dim,
95 gfeat_in_dim,
96 hidden_layer_dim,
97 hidden_layer_dim,
98 num_hid_layers,
99 dropout,
100 )
101 if global_layer
102 else None,
103 )
104
105
106 self.ML_list = torch.nn.ModuleList(
107 [
108 MetaLayer(
109 EdgeLayer(
110 hidden_layer_dim,
111 hidden_layer_dim,
112 hidden_layer_dim if global_layer else 0,
113 hidden_layer_dim,
114 hidden_layer_dim,
115 num_hid_layers,
116 dropout,
117 ),
118 NodeLayer(
119 hidden_layer_dim,
120 hidden_layer_dim,
121 hidden_layer_dim if global_layer else 0,
122 hidden_layer_dim,
123 hidden_layer_dim,
124 num_hid_layers,
125 dropout,
126 ),
128 hidden_layer_dim,
129 hidden_layer_dim,
130 hidden_layer_dim,
131 hidden_layer_dim,
132 hidden_layer_dim,
133 num_hid_layers,
134 dropout,
135 )
136 if global_layer
137 else None,
138 )
139 for _ in range(num_ML)
140 ]
141 )
142
143
144 self.last_ML = MetaLayer(
145 EdgeLayer(
146 hidden_layer_dim,
147 hidden_layer_dim,
148 hidden_layer_dim if global_layer else 0,
149 hidden_layer_dim,
150 edge_classes,
151 num_hid_layers,
152 dropout,
153 normalize=False, # Do not normalize output layer
154 ),
155 NodeLayer(
156 hidden_layer_dim,
157 edge_classes,
158 hidden_layer_dim if global_layer else 0,
159 hidden_layer_dim,
160 x_classes,
161 num_hid_layers,
162 dropout,
163 normalize=False, # Do not normalize output layer
164 ),
166 x_classes,
167 edge_classes,
168 hidden_layer_dim,
169 hidden_layer_dim,
170 1,
171 num_hid_layers,
172 dropout,
173 normalize=False, # Do not normalize output layer
174 )
175 if global_layer
176 else None,
177 )
178
179 def forward(self, batch):
180 """
181 Called internally by PyTorch to propagate the input through the network.
182 """
183 x, u, edge_index, edge_attr, torch_batch = (
184 batch.x,
185 batch.u,
186 batch.edge_index,
187 batch.edge_attr,
188 batch.batch,
189 )
190
191 x, edge_attr, u = self.first_ML(
192 x=x, edge_index=edge_index, edge_attr=edge_attr, u=u, batch=torch_batch
193 )
194
195 del batch
196
197 for ML in self.ML_list:
198 x_skip = x
199 edge_skip = edge_attr
200 u_skip = u
201
202 x, edge_attr, u = ML(
203 x=x, edge_index=edge_index, edge_attr=edge_attr, u=u, batch=torch_batch
204 )
205
206 # Skip connections are added
207 x += x_skip
208 edge_attr += edge_skip
209 u += u_skip
210
211 del x_skip, edge_skip, u_skip
212
213 x, edge_attr, u = self.last_ML(
214 x=x, edge_index=edge_index, edge_attr=edge_attr, u=u, batch=torch_batch
215 )
216
217 # Edge labels are symmetrized
218 edge_index_t = edge_index[[1, 0]] # edge_index transposed
219
220 for i in range(edge_attr.shape[1]):
221 # edge_attr converted to sparse tensor...
222 edge_matrix = torch.sparse_coo_tensor(
223 edge_index, edge_attr[:, i]
224 )
225 # ... and its transposed
226 edge_matrix_t = torch.sparse_coo_tensor(
227 edge_index_t, edge_attr[:, i]
228 )
229
230 # Symmetrization happens here
231 edge_attr[:, i] = (
232 ((edge_matrix + edge_matrix_t) / 2.0).coalesce()
233 ).values()
234
235 return x, edge_attr, u
def __init__(self, nfeat_in_dim, efeat_in_dim, gfeat_in_dim, edge_classes=6, x_classes=7, hidden_layer_dim=128, num_hid_layers=1, num_ML=1, dropout=0.0, global_layer=True, **kwargs)
ML_list
Intermediate MetaLayers.