Belle II Software prerelease-11-00-00a
TrackQETrainingDataCollectorModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <tracking/modules/trackQualityEstimator/TrackQETrainingDataCollectorModule.h>
10#include <tracking/trackFindingVXD/trackQualityEstimators/QualityEstimatorTripletFit.h>
11#include <tracking/trackFindingVXD/trackQualityEstimators/QualityEstimatorCircleFit.h>
12#include <tracking/trackFindingVXD/trackQualityEstimators/QualityEstimatorRiemannHelixFit.h>
13
14using namespace Belle2;
15
16
17REG_MODULE(TrackQETrainingDataCollector);
18
20{
21 //Set module properties
22 setDescription("Module to collect training data for a specified qualityEstimator and store it in a root file.");
24
25 addParam("recoTracksStoreArrayName",
27 "Name of the recoTrack StoreArray.",
29
30 addParam("CDCRecoTracksStoreArrayName",
32 "Name of the CDC StoreArray.",
34
35 addParam("SVDRecoTracksStoreArrayName",
37 "Name of the SVD StoreArray.",
39
40 addParam("PXDRecoTracksStoreArrayName",
42 "Name of the PXD StoreArray.",
44
45 addParam("TrainingDataOutputName",
47 "Name of the output rootfile.",
49
50 addParam("collectEventFeatures",
52 "Whether to use eventwise features.",
54
55 addParam("SVDEstimationMethod",
57 "Identifier which estimation method to use for SVD. Valid identifiers are: [tripletFit, circleFit, helixFit]",
59}
60
62{
65 m_eventInfoExtractor = std::make_unique<EventInfoExtractor>(m_variableSet);
66 }
67 m_recoTrackExtractor = std::make_unique<RecoTrackExtractor>(m_variableSet);
68 m_subRecoTrackExtractor = std::make_unique<SubRecoTrackExtractor>(m_variableSet);
69 m_hitInfoExtractor = std::make_unique<HitInfoExtractor>(m_variableSet);
70
71 // create pointer to chosen estimator
72 if (m_SVDEstimationMethod == "tripletFit") {
73 m_estimator = std::make_unique<QualityEstimatorTripletFit>();
74 } else if (m_SVDEstimationMethod == "circleFit") {
75 m_estimator = std::make_unique<QualityEstimatorCircleFit>();
76 } else if (m_SVDEstimationMethod == "helixFit") {
77 m_estimator = std::make_unique<QualityEstimatorRiemannHelixFit>();
78 }
79 B2ASSERT("QualityEstimator could not be initialized with method: " << m_SVDEstimationMethod, m_estimator);
80
81 m_qeResultsExtractor = std::make_unique<QEResultsExtractor>(m_SVDEstimationMethod, m_variableSet, "SVD_");
82 m_variableSet.emplace_back("SVD_NSpacePoints", &m_nSpacePoints);
83 m_clusterInfoExtractor = std::make_unique<ClusterInfoExtractor>(m_variableSet, false, "SVD_");
84
85 m_qeResultsExtractorBefore = std::make_unique<QEResultsExtractor>(m_SVDEstimationMethod, m_variableSet, "SVDbefore_");
86 m_variableSet.emplace_back("SVDbefore_NSpacePoints", &m_nSpacePointsBefore);
87 m_clusterInfoExtractorBefore = std::make_unique<ClusterInfoExtractor>(m_variableSet, false, "SVDbefore_");
88
89 m_variableSet.emplace_back("truth", &m_matched);
90 m_variableSet.emplace_back("background", &m_background);
91 m_variableSet.emplace_back("ghost", &m_ghost);
92 m_variableSet.emplace_back("fake", &m_fake);
93 m_variableSet.emplace_back("clone", &m_clone);
94
95 m_recorder = std::make_unique<SimpleVariableRecorder>(m_variableSet, m_TrainingDataOutputName, "tree");
96}
97
102
104{
105 for (const RecoTrack& recoTrack : m_recoTracks) {
106 m_matched = float(recoTrack.getMatchingStatus() == RecoTrack::MatchingStatus::c_matched);
107 m_background = float(recoTrack.getMatchingStatus() == RecoTrack::MatchingStatus::c_background);
108 m_ghost = float(recoTrack.getMatchingStatus() == RecoTrack::MatchingStatus::c_ghost);
109 m_fake = float((recoTrack.getMatchingStatus() == RecoTrack::MatchingStatus::c_background)
110 || (recoTrack.getMatchingStatus() == RecoTrack::MatchingStatus::c_ghost));
111 m_clone = float(recoTrack.getMatchingStatus() == RecoTrack::MatchingStatus::c_clone);
112
113 RecoTrack* pxdRecoTrackPtr = recoTrack.getRelatedTo<RecoTrack>(m_pxdRecoTracksStoreArrayName);
114
115
116 // Try to find all CDC tracks that are related to some hits in the RecoTrack.
117 std::vector<RecoTrack*> allCDCTracks;
118 const auto& cdcHitList = recoTrack.getCDCHitList();
119 for (auto* cdcHit : cdcHitList) {
120 const RelationVector<RecoTrack>& relatedCDCTracks =
121 cdcHit->getRelationsWith<RecoTrack>(m_cdcRecoTracksStoreArrayName);
122 for (unsigned int index = 0; index < relatedCDCTracks.size(); ++index) {
123 RecoTrack* relatedCDCTrack = relatedCDCTracks[index];
124 if (std::find(allCDCTracks.begin(), allCDCTracks.end(), relatedCDCTrack) == allCDCTracks.end()) {
125 allCDCTracks.push_back(relatedCDCTrack);
126 }
127 }
128 }
129 // The reconstructed track contains at most one CDC part.
130 // Try to match the hit list to find the right CDC track.
131 // If no matching CDC tracks are found, then cdcRecoTrackPtr will still be nullptr.
132 RecoTrack* cdcRecoTrackPtr = nullptr;
133 for (RecoTrack* foundCDCTrack : allCDCTracks) {
134 const auto& foundCDCTrackHitList = foundCDCTrack->getCDCHitList();
135 if (foundCDCTrackHitList.size() == cdcHitList.size() and
136 std::equal(foundCDCTrackHitList.begin(), foundCDCTrackHitList.end(), cdcHitList.begin())) {
137 cdcRecoTrackPtr = foundCDCTrack;
138 break;
139 }
140 }
141
142 // Try to find all SVD tracks that are related to some hits in the RecoTrack.
143 std::vector<RecoTrack*> allSVDTracks;
144 const auto& svdHitList = recoTrack.getSVDHitList();
145 for (auto* svdHit : svdHitList) {
146 const RelationVector<RecoTrack>& relatedSVDTracks = svdHit->getRelationsWith<RecoTrack>(m_svdRecoTracksStoreArrayName);
147 for (unsigned int index = 0; index < relatedSVDTracks.size(); ++index) {
148 RecoTrack* relatedSVDTrack = relatedSVDTracks[index];
149 if (std::find(allSVDTracks.begin(), allSVDTracks.end(), relatedSVDTrack) == allSVDTracks.end()) {
150 allSVDTracks.push_back(relatedSVDTrack);
151 }
152 }
153 }
154 // The reconstructed track contains at most two SVD parts.
155 // Try to match the hit list to find the right SVD tracks.
156 // If no matching SVD tracks are found, then svdRecoTrackPtr will still be nullptr.
157 RecoTrack* svdRecoTrackPtr = nullptr;
158 RecoTrack* svdRecoTrackBeforePtr = nullptr;
159 // First try to match the whole SVD track, which means it only contains one SVD part.
160 for (RecoTrack* foundSVDTrack : allSVDTracks) {
161 const auto& foundSVDTrackHitList = foundSVDTrack->getSVDHitList();
162 if (foundSVDTrackHitList.size() != svdHitList.size())
163 continue;
164 if (std::equal(foundSVDTrackHitList.begin(), foundSVDTrackHitList.end(), svdHitList.begin())) {
165 svdRecoTrackPtr = foundSVDTrack;
166 break;
167 }
168 }
169 if (svdRecoTrackPtr == nullptr) {
170 // Next try to match two SVD tracks.
171 for (RecoTrack* foundSVDTrackBefore : allSVDTracks) {
172 const auto& foundSVDTrackBeforeHitList = foundSVDTrackBefore->getSVDHitList();
173 auto sizeBefore = foundSVDTrackBeforeHitList.size();
174 if (sizeBefore >= svdHitList.size())
175 continue;
176 if (not std::equal(foundSVDTrackBeforeHitList.begin(), foundSVDTrackBeforeHitList.end(), svdHitList.begin()))
177 continue;
178 auto rest = svdHitList.size() - sizeBefore;
179 for (RecoTrack* foundSVDTrackAfter : allSVDTracks) {
180 const auto& foundSVDTrackAfterHitList = foundSVDTrackAfter->getSVDHitList();
181 if (foundSVDTrackBefore == foundSVDTrackAfter)
182 continue;
183 if (foundSVDTrackAfterHitList.size() != rest)
184 continue;
185 if (std::equal(foundSVDTrackAfterHitList.begin(), foundSVDTrackAfterHitList.end(), svdHitList.begin() + sizeBefore)) {
186 svdRecoTrackPtr = foundSVDTrackAfter;
187 svdRecoTrackBeforePtr = foundSVDTrackBefore;
188 break;
189 }
190 }
191 if (svdRecoTrackBeforePtr)
192 break;
193 }
194 }
195
196 const SpacePointTrackCand* spacePointTrackCand = nullptr;
197 if (svdRecoTrackPtr) {
198 spacePointTrackCand = svdRecoTrackPtr->getRelatedTo<SpacePointTrackCand>("SPTrackCands");
199 }
200 std::vector<SpacePoint const*> sortedHits;
201 if (spacePointTrackCand and spacePointTrackCand->hasRefereeStatus(SpacePointTrackCand::c_isActive)) {
202 sortedHits = spacePointTrackCand->getSortedHits();
203 }
204
205 m_clusterInfoExtractor->extractVariables(sortedHits);
206 m_nSpacePoints = sortedHits.size();
207 m_qeResultsExtractor->extractVariables(m_estimator->estimateQualityAndProperties(sortedHits));
208
209 const SpacePointTrackCand* spacePointTrackCandBefore = nullptr;
210 if (svdRecoTrackBeforePtr) {
211 spacePointTrackCandBefore = svdRecoTrackBeforePtr->getRelatedTo<SpacePointTrackCand>("SPTrackCands");
212 }
213 std::vector<SpacePoint const*> sortedHitsBefore;
214 if (spacePointTrackCandBefore and spacePointTrackCandBefore->hasRefereeStatus(SpacePointTrackCand::c_isActive)) {
215 sortedHitsBefore = spacePointTrackCandBefore->getSortedHits();
216 }
217
218 m_clusterInfoExtractorBefore->extractVariables(sortedHitsBefore);
219 m_nSpacePointsBefore = sortedHits.size();
220 m_qeResultsExtractorBefore->extractVariables(m_estimator->estimateQualityAndProperties(sortedHitsBefore));
221
223 m_eventInfoExtractor->extractVariables(m_recoTracks, recoTrack);
224 }
225 m_recoTrackExtractor->extractVariables(recoTrack);
226 // TODO: also use `CKFCDCRecoTracks` and its features in quality estimation
227 m_subRecoTrackExtractor->extractVariables(cdcRecoTrackPtr, svdRecoTrackPtr, pxdRecoTrackPtr);
228 m_hitInfoExtractor->extractVariables(recoTrack);
229
230 // record variables
231 m_recorder->record();
232 }
233}
234
236{
237 m_recorder->write();
238 m_recorder.reset();
239}
void setDescription(const std::string &description)
Sets the description of the module.
Definition Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition Module.cc:208
Module()
Constructor.
Definition Module.cc:30
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition Module.h:80
@ c_TerminateInAllProcesses
When using parallel processing, call this module's terminate() function in all processes().
Definition Module.h:83
This is the Reconstruction Event-Data Model Track.
Definition RecoTrack.h:79
std::vector< Belle2::RecoTrack::UsedSVDHit * > getSVDHitList() const
Return an unsorted list of svd hits.
Definition RecoTrack.h:452
std::vector< Belle2::RecoTrack::UsedCDCHit * > getCDCHitList() const
Return an unsorted list of cdc hits.
Definition RecoTrack.h:455
Class for type safe access to objects that are referred to in relations.
size_t size() const
Get number of relations.
TO * getRelatedTo(const std::string &name="", const std::string &namedRelation="") const
Get the object to which this object has a relation.
Storage for (VXD) SpacePoint-based track candidates.
@ c_isActive
bit 11: SPTC is active (i.e.
bool hasRefereeStatus(unsigned int short bitmask) const
Check if the SpacePointTrackCand has the status characterized by the bitmask.
const std::vector< const Belle2::SpacePoint * > getSortedHits() const
get hits (space points) sorted by their respective sorting parameter
std::unique_ptr< SimpleVariableRecorder > m_recorder
pointer to the object that writes out the collected data into a root file
std::unique_ptr< HitInfoExtractor > m_hitInfoExtractor
pointer to object that extracts info from the hits within the RecoTrack
std::unique_ptr< EventInfoExtractor > m_eventInfoExtractor
pointer to object that extracts info from the whole event
void event() override
applies the selected quality estimation method for a given set of TCs
std::vector< TrackingUtilities::Named< float * > > m_variableSet
set of named variables to be collected
std::string m_svdRecoTracksStoreArrayName
Name of the SVD StoreArray.
float m_nSpacePoints
number of SpacePoints in SPTC as additional info for MVA, type is float to be consistent with m_varia...
bool m_collectEventFeatures
Parameter to enable event-wise features.
std::string m_pxdRecoTracksStoreArrayName
Name of the PXD StoreArray.
void terminate() override
write out data from m_recorder
std::unique_ptr< QEResultsExtractor > m_qeResultsExtractor
pointer to object that extracts the results from the estimation method (including QI,...
std::unique_ptr< ClusterInfoExtractor > m_clusterInfoExtractorBefore
pointer to object that extracts info from the clusters of a SPTC For the SVD track before CDC.
void beginRun() override
sets magnetic field strength
std::unique_ptr< ClusterInfoExtractor > m_clusterInfoExtractor
pointer to object that extracts info from the clusters of a SPTC
std::string m_cdcRecoTracksStoreArrayName
Name of the CDC StoreArray.
float m_matched
truth information collected with m_estimatorMC type is float to be consistent with m_variableSet (and...
std::unique_ptr< QualityEstimatorBase > m_estimator
pointer to the selected QualityEstimator
float m_background
1 if track is background track, 0 else
StoreArray< RecoTrack > m_recoTracks
Store Array of the recoTracks.
std::unique_ptr< RecoTrackExtractor > m_recoTrackExtractor
pointer to object that extracts info from the root RecoTrack
std::string m_SVDEstimationMethod
Identifier which estimation method to use for SVD.
std::string m_recoTracksStoreArrayName
Name of the recoTrack StoreArray.
std::unique_ptr< SubRecoTrackExtractor > m_subRecoTrackExtractor
pointer to object that extracts info from the related sub RecoTracks
std::string m_TrainingDataOutputName
name of the output rootfile
std::unique_ptr< QEResultsExtractor > m_qeResultsExtractorBefore
pointer to object that extracts the results from the estimation method (including QI,...
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition Module.h:559
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition Module.h:649
Abstract base class for different kinds of events.