Belle II Software development
CDCTriggerRecoMatcherModule.h
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8#pragma once
9
10#include <string>
11#include <vector>
12#include <Eigen/Dense>
13
14#include "framework/core/Module.h"
15#include "framework/datastore/StoreArray.h"
16#include "trg/cdc/dataobjects/CDCTriggerSegmentHit.h"
17#include "tracking/dataobjects/RecoTrack.h"
18#include "cdc/dataobjects/CDCHit.h"
19#include "trg/cdc/dataobjects/CDCTriggerTrack.h"
20#include "trg/cdc/dataobjects/CDCTrigger3DHTrack.h"
21
22namespace Belle2 {
27
28 namespace {
29 template <typename Iter>
30 struct IteratorRange {
31 Iter first, last;
32 IteratorRange(Iter f, Iter l) : first(f), last(l) {}
33 Iter begin() const { return first; }
34 Iter end() const { return last; }
35 };
36
37 template <typename Iter>
38 inline IteratorRange<Iter> as_range(std::pair<Iter, Iter> const& p)
39 {
40 return IteratorRange<Iter>(p.first, p.second);
41 }
42
43 typedef int HitId;
44 typedef int TrackId;
45 typedef float Purity;
46 typedef float Efficiency;
47 }
48
49 // Templated module
50 template <typename TrgTrackType>
51 class CDCTriggerRecoMatcherModuleT : public Module {
52 public:
53 CDCTriggerRecoMatcherModuleT() : Module()
54 {
55 setDescription("A module to match trigger tracks to RecoTracks.\n"
56 "It then makes relations from RecoTracks to trigger tracks.");
57
59
60 addParam("RecoTrackCollectionName", m_RecoTrackCollectionName,
61 "Name of the RecoTrack StoreArray to be matched.",
62 std::string("RecoTracks"));
63 addParam("TrgTrackCollectionName", m_TrgTrackCollectionName,
64 "Name of the trigger track StoreArray to be matched.",
65 std::string("TRGCDC2DFinderTracks"));
66 addParam("hitCollectionName", m_hitCollectionName,
67 "Name of the StoreArray containing hits used for matching",
68 std::string(""));
69 addParam("axialOnly", m_axialOnly,
70 "Switch for 2D matching (axial only).",
71 false);
72 addParam("minPurity", m_minPurity,
73 "Minimum purity for matching.",
74 0.1);
75 addParam("minEfficiency", m_minEfficiency,
76 "Minimum efficiency for matching.",
77 0.1);
78 addParam("relateClonesAndMerged", m_relateClonesAndMerged,
79 "Switch for creating relations for clones and merged tracks.",
80 true);
81 addParam("relateHitsByID", m_relateHitsByID,
82 "Switch for creating hit relations based on wire ID.",
83 true);
84 }
85
86 virtual ~CDCTriggerRecoMatcherModuleT() {}
87
88 void initialize() override
89 {
90 m_segmentHits.isRequired(m_hitCollectionName);
91 m_trgTracks.isRequired(m_TrgTrackCollectionName);
92 m_recoTracks.isRequired(m_RecoTrackCollectionName);
93
94 m_trgTracks.requireRelationTo(m_segmentHits);
95 m_recoTracks.registerRelationTo(m_segmentHits);
96 m_recoTracks.registerRelationTo(m_trgTracks);
97 m_trgTracks.registerRelationTo(m_recoTracks);
98 }
99
100 void event() override
101 {
102 for (int ireco = 0; ireco < m_recoTracks.getEntries(); ++ireco) {
103 RecoTrack* recoTrack = m_recoTracks[ireco];
104 // Skip if relations already exist
105 if (recoTrack->getRelationsTo<CDCTriggerSegmentHit>(m_hitCollectionName).size() > 0) continue;
106
107 for (CDCHit* cdcHit : recoTrack->getCDCHitList()) {
108 if (m_relateHitsByID) {
109 for (CDCTriggerSegmentHit& tsHit : m_segmentHits) {
110 if (tsHit.getID() == cdcHit->getID()) {
111 recoTrack->addRelationTo(&tsHit);
112 }
113 }
114 } else {
115 // Look for relations between CDC hits and TS hits
116 auto relHits = cdcHit->template getRelationsFrom<CDCTriggerSegmentHit>(m_hitCollectionName);
117 for (size_t i = 0; i < relHits.size(); ++i) {
118 // Create relations only for priority hits (relation weight 2)
119 if (relHits.weight(i) > 1)
120 recoTrack->addRelationTo(relHits[i]);
121 }
122 }
123 }
124 }
125
126 // Early exit if no tracks are present
127 int nRecoTracks = m_recoTracks.getEntries();
128 int nTrgTracks = m_trgTracks.getEntries();
129 if (nRecoTracks == 0 || nTrgTracks == 0) return;
130
131 // Helper to build multimap of hitId -> trackId
132 auto buildHitToTrackMap = [&](auto & tracks) {
133 std::multimap<HitId, TrackId> result;
134 for (TrackId id = 0; id < tracks.getEntries(); ++id) {
135 auto relHits = tracks[id]->template getRelationsTo<CDCTriggerSegmentHit>(m_hitCollectionName);
136 for (auto& hit : relHits) {
137 result.insert(std::make_pair(hit.getArrayIndex(), id));
138 }
139 }
140 return result;
141 };
142
143 auto recoTrackIdByHitId = buildHitToTrackMap(m_recoTracks);
144 auto trgTrackIdByHitId = buildHitToTrackMap(m_trgTracks);
145
146 Eigen::MatrixXi confusionMatrix = Eigen::MatrixXi::Zero(nTrgTracks, nRecoTracks);
147 Eigen::RowVectorXi totalHitsByRecoTrackId = Eigen::RowVectorXi::Zero(nRecoTracks);
148 Eigen::VectorXi totalHitsByTrgTrackId = Eigen::VectorXi::Zero(nTrgTracks);
149
150 // Fill the confusion matrix (add a weight if a hit is shared)
151 for (HitId hitId = 0; hitId < m_segmentHits.getEntries(); ++hitId) {
152 if (m_axialOnly && m_segmentHits[hitId]->getISuperLayer() % 2) continue;
153
154 auto trgRange = trgTrackIdByHitId.equal_range(hitId);
155 auto recoRange = recoTrackIdByHitId.equal_range(hitId);
156
157 for (auto& [_, tId] : as_range(trgRange)) totalHitsByTrgTrackId(tId) += 1;
158 for (auto& [_, rId] : as_range(recoRange)) totalHitsByRecoTrackId(rId) += 1;
159
160 for (auto& [_, rId] : as_range(recoRange)) {
161 for (auto& [_, tId] : as_range(trgRange)) {
162 confusionMatrix(tId, rId) += 1;
163 }
164 }
165 }
166
167 // Helper to build map of bestRecoId -> purity
168 auto bestRecoForTrg = [&](TrackId trgId) {
169 Eigen::RowVectorXi row = confusionMatrix.row(trgId);
170 Eigen::RowVectorXi::Index bestRecoId;
171 int hits = row.maxCoeff(&bestRecoId);
172 Purity purity = Purity(hits) / totalHitsByTrgTrackId(trgId);
173 return std::pair<TrackId, Purity>(bestRecoId, purity);
174 };
175
176 // Helper to build map of bestTrgId -> eff
177 auto bestTrgForReco = [&](TrackId recoId) {
178 Eigen::VectorXi col = confusionMatrix.col(recoId);
179 Eigen::VectorXi::Index bestTrgId;
180 int hits = col.maxCoeff(&bestTrgId);
181 Efficiency eff = Efficiency(hits) / totalHitsByRecoTrackId(recoId);
182 return std::pair<TrackId, Efficiency>(bestTrgId, eff);
183 };
184
185 // Precompute relations
186 std::vector<std::pair<TrackId, Purity>> purestRecoByTrg(nTrgTracks);
187 for (TrackId t = 0; t < nTrgTracks; ++t) purestRecoByTrg[t] = bestRecoForTrg(t);
188
189 std::vector<std::pair<TrackId, Efficiency>> mostEffTrgByReco(nRecoTracks);
190 for (TrackId r = 0; r < nRecoTracks; ++r) mostEffTrgByReco[r] = bestTrgForReco(r);
191
192 // Classification of trigger tracks
193 for (TrackId t = 0; t < nTrgTracks; ++t) {
194 auto [r, purity] = purestRecoByTrg[t];
195 TrgTrackType* trg = m_trgTracks[t];
196
197 if (purity < m_minPurity) continue; // Ghost track
198
199 auto [bestTrg, eff] = mostEffTrgByReco[r];
200 RecoTrack* reco = m_recoTracks[r];
201
202 if (t == bestTrg) {
203 trg->addRelationTo(reco, purity); // Matched track
204 } else if (m_relateClonesAndMerged) {
205 trg->addRelationTo(reco, -purity); // Clone track
206 }
207 }
208
209 // Classification of reco tracks
210 for (TrackId r = 0; r < nRecoTracks; ++r) {
211 auto [t, eff] = mostEffTrgByReco[r];
212 RecoTrack* reco = m_recoTracks[r];
213
214 if (eff < m_minEfficiency) continue; // Missing track
215
216 auto [bestReco, purity] = purestRecoByTrg[t];
217 TrgTrackType* trg = m_trgTracks[t];
218
219 if (r == bestReco) {
220 reco->addRelationTo(trg, eff); // Matched track
221 } else if (m_relateClonesAndMerged) {
222 reco->addRelationTo(trg, -eff); // Merged track
223 }
224 }
225 }
226
227 private:
228 std::string m_RecoTrackCollectionName;
229 std::string m_TrgTrackCollectionName;
230 std::string m_hitCollectionName;
231 bool m_axialOnly{false};
232 double m_minPurity{0.1};
233 double m_minEfficiency{0.1};
234 bool m_relateClonesAndMerged{true};
235 bool m_relateHitsByID{true};
236
237 StoreArray<CDCTriggerSegmentHit> m_segmentHits;
238 StoreArray<TrgTrackType> m_trgTracks;
239 StoreArray<RecoTrack> m_recoTracks;
240 };
241
242 // Aliases for the the two modules
243 class CDCTriggerRecoMatcherModule : public CDCTriggerRecoMatcherModuleT<CDCTriggerTrack> {
244 public:
245 CDCTriggerRecoMatcherModule() : CDCTriggerRecoMatcherModuleT<CDCTriggerTrack>() {}
246 };
247
248 class CDCTrigger3DHRecoMatcherModule : public CDCTriggerRecoMatcherModuleT<CDCTrigger3DHTrack> {
249 public:
250 CDCTrigger3DHRecoMatcherModule() : CDCTriggerRecoMatcherModuleT<CDCTrigger3DHTrack>() {}
251 };
252
253}
Class containing the result of the unpacker in raw data and the result of the digitizer in simulation...
Definition CDCHit.h:40
void initialize() override
Initialize the Module.
void event() override
This method is the core of the module.
Combination of several CDCHits to a track segment hit for the trigger.
void setDescription(const std::string &description)
Sets the description of the module.
Definition Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition Module.cc:208
Module()
Constructor.
Definition Module.cc:30
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition Module.h:80
This is the Reconstruction Event-Data Model Track.
Definition RecoTrack.h:79
std::vector< Belle2::RecoTrack::UsedCDCHit * > getCDCHitList() const
Return an unsorted list of cdc hits.
Definition RecoTrack.h:455
void addRelationTo(const RelationsInterface< BASE > *object, float weight=1.0, const std::string &namedRelation="") const
Add a relation from this object to another object (with caching).
RelationVector< TO > getRelationsTo(const std::string &name="", const std::string &namedRelation="") const
Get the relations that point from this object to another store array.
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition Module.h:559
Abstract base class for different kinds of events.