9#include <tracking/modules/CATFinder/CATFinderModule.h>
11#include <tracking/dbobjects/CATFinderParameters.h>
12#include <tracking/gnnFinder/Utils.h>
13#include <tracking/dataobjects/RecoHitInformation.h>
14#include <tracking/dataobjects/RecoTrack.h>
16#include <cdc/dataobjects/CDCHit.h>
17#include <cdc/geometry/CDCGeometryPar.h>
18#include <framework/database/DBAccessorBase.h>
19#include <framework/database/DBObjPtr.h>
20#include <framework/logging/Logger.h>
29#include <Math/Vector3D.h>
30#include <TMatrixDSym.h>
37CATFinderModule::CATFinderModule() :
Module()
39 setDescription(
"The GNN-based CDC AI Track Finder, also known as CATFinder.");
40 addParam(
"recoTracksStoreArrayName", m_CDCRecoTracksName,
"Name of the output store array of CDC RecoTrack.",
42 addParam(
"catFinderWeightfileName", m_catFinderWeightfileName,
43 "Name of the CATFinder weightfile as stored in the conditions database.",
44 m_catFinderWeightfileName);
47void CATFinderModule::initialize()
49 m_CDCHits.isRequired();
50 m_wireHitVector.isRequired();
51 m_recoHitInformations.registerInDataStore();
52 m_CDCRecoTracks.registerInDataStore(m_CDCRecoTracksName);
53 m_CDCHits.registerRelationTo(m_CDCRecoTracks);
54 m_recoHitInformations.registerRelationTo(m_CDCHits);
55 m_CDCRecoTracks.registerRelationTo(m_recoHitInformations);
58void CATFinderModule::beginRun()
61 if (not parameters.isValid())
62 B2FATAL(
"CATFinderParameters is not valid");
65 const std::string filename = weightfile.getFilename();
67 B2FATAL(m_catFinderWeightfileName <<
" is not valid");
70 m_tdcOffset = parameters->getTDCOffset();
71 m_tdcScale = parameters->getTDCScale();
72 m_adcClip = parameters->getADCClip();
73 m_slayerScale = parameters->getSLayerScale();
74 m_clayerScale = parameters->getCLayerScale();
75 m_layerScale = parameters->getLayerScale();
76 m_spatialCoordinatesScale = parameters->getSpatialCoordinatesScale();
77 m_nInputFeatures = parameters->getNInputFeatures();
78 m_latentSpaceNDim = parameters->getLatentSpaceNDim();
79 m_tBeta = parameters->getTBeta();
80 m_tDistance = parameters->getTDistance();
81 m_maxRadius = parameters->getMaxRadius();
82 m_minNumberHits = parameters->getMinNumberHits();
83 m_inputTFeaturesName = parameters->getInputTFeaturesName();
84 m_outputTBetaName = parameters->getOutputTBetaName();
85 m_outputTCoordinatesName = parameters->getOutputTCoordinatesName();
86 m_outputTMomentumName = parameters->getOutputTMomentumName();
87 m_outputTVertexName = parameters->getOutputTVertexName();
88 m_outputTChargeName = parameters->getOutputTChargeName();
92 m_session = std::make_unique<MVA::ONNX::Session>(filename.c_str());
95void CATFinderModule::event()
99 const std::vector<TrackingUtilities::CDCWireHit>& wireHitVector = *m_wireHitVector;
102 unsigned int nHits = 0;
103 for (
const auto& wireHit : wireHitVector) {
104 if (!wireHit.getAutomatonCell().hasMaskedFlag())
127 std::vector<unsigned int> tensorIndexToHitIndex;
128 tensorIndexToHitIndex.reserve(nHits);
130 unsigned int iHit = 0;
131 for (
unsigned int iWireHit = 0; iWireHit < wireHitVector.size(); ++iWireHit) {
134 if (wireHitVector[iWireHit].getAutomatonCell().hasMaskedFlag())
138 const CDCHit* cdcHit = wireHitVector[iWireHit].getHit();
139 const unsigned short clayer = cdcHit->
getICLayer();
140 const unsigned short wire = cdcHit->
getIWire();
141 const auto wirePos = cdcGeometryPar.c_Aligned;
143 const double tdc_scaled = (
static_cast<double>(cdcHit->
getTDCCount()) - m_tdcOffset) / m_tdcScale;
144 const double adc_clipped = cdcHit->
getADCCount() > m_adcClip
146 :
static_cast<double>(cdcHit->
getADCCount()) / m_adcClip;
152 input_t->
at({iHit, 0}) = 0.5 * (posForward.
x() + posBackward.
x()) / m_spatialCoordinatesScale;
153 input_t->at({iHit, 1}) = 0.5 * (posForward.
y() + posBackward.
y()) / m_spatialCoordinatesScale;
154 input_t->at({iHit, 2}) = tdc_scaled;
155 input_t->at({iHit, 3}) = adc_clipped;
156 input_t->at({iHit, 4}) =
static_cast<double>(cdcHit->
getISuperLayer()) / m_slayerScale;
157 input_t->at({iHit, 5}) =
static_cast<double>(clayer) / m_clayerScale;
158 input_t->at({iHit, 6}) =
static_cast<double>(cdcHit->
getILayer()) / m_layerScale;
160 tensorIndexToHitIndex.push_back(iWireHit);
164 B2DEBUG(29,
"CDCWireHits in the event: " << nHits);
166 B2ERROR(
"Different number of hits: something went wrong...");
170 {{m_inputTFeaturesName, input_t}},
171 {{m_outputTBetaName, beta_t}, {m_outputTCoordinatesName, coord_t}, {m_outputTMomentumName, momentum_t}, {m_outputTVertexName, vertex_t}, {m_outputTChargeName, charge_t}}
175 std::vector<unsigned int> betaIndices(nHits);
176 std::iota(betaIndices.begin(), betaIndices.end(), 0);
177 std::sort(betaIndices.begin(), betaIndices.end(),
178 [&](
unsigned int i1,
unsigned int i2) { return beta_t->at({i1, 0}) > beta_t->
at({i2, 0}); });
181 std::vector<uint8_t> selectedBetas(nHits);
182 for (
unsigned int i = 0; i < nHits; ++i)
183 selectedBetas[i] =
static_cast<uint8_t
>(beta_t->
at({i, 0}) > m_tBeta);
187 std::vector<unsigned int> conPointIndices;
188 const float thresholdSquared = m_tDistance * m_tDistance;
191 auto isOutOfRadius = [&](
unsigned int i) {
192 for (
unsigned int iConPoint : conPointIndices) {
194 d += (coord_t->
at({i, 0}) - coord_t->
at({iConPoint, 0})) * (coord_t->
at({i, 0}) - coord_t->
at({iConPoint, 0}));
195 d += (coord_t->
at({i, 1}) - coord_t->
at({iConPoint, 1})) * (coord_t->
at({i, 1}) - coord_t->
at({iConPoint, 1}));
196 d += (coord_t->
at({i, 2}) - coord_t->
at({iConPoint, 2})) * (coord_t->
at({i, 2}) - coord_t->
at({iConPoint, 2}));
197 if (d <= thresholdSquared)
203 for (
unsigned int i = 0; i < betaIndices.size(); ++i) {
204 unsigned int iBeta = betaIndices[i];
205 if (!selectedBetas[iBeta])
207 if (conPointIndices.empty() or isOutOfRadius(iBeta))
208 conPointIndices.push_back(iBeta);
210 selectedBetas[iBeta] = 0;
212 B2DEBUG(29,
"Condensation points in the event: " << conPointIndices.size());
215 for (
unsigned int iConPoint : conPointIndices) {
217 std::vector<GNNFinder::Utils::KDTHit> kdtHits;
218 kdtHits.reserve(nHits);
221 for (
unsigned int i = 0; i < nHits; ++i) {
222 const double dx = coord_t->
at({iConPoint, 0}) - coord_t->
at({i, 0});
223 const double dy = coord_t->
at({iConPoint, 1}) - coord_t->
at({i, 1});
224 const double dz = coord_t->
at({iConPoint, 2}) - coord_t->
at({i, 2});
225 if (std::hypot(dx, dy, dz) < m_maxRadius) {
227 kdtHits.push_back({input_t->at({i, 0}), input_t->at({i, 1}),
static_cast<int>(i)});
232 if (kdtHits.size() < m_minNumberHits)
236 const ROOT::Math::XYZVector momentum(
237 momentum_t->
at({iConPoint, 0}), momentum_t->
at({iConPoint, 1}), momentum_t->
at({iConPoint, 2}));
238 const ROOT::Math::XYZVector position(
239 vertex_t->
at({iConPoint, 0}) * m_spatialCoordinatesScale,
240 vertex_t->
at({iConPoint, 1}) * m_spatialCoordinatesScale,
241 vertex_t->
at({iConPoint, 2}) * m_spatialCoordinatesScale);
242 const int charge = (charge_t->
at({iConPoint, 0}) >= 0.5) ? 1 : -1;
244 B2DEBUG(29,
LogVar(
"Condensation point", iConPoint) <<
LogVar(
"Attached hits", kdtHits.size())
245 <<
LogVar(
"Momentum", std::sqrt(momentum.Mag2())) <<
LogVar(
"Vertex", std::sqrt(position.Mag2()))
246 <<
LogVar(
"Charge", charge));
248 if (std::isnan(position.X()) or std::isnan(momentum.X())) {
249 B2WARNING(
"Skipping track with NaN values.");
256 auto [startingX, startingY] = GNNFinder::Utils::intersectCylinderXY(position, momentum,
258 std::vector<int> sortedIndices =
259 hitOrderer.
orderHits(startingX, startingY, std::move(kdtHits));
262 RecoTrack* cdcRecotrack = m_CDCRecoTracks.appendNew();
267 auto seedCovariance = TMatrixDSym(6);
268 for (
int j = 0; j < 6; ++j)
269 for (
int k = 0; k < 6; ++k)
270 seedCovariance[j][k] = 1e-1;
274 int iRecoTrackHit = 0;
275 for (
int tensorIndex : sortedIndices) {
276 cdcRecotrack->
addCDCHit(wireHitVector[tensorIndexToHitIndex[tensorIndex]].getHit(), iRecoTrackHit);
DataType y() const
access variable Y (= .at(1) without boundary check)
DataType x() const
access variable X (= .at(0) without boundary check)
Class containing the result of the unpacker in raw data and the result of the digitizer in simulation...
unsigned short getICLayer() const
Getter for iCLayer (0-55).
unsigned short getIWire() const
Getter for iWire.
short getTDCCount() const
Getter for TDC count.
unsigned short getADCCount() const
Getter for integrated charge.
unsigned short getISuperLayer() const
Getter for iSuperLayer.
unsigned short getILayer() const
Getter for iLayer.
The Class for CDC Geometry Parameters.
const B2Vector3D wireForwardPosition(uint layerId, int cellId, EWirePosition set=c_Base) const
Returns the forward position of the input sense wire.
const B2Vector3D wireBackwardPosition(uint layerId, int cellId, EWirePosition set=c_Base) const
Returns the backward position of the input sense wire.
static CDCGeometryPar & Instance(const CDCGeometry *=nullptr)
Static method to get a reference to the CDCGeometryPar instance.
Base class for DBObjPtr and DBArray for easier common treatment.
Class for accessing objects in the database.
@ c_RawFile
Just a plain old file, we don't look at it just provide the filename.
Sorts CDC hits spatially using KD-tree nearest neighbor traversal.
static std::vector< int > orderHits(const double startingX, const double startingY, std::vector< KDTHit > kdtHits)
Sort hits spatially based on proximity to a starting position.
Represents an input or output tensor for an ONNX model.
auto & at(size_t index)
Accesses the element at the specified flat index.
This is the Reconstruction Event-Data Model Track.
void setChargeSeed(const short int chargeSeed)
Set the charge seed stored in the reco track. ATTENTION: This is not the fitted charge.
bool addCDCHit(const UsedCDCHit *cdcHit, const unsigned int sortingParameter, RightLeftInformation rightLeftInformation=RightLeftInformation::c_undefinedRightLeftInformation, OriginTrackFinder foundByTrackFinder=OriginTrackFinder::c_undefinedTrackFinder)
Adds a cdc hit with the given information to the reco track.
void setPositionAndMomentum(const ROOT::Math::XYZVector &positionSeed, const ROOT::Math::XYZVector &momentumSeed)
Set the position and momentum seed of the reco track. ATTENTION: This is not the fitted position or m...
void setSeedCovariance(const TMatrixDSym &seedCovariance)
Set the covariance of the seed. ATTENTION: This is not the fitted covariance.
Class to store variables with their name which were sent to the logging service.
static auto make_shared(std::vector< int64_t > shape)
Convenience method to create a shared pointer to a Tensor from shape.
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
B2Vector3< double > B2Vector3D
typedef for common usage with double
DataType at(unsigned i) const
safe member access (with boundary check!)
Abstract base class for different kinds of events.