Belle II Software prerelease-11-00-00a
__init__.py
1
8
9import numpy as np
10import torch
11import torch.nn as nn
12import torch.nn.functional as F
13
14
15# \cond suppress doxygen warning
16
17def manual_knn(x, k):
18 """
19 Compute k-nearest neighbours for each node using a full pairwise distance matrix.
20
21 For each node i, returns the indices of its k nearest neighbours (including itself
22 at distance 0). Does not support batched graphs - assumes a single graph with no
23 batch index vector.
24
25 :param x: Node feature tensor of shape [N, D].
26 :param k: Number of nearest neighbours to retrieve per node.
27 :return: Edge index tensor of shape [2, N*k], where row 0 contains source
28 (sender) indices and row 1 contains target (receiver) indices.
29 """
30 d = torch.cdist(x, x)
31 i = d.topk(k, largest=False).indices.reshape(-1)
32 j = torch.arange(x.shape[0])
33 j = torch.broadcast_to(j[:, np.newaxis], (x.shape[0], k)).reshape(-1)
34 return torch.stack([j, i], axis=0)
35
36
37class BatchNorm(nn.Module):
38 """
39 Thin wrapper around nn.BatchNorm1d to match the interface of
40 ``torch_geometric.nn.norm.BatchNorm``, which accepts a single tensor argument.
41
42 This allows ``BatchNorm`` to be used interchangeably with the PyG version
43 without requiring ``torch_geometric`` as a dependency at inference time.
44
45 Reference:
46 https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/norm/batch_norm.html#BatchNorm
47 """
48
49 def __init__(self, input_dim, momentum):
50 """
51 :param input_dim: Number of features (channels) of the input tensor.
52 :param momentum: Momentum for the running mean/variance update in ``BatchNorm1d``.
53 """
54 super().__init__()
55 self.module = nn.BatchNorm1d(input_dim, momentum=momentum)
56
57 def forward(self, x):
58 """
59 :param x: Input tensor of shape [N, input_dim].
60 :return: Batch-normalised tensor of the same shape.
61 """
62 return self.module(x)
63
64
65class GravNetConv(nn.Module):
66 """
67 Single GravNet convolutional layer.
68
69 Each node learns a position in a latent *space* (via ``lin_s``) and a message
70 vector (via ``lin_h``). Edges are formed by connecting every node to its k
71 nearest neighbours in that latent space. Edge weights decay exponentially
72 with squared latent-space distance, so geometrically close nodes exchange
73 stronger messages. Aggregated messages (``mean`` and ``max``) are concatenated
74 and projected to the output dimension, with a residual path from the input.
75
76 This implementation replaces the scatter-based aggregation of the original
77 PyG ``GravNetConv`` with a reshape-based aggregation that is compatible with
78 ONNX export and pure-PyTorch inference:
79 - Every node has exactly k neighbours (uniform neighbourhood size).
80 - The KNN result is sorted by receiving node index, so messages can be
81 reshaped directly to [N, k, propagate_dimensions] without scatter.
82
83 Reference:
84 Learning Representations of Irregular Particle-detector Geometry
85 with Distance-weighted Graph Networks
86 https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/gravnet_conv.html#GravNetConv
87 """
88
89 def __init__(self, in_channels=128, out_channels=128, space_dimensions=3, propagate_dimensions=8, k=4):
90 """
91 :param in_channels: Number of input features per node.
92 :param out_channels: Number of output features per node.
93 :param space_dimensions: Dimensionality of the learned latent space used for KNN.
94 :param propagate_dimensions: Dimensionality of the message vectors exchanged along edges.
95 :param k: Number of nearest neighbours per node.
96 """
97 super().__init__()
98 self.in_channels = in_channels
99 self.out_channels = out_channels
100 self.k = k
101 self.lin_s = nn.Linear(in_channels, space_dimensions)
102 self.lin_h = nn.Linear(in_channels, propagate_dimensions)
103 self.lin_out1 = nn.Linear(in_channels, out_channels, bias=False)
104 self.lin_out2 = nn.Linear(2 * propagate_dimensions, out_channels)
105
106 def forward(self, x):
107 """
108 :param x: Node feature tensor of shape [N, in_channels].
109 :return: Updated node feature tensor of shape [N, out_channels].
110 """
111 # For less than k nodes we set k = n, since topk requires at least k elements
112 k = torch.minimum(torch.tensor(x.size(0)), torch.tensor(self.k))
113
114 h = self.lin_h(x)
115 s = self.lin_s(x)
116
117 ei, ej = manual_knn(s, k)
118
119 # Edge weights: Gaussian kernel in latent space.
120 # The factor of 10 sharpens the kernel, giving a better-spread weight distribution.
121 edge_weight = (s[ei] - s[ej]).pow(2).sum(-1)
122 edge_weight = torch.exp(-10. * edge_weight)
123 msg = h[ej] * edge_weight.unsqueeze(1) # [N*k, propagate_dimensions]
124
125 # Reshape-based aggregation: valid because every node has exactly k neighbours
126 # and messages are ordered by receiving node index (guaranteed by ``manual_knn``).
127 # Concatenate mean- and max-pooled messages to capture different statistics.
128 out = torch.cat(
129 [
130 msg.reshape(-1, k, msg.shape[-1]).mean(axis=1),
131 msg.reshape(-1, k, msg.shape[-1]).amax(axis=1),
132 ], axis=-1
133 ) # [N, 2 * propagate_dimensions]
134
135 return self.lin_out1(x) + self.lin_out2(out)
136
137
138class CDCNet(nn.Module):
139 """
140 GNN for object condensation and track parameter regression on CDC hits.
141
142 Architecture overview
143 ---------------------
144 1. Input batch normalisation.
145 2. Global mean pooling is concatenated to node features before each block
146 (global exchange), giving every node access to event-level context.
147 3. A stack of ``nblocks`` GravNet blocks, each consisting of:
148 - Two linear + ELU layers,
149 - BatchNorm,
150 - One linear + ELU layer,
151 - GravNetConv (learned latent-space KNN + distance-weighted message passing),
152 - BatchNorm.
153 The output of each block is collected for skip connections.
154 4. All block outputs are concatenated and passed through a dense layer.
155 5. Five output heads predict:
156 - ``beta`` : condensation score in [0, 1] (sigmoid), shape [N, 1].
157 - ``ccoords``: latent clustering coordinates, shape [N, coord_dim].
158 - ``p`` : 3-momentum vector (px, py, pz), shape [N, 3].
159 - ``vertex`` : production vertex (vx, vy, vz), shape [N, 3].
160 - ``charge`` : particle charge probability in [0, 1] (sigmoid), shape [N, 1].
161
162 ONNX export constraints
163 -----------------------
164 Several design choices were made to allow the model to be exported to ONNX
165 and run in pure PyTorch without torch_geometric:
166 - Custom ``GravNetConv`` using ``manual_knn`` and reshape-based aggregation
167 instead of scatter-based aggregation.
168 - Batch size fixed to 1 (no batch index vector).
169 - Standard ``nn.BatchNorm1d`` and global mean pooling instead of PyG variants.
170
171 References
172 ----------
173 - Learning Representations of Irregular Particle-detector Geometry
174 with Distance-weighted Graph Networks (https://doi.org/10.1140/epjc/s10052-019-7113-9)
175 - End-to-End Multi-track Reconstruction Using Graph Neural Networks
176 at Belle II (https://doi.org/10.1007/s41781-025-00135-6)
177 """
178
179 def __init__(
180 self,
181 input_dim,
182 k=10,
183 dim1=64,
184 dim2=32,
185 nblocks=4,
186 coord_dim=2,
187 space_dimensions=4,
188 momentum=0.6,
189 ):
190 """
191 :param input_dim: Number of input features per node.
192 :param k: Number of nearest neighbours for each GravNetConv layer.
193 :param dim1: Hidden feature dimension inside each GravNet block.
194 :param dim2: Output feature dimension of each GravNet block
195 (collected for skip connections).
196 :param nblocks: Number of GravNet blocks stacked in the network.
197 :param coord_dim: Dimensionality of the predicted clustering coordinates.
198 :param space_dimensions: Dimensionality of the latent space learned by GravNetConv
199 for KNN construction.
200 :param momentum: Momentum for all BatchNorm layers.
201 """
202 super().__init__()
203
204 self.batch_norm_0 = BatchNorm(input_dim, momentum=0.6)
205
206 # First block to start with input dimension
207 # Input is (node features || global mean) → 2 * input_dim
208 self.blocks = nn.ModuleList(
209 [
210 # Start with the first block according to input dimension
211 nn.ModuleList(
212 [
213 nn.Linear(2 * input_dim, dim1),
214 nn.Linear(dim1, dim1),
215 BatchNorm(dim1, momentum=momentum),
216 nn.Linear(dim1, dim1),
217 GravNetConv(
218 in_channels=dim1,
219 out_channels=dim1 * 2,
220 space_dimensions=space_dimensions,
221 k=k,
222 propagate_dimensions=dim1,
223 ),
224 BatchNorm(dim1 * 2, momentum=momentum),
225 nn.Linear(dim1 * 2, dim2),
226 ]
227 )
228 ]
229 )
230
231 # Loop over remaining blocks as they are currently built the same
232 # Input is (node features || global mean) → 4 * dim1
233 # because node features are already dim1*2 after the first GravNetConv
234 self.blocks.extend(
235 nn.ModuleList(
236 [
237 # Add according to number of blocks
238 nn.ModuleList(
239 [
240 nn.Linear(4 * dim1, dim1),
241 nn.Linear(dim1, dim1),
242 BatchNorm(dim1, momentum=momentum),
243 nn.Linear(dim1, dim1),
244 GravNetConv(
245 in_channels=dim1,
246 out_channels=dim1 * 2,
247 space_dimensions=space_dimensions,
248 k=k,
249 propagate_dimensions=dim1,
250 ),
251 # Edges, so need dim1 times 2
252 BatchNorm(dim1 * 2, momentum=momentum),
253 nn.Linear(dim1 * 2, dim2),
254 ]
255 )
256 for _ in range(nblocks - 1)
257 ]
258 )
259 )
260
261 # There are skip connections between the blocks,
262 # this layer combines them, therefore scales with nblocks
263 self.dense_cat = nn.Linear(dim2 * (nblocks), dim1)
264
265 # These are the output layers for object condensation
266 self.p_beta_layer = nn.Linear(dim1, 1) # predict condensation point
267 self.p_coords_layer = nn.Linear(dim1, coord_dim) # predict latent space coordinates
268
269 # These are the output layers for the track parameters
270 self.p_p_layer = nn.Linear(dim1, 3) # predict track momentum
271 self.p_vertex_layer = nn.Linear(dim1, 3) # predict track starting point
272 self.p_charge_layer = nn.Linear(dim1, 1) # predict track charge
273
274 def forward(self, x):
275 """
276 Forward pass through CDCNet.
277
278 :param x: Node feature tensor of shape [N, input_dim].
279 :return: Five-tuple of tensors, all of shape [N, *]:
280 - beta [N, 1]: condensation score in [0, 1].
281 - coords [N, coord_dim]: latent clustering coordinates.
282 - p [N, 3]: predicted 3-momentum (px, py, pz).
283 - vertex [N, 3]: predicted production vertex (vx, vy, vz).
284 - charge [N, 1]: charge probability in [0, 1].
285 """
286
287 features = []
288 x = self.batch_norm_0(x)
289
290 # Global exchange before the first block: broadcast event mean to all nodes
291 out = x.mean(axis=0, keepdim=True)
292 x = torch.cat(torch.broadcast_tensors(x, out), dim=-1)
293
294 for i, block in enumerate(self.blocks):
295 if i > 0:
296 # Global exchange before each subsequent block
297 out = x.mean(axis=0, keepdim=True)
298 x = torch.cat(torch.broadcast_tensors(x, out), dim=-1)
299 x = F.elu(block[0](x))
300 x = F.elu(block[1](x))
301 x = block[2](x) # BatchNorm
302 x = F.elu(block[3](x))
303 x = block[4](x) # GravNetConv
304 x = block[5](x) # BatchNorm
305 features.append(F.elu(block[6](x))) # collect for skip connection
306
307 # Concatenate features and put through final dense neural network
308 x = torch.cat(features, dim=1)
309 x = F.elu(self.dense_cat(x))
310
311 # Here are the networks for object condensation predictions
312 p_beta = torch.sigmoid(self.p_beta_layer(x))
313 p_coords = self.p_coords_layer(x)
314
315 # Here are the networks for track parameters predictions
316 p_p = self.p_p_layer(x)
317 p_vertex = self.p_vertex_layer(x)
318 p_charge = torch.sigmoid(self.p_charge_layer(x))
319
320 return (
321 p_beta,
322 p_coords,
323 p_p,
324 p_vertex,
325 p_charge,
326 )
327
328# \endcond