9#include <tracking/modules/CATFinder/CATFinderModule.h>
11#include <tracking/dbobjects/CATFinderParameters.h>
12#include <tracking/gnnFinder/Utils.h>
13#include <tracking/dataobjects/RecoHitInformation.h>
14#include <tracking/dataobjects/RecoTrack.h>
16#include <cdc/dataobjects/CDCHit.h>
17#include <cdc/geometry/CDCGeometryPar.h>
18#include <framework/database/DBAccessorBase.h>
19#include <framework/database/DBObjPtr.h>
20#include <framework/logging/Logger.h>
29#include <Math/Vector3D.h>
30#include <TMatrixDSym.h>
37CATFinderModule::CATFinderModule() :
Module()
39 setDescription(
"The GNN-based CDC AI Track Finder, also known as CATFinder.");
40 addParam(
"recoTracksStoreArrayName", m_CDCRecoTracksName,
"Name of the output store array of CDC RecoTrack.",
42 addParam(
"catFinderWeightfileName", m_catFinderWeightfileName,
43 "Name of the CATFinder weightfile as stored in the conditions database.",
44 m_catFinderWeightfileName);
61 if (not parameters.isValid())
62 B2FATAL(
"CATFinderParameters is not valid");
65 const std::string filename = weightfile.getFilename();
79 m_tBeta = parameters->getTBeta();
92 m_session = std::make_unique<MVA::ONNX::Session>(filename.c_str());
99 const std::vector<TrackingUtilities::CDCWireHit>& wireHitVector = *
m_wireHitVector;
102 unsigned int nHits = 0;
103 for (
const auto& wireHit : wireHitVector) {
104 if (!wireHit.getAutomatonCell().hasMaskedFlag())
127 std::vector<unsigned int> tensorIndexToHitIndex;
128 tensorIndexToHitIndex.reserve(nHits);
130 unsigned int iHit = 0;
131 for (
unsigned int iWireHit = 0; iWireHit < wireHitVector.size(); ++iWireHit) {
134 if (wireHitVector[iWireHit].getAutomatonCell().hasMaskedFlag())
138 const CDCHit* cdcHit = wireHitVector[iWireHit].getHit();
139 const unsigned short clayer = cdcHit->
getICLayer();
140 const unsigned short wire = cdcHit->
getIWire();
141 const auto wirePos = cdcGeometryPar.c_Aligned;
154 input_t->at({iHit, 2}) = tdc_scaled;
155 input_t->at({iHit, 3}) = adc_clipped;
157 input_t->at({iHit, 5}) =
static_cast<double>(clayer) /
m_clayerScale;
160 tensorIndexToHitIndex.push_back(iWireHit);
164 B2DEBUG(29,
"CDCWireHits in the event: " << nHits);
166 B2ERROR(
"Different number of hits: something went wrong...");
175 std::vector<unsigned int> betaIndices(nHits);
176 std::iota(betaIndices.begin(), betaIndices.end(), 0);
177 std::sort(betaIndices.begin(), betaIndices.end(),
178 [&](
unsigned int i1,
unsigned int i2) { return beta_t->at({i1, 0}) > beta_t->
at({i2, 0}); });
181 std::vector<uint8_t> selectedBetas(nHits);
182 for (
unsigned int i = 0; i < nHits; ++i)
183 selectedBetas[i] =
static_cast<uint8_t
>(beta_t->at({i, 0}) > m_tBeta);
187 std::vector<unsigned int> conPointIndices;
188 const float thresholdSquared = m_tDistance * m_tDistance;
191 auto isOutOfRadius = [&](
unsigned int i) {
192 for (
unsigned int iConPoint : conPointIndices) {
194 d += (coord_t->at({i, 0}) - coord_t->at({iConPoint, 0})) * (coord_t->at({i, 0}) - coord_t->at({iConPoint, 0}));
195 d += (coord_t->at({i, 1}) - coord_t->at({iConPoint, 1})) * (coord_t->at({i, 1}) - coord_t->at({iConPoint, 1}));
196 d += (coord_t->at({i, 2}) - coord_t->at({iConPoint, 2})) * (coord_t->at({i, 2}) - coord_t->at({iConPoint, 2}));
197 if (d <= thresholdSquared)
203 for (
unsigned int i = 0; i < betaIndices.size(); ++i) {
204 unsigned int iBeta = betaIndices[i];
205 if (!selectedBetas[iBeta])
207 if (conPointIndices.empty() or isOutOfRadius(iBeta))
208 conPointIndices.push_back(iBeta);
210 selectedBetas[iBeta] = 0;
212 B2DEBUG(29,
"Condensation points in the event: " << conPointIndices.size());
215 for (
unsigned int iConPoint : conPointIndices) {
217 std::vector<GNNFinder::Utils::KDTHit> kdtHits;
218 kdtHits.reserve(nHits);
221 for (
unsigned int i = 0; i < nHits; ++i) {
222 const double dx = coord_t->at({iConPoint, 0}) - coord_t->at({i, 0});
223 const double dy = coord_t->at({iConPoint, 1}) - coord_t->at({i, 1});
224 const double dz = coord_t->at({iConPoint, 2}) - coord_t->at({i, 2});
225 if (std::hypot(dx, dy, dz) < m_maxRadius) {
227 kdtHits.push_back({input_t->at({i, 0}), input_t->at({i, 1}),
static_cast<int>(i)});
232 if (kdtHits.size() < m_minNumberHits)
236 const ROOT::Math::XYZVector momentum(
237 momentum_t->at({iConPoint, 0}), momentum_t->at({iConPoint, 1}), momentum_t->at({iConPoint, 2}));
238 const ROOT::Math::XYZVector position(
239 vertex_t->at({iConPoint, 0}) * m_spatialCoordinatesScale,
240 vertex_t->at({iConPoint, 1}) * m_spatialCoordinatesScale,
241 vertex_t->at({iConPoint, 2}) * m_spatialCoordinatesScale);
242 const int charge = (charge_t->at({iConPoint, 0}) >= 0.5) ? 1 : -1;
244 B2DEBUG(29, LogVar(
"Condensation point", iConPoint) << LogVar(
"Attached hits", kdtHits.size())
245 << LogVar(
"Momentum", std::sqrt(momentum.Mag2())) << LogVar(
"Vertex", std::sqrt(position.Mag2()))
246 << LogVar(
"Charge", charge));
248 if (std::isnan(position.X()) or std::isnan(momentum.X())) {
249 B2WARNING(
"Skipping track with NaN values.");
256 auto [startingX, startingY] = GNNFinder::Utils::intersectCylinderXY(position, momentum,
258 std::vector<int> sortedIndices =
259 hitOrderer.
orderHits(startingX, startingY, std::move(kdtHits));
262 RecoTrack* cdcRecotrack = m_CDCRecoTracks.appendNew();
267 auto seedCovariance = TMatrixDSym(6);
268 for (
int j = 0; j < 6; ++j)
269 for (
int k = 0; k < 6; ++k)
270 seedCovariance[j][k] = 1e-1;
274 int iRecoTrackHit = 0;
275 for (
int tensorIndex : sortedIndices) {
276 cdcRecotrack->
addCDCHit(wireHitVector[tensorIndexToHitIndex[tensorIndex]].getHit(), iRecoTrackHit);
DataType y() const
access variable Y (= .at(1) without boundary check)
DataType x() const
access variable X (= .at(0) without boundary check)
std::string m_outputTBetaName
Name of the output tensor carrying the per-hit beta (condensation score) values.
StoreArray< RecoTrack > m_CDCRecoTracks
Output store array of RecoTrack.
float m_adcClip
Maximum ADC value used for normalization; values above are clipped.
std::string m_outputTCoordinatesName
Name of the output tensor carrying the per-hit condensation coordinates.
std::string m_outputTVertexName
Name of the output tensor carrying the predicted vertices.
std::unique_ptr< MVA::ONNX::Session > m_session
ONNX inference session.
std::string m_CDCRecoTracksName
Name of the output store array of CDC RecoTrack.
float m_slayerScale
Scale factor for normalizing superlayer indices.
TrackingUtilities::StoreWrappedObjPtr< std::vector< TrackingUtilities::CDCWireHit > > m_wireHitVector
Input vector of CDCWireHit.
std::string m_outputTChargeName
Name of the output tensor carrying the predicted charges.
void initialize() override
Initializes the module and registers required store arrays and relations.
float m_maxRadius
Maximum radius in latent space to associate hits with a condensation point.
void event() override
Processes a single event in the CATFinderModule.
float m_clayerScale
Scale factor for normalizing cell layer indices.
float m_spatialCoordinatesScale
Scale factor for spatial coordinates (from basf2 units to internal GNN units).
std::string m_inputTFeaturesName
Name of the input tensor carrying the per-hit features.
void beginRun() override
Prepares the CATFinderModule for a new run by initializing the ONNX session from the weight file.
float m_tdcScale
Scale factor for TDC normalization.
float m_tBeta
Threshold for the beta value to select candidate condensation points.
unsigned int m_minNumberHits
Minimum number of associated CDC hits required to form a valid track.
std::string m_outputTMomentumName
Name of the output tensor carrying the predicted momenta.
unsigned int m_latentSpaceNDim
Dimensionality of the latent space used by the GNN.
float m_tDistance
Minimum distance required between condensation points in latent space.
unsigned int m_nInputFeatures
Number of input features per node for the GNN model.
float m_layerScale
Scale factor for normalizing layer indices.
float m_tdcOffset
Offset applied to TDC counts.
StoreArray< RecoHitInformation > m_recoHitInformations
Output store array of RecoHitInformation.
std::string m_catFinderWeightfileName
Name of the CATFinder weightfile as stored in the conditions database.
StoreArray< CDCHit > m_CDCHits
Input store array of CDCHit.
Class containing the result of the unpacker in raw data and the result of the digitizer in simulation...
unsigned short getICLayer() const
Getter for iCLayer (0-55).
unsigned short getIWire() const
Getter for iWire.
short getTDCCount() const
Getter for TDC count.
unsigned short getADCCount() const
Getter for integrated charge.
unsigned short getISuperLayer() const
Getter for iSuperLayer.
unsigned short getILayer() const
Getter for iLayer.
The Class for CDC Geometry Parameters.
const B2Vector3D wireForwardPosition(uint layerId, int cellId, EWirePosition set=c_Base) const
Returns the forward position of the input sense wire.
const B2Vector3D wireBackwardPosition(uint layerId, int cellId, EWirePosition set=c_Base) const
Returns the backward position of the input sense wire.
static CDCGeometryPar & Instance(const CDCGeometry *=nullptr)
Static method to get a reference to the CDCGeometryPar instance.
Base class for DBObjPtr and DBArray for easier common treatment.
Class for accessing objects in the database.
@ c_RawFile
Just a plain old file, we don't look at it just provide the filename.
Sorts CDC hits spatially using KD-tree nearest neighbor traversal.
static std::vector< int > orderHits(const double startingX, const double startingY, std::vector< KDTHit > kdtHits)
Sort hits spatially based on proximity to a starting position.
Represents an input or output tensor for an ONNX model.
auto & at(size_t index)
Accesses the element at the specified flat index.
This is the Reconstruction Event-Data Model Track.
void setChargeSeed(const short int chargeSeed)
Set the charge seed stored in the reco track. ATTENTION: This is not the fitted charge.
bool addCDCHit(const UsedCDCHit *cdcHit, const unsigned int sortingParameter, RightLeftInformation rightLeftInformation=RightLeftInformation::c_undefinedRightLeftInformation, OriginTrackFinder foundByTrackFinder=OriginTrackFinder::c_undefinedTrackFinder)
Adds a cdc hit with the given information to the reco track.
void setPositionAndMomentum(const ROOT::Math::XYZVector &positionSeed, const ROOT::Math::XYZVector &momentumSeed)
Set the position and momentum seed of the reco track. ATTENTION: This is not the fitted position or m...
void setSeedCovariance(const TMatrixDSym &seedCovariance)
Set the covariance of the seed. ATTENTION: This is not the fitted covariance.
static auto make_shared(std::vector< int64_t > shape)
Convenience method to create a shared pointer to a Tensor from shape.
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
B2Vector3< double > B2Vector3D
typedef for common usage with double
DataType at(unsigned i) const
safe member access (with boundary check!)
Abstract base class for different kinds of events.