12import torch.nn.functional
as F
19 Compute k-nearest neighbours for each node using a full pairwise distance matrix.
21 For each node i, returns the indices of its k nearest neighbours (including itself
22 at distance 0). Does not support batched graphs - assumes a single graph with no
25 :param x: Node feature tensor of shape [N, D].
26 :param k: Number of nearest neighbours to retrieve per node.
27 :return: Edge index tensor of shape [2, N*k], where row 0 contains source
28 (sender) indices and row 1 contains target (receiver) indices.
31 i = d.topk(k, largest=
False).indices.reshape(-1)
32 j = torch.arange(x.shape[0])
33 j = torch.broadcast_to(j[:, np.newaxis], (x.shape[0], k)).reshape(-1)
34 return torch.stack([j, i], axis=0)
37class BatchNorm(nn.Module):
39 Thin wrapper around nn.BatchNorm1d to match the interface of
40 ``torch_geometric.nn.norm.BatchNorm``, which accepts a single tensor argument.
42 This allows ``BatchNorm`` to be used interchangeably with the PyG version
43 without requiring ``torch_geometric`` as a dependency at inference time.
46 https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/norm/batch_norm.html#BatchNorm
49 def __init__(self, input_dim, momentum):
51 :param input_dim: Number of features (channels) of the input tensor.
52 :param momentum: Momentum for the running mean/variance update in ``BatchNorm1d``.
55 self.module = nn.BatchNorm1d(input_dim, momentum=momentum)
59 :param x: Input tensor of shape [N, input_dim].
60 :return: Batch-normalised tensor of the same shape.
65class GravNetConv(nn.Module):
67 Single GravNet convolutional layer.
69 Each node learns a position in a latent *space* (via ``lin_s``) and a message
70 vector (via ``lin_h``). Edges are formed by connecting every node to its k
71 nearest neighbours in that latent space. Edge weights decay exponentially
72 with squared latent-space distance, so geometrically close nodes exchange
73 stronger messages. Aggregated messages (``mean`` and ``max``) are concatenated
74 and projected to the output dimension, with a residual path from the input.
76 This implementation replaces the scatter-based aggregation of the original
77 PyG ``GravNetConv`` with a reshape-based aggregation that is compatible with
78 ONNX export and pure-PyTorch inference:
79 - Every node has exactly k neighbours (uniform neighbourhood size).
80 - The KNN result is sorted by receiving node index, so messages can be
81 reshaped directly to [N, k, propagate_dimensions] without scatter.
84 Learning Representations of Irregular Particle-detector Geometry
85 with Distance-weighted Graph Networks
86 https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/gravnet_conv.html#GravNetConv
89 def __init__(self, in_channels=128, out_channels=128, space_dimensions=3, propagate_dimensions=8, k=4):
91 :param in_channels: Number of input features per node.
92 :param out_channels: Number of output features per node.
93 :param space_dimensions: Dimensionality of the learned latent space used for KNN.
94 :param propagate_dimensions: Dimensionality of the message vectors exchanged along edges.
95 :param k: Number of nearest neighbours per node.
98 self.in_channels = in_channels
99 self.out_channels = out_channels
101 self.lin_s = nn.Linear(in_channels, space_dimensions)
102 self.lin_h = nn.Linear(in_channels, propagate_dimensions)
103 self.lin_out1 = nn.Linear(in_channels, out_channels, bias=
False)
104 self.lin_out2 = nn.Linear(2 * propagate_dimensions, out_channels)
106 def forward(self, x):
108 :param x: Node feature tensor of shape [N, in_channels].
109 :return: Updated node feature tensor of shape [N, out_channels].
112 k = torch.minimum(torch.tensor(x.size(0)), torch.tensor(self.k))
117 ei, ej = manual_knn(s, k)
121 edge_weight = (s[ei] - s[ej]).pow(2).sum(-1)
122 edge_weight = torch.exp(-10. * edge_weight)
123 msg = h[ej] * edge_weight.unsqueeze(1)
130 msg.reshape(-1, k, msg.shape[-1]).mean(axis=1),
131 msg.reshape(-1, k, msg.shape[-1]).amax(axis=1),
135 return self.lin_out1(x) + self.lin_out2(out)
138class CDCNet(nn.Module):
140 GNN for object condensation and track parameter regression on CDC hits.
142 Architecture overview
143 ---------------------
144 1. Input batch normalisation.
145 2. Global mean pooling is concatenated to node features before each block
146 (global exchange), giving every node access to event-level context.
147 3. A stack of ``nblocks`` GravNet blocks, each consisting of:
148 - Two linear + ELU layers,
150 - One linear + ELU layer,
151 - GravNetConv (learned latent-space KNN + distance-weighted message passing),
153 The output of each block is collected for skip connections.
154 4. All block outputs are concatenated and passed through a dense layer.
155 5. Five output heads predict:
156 - ``beta`` : condensation score in [0, 1] (sigmoid), shape [N, 1].
157 - ``ccoords``: latent clustering coordinates, shape [N, coord_dim].
158 - ``p`` : 3-momentum vector (px, py, pz), shape [N, 3].
159 - ``vertex`` : production vertex (vx, vy, vz), shape [N, 3].
160 - ``charge`` : particle charge probability in [0, 1] (sigmoid), shape [N, 1].
162 ONNX export constraints
163 -----------------------
164 Several design choices were made to allow the model to be exported to ONNX
165 and run in pure PyTorch without torch_geometric:
166 - Custom ``GravNetConv`` using ``manual_knn`` and reshape-based aggregation
167 instead of scatter-based aggregation.
168 - Batch size fixed to 1 (no batch index vector).
169 - Standard ``nn.BatchNorm1d`` and global mean pooling instead of PyG variants.
173 - Learning Representations of Irregular Particle-detector Geometry
174 with Distance-weighted Graph Networks (https://doi.org/10.1140/epjc/s10052-019-7113-9)
175 - End-to-End Multi-track Reconstruction Using Graph Neural Networks
176 at Belle II (https://doi.org/10.1007/s41781-025-00135-6)
191 :param input_dim: Number of input features per node.
192 :param k: Number of nearest neighbours for each GravNetConv layer.
193 :param dim1: Hidden feature dimension inside each GravNet block.
194 :param dim2: Output feature dimension of each GravNet block
195 (collected for skip connections).
196 :param nblocks: Number of GravNet blocks stacked in the network.
197 :param coord_dim: Dimensionality of the predicted clustering coordinates.
198 :param space_dimensions: Dimensionality of the latent space learned by GravNetConv
199 for KNN construction.
200 :param momentum: Momentum for all BatchNorm layers.
204 self.batch_norm_0 = BatchNorm(input_dim, momentum=0.6)
208 self.blocks = nn.ModuleList(
213 nn.Linear(2 * input_dim, dim1),
214 nn.Linear(dim1, dim1),
215 BatchNorm(dim1, momentum=momentum),
216 nn.Linear(dim1, dim1),
219 out_channels=dim1 * 2,
220 space_dimensions=space_dimensions,
222 propagate_dimensions=dim1,
224 BatchNorm(dim1 * 2, momentum=momentum),
225 nn.Linear(dim1 * 2, dim2),
240 nn.Linear(4 * dim1, dim1),
241 nn.Linear(dim1, dim1),
242 BatchNorm(dim1, momentum=momentum),
243 nn.Linear(dim1, dim1),
246 out_channels=dim1 * 2,
247 space_dimensions=space_dimensions,
249 propagate_dimensions=dim1,
252 BatchNorm(dim1 * 2, momentum=momentum),
253 nn.Linear(dim1 * 2, dim2),
256 for _
in range(nblocks - 1)
263 self.dense_cat = nn.Linear(dim2 * (nblocks), dim1)
266 self.p_beta_layer = nn.Linear(dim1, 1)
267 self.p_coords_layer = nn.Linear(dim1, coord_dim)
270 self.p_p_layer = nn.Linear(dim1, 3)
271 self.p_vertex_layer = nn.Linear(dim1, 3)
272 self.p_charge_layer = nn.Linear(dim1, 1)
274 def forward(self, x):
276 Forward pass through CDCNet.
278 :param x: Node feature tensor of shape [N, input_dim].
279 :return: Five-tuple of tensors, all of shape [N, *]:
280 - beta [N, 1]: condensation score in [0, 1].
281 - coords [N, coord_dim]: latent clustering coordinates.
282 - p [N, 3]: predicted 3-momentum (px, py, pz).
283 - vertex [N, 3]: predicted production vertex (vx, vy, vz).
284 - charge [N, 1]: charge probability in [0, 1].
288 x = self.batch_norm_0(x)
291 out = x.mean(axis=0, keepdim=
True)
292 x = torch.cat(torch.broadcast_tensors(x, out), dim=-1)
294 for i, block
in enumerate(self.blocks):
297 out = x.mean(axis=0, keepdim=
True)
298 x = torch.cat(torch.broadcast_tensors(x, out), dim=-1)
299 x = F.elu(block[0](x))
300 x = F.elu(block[1](x))
302 x = F.elu(block[3](x))
305 features.append(F.elu(block[6](x)))
308 x = torch.cat(features, dim=1)
309 x = F.elu(self.dense_cat(x))
312 p_beta = torch.sigmoid(self.p_beta_layer(x))
313 p_coords = self.p_coords_layer(x)
316 p_p = self.p_p_layer(x)
317 p_vertex = self.p_vertex_layer(x)
318 p_charge = torch.sigmoid(self.p_charge_layer(x))