Belle II Software development
TrackQETrainingDataCollectorModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <tracking/modules/trackQualityEstimator/TrackQETrainingDataCollectorModule.h>
10
11using namespace Belle2;
12
13
14REG_MODULE(TrackQETrainingDataCollector);
15
17{
18 //Set module properties
19 setDescription("Module to collect training data for a specified qualityEstimator and store it in a root file.");
21
22 addParam("recoTracksStoreArrayName",
24 "Name of the recoTrack StoreArray.",
26
27 addParam("SVDCDCRecoTracksStoreArrayName",
29 "Name of the SVD-CDC StoreArray.",
31
32 addParam("SVDPlusCDCStandaloneRecoTracksStoreArrayName",
34 "Name of the combined CDC-SVD StoreArray with tracks added from the CDC to SVD CKF.",
36
37 addParam("CDCRecoTracksStoreArrayName",
39 "Name of the CDC StoreArray.",
41
42 addParam("SVDRecoTracksStoreArrayName",
44 "Name of the SVD StoreArray.",
46
47 addParam("PXDRecoTracksStoreArrayName",
49 "Name of the PXD StoreArray.",
51
52 addParam("TrainingDataOutputName",
54 "Name of the output rootfile.",
56
57 addParam("collectEventFeatures",
59 "Whether to use eventwise features.",
61}
62
64{
67 m_eventInfoExtractor = std::make_unique<EventInfoExtractor>(m_variableSet);
68 }
69 m_recoTrackExtractor = std::make_unique<RecoTrackExtractor>(m_variableSet);
70 m_subRecoTrackExtractor = std::make_unique<SubRecoTrackExtractor>(m_variableSet);
71 m_hitInfoExtractor = std::make_unique<HitInfoExtractor>(m_variableSet);
72
73 m_variableSet.emplace_back("truth", &m_matched);
74 m_variableSet.emplace_back("background", &m_background);
75 m_variableSet.emplace_back("ghost", &m_ghost);
76 m_variableSet.emplace_back("fake", &m_fake);
77 m_variableSet.emplace_back("clone", &m_clone);
78
79 m_recorder = std::make_unique<SimpleVariableRecorder>(m_variableSet, m_TrainingDataOutputName, "tree");
80}
81
83{
84
85}
86
88{
89 for (const RecoTrack& recoTrack : m_recoTracks) {
90 m_matched = float(recoTrack.getMatchingStatus() == RecoTrack::MatchingStatus::c_matched);
91 m_background = float(recoTrack.getMatchingStatus() == RecoTrack::MatchingStatus::c_background);
92 m_ghost = float(recoTrack.getMatchingStatus() == RecoTrack::MatchingStatus::c_ghost);
93 m_fake = float((recoTrack.getMatchingStatus() == RecoTrack::MatchingStatus::c_background)
94 || (recoTrack.getMatchingStatus() == RecoTrack::MatchingStatus::c_ghost));
95 m_clone = float(recoTrack.getMatchingStatus() == RecoTrack::MatchingStatus::c_clone);
96
98 // combined CDC and SVD tracks after both CDC-to-SVD and also SVD-to-CDC CKF
99 RecoTrack* svdCDCRecoTrackPtr = recoTrack.getRelatedTo<RecoTrack>(m_svdCDCRecoTracksStoreArrayName);
100 // combined SVD and CDC-standalone tracks after CDC-to-SVD CKF
101 RecoTrack* svdPlusCDCStandaloneRecoTrackPtr = nullptr;
102 // CDC tracks from CDC-standalone tracking
103 RecoTrack* cdcRecoTrackPtr = nullptr;
104 // SVD tracks from VXDTF2 (SVD-standalone) tracking
105 RecoTrack* svdRecoTrackPtr = nullptr;
106
107 if (svdCDCRecoTrackPtr) {
108 // Relation graph when SVD-to-CDC CFK is enabled:
109 // SVDCDCRecoTracks -> SVDPlusCDCStandaloneRecoTracks -> CDCRecoTracks & SVDRecoTracks
110 svdPlusCDCStandaloneRecoTrackPtr = svdCDCRecoTrackPtr->getRelatedTo<RecoTrack>(
112 if (not svdPlusCDCStandaloneRecoTrackPtr) {
113 // Relation graph when SVD-to-CDC CFK is disabled:
114 // SVDCDCRecoTracks -> CDCRecoTracks & SVDRecoTracks
115 svdPlusCDCStandaloneRecoTrackPtr = svdCDCRecoTrackPtr;
116 }
117 cdcRecoTrackPtr = svdPlusCDCStandaloneRecoTrackPtr->getRelatedTo<RecoTrack>(m_cdcRecoTracksStoreArrayName);
118 svdRecoTrackPtr = svdPlusCDCStandaloneRecoTrackPtr->getRelatedTo<RecoTrack>(m_svdRecoTracksStoreArrayName);
119 }
120
122 m_eventInfoExtractor->extractVariables(m_recoTracks, recoTrack);
123 }
124 m_recoTrackExtractor->extractVariables(recoTrack);
125 // TODO: also use `CKFCDCRecoTracks` and its features in quality estimation
126 m_subRecoTrackExtractor->extractVariables(cdcRecoTrackPtr, svdRecoTrackPtr, pxdRecoTrackPtr);
127 m_hitInfoExtractor->extractVariables(recoTrack);
128
129 // record variables
130 m_recorder->record();
131 }
132}
133
135{
136 m_recorder->write();
137 m_recorder.reset();
138}
Base class for Modules.
Definition: Module.h:72
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition: Module.cc:208
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition: Module.h:80
@ c_TerminateInAllProcesses
When using parallel processing, call this module's terminate() function in all processes().
Definition: Module.h:83
This is the Reconstruction Event-Data Model Track.
Definition: RecoTrack.h:79
TO * getRelatedTo(const std::string &name="", const std::string &namedRelation="") const
Get the object to which this object has a relation.
bool isRequired(const std::string &name="")
Ensure this array/object has been registered previously.
std::unique_ptr< SimpleVariableRecorder > m_recorder
pointer to the object that writes out the collected data into a root file
std::unique_ptr< HitInfoExtractor > m_hitInfoExtractor
pointer to object that extracts info from the hits within the RecoTrack
std::unique_ptr< EventInfoExtractor > m_eventInfoExtractor
pointer to object that extracts info from the whole event
std::string m_svdPlusCDCStandaloneRecoTracksStoreArrayName
Name of the StoreArray of SVD tracks combined with CDC-tracks from standalone CDC tracking.
std::string m_svdCDCRecoTracksStoreArrayName
Name of the SVD-CDC StoreArray.
void event() override
applies the selected quality estimation method for a given set of TCs
std::string m_svdRecoTracksStoreArrayName
Name of the SVD StoreArray.
bool m_collectEventFeatures
Parameter to enable event-wise features.
std::string m_pxdRecoTracksStoreArrayName
Name of the PXD StoreArray.
void terminate() override
write out data from m_recorder
void beginRun() override
sets magnetic field strength
std::vector< Named< float * > > m_variableSet
set of named variables to be collected
std::string m_cdcRecoTracksStoreArrayName
Name of the CDC StoreArray.
float m_matched
truth information collected with m_estimatorMC type is float to be consistend with m_variableSet (and...
float m_background
1 if track is background track, 0 else
StoreArray< RecoTrack > m_recoTracks
Store Array of the recoTracks.
std::unique_ptr< RecoTrackExtractor > m_recoTrackExtractor
pointer to object that extracts info from the root RecoTrack
std::string m_recoTracksStoreArrayName
Name of the recoTrack StoreArray.
std::unique_ptr< SubRecoTrackExtractor > m_subRecoTrackExtractor
pointer to object that extracts info from the related sub RecoTracks
std::string m_TrainingDataOutputName
name of the output rootfile
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:560
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:650
Abstract base class for different kinds of events.