Belle II Software prerelease-11-00-00a
TrackQualityEstimatorMVAModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <tracking/modules/trackQualityEstimator/TrackQualityEstimatorMVAModule.h>
10#include <tracking/trackFindingVXD/trackQualityEstimators/QualityEstimatorTripletFit.h>
11#include <tracking/trackFindingVXD/trackQualityEstimators/QualityEstimatorCircleFit.h>
12#include <tracking/trackFindingVXD/trackQualityEstimators/QualityEstimatorRiemannHelixFit.h>
13#include <algorithm>
14
15using namespace Belle2;
16using namespace TrackingUtilities;
17
18
19REG_MODULE(TrackQualityEstimatorMVA);
20
22{
23 //Set module properties
24 setDescription("The quality estimator module for a fully reconstructed track");
26
27 addParam("recoTracksStoreArrayName",
29 "Name of the recoTrack StoreArray.",
31
32 addParam("CDCRecoTracksStoreArrayName",
34 "Name of the CDC StoreArray.",
36
37 addParam("SVDRecoTracksStoreArrayName",
39 "Name of the SVD StoreArray.",
41
42 addParam("PXDRecoTracksStoreArrayName",
44 "Name of the PXD StoreArray.",
46
47 addParam("TracksStoreArrayName",
49 "Name of the fitted mdst Tracks StoreArray.",
51
52 addParam("WeightFileIdentifier",
54 "Identifier of weightfile in Database or local root/xml file.",
56
57 addParam("collectEventFeatures",
59 "Whether to use eventwise features.",
61
62 addParam("VXDEstimationMethod",
64 "Identifier which estimation method to use for SVD. Valid identifiers are: [tripletFit, circleFit, helixFit]",
66}
67
69{
71
73 m_eventInfoExtractor = std::make_unique<EventInfoExtractor>(m_variableSet);
74 }
75 m_recoTrackExtractor = std::make_unique<RecoTrackExtractor>(m_variableSet);
76 m_subRecoTrackExtractor = std::make_unique<SubRecoTrackExtractor>(m_variableSet);
77 m_hitInfoExtractor = std::make_unique<HitInfoExtractor>(m_variableSet);
78
79 // create pointer to chosen estimator
80 if (m_SVDEstimationMethod == "tripletFit") {
81 m_estimator = std::make_unique<QualityEstimatorTripletFit>();
82 } else if (m_SVDEstimationMethod == "circleFit") {
83 m_estimator = std::make_unique<QualityEstimatorCircleFit>();
84 } else if (m_SVDEstimationMethod == "helixFit") {
85 m_estimator = std::make_unique<QualityEstimatorRiemannHelixFit>();
86 }
87 B2ASSERT("QualityEstimator could not be initialized with method: " << m_SVDEstimationMethod, m_estimator);
88
89 m_qeResultsExtractor = std::make_unique<QEResultsExtractor>(m_SVDEstimationMethod, m_variableSet, "SVD_");
90 m_variableSet.emplace_back("SVD_NSpacePoints", &m_nSpacePoints);
91 m_clusterInfoExtractor = std::make_unique<ClusterInfoExtractor>(m_variableSet, false, "SVD_");
92
93 m_qeResultsExtractorBefore = std::make_unique<QEResultsExtractor>(m_SVDEstimationMethod, m_variableSet, "SVDbefore_");
94 m_variableSet.emplace_back("SVDbefore_NSpacePoints", &m_nSpacePointsBefore);
95 m_clusterInfoExtractorBefore = std::make_unique<ClusterInfoExtractor>(m_variableSet, false, "SVDbefore_");
96
97 m_mvaExpert = std::make_unique<MVAExpert>(m_weightFileIdentifier, m_variableSet);
98 m_mvaExpert->initialize();
99}
100
105
107{
108 for (RecoTrack& recoTrack : m_recoTracks) {
109 const RecoTrack* pxdRecoTrack = recoTrack.getRelatedTo<RecoTrack>(m_pxdRecoTracksStoreArrayName);
110
111 // Try to find all CDC tracks that are related to some hits in the RecoTrack.
112 std::vector<RecoTrack*> allCDCTracks;
113 const auto& cdcHitList = recoTrack.getCDCHitList();
114 for (auto* cdcHit : cdcHitList) {
115 const RelationVector<RecoTrack>& relatedCDCTracks =
116 cdcHit->getRelationsWith<RecoTrack>(m_cdcRecoTracksStoreArrayName);
117 for (unsigned int index = 0; index < relatedCDCTracks.size(); ++index) {
118 RecoTrack* relatedCDCTrack = relatedCDCTracks[index];
119 if (std::find(allCDCTracks.begin(), allCDCTracks.end(), relatedCDCTrack) == allCDCTracks.end()) {
120 allCDCTracks.push_back(relatedCDCTrack);
121 }
122 }
123 }
124 // The reconstructed track contains at most one CDC part.
125 // Try to match the hit list to find the right CDC track.
126 // If no matching CDC tracks are found, then cdcRecoTrackPtr will still be nullptr.
127 RecoTrack* cdcRecoTrackPtr = nullptr;
128 for (RecoTrack* foundCDCTrack : allCDCTracks) {
129 const auto& foundCDCTrackHitList = foundCDCTrack->getCDCHitList();
130 if (foundCDCTrackHitList.size() == cdcHitList.size() and
131 std::equal(foundCDCTrackHitList.begin(), foundCDCTrackHitList.end(), cdcHitList.begin())) {
132 cdcRecoTrackPtr = foundCDCTrack;
133 break;
134 }
135 }
136
137 // Try to find all SVD tracks that are related to some hits in the RecoTrack.
138 std::vector<RecoTrack*> allSVDTracks;
139 const auto& svdHitList = recoTrack.getSVDHitList();
140 for (auto* svdHit : svdHitList) {
141 const RelationVector<RecoTrack>& relatedSVDTracks = svdHit->getRelationsWith<RecoTrack>(m_svdRecoTracksStoreArrayName);
142 for (unsigned int index = 0; index < relatedSVDTracks.size(); ++index) {
143 RecoTrack* relatedSVDTrack = relatedSVDTracks[index];
144 if (std::find(allSVDTracks.begin(), allSVDTracks.end(), relatedSVDTrack) == allSVDTracks.end()) {
145 allSVDTracks.push_back(relatedSVDTrack);
146 }
147 }
148 }
149 // The reconstructed track contains at most two SVD parts.
150 // Try to match the hit list to find the right SVD tracks.
151 // If no matching SVD tracks are found, then svdRecoTrackPtr will still be nullptr.
152 RecoTrack* svdRecoTrackPtr = nullptr;
153 RecoTrack* svdRecoTrackBeforePtr = nullptr;
154 // First try to match the whole SVD track, which means it only contains one SVD part.
155 for (RecoTrack* foundSVDTrack : allSVDTracks) {
156 const auto& foundSVDTrackHitList = foundSVDTrack->getSVDHitList();
157 if (foundSVDTrackHitList.size() != svdHitList.size())
158 continue;
159 if (std::equal(foundSVDTrackHitList.begin(), foundSVDTrackHitList.end(), svdHitList.begin())) {
160 svdRecoTrackPtr = foundSVDTrack;
161 break;
162 }
163 }
164 if (svdRecoTrackPtr == nullptr) {
165 // Next try to match two SVD tracks.
166 for (RecoTrack* foundSVDTrackBefore : allSVDTracks) {
167 const auto& foundSVDTrackBeforeHitList = foundSVDTrackBefore->getSVDHitList();
168 auto sizeBefore = foundSVDTrackBeforeHitList.size();
169 if (sizeBefore >= svdHitList.size())
170 continue;
171 if (not std::equal(foundSVDTrackBeforeHitList.begin(), foundSVDTrackBeforeHitList.end(), svdHitList.begin()))
172 continue;
173 auto rest = svdHitList.size() - sizeBefore;
174 for (RecoTrack* foundSVDTrackAfter : allSVDTracks) {
175 const auto& foundSVDTrackAfterHitList = foundSVDTrackAfter->getSVDHitList();
176 if (foundSVDTrackBefore == foundSVDTrackAfter)
177 continue;
178 if (foundSVDTrackAfterHitList.size() != rest)
179 continue;
180 if (std::equal(foundSVDTrackAfterHitList.begin(), foundSVDTrackAfterHitList.end(), svdHitList.begin() + sizeBefore)) {
181 svdRecoTrackPtr = foundSVDTrackAfter;
182 svdRecoTrackBeforePtr = foundSVDTrackBefore;
183 break;
184 }
185 }
186 if (svdRecoTrackBeforePtr)
187 break;
188 }
189 }
190
191 const SpacePointTrackCand* spacePointTrackCand = nullptr;
192 if (svdRecoTrackPtr) {
193 spacePointTrackCand = svdRecoTrackPtr->getRelatedTo<SpacePointTrackCand>("SPTrackCands");
194 }
195 std::vector<SpacePoint const*> sortedHits;
196 if (spacePointTrackCand and spacePointTrackCand->hasRefereeStatus(SpacePointTrackCand::c_isActive)) {
197 sortedHits = spacePointTrackCand->getSortedHits();
198 }
199
200 m_clusterInfoExtractor->extractVariables(sortedHits);
201 m_nSpacePoints = sortedHits.size();
202 m_qeResultsExtractor->extractVariables(m_estimator->estimateQualityAndProperties(sortedHits));
203
204 const SpacePointTrackCand* spacePointTrackCandBefore = nullptr;
205 if (svdRecoTrackBeforePtr) {
206 spacePointTrackCandBefore = svdRecoTrackBeforePtr->getRelatedTo<SpacePointTrackCand>("SPTrackCands");
207 }
208 std::vector<SpacePoint const*> sortedHitsBefore;
209 if (spacePointTrackCandBefore and spacePointTrackCandBefore->hasRefereeStatus(SpacePointTrackCand::c_isActive)) {
210 sortedHitsBefore = spacePointTrackCandBefore->getSortedHits();
211 }
212
213 m_clusterInfoExtractorBefore->extractVariables(sortedHitsBefore);
214 m_nSpacePointsBefore = sortedHitsBefore.size();
215 m_qeResultsExtractorBefore->extractVariables(m_estimator->estimateQualityAndProperties(sortedHitsBefore));
216
218 m_eventInfoExtractor->extractVariables(m_recoTracks, recoTrack);
219 }
220 m_recoTrackExtractor->extractVariables(recoTrack);
221 m_subRecoTrackExtractor->extractVariables(cdcRecoTrackPtr, svdRecoTrackPtr, pxdRecoTrack);
222 m_hitInfoExtractor->extractVariables(recoTrack);
223 // get quality indicator from classifier
224 const float qualityIndicator = m_mvaExpert->predict();
225 // set quality indicator property in RecoTracks and mdst Tracks from track fit
226 recoTrack.setQualityIndicator(qualityIndicator);
227 }
228}
void setDescription(const std::string &description)
Sets the description of the module.
Definition Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition Module.cc:208
Module()
Constructor.
Definition Module.cc:30
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition Module.h:80
This is the Reconstruction Event-Data Model Track.
Definition RecoTrack.h:79
std::vector< Belle2::RecoTrack::UsedSVDHit * > getSVDHitList() const
Return an unsorted list of svd hits.
Definition RecoTrack.h:452
std::vector< Belle2::RecoTrack::UsedCDCHit * > getCDCHitList() const
Return an unsorted list of cdc hits.
Definition RecoTrack.h:455
Class for type safe access to objects that are referred to in relations.
size_t size() const
Get number of relations.
TO * getRelatedTo(const std::string &name="", const std::string &namedRelation="") const
Get the object to which this object has a relation.
Storage for (VXD) SpacePoint-based track candidates.
@ c_isActive
bit 11: SPTC is active (i.e.
bool hasRefereeStatus(unsigned int short bitmask) const
Check if the SpacePointTrackCand has the status characterized by the bitmask.
const std::vector< const Belle2::SpacePoint * > getSortedHits() const
get hits (space points) sorted by their respective sorting parameter
std::unique_ptr< HitInfoExtractor > m_hitInfoExtractor
pointer to object that extracts info from the hits within the RecoTrack
std::unique_ptr< EventInfoExtractor > m_eventInfoExtractor
pointer to object that extracts info from the whole event
void initialize() override
Initializes the Module.
void event() override
Applies the selected quality estimation method for a given set of TCs.
std::unique_ptr< TrackingUtilities::MVAExpert > m_mvaExpert
pointer to the object to interact with the MVA package
std::vector< TrackingUtilities::Named< float * > > m_variableSet
set of named variables to be used in MVA
std::string m_svdRecoTracksStoreArrayName
Name of the SVD StoreArray.
std::string m_tracksStoreArrayName
Name of the StoreArray with mdst Tracks from track fit.
float m_nSpacePoints
number of SpacePoints in SPTC as additional info for MVA, type is float to be consistent with m_varia...
bool m_collectEventFeatures
Parameter to enable event-wise features.
std::string m_pxdRecoTracksStoreArrayName
Name of the PXD StoreArray.
std::unique_ptr< QEResultsExtractor > m_qeResultsExtractor
pointer to object that extracts the results from the estimation method (including QI,...
std::unique_ptr< ClusterInfoExtractor > m_clusterInfoExtractorBefore
pointer to object that extracts info from the clusters of a SPTC For the SVD track before CDC.
std::string m_weightFileIdentifier
identifier of weightfile in Database or local root/xml file
void beginRun() override
Launches mvaExpert and sets the magnetic field strength.
std::unique_ptr< ClusterInfoExtractor > m_clusterInfoExtractor
pointer to object that extracts info from the clusters of a SPTC
std::string m_cdcRecoTracksStoreArrayName
Name of the CDC StoreArray.
std::unique_ptr< QualityEstimatorBase > m_estimator
pointer to the selected QualityEstimator
StoreArray< RecoTrack > m_recoTracks
Store Array of the recoTracks.
std::unique_ptr< RecoTrackExtractor > m_recoTrackExtractor
pointer to object that extracts info from the root RecoTrack
std::string m_SVDEstimationMethod
Identifier which estimation method to use for SVD.
std::string m_recoTracksStoreArrayName
Name of the recoTrack StoreArray.
std::unique_ptr< SubRecoTrackExtractor > m_subRecoTrackExtractor
pointer to object that extracts info from the related sub RecoTracks
float m_nSpacePointsBefore
number of SpacePoints in SPTC.
std::unique_ptr< QEResultsExtractor > m_qeResultsExtractorBefore
pointer to object that extracts the results from the estimation method (including QI,...
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition Module.h:559
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition Module.h:649
Abstract base class for different kinds of events.