10 #include <tracking/modules/mcMatcher/MCRecoTracksMatcherModule.h>
13 #include <framework/datastore/StoreArray.h>
15 #include <framework/gearbox/Const.h>
17 #include <tracking/dataobjects/RecoTrack.h>
18 #include <mdst/dataobjects/MCParticle.h>
21 #include <pxd/dataobjects/PXDCluster.h>
22 #include <svd/dataobjects/SVDCluster.h>
23 #include <cdc/dataobjects/CDCHit.h>
27 #include <Eigen/Dense>
35 #ifdef __INTEL_COMPILER
36 #pragma warning disable 177
47 struct iter_pair_range : std::pair<Iter, Iter> {
48 explicit iter_pair_range(std::pair<Iter, Iter>
const& x)
49 : std::pair<Iter, Iter>(x)
65 return begin() == end();
70 inline iter_pair_range<Iter> as_range(std::pair<Iter, Iter>
const& x)
72 return iter_pair_range<Iter>(x);
77 using RecoTrackId = int;
78 struct WeightedRecoTrackId {
79 operator int()
const {
return id; }
84 using DetHitIdPair = std::pair<DetId, HitId>;
86 struct CompDetHitIdPair {
87 bool operator()(
const std::pair<DetId, HitId>& lhs,
88 const std::pair<std::pair<DetId, HitId>, WeightedRecoTrackId>& rhs)
90 return lhs < rhs.first;
93 bool operator()(
const std::pair<std::pair<DetId, HitId>, WeightedRecoTrackId>& lhs,
94 const std::pair<DetId, HitId>& rhs)
96 return lhs.first < rhs;
101 template <
class AMapOrSet>
102 void fillIDsFromStoreArray(AMapOrSet& recoTrackID_by_hitID,
105 RecoTrackId recoTrackId = -1;
106 for (
const RecoTrack& recoTrack : storedRecoTracks) {
108 std::vector<std::pair<DetHitIdPair, WeightedRecoTrackId> > hitIDsInTrack;
109 double totalWeight = 0;
111 const OriginTrackFinder c_MCTrackFinderAuxiliaryHit =
112 OriginTrackFinder::c_MCTrackFinderAuxiliaryHit;
115 OriginTrackFinder originFinder = recoTrack.getFoundByTrackFinder(cdcHit);
116 double weight = originFinder == c_MCTrackFinderAuxiliaryHit ? 0 : 1;
117 totalWeight += weight;
118 hitIDsInTrack.push_back({{Const::CDC, cdcHit->getArrayIndex()}, {recoTrackId, weight}});
121 OriginTrackFinder originFinder = recoTrack.getFoundByTrackFinder(svdHit);
122 double weight = originFinder == c_MCTrackFinderAuxiliaryHit ? 0 : 1;
123 totalWeight += weight;
124 hitIDsInTrack.push_back({{Const::SVD, svdHit->getArrayIndex()}, {recoTrackId, weight}});
127 OriginTrackFinder originFinder = recoTrack.getFoundByTrackFinder(pxdHit);
128 double weight = originFinder == c_MCTrackFinderAuxiliaryHit ? 0 : 1;
129 totalWeight += weight;
130 hitIDsInTrack.push_back({{Const::PXD, pxdHit->getArrayIndex()}, {recoTrackId, weight}});
134 if (totalWeight == 0) {
135 for (std::pair<DetHitIdPair, WeightedRecoTrackId>& recoTrack_for_hitID : hitIDsInTrack) {
136 recoTrack_for_hitID.second.weight = 1;
141 typename AMapOrSet::iterator itInsertHint = recoTrackID_by_hitID.end();
142 for (std::pair<DetHitIdPair, WeightedRecoTrackId>& recoTrack_for_hitID : hitIDsInTrack) {
143 itInsertHint = recoTrackID_by_hitID.insert(itInsertHint, recoTrack_for_hitID);
152 setDescription(
"This module compares reconstructed tracks generated by some pattern recognition "
153 "algorithm for PXD, SVD and/or CDC to ideal Monte Carlo tracks and performs a "
154 "matching from the former to the underlying MCParticles.");
160 addParam(
"prRecoTracksStoreArrayName",
162 "Name of the collection containing the tracks as generate a patter recognition algorithm to be evaluated ",
165 addParam(
"mcRecoTracksStoreArrayName",
167 "Name of the collection containing the reference tracks as generate by a Monte-Carlo-Tracker (e.g. MCTrackFinder)",
168 std::string(
"MCGFTrackCands"));
173 "Set true if PXDHits or PXDClusters should be used in the matching in case they are present",
178 "Set true if SVDHits or SVDClusters should be used in the matching in case they are present",
183 "Set true if CDCHits should be used in the matching in case they are present",
188 "Set true if only the axial CDCHits should be used",
193 "Minimal purity of a PRTrack to be considered matchable to a MCTrack. "
194 "This number encodes how many correct hits are minimally need to compensate for a false hits. "
195 "The default 2.0 / 3.0 suggests that for each background hit can be compensated by two correct hits.",
200 "Minimal efficiency of a MCTrack to be considered matchable to a PRTrack. "
201 "This number encodes which fraction of the true hits must at least be in the reconstructed track. "
202 "The default 0.05 suggests that at least 5% of the true hits should have been picked up.",
211 if (storeMCParticles.isOptional()) {
219 storePRRecoTracks.isRequired();
220 storeMCRecoTracks.isRequired();
221 storeMCParticles.isRequired();
244 pxdClusters.isOptional();
250 svdClusters.isOptional();
256 cdcHits.isOptional();
265 B2DEBUG(100,
"Skipping MC Track Matcher as there are no MC Particles registered in the DataStore.");
269 B2DEBUG(100,
"########## MCRecoTracksMatcherModule ############");
276 int nMCRecoTracks = mcRecoTracks.
getEntries();
277 int nPRRecoTracks = prRecoTracks.
getEntries();
279 B2DEBUG(100,
"Number patter recognition tracks is " << nPRRecoTracks);
280 B2DEBUG(100,
"Number Monte-Carlo tracks is " << nMCRecoTracks);
282 if (not nMCRecoTracks or not nPRRecoTracks) {
290 std::multimap<DetHitIdPair, WeightedRecoTrackId > mcId_by_hitId;
291 fillIDsFromStoreArray(mcId_by_hitId, mcRecoTracks);
298 std::set<std::pair<DetHitIdPair, WeightedRecoTrackId>> prId_by_hitId;
299 fillIDsFromStoreArray(prId_by_hitId, prRecoTracks);
304 std::map<DetId, int> nHits_by_detId;
309 if (pxdClusters.isOptional()) {
311 nHits_by_detId[Const::PXD] = nHits;
318 if (svdClusters.isOptional()) {
320 nHits_by_detId[Const::SVD] = nHits;
327 if (cdcHits.isOptional()) {
328 nHits_by_detId[Const::CDC] = cdcHits.
getEntries();
336 Eigen::MatrixXd confusionMatrix = Eigen::MatrixXd::Zero(nPRRecoTracks, nMCRecoTracks + 1);
337 Eigen::MatrixXd weightedConfusionMatrix = Eigen::MatrixXd::Zero(nPRRecoTracks, nMCRecoTracks + 1);
341 Eigen::RowVectorXd totalNDF_by_mcId = Eigen::RowVectorXd::Zero(nMCRecoTracks + 1);
342 Eigen::RowVectorXd totalWeight_by_mcId = Eigen::RowVectorXd::Zero(nMCRecoTracks + 1);
346 Eigen::VectorXd totalNDF_by_prId = Eigen::VectorXd::Zero(nPRRecoTracks);
349 const int mcBkgId = nMCRecoTracks;
353 for (
const std::pair<const DetId, NDF>& detId_nHits_pair : nHits_by_detId) {
355 DetId detId = detId_nHits_pair.first;
356 int nHits = detId_nHits_pair.second;
359 for (HitId hitId = 0; hitId < nHits; ++hitId) {
360 DetHitIdPair detId_hitId_pair(detId, hitId);
364 const CDCHit* cdcHit = cdcHits[hitId];
372 const auto mcIds_for_detId_hitId_pair =
373 as_range(mcId_by_hitId.equal_range(detId_hitId_pair));
376 const auto prIds_for_detId_hitId_pair =
377 as_range(std::equal_range(prId_by_hitId.begin(),
380 CompDetHitIdPair()));
384 if (mcIds_for_detId_hitId_pair.empty()) {
387 RecoTrackId mcId = mcBkgId;
389 totalNDF_by_mcId(mcId) += ndfForOneHit;
390 totalWeight_by_mcId(mcId) += ndfForOneHit * mcWeight;
392 for (
const auto& detId_hitId_pair_and_mcId : mcIds_for_detId_hitId_pair) {
393 WeightedRecoTrackId mcId = detId_hitId_pair_and_mcId.second;
394 double mcWeight = mcId.weight;
395 totalNDF_by_mcId(mcId) += ndfForOneHit;
396 totalWeight_by_mcId(mcId) += ndfForOneHit * mcWeight;
402 for (
const auto& detId_hitId_pair_and_prId : prIds_for_detId_hitId_pair) {
403 RecoTrackId prId = detId_hitId_pair_and_prId.second;
404 totalNDF_by_prId(prId) += ndfForOneHit;
408 for (
const auto& detId_hitId_pair_and_prId : prIds_for_detId_hitId_pair) {
409 RecoTrackId prId = detId_hitId_pair_and_prId.second;
410 if (mcIds_for_detId_hitId_pair.empty()) {
411 RecoTrackId mcId = mcBkgId;
413 confusionMatrix(prId, mcId) += ndfForOneHit;
414 weightedConfusionMatrix(prId, mcId) += ndfForOneHit * mcWeight;
416 for (
const auto& detId_hitId_pair_and_mcId : mcIds_for_detId_hitId_pair) {
417 WeightedRecoTrackId mcId = detId_hitId_pair_and_mcId.second;
418 double mcWeight = mcId.weight;
419 confusionMatrix(prId, mcId) += ndfForOneHit;
420 weightedConfusionMatrix(prId, mcId) += ndfForOneHit * mcWeight;
427 B2DEBUG(200,
"Confusion matrix of the event : " << std::endl << confusionMatrix);
428 B2DEBUG(200,
"Weighted confusion matrix of the event : " << std::endl << weightedConfusionMatrix);
430 B2DEBUG(200,
"totalNDF_by_mcId : " << std::endl << totalNDF_by_mcId);
431 B2DEBUG(200,
"totalWeight_by_mcId : " << std::endl << totalWeight_by_mcId);
433 B2DEBUG(200,
"totalNDF_by_prId : " << std::endl << totalNDF_by_prId);
435 Eigen::MatrixXd purityMatrix = confusionMatrix.array().colwise() / totalNDF_by_prId.array();
436 Eigen::MatrixXd efficiencyMatrix = confusionMatrix.array().rowwise() / totalNDF_by_mcId.array();
437 Eigen::MatrixXd weightedEfficiencyMatrix = weightedConfusionMatrix.array().rowwise() / totalWeight_by_mcId.array();
439 B2DEBUG(100,
"Purities");
440 B2DEBUG(100, purityMatrix);
442 B2DEBUG(100,
"Efficiencies");
443 B2DEBUG(100, efficiencyMatrix);
445 B2DEBUG(100,
"Weighted efficiencies");
446 B2DEBUG(100, weightedEfficiencyMatrix);
450 using Efficiency = float;
451 using Purity = float;
453 struct MostWeightEfficientPRId {
455 Efficiency weightedEfficiency;
456 Efficiency efficiency;
458 std::vector<MostWeightEfficientPRId> mostWeightEfficientPRId_by_mcId(nMCRecoTracks);
459 for (RecoTrackId mcId = 0; mcId < nMCRecoTracks; ++mcId) {
460 Eigen::VectorXd efficiencyCol = efficiencyMatrix.col(mcId);
461 Eigen::VectorXd weightedEfficiencyCol = weightedEfficiencyMatrix.col(mcId);
463 RecoTrackId bestPrId = 0;
464 Efficiency bestWeightedEfficiency = weightedEfficiencyCol(0);
465 Efficiency bestEfficiency = efficiencyCol(0);
466 Purity bestPurity = purityMatrix.row(0)(mcId);
470 bestWeightedEfficiency = 0;
474 for (RecoTrackId prId = 1; prId < nPRRecoTracks; ++prId) {
475 Eigen::RowVectorXd purityRow = purityMatrix.row(prId);
477 Efficiency currentWeightedEfficiency = weightedEfficiencyCol(prId);
478 Efficiency currentEfficiency = efficiencyCol(prId);
479 Purity currentPurity = purityRow(mcId);
483 currentWeightedEfficiency = 0;
486 if (std::tie(currentWeightedEfficiency, currentEfficiency, currentPurity) >
487 std::tie(bestWeightedEfficiency, bestEfficiency, bestPurity)) {
489 bestEfficiency = currentEfficiency;
490 bestWeightedEfficiency = currentWeightedEfficiency;
491 bestPurity = currentPurity;
495 bestWeightedEfficiency = weightedEfficiencyCol(bestPrId);
496 bestEfficiency = efficiencyCol(bestPrId);
497 mostWeightEfficientPRId_by_mcId[mcId] = {bestPrId, bestWeightedEfficiency, bestEfficiency};
502 struct MostPureMCId {
507 std::vector<MostPureMCId> mostPureMCId_by_prId(nPRRecoTracks);
508 for (
int prId = 0; prId < nPRRecoTracks; ++prId) {
509 Eigen::RowVectorXd purityRow = purityMatrix.row(prId);
512 Purity highestPurity = purityRow.maxCoeff(&mcId);
514 mostPureMCId_by_prId[prId] = {mcId, highestPurity};
520 RecoTrackId mcId = -1;
521 B2DEBUG(200,
"MCTrack to highest weighted efficiency PRTrack relation");
522 for (
const auto& mostWeightEfficientPRId_for_mcId : mostWeightEfficientPRId_by_mcId) {
524 const Efficiency& weightedEfficiency = mostWeightEfficientPRId_for_mcId.weightedEfficiency;
525 const RecoTrackId& prId = mostWeightEfficientPRId_for_mcId.id;
527 "mcId : " << mcId <<
" -> prId : " << prId <<
" with weighted efficiency "
528 << weightedEfficiency);
535 RecoTrackId prId = -1;
536 B2DEBUG(200,
"PRTrack to highest purity MCTrack relation");
537 for (
const auto& mostPureMCId_for_prId : mostPureMCId_by_prId) {
539 const RecoTrackId& mcId = mostPureMCId_for_prId.id;
540 const Purity& purity = mostPureMCId_for_prId.purity;
541 B2DEBUG(200,
"prId : " << prId <<
" -> mcId : " << mcId <<
" with purity " << purity);
546 int nMatched{}, nBackground{}, nClones{}, nGhost{};
552 for (RecoTrackId prId = 0; prId < nPRRecoTracks; ++prId) {
553 RecoTrack* prRecoTrack = prRecoTracks[prId];
555 const MostPureMCId& mostPureMCId = mostPureMCId_by_prId[prId];
557 const RecoTrackId& mcId = mostPureMCId.id;
558 const Purity& purity = mostPureMCId.purity;
564 B2DEBUG(100,
"Stored PRTrack " << prId <<
" as ghost because of too low purity");
569 if (mcId == mcBkgId) {
572 B2DEBUG(100,
"Stored PRTrack " << prId <<
" as background because of too low purity.");
582 RecoTrack* mcRecoTrack = mcRecoTracks[mcId];
584 B2ASSERT(
"No relation from MCRecoTrack to MCParticle.", mcParticle);
586 const MostWeightEfficientPRId& mostWeightEfficientPRId_for_mcId =
587 mostWeightEfficientPRId_by_mcId[mcId];
589 const RecoTrackId& mostWeightEfficientPRId = mostWeightEfficientPRId_for_mcId.id;
590 const Efficiency& weightedEfficiency = mostWeightEfficientPRId_for_mcId.weightedEfficiency;
597 if (prId == mostWeightEfficientPRId) {
605 B2DEBUG(100,
"Stored PRTrack " << prId <<
" as matched.");
606 B2DEBUG(100,
"MC Match prId " << prId <<
" to mcPartId " << mcParticle->
getArrayIndex());
607 B2DEBUG(100,
"Purity rel: prId " << prId <<
" -> mcId " << mcId <<
" : " << purity);
618 B2DEBUG(100,
"Stored PRTrack " << prId <<
" as ghost because of too low efficiency.");
632 B2DEBUG(100,
"Stored PRTrack " << prId <<
" as clone.");
633 B2DEBUG(100,
"MC Match prId " << prId <<
" to mcPartId " << mcParticle->
getArrayIndex());
634 B2DEBUG(100,
"Purity rel: prId " << prId <<
" -> mcId " << mcId <<
" : " << -purity);
638 B2DEBUG(100,
"Number of matches " << nMatched);
639 B2DEBUG(100,
"Number of clones " << nClones);
640 B2DEBUG(100,
"Number of bkg " << nBackground);
641 B2DEBUG(100,
"Number of ghost " << nGhost);
645 for (RecoTrackId mcId = 0; mcId < nMCRecoTracks; ++mcId) {
646 RecoTrack* mcRecoTrack = mcRecoTracks[mcId];
649 const MostWeightEfficientPRId& mostWeightEfficiencyPRId = mostWeightEfficientPRId_by_mcId[mcId];
651 const RecoTrackId& prId = mostWeightEfficiencyPRId.id;
652 const Efficiency& weightedEfficiency = mostWeightEfficiencyPRId.weightedEfficiency;
655 B2ASSERT(
"Index of pattern recognition tracks out of range.", prId < nPRRecoTracks and prId >= 0);
657 RecoTrack* prRecoTrack = prRecoTracks[prId];
659 const MostPureMCId& mostPureMCId_for_prId = mostPureMCId_by_prId[prId];
660 const RecoTrackId& mostPureMCId = mostPureMCId_for_prId.id;
663 if (mcId == mostPureMCId and
669 B2DEBUG(100,
"Efficiency rel: mcId " << mcId <<
" -> prId " << prId <<
" : " << weightedEfficiency);
677 bool isMergedMCRecoTrack =
682 if (isMergedMCRecoTrack) {
683 mcRecoTrack->
addRelationTo(prRecoTrack, -weightedEfficiency);
685 B2DEBUG(100,
"Efficiency rel: mcId " << mcId <<
" -> prId " << prId <<
" : " << -weightedEfficiency);
692 B2DEBUG(100,
"mcId " << mcId <<
" is missing. No relation created.");
693 B2DEBUG(100,
"is Primary" << mcRecoTracks[mcId]->getRelatedTo<MCParticle>()->isPrimaryParticle());
694 B2DEBUG(100,
"best prId " << prId <<
" with purity " << mostPureMCId_for_prId.purity <<
" -> " << mostPureMCId);
695 B2DEBUG(100,
"MC Total ndf " << totalNDF_by_mcId[mcId]);
696 B2DEBUG(100,
"MC Total weight" << totalWeight_by_mcId[mcId]);
697 B2DEBUG(100,
"MC Overlap ndf\n " << confusionMatrix.col(mcId).transpose());
698 B2DEBUG(100,
"MC Overlap weight\n " << weightedConfusionMatrix.col(mcId).transpose());
699 B2DEBUG(100,
"MC Efficiencies for the track\n" << efficiencyMatrix.col(mcId).transpose());
700 B2DEBUG(100,
"MC Weighted efficiencies for the track\n" << weightedEfficiencyMatrix.col(mcId).transpose());
703 B2DEBUG(100,
"########## End MCRecoTracksMatcherModule ############");