Belle II Software development
CATFinderModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <tracking/modules/CATFinder/CATFinderModule.h>
10
11#include <tracking/dbobjects/CATFinderParameters.h>
12#include <tracking/gnnFinder/Utils.h>
13#include <tracking/dataobjects/RecoHitInformation.h>
14#include <tracking/dataobjects/RecoTrack.h>
15
16#include <cdc/dataobjects/CDCHit.h>
17#include <cdc/geometry/CDCGeometryPar.h>
18#include <framework/database/DBAccessorBase.h>
19#include <framework/database/DBObjPtr.h>
20#include <framework/logging/Logger.h>
21
22#include <algorithm>
23#include <cmath>
24#include <cstdint>
25#include <numeric>
26#include <utility>
27#include <vector>
28
29#include <Math/Vector3D.h>
30#include <TMatrixDSym.h>
31
32using namespace Belle2;
34
35REG_MODULE(CATFinder);
36
37CATFinderModule::CATFinderModule() : Module()
38{
39 setDescription("The GNN-based CDC AI Track Finder, also known as CATFinder.");
40 addParam("recoTracksStoreArrayName", m_CDCRecoTracksName, "Name of the output store array of CDC RecoTrack.",
41 m_CDCRecoTracksName);
42 addParam("catFinderWeightfileName", m_catFinderWeightfileName,
43 "Name of the CATFinder weightfile as stored in the conditions database.",
44 m_catFinderWeightfileName);
45}
46
48{
49 m_CDCHits.isRequired();
50 m_wireHitVector.isRequired();
51 m_recoHitInformations.registerInDataStore();
52 m_CDCRecoTracks.registerInDataStore(m_CDCRecoTracksName);
53 m_CDCHits.registerRelationTo(m_CDCRecoTracks);
54 m_recoHitInformations.registerRelationTo(m_CDCHits);
55 m_CDCRecoTracks.registerRelationTo(m_recoHitInformations);
56}
57
59{
61 if (not parameters.isValid())
62 B2FATAL("CATFinderParameters is not valid");
63
65 const std::string filename = weightfile.getFilename();
66 if (filename == "")
67 B2FATAL(m_catFinderWeightfileName << " is not valid");
68
69 // Get the relevant parameters
70 m_tdcOffset = parameters->getTDCOffset();
71 m_tdcScale = parameters->getTDCScale();
72 m_adcClip = parameters->getADCClip();
73 m_slayerScale = parameters->getSLayerScale();
74 m_clayerScale = parameters->getCLayerScale();
75 m_layerScale = parameters->getLayerScale();
76 m_spatialCoordinatesScale = parameters->getSpatialCoordinatesScale();
77 m_nInputFeatures = parameters->getNInputFeatures();
78 m_latentSpaceNDim = parameters->getLatentSpaceNDim();
79 m_tBeta = parameters->getTBeta();
80 m_tDistance = parameters->getTDistance();
81 m_maxRadius = parameters->getMaxRadius();
82 m_minNumberHits = parameters->getMinNumberHits();
83 m_inputTFeaturesName = parameters->getInputTFeaturesName();
84 m_outputTBetaName = parameters->getOutputTBetaName();
85 m_outputTCoordinatesName = parameters->getOutputTCoordinatesName();
86 m_outputTMomentumName = parameters->getOutputTMomentumName();
87 m_outputTVertexName = parameters->getOutputTVertexName();
88 m_outputTChargeName = parameters->getOutputTChargeName();
89
90 // Get the weightfile and initialize the ONNX session
91
92 m_session = std::make_unique<MVA::ONNX::Session>(filename.c_str());
93}
94
96{
98
99 const std::vector<TrackingUtilities::CDCWireHit>& wireHitVector = *m_wireHitVector;
100
101 // Use only "unmasked" hits
102 unsigned int nHits = 0;
103 for (const auto& wireHit : wireHitVector) {
104 if (!wireHit.getAutomatonCell().hasMaskedFlag())
105 nHits++;
106 }
107
108 // Nothing to do if all the hits are already masked
109 if (nHits == 0)
110 return;
111
112 // Input tensore: features per hit
113 auto input_t = Tensor<float>::make_shared({nHits, m_nInputFeatures});
114 // Output tensor: condensation score
115 auto beta_t = Tensor<float>::make_shared({nHits, 1});
116 // Output tensor: predicted coordinates in the latent space
117 auto coord_t = Tensor<float>::make_shared({nHits, 3});
118 // Output tensor: predicted momentum
119 auto momentum_t = Tensor<float>::make_shared({nHits, 3});
120 // Output tensor: predicted "vertex"
121 auto vertex_t = Tensor<float>::make_shared({nHits, 3});
122 // Output tensor: predicted charge
123 auto charge_t = Tensor<float>::make_shared({nHits, 1});
124
125 // Map from tensor row index back to the original wireHitVector index,
126 // needed later when adding CDC hits to the RecoTrack
127 std::vector<unsigned int> tensorIndexToHitIndex;
128 tensorIndexToHitIndex.reserve(nHits);
129
130 unsigned int iHit = 0;
131 for (unsigned int iWireHit = 0; iWireHit < wireHitVector.size(); ++iWireHit) {
132
133 // Again: skip the already masked hits
134 if (wireHitVector[iWireHit].getAutomatonCell().hasMaskedFlag())
135 continue;
136
137 // Prepare the input features
138 const CDCHit* cdcHit = wireHitVector[iWireHit].getHit();
139 const unsigned short clayer = cdcHit->getICLayer();
140 const unsigned short wire = cdcHit->getIWire();
141 const auto wirePos = cdcGeometryPar.c_Aligned;
142
143 const double tdc_scaled = (static_cast<double>(cdcHit->getTDCCount()) - m_tdcOffset) / m_tdcScale;
144 const double adc_clipped = cdcHit->getADCCount() > m_adcClip
145 ? 1.
146 : static_cast<double>(cdcHit->getADCCount()) / m_adcClip;
147
148 const B2Vector3D posForward = cdcGeometryPar.wireForwardPosition(clayer, wire, wirePos);
149 const B2Vector3D posBackward = cdcGeometryPar.wireBackwardPosition(clayer, wire, wirePos);
150
151 // Prepare the tensor
152 input_t->at({iHit, 0}) = 0.5 * (posForward.x() + posBackward.x()) / m_spatialCoordinatesScale;
153 input_t->at({iHit, 1}) = 0.5 * (posForward.y() + posBackward.y()) / m_spatialCoordinatesScale;
154 input_t->at({iHit, 2}) = tdc_scaled;
155 input_t->at({iHit, 3}) = adc_clipped;
156 input_t->at({iHit, 4}) = static_cast<double>(cdcHit->getISuperLayer()) / m_slayerScale;
157 input_t->at({iHit, 5}) = static_cast<double>(clayer) / m_clayerScale;
158 input_t->at({iHit, 6}) = static_cast<double>(cdcHit->getILayer()) / m_layerScale;
159
160 tensorIndexToHitIndex.push_back(iWireHit);
161 ++iHit;
162 }
163
164 B2DEBUG(29, "CDCWireHits in the event: " << nHits);
165 if (nHits != iHit)
166 B2ERROR("Different number of hits: something went wrong...");
167
168 // Run the GNN inference
169 m_session->run(
170 {{m_inputTFeaturesName, input_t}},
171 {{m_outputTBetaName, beta_t}, {m_outputTCoordinatesName, coord_t}, {m_outputTMomentumName, momentum_t}, {m_outputTVertexName, vertex_t}, {m_outputTChargeName, charge_t}}
172 );
173
174 // Build an index array sorted by descending beta so we process the most probable condensation points first
175 std::vector<unsigned int> betaIndices(nHits);
176 std::iota(betaIndices.begin(), betaIndices.end(), 0);
177 std::sort(betaIndices.begin(), betaIndices.end(),
178 [&](unsigned int i1, unsigned int i2) { return beta_t->at({i1, 0}) > beta_t->at({i2, 0}); });
179
180 // A hit is a candidate condensation point only if its beta exceeds the threshold
181 std::vector<uint8_t> selectedBetas(nHits);
182 for (unsigned int i = 0; i < nHits; ++i)
183 selectedBetas[i] = static_cast<uint8_t>(beta_t->at({i, 0}) > m_tBeta);
184
185 // A new condensation point is accepted only if it lies farther than m_tDistance
186 // from every already-accepted point in latent coordinate space
187 std::vector<unsigned int> conPointIndices;
188 const float thresholdSquared = m_tDistance * m_tDistance;
189
190 // Returns true if hit i is outside the exclusion radius of every accepted condensation point
191 auto isOutOfRadius = [&](unsigned int i) {
192 for (unsigned int iConPoint : conPointIndices) {
193 double d = 0.0;
194 d += (coord_t->at({i, 0}) - coord_t->at({iConPoint, 0})) * (coord_t->at({i, 0}) - coord_t->at({iConPoint, 0}));
195 d += (coord_t->at({i, 1}) - coord_t->at({iConPoint, 1})) * (coord_t->at({i, 1}) - coord_t->at({iConPoint, 1}));
196 d += (coord_t->at({i, 2}) - coord_t->at({iConPoint, 2})) * (coord_t->at({i, 2}) - coord_t->at({iConPoint, 2}));
197 if (d <= thresholdSquared)
198 return false;
199 }
200 return true;
201 };
202
203 for (unsigned int i = 0; i < betaIndices.size(); ++i) {
204 unsigned int iBeta = betaIndices[i];
205 if (!selectedBetas[iBeta])
206 continue; // Below the beta threshold: not a condensation point candidate
207 if (conPointIndices.empty() or isOutOfRadius(iBeta))
208 conPointIndices.push_back(iBeta); // Accept as a new condensation point
209 else
210 selectedBetas[iBeta] = 0; // Too close to an existing seed: let's discard it
211 }
212 B2DEBUG(29, "Condensation points in the event: " << conPointIndices.size());
213
214 // Convert the condensation points into RecoTracks: one condensation point -> one RecoTrack
215 for (unsigned int iConPoint : conPointIndices) {
216
217 std::vector<GNNFinder::Utils::KDTHit> kdtHits;
218 kdtHits.reserve(nHits);
219
220 // Collect all hits whose clustering coordinates fall within m_maxRadius of this seed
221 for (unsigned int i = 0; i < nHits; ++i) {
222 const double dx = coord_t->at({iConPoint, 0}) - coord_t->at({i, 0});
223 const double dy = coord_t->at({iConPoint, 1}) - coord_t->at({i, 1});
224 const double dz = coord_t->at({iConPoint, 2}) - coord_t->at({i, 2});
225 if (std::hypot(dx, dy, dz) < m_maxRadius) {
226 // Store the wire X and Y coordinatres together with the tensor row index for later lookup
227 kdtHits.push_back({input_t->at({i, 0}), input_t->at({i, 1}), static_cast<int>(i)});
228 }
229 }
230
231 // Reject tracks with too few hits
232 if (kdtHits.size() < m_minNumberHits)
233 continue;
234
235 // Retrieve predicted momentum, vertex and charge from the condensation point's output
236 const ROOT::Math::XYZVector momentum(
237 momentum_t->at({iConPoint, 0}), momentum_t->at({iConPoint, 1}), momentum_t->at({iConPoint, 2}));
238 const ROOT::Math::XYZVector position(
239 vertex_t->at({iConPoint, 0}) * m_spatialCoordinatesScale,
240 vertex_t->at({iConPoint, 1}) * m_spatialCoordinatesScale,
241 vertex_t->at({iConPoint, 2}) * m_spatialCoordinatesScale);
242 const int charge = (charge_t->at({iConPoint, 0}) >= 0.5) ? 1 : -1;
243
244 B2DEBUG(29, LogVar("Condensation point", iConPoint) << LogVar("Attached hits", kdtHits.size())
245 << LogVar("Momentum", std::sqrt(momentum.Mag2())) << LogVar("Vertex", std::sqrt(position.Mag2()))
246 << LogVar("Charge", charge));
247
248 if (std::isnan(position.X()) or std::isnan(momentum.X())) {
249 B2WARNING("Skipping track with NaN values.");
250 continue;
251 }
252
253 // Order hits along the helix from the innermost CDC wall outward,
254 // starting from where the predicted trajectory intersects the inner wall
256 auto [startingX, startingY] = GNNFinder::Utils::intersectCylinderXY(position, momentum,
257 16); // TODO: find a better way to represent 16...
258 std::vector<int> sortedIndices =
259 hitOrderer.orderHits(startingX, startingY, std::move(kdtHits));
260
261 // Create a new RecoTrack and seed it with the network predictions
262 RecoTrack* cdcRecotrack = m_CDCRecoTracks.appendNew();
263 cdcRecotrack->setPositionAndMomentum(position, momentum);
264 cdcRecotrack->setChargeSeed(charge);
265
266 // Use a loose covariance matrix as the seed uncertainty
267 auto seedCovariance = TMatrixDSym(6);
268 for (int j = 0; j < 6; ++j)
269 for (int k = 0; k < 6; ++k)
270 seedCovariance[j][k] = 1e-1;
271 cdcRecotrack->setSeedCovariance(seedCovariance);
272
273 // Add CDC hits in the sorted order
274 int iRecoTrackHit = 0;
275 for (int tensorIndex : sortedIndices) {
276 cdcRecotrack->addCDCHit(wireHitVector[tensorIndexToHitIndex[tensorIndex]].getHit(), iRecoTrackHit);
277 ++iRecoTrackHit;
278 }
279 }
280}
DataType y() const
access variable Y (= .at(1) without boundary check)
Definition B2Vector3.h:427
DataType x() const
access variable X (= .at(0) without boundary check)
Definition B2Vector3.h:425
std::string m_outputTBetaName
Name of the output tensor carrying the per-hit beta (condensation score) values.
StoreArray< RecoTrack > m_CDCRecoTracks
Output store array of RecoTrack.
float m_adcClip
Maximum ADC value used for normalization; values above are clipped.
std::string m_outputTCoordinatesName
Name of the output tensor carrying the per-hit condensation coordinates.
std::string m_outputTVertexName
Name of the output tensor carrying the predicted vertices.
std::unique_ptr< MVA::ONNX::Session > m_session
ONNX inference session.
std::string m_CDCRecoTracksName
Name of the output store array of CDC RecoTrack.
float m_slayerScale
Scale factor for normalizing superlayer indices.
TrackingUtilities::StoreWrappedObjPtr< std::vector< TrackingUtilities::CDCWireHit > > m_wireHitVector
Input vector of CDCWireHit.
std::string m_outputTChargeName
Name of the output tensor carrying the predicted charges.
void initialize() override
Initializes the module and registers required store arrays and relations.
float m_maxRadius
Maximum radius in latent space to associate hits with a condensation point.
void event() override
Processes a single event in the CATFinderModule.
float m_clayerScale
Scale factor for normalizing cell layer indices.
float m_spatialCoordinatesScale
Scale factor for spatial coordinates (from basf2 units to internal GNN units).
std::string m_inputTFeaturesName
Name of the input tensor carrying the per-hit features.
void beginRun() override
Prepares the CATFinderModule for a new run by initializing the ONNX session from the weight file.
float m_tdcScale
Scale factor for TDC normalization.
float m_tBeta
Threshold for the beta value to select candidate condensation points.
unsigned int m_minNumberHits
Minimum number of associated CDC hits required to form a valid track.
std::string m_outputTMomentumName
Name of the output tensor carrying the predicted momenta.
unsigned int m_latentSpaceNDim
Dimensionality of the latent space used by the GNN.
float m_tDistance
Minimum distance required between condensation points in latent space.
unsigned int m_nInputFeatures
Number of input features per node for the GNN model.
float m_layerScale
Scale factor for normalizing layer indices.
float m_tdcOffset
Offset applied to TDC counts.
StoreArray< RecoHitInformation > m_recoHitInformations
Output store array of RecoHitInformation.
std::string m_catFinderWeightfileName
Name of the CATFinder weightfile as stored in the conditions database.
StoreArray< CDCHit > m_CDCHits
Input store array of CDCHit.
Class containing the result of the unpacker in raw data and the result of the digitizer in simulation...
Definition CDCHit.h:40
unsigned short getICLayer() const
Getter for iCLayer (0-55).
Definition CDCHit.h:178
unsigned short getIWire() const
Getter for iWire.
Definition CDCHit.h:166
short getTDCCount() const
Getter for TDC count.
Definition CDCHit.h:219
unsigned short getADCCount() const
Getter for integrated charge.
Definition CDCHit.h:230
unsigned short getISuperLayer() const
Getter for iSuperLayer.
Definition CDCHit.h:184
unsigned short getILayer() const
Getter for iLayer.
Definition CDCHit.h:172
The Class for CDC Geometry Parameters.
const B2Vector3D wireForwardPosition(uint layerId, int cellId, EWirePosition set=c_Base) const
Returns the forward position of the input sense wire.
const B2Vector3D wireBackwardPosition(uint layerId, int cellId, EWirePosition set=c_Base) const
Returns the backward position of the input sense wire.
static CDCGeometryPar & Instance(const CDCGeometry *=nullptr)
Static method to get a reference to the CDCGeometryPar instance.
Base class for DBObjPtr and DBArray for easier common treatment.
Class for accessing objects in the database.
Definition DBObjPtr.h:21
@ c_RawFile
Just a plain old file, we don't look at it just provide the filename.
Sorts CDC hits spatially using KD-tree nearest neighbor traversal.
Definition Utils.h:123
static std::vector< int > orderHits(const double startingX, const double startingY, std::vector< KDTHit > kdtHits)
Sort hits spatially based on proximity to a starting position.
Definition Utils.cc:159
Represents an input or output tensor for an ONNX model.
Definition ONNX.h:51
auto & at(size_t index)
Accesses the element at the specified flat index.
Definition ONNX.h:181
Base class for Modules.
Definition Module.h:72
This is the Reconstruction Event-Data Model Track.
Definition RecoTrack.h:79
void setChargeSeed(const short int chargeSeed)
Set the charge seed stored in the reco track. ATTENTION: This is not the fitted charge.
Definition RecoTrack.h:597
bool addCDCHit(const UsedCDCHit *cdcHit, const unsigned int sortingParameter, RightLeftInformation rightLeftInformation=RightLeftInformation::c_undefinedRightLeftInformation, OriginTrackFinder foundByTrackFinder=OriginTrackFinder::c_undefinedTrackFinder)
Adds a cdc hit with the given information to the reco track.
Definition RecoTrack.h:243
void setPositionAndMomentum(const ROOT::Math::XYZVector &positionSeed, const ROOT::Math::XYZVector &momentumSeed)
Set the position and momentum seed of the reco track. ATTENTION: This is not the fitted position or m...
Definition RecoTrack.h:590
void setSeedCovariance(const TMatrixDSym &seedCovariance)
Set the covariance of the seed. ATTENTION: This is not the fitted covariance.
Definition RecoTrack.h:614
static auto make_shared(std::vector< int64_t > shape)
Convenience method to create a shared pointer to a Tensor from shape.
Definition ONNX.h:145
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition Module.h:649
B2Vector3< double > B2Vector3D
typedef for common usage with double
Definition B2Vector3.h:516
DataType at(unsigned i) const
safe member access (with boundary check!)
Definition B2Vector3.h:759
Abstract base class for different kinds of events.