Belle II Software  release-08-01-10
VXDQETrainingDataCollectorModule.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <tracking/modules/vxdtfQualityEstimator/VXDQETrainingDataCollectorModule.h>
10 #include <tracking/trackFindingVXD/trackQualityEstimators/QualityEstimatorTripletFit.h>
11 #include <tracking/trackFindingVXD/trackQualityEstimators/QualityEstimatorRiemannHelixFit.h>
12 #include <tracking/trackFindingVXD/trackQualityEstimators/QualityEstimatorMC.h>
13 #include <tracking/trackFindingVXD/trackQualityEstimators/QualityEstimatorCircleFit.h>
14 #include <framework/geometry/BFieldManager.h>
15 
16 using namespace Belle2;
17 
18 
19 REG_MODULE(VXDQETrainingDataCollector);
20 
22 {
23  //Set module properties
24  setDescription("Module to collect training data for the VXDQualityEstimatorMVA and store it in a root file.");
26 
27  addParam("EstimationMethod",
29  "Identifier which estimation method to use. Valid identifiers are: [circleFit, "
30  "tripletFit, helixFit]",
32 
33  addParam("MCInfo",
34  m_MCInfo,
35  "If true, MC information is used. Thus, to run over data, this needs to be set to false.",
36  m_MCInfo);
37 
38  addParam("SpacePointTrackCandsStoreArrayName",
40  "Name of StoreArray containing the SpacePointTrackCandidates to be estimated.",
42 
43  addParam("MCRecoTracksStoreArrayName",
45  "Name of StoreArray containing MCRecoTracks. Only required for MCInfo method",
47 
48  addParam("MCStrictQualityEstimator",
50  "Only required for MCInfo method. If false combining several MCTracks is allowed.",
52 
53  addParam("mva_target",
55  "Whether to write out MVA target which requires complete agreement between SVD CLusters "
56  "of pattern "
57  "recognition track and MC track to yield 1, else 0, and thus provides maximal hit "
58  "purity and hit efficiency.",
59  m_mva_target);
60 
61  addParam("TrainingDataOutputName",
63  "Name of the output rootfile.",
65 
66  addParam("ClusterInformation",
68  "How to compile information from clusters ['Average']",
70 
71  addParam("UseTimingInfo",
73  "Whether to collect timing information",
75 }
76 
78 {
80 
81  m_qeResultsExtractor = std::make_unique<QEResultsExtractor>(m_EstimationMethod, m_variableSet);
82 
83  m_variableSet.emplace_back("NSpacePoints", &m_nSpacePoints);
84 
85  m_variableSet.emplace_back("truth", &m_truth);
86 
87  if (m_ClusterInformation == "Average") {
88  m_clusterInfoExtractor = std::make_unique<ClusterInfoExtractor>(m_variableSet, m_UseTimingInfo);
89  }
90 
91  m_recorder = std::make_unique<SimpleVariableRecorder>(m_variableSet, m_TrainingDataOutputName, "tree");
92 
93  // create pointer to chosen estimator
94  if (m_EstimationMethod == "tripletFit") {
95  m_estimator = std::make_unique<QualityEstimatorTripletFit>();
96  } else if (m_EstimationMethod == "circleFit") {
97  m_estimator = std::make_unique<QualityEstimatorCircleFit>();
98  } else if (m_EstimationMethod == "helixFit") {
99  m_estimator = std::make_unique<QualityEstimatorRiemannHelixFit>();
100  }
101  B2ASSERT("Not all QualityEstimators could be initialized!", m_estimator);
102 
103  if (m_MCInfo) {
104  m_estimatorMC = std::make_unique<QualityEstimatorMC>(m_MCRecoTracksStoreArrayName, m_MCStrictQualityEstimator, m_mva_target);
105  B2ASSERT("QualityEstimatorMC could be initialized!", m_estimatorMC);
106  }
107 }
108 
110 {
111  // BField is required by all QualityEstimators
112  const double bFieldZ = BFieldManager::getField(0, 0, 0).Z() / Unit::T;
113  m_estimator->setMagneticFieldStrength(bFieldZ);
114  if (m_MCInfo) {
115  m_estimatorMC->setMagneticFieldStrength(bFieldZ);
116  QualityEstimatorMC* MCestimator = static_cast<QualityEstimatorMC*>(m_estimatorMC.get());
117  MCestimator->forceUpdateClusterNames();
118  }
119 }
120 
122 {
124 
125  if (not aTC.hasRefereeStatus(SpacePointTrackCand::c_isActive)) {
126  continue;
127  }
128 
129  std::vector<SpacePoint const*> const sortedHits = aTC.getSortedHits();
130  if (m_ClusterInformation == "Average") {
131  m_clusterInfoExtractor->extractVariables(sortedHits);
132  }
133  m_nSpacePoints = sortedHits.size();
134  if (m_MCInfo) {
135  const double mc_quality = m_estimatorMC->estimateQuality(sortedHits);
136  m_truth = float(mc_quality > 0);
137  }
138  m_qeResultsExtractor->extractVariables(m_estimator->estimateQualityAndProperties(sortedHits));
139 
140  m_recorder->record();
141  }
142 }
143 
145 {
146  m_recorder->write();
147  m_recorder.reset();
148 }
Base class for Modules.
Definition: Module.h:72
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:214
void setPropertyFlags(unsigned int propertyFlags)
Sets the flags for the module properties.
Definition: Module.cc:208
@ c_ParallelProcessingCertified
This module can be run in parallel processing mode safely (All I/O must be done through the data stor...
Definition: Module.h:80
@ c_TerminateInAllProcesses
When using parallel processing, call this module's terminate() function in all processes().
Definition: Module.h:83
Class implementing the algorithm used for the MC based quality estimation.
void forceUpdateClusterNames()
Setter to force the class to update its cluster names.
Storage for (VXD) SpacePoint-based track candidates.
@ c_isActive
bit 11: SPTC is active (i.e.
bool isRequired(const std::string &name="")
Ensure this array/object has been registered previously.
static const double T
[tesla]
Definition: Unit.h:120
std::unique_ptr< SimpleVariableRecorder > m_recorder
pointer to the object that writes out the collected data into a root file
std::string m_SpacePointTrackCandsStoreArrayName
sets the name of the expected StoreArray containing SpacePointTrackCands
bool m_MCStrictQualityEstimator
Required for MCInfo method, activates its strict version.
void initialize() override
Initializes the Module.
void event() override
applies the selected quality estimation method for a given set of TCs
float m_truth
truth information collected with m_estimatorMC type is float to be consistend with m_variableSet (and...
std::string m_EstimationMethod
Identifier which estimation method to use.
float m_nSpacePoints
number of SpacePoints in SPTC as additional info to be collected, type is float to be consistend with...
void terminate() override
write out data from m_recorder
std::unique_ptr< QEResultsExtractor > m_qeResultsExtractor
pointer to object that extracts the results from the estimation mehtod (including QI,...
bool m_UseTimingInfo
whether to collect timing information
bool m_mva_target
Bool to indicate if mva target requiring complete agreement in SVD Clusters between MC and PR track t...
void beginRun() override
sets magnetic field strength
std::unique_ptr< ClusterInfoExtractor > m_clusterInfoExtractor
pointer to object that extracts info from the clusters of a SPTC
std::string m_MCRecoTracksStoreArrayName
sets the name of the expected StoreArray containing MCRecoTracks.
std::vector< Named< float * > > m_variableSet
set of named variables to be collected
bool m_MCInfo
whether to collect MC truth information
std::unique_ptr< QualityEstimatorBase > m_estimatorMC
QualityEstimatorMC as target for training.
std::unique_ptr< QualityEstimatorBase > m_estimator
pointer to the selected QualityEstimator
std::string m_TrainingDataOutputName
name of the output rootfile
StoreArray< SpacePointTrackCand > m_spacePointTrackCands
the storeArray for SpacePointTrackCands as member, is faster than recreating link for each event
std::string m_ClusterInformation
how to compile information from clusters ['Average']
REG_MODULE(arichBtest)
Register the Module.
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:560
static void getField(const double *pos, double *field)
return the magnetic field at a given position.
Definition: BFieldManager.h:91
Abstract base class for different kinds of events.