Belle II Software development
VXDQETrainingDataCollectorModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <tracking/modules/vxdtfQualityEstimator/VXDQETrainingDataCollectorModule.h>
10#include <tracking/trackFindingVXD/trackQualityEstimators/QualityEstimatorTripletFit.h>
11#include <tracking/trackFindingVXD/trackQualityEstimators/QualityEstimatorRiemannHelixFit.h>
12#include <tracking/trackFindingVXD/trackQualityEstimators/QualityEstimatorMC.h>
13#include <tracking/trackFindingVXD/trackQualityEstimators/QualityEstimatorCircleFit.h>
14#include <framework/geometry/BFieldManager.h>
15
16using namespace Belle2;
17
18
19REG_MODULE(VXDQETrainingDataCollector);
20
22{
23 //Set module properties
24 setDescription("Module to collect training data for the VXDQualityEstimatorMVA and store it in a root file.");
26
27 addParam("EstimationMethod",
29 "Identifier which estimation method to use. Valid identifiers are: [circleFit, "
30 "tripletFit, helixFit]",
32
33 addParam("MCInfo",
35 "If true, MC information is used. Thus, to run over data, this needs to be set to false.",
36 m_MCInfo);
37
38 addParam("SpacePointTrackCandsStoreArrayName",
40 "Name of StoreArray containing the SpacePointTrackCandidates to be estimated.",
42
43 addParam("MCRecoTracksStoreArrayName",
45 "Name of StoreArray containing MCRecoTracks. Only required for MCInfo method",
47
48 addParam("MCStrictQualityEstimator",
50 "Only required for MCInfo method. If false combining several MCTracks is allowed.",
52
53 addParam("mva_target",
55 "Whether to write out MVA target which requires complete agreement between SVD CLusters "
56 "of pattern "
57 "recognition track and MC track to yield 1, else 0, and thus provides maximal hit "
58 "purity and hit efficiency.",
60
61 addParam("TrainingDataOutputName",
63 "Name of the output rootfile.",
65
66 addParam("ClusterInformation",
68 "How to compile information from clusters ['Average']",
70
71 addParam("UseTimingInfo",
73 "Whether to collect timing information",
75}
76
78{
80
81 m_qeResultsExtractor = std::make_unique<QEResultsExtractor>(m_EstimationMethod, m_variableSet);
82
83 m_variableSet.emplace_back("NSpacePoints", &m_nSpacePoints);
84
85 m_variableSet.emplace_back("truth", &m_truth);
86
87 if (m_ClusterInformation == "Average") {
88 m_clusterInfoExtractor = std::make_unique<ClusterInfoExtractor>(m_variableSet, m_UseTimingInfo);
89 }
90
91 m_recorder = std::make_unique<SimpleVariableRecorder>(m_variableSet, m_TrainingDataOutputName, "tree");
92
93 // create pointer to chosen estimator
94 if (m_EstimationMethod == "tripletFit") {
95 m_estimator = std::make_unique<QualityEstimatorTripletFit>();
96 } else if (m_EstimationMethod == "circleFit") {
97 m_estimator = std::make_unique<QualityEstimatorCircleFit>();
98 } else if (m_EstimationMethod == "helixFit") {
99 m_estimator = std::make_unique<QualityEstimatorRiemannHelixFit>();
100 }
101 B2ASSERT("Not all QualityEstimators could be initialized!", m_estimator);
102
103 if (m_MCInfo) {
105 B2ASSERT("QualityEstimatorMC could be initialized!", m_estimatorMC);
106 }
107}
108
110{
111 // BField is required by all QualityEstimators
112 const double bFieldZ = BFieldManager::getField(0, 0, 0).Z() / Unit::T;
113 m_estimator->setMagneticFieldStrength(bFieldZ);
114 if (m_MCInfo) {
115 m_estimatorMC->setMagneticFieldStrength(bFieldZ);
116 QualityEstimatorMC* MCestimator = static_cast<QualityEstimatorMC*>(m_estimatorMC.get());
117 MCestimator->forceUpdateClusterNames();
118 }
119}
120
122{
124
125 if (not aTC.hasRefereeStatus(SpacePointTrackCand::c_isActive)) {
126 continue;
127 }
128
129 std::vector<SpacePoint const*> const sortedHits = aTC.getSortedHits();
130 if (m_ClusterInformation == "Average") {
131 m_clusterInfoExtractor->extractVariables(sortedHits);
132 }
133 m_nSpacePoints = sortedHits.size();
134 if (m_MCInfo) {
135 const double mc_quality = m_estimatorMC->estimateQuality(sortedHits);
136 m_truth = float(mc_quality > 0);
137 }
138 m_qeResultsExtractor->extractVariables(m_estimator->estimateQualityAndProperties(sortedHits));
139
140 m_recorder->record();
141 }
142}
143
145{
146 m_recorder->write();
147 m_recorder.reset();
148}
Base class for Modules.
Definition: Module.h:72
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition: Module.cc:208
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition: Module.h:80
@ c_TerminateInAllProcesses
When using parallel processing, call this module's terminate() function in all processes().
Definition: Module.h:83
Class implementing the algorithm used for the MC based quality estimation.
void forceUpdateClusterNames()
Setter to force the class to update its cluster names.
Storage for (VXD) SpacePoint-based track candidates.
@ c_isActive
bit 11: SPTC is active (i.e.
bool isRequired(const std::string &name="")
Ensure this array/object has been registered previously.
static const double T
[tesla]
Definition: Unit.h:120
std::unique_ptr< SimpleVariableRecorder > m_recorder
pointer to the object that writes out the collected data into a root file
std::string m_SpacePointTrackCandsStoreArrayName
sets the name of the expected StoreArray containing SpacePointTrackCands
bool m_MCStrictQualityEstimator
Required for MCInfo method, activates its strict version.
void initialize() override
Initializes the Module.
void event() override
applies the selected quality estimation method for a given set of TCs
float m_truth
truth information collected with m_estimatorMC type is float to be consistent with m_variableSet (and...
std::string m_EstimationMethod
Identifier which estimation method to use.
float m_nSpacePoints
number of SpacePoints in SPTC as additional info to be collected, type is float to be consistent with...
void terminate() override
write out data from m_recorder
std::unique_ptr< QEResultsExtractor > m_qeResultsExtractor
pointer to object that extracts the results from the estimation method (including QI,...
bool m_UseTimingInfo
whether to collect timing information
bool m_mva_target
Bool to indicate if mva target requiring complete agreement in SVD Clusters between MC and PR track t...
void beginRun() override
sets magnetic field strength
std::unique_ptr< ClusterInfoExtractor > m_clusterInfoExtractor
pointer to object that extracts info from the clusters of a SPTC
std::string m_MCRecoTracksStoreArrayName
sets the name of the expected StoreArray containing MCRecoTracks.
std::vector< Named< float * > > m_variableSet
set of named variables to be collected
bool m_MCInfo
whether to collect MC truth information
std::unique_ptr< QualityEstimatorBase > m_estimatorMC
QualityEstimatorMC as target for training.
std::unique_ptr< QualityEstimatorBase > m_estimator
pointer to the selected QualityEstimator
std::string m_TrainingDataOutputName
name of the output rootfile
StoreArray< SpacePointTrackCand > m_spacePointTrackCands
the storeArray for SpacePointTrackCands as member, is faster than recreating link for each event
std::string m_ClusterInformation
how to compile information from clusters ['Average']
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:560
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:650
static void getField(const double *pos, double *field)
return the magnetic field at a given position.
Definition: BFieldManager.h:91
Abstract base class for different kinds of events.