Belle II Software  release-06-01-15
VXDQETrainingDataCollectorModule.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <tracking/modules/vxdtfQualityEstimator/VXDQETrainingDataCollectorModule.h>
10 #include <tracking/trackFindingVXD/trackQualityEstimators/QualityEstimatorTripletFit.h>
11 #include <tracking/trackFindingVXD/trackQualityEstimators/QualityEstimatorRiemannHelixFit.h>
12 #include <tracking/trackFindingVXD/trackQualityEstimators/QualityEstimatorMC.h>
13 #include <tracking/trackFindingVXD/trackQualityEstimators/QualityEstimatorCircleFit.h>
14 #include <framework/geometry/BFieldManager.h>
15 
16 using namespace Belle2;
17 
18 
19 REG_MODULE(VXDQETrainingDataCollector)
20 
22 {
23  //Set module properties
24  setDescription("Module to collect training data for the VXDQualityEstimatorMVA and store it in a root file.");
25  setPropertyFlags(c_ParallelProcessingCertified | c_TerminateInAllProcesses);
26 
27  addParam("EstimationMethod",
28  m_EstimationMethod,
29  "Identifier which estimation method to use. Valid identifiers are: [circleFit, "
30  "tripletFit, helixFit]",
31  m_EstimationMethod);
32 
33  addParam("MCInfo",
34  m_MCInfo,
35  "If true, MC information is used. Thus, to run over data, this needs to be set to false.",
36  m_MCInfo);
37 
38  addParam("SpacePointTrackCandsStoreArrayName",
39  m_SpacePointTrackCandsStoreArrayName,
40  "Name of StoreArray containing the SpacePointTrackCandidates to be estimated.",
41  m_SpacePointTrackCandsStoreArrayName);
42 
43  addParam("MCRecoTracksStoreArrayName",
44  m_MCRecoTracksStoreArrayName,
45  "Name of StoreArray containing MCRecoTracks. Only required for MCInfo method",
46  m_MCRecoTracksStoreArrayName);
47 
48  addParam("MCStrictQualityEstimator",
49  m_MCStrictQualityEstimator,
50  "Only required for MCInfo method. If false combining several MCTracks is allowed.",
51  m_MCStrictQualityEstimator);
52 
53  addParam("mva_target",
54  m_mva_target,
55  "Whether to write out MVA target which requires complete agreement between SVD CLusters "
56  "of pattern "
57  "recognition track and MC track to yield 1, else 0, and thus provides maximal hit "
58  "purity and hit efficiency.",
59  m_mva_target);
60 
61  addParam("TrainingDataOutputName",
62  m_TrainingDataOutputName,
63  "Name of the output rootfile.",
64  m_TrainingDataOutputName);
65 
66  addParam("ClusterInformation",
67  m_ClusterInformation,
68  "How to compile information from clusters ['Average']",
69  m_ClusterInformation);
70 
71  addParam("UseTimingInfo",
72  m_UseTimingInfo,
73  "Whether to collect timing information",
74  m_UseTimingInfo);
75 }
76 
78 {
80 
81  m_qeResultsExtractor = std::make_unique<QEResultsExtractor>(m_EstimationMethod, m_variableSet);
82 
83  m_variableSet.emplace_back("NSpacePoints", &m_nSpacePoints);
84 
85  m_variableSet.emplace_back("truth", &m_truth);
86 
87  if (m_ClusterInformation == "Average") {
88  m_clusterInfoExtractor = std::make_unique<ClusterInfoExtractor>(m_variableSet, m_UseTimingInfo);
89  }
90 
91  m_recorder = std::make_unique<SimpleVariableRecorder>(m_variableSet, m_TrainingDataOutputName, "tree");
92 
93  // create pointer to chosen estimator
94  if (m_EstimationMethod == "tripletFit") {
95  m_estimator = std::make_unique<QualityEstimatorTripletFit>();
96  } else if (m_EstimationMethod == "circleFit") {
97  m_estimator = std::make_unique<QualityEstimatorCircleFit>();
98  } else if (m_EstimationMethod == "helixFit") {
99  m_estimator = std::make_unique<QualityEstimatorRiemannHelixFit>();
100  }
101  B2ASSERT("Not all QualityEstimators could be initialized!", m_estimator);
102 
103  if (m_MCInfo) {
104  m_estimatorMC = std::make_unique<QualityEstimatorMC>(m_MCRecoTracksStoreArrayName, m_MCStrictQualityEstimator, m_mva_target);
105  B2ASSERT("QualityEstimatorMC could be initialized!", m_estimatorMC);
106  }
107 }
108 
110 {
111  // BField is required by all QualityEstimators
112  const double bFieldZ = BFieldManager::getField(0, 0, 0).Z() / Unit::T;
113  m_estimator->setMagneticFieldStrength(bFieldZ);
114  if (m_MCInfo) {
115  m_estimatorMC->setMagneticFieldStrength(bFieldZ);
116  }
117 }
118 
120 {
122 
123  if (not aTC.hasRefereeStatus(SpacePointTrackCand::c_isActive)) {
124  continue;
125  }
126 
127  std::vector<SpacePoint const*> const sortedHits = aTC.getSortedHits();
128  if (m_ClusterInformation == "Average") {
129  m_clusterInfoExtractor->extractVariables(sortedHits);
130  }
131  m_nSpacePoints = sortedHits.size();
132  if (m_MCInfo) {
133  const double mc_quality = m_estimatorMC->estimateQuality(sortedHits);
134  m_truth = float(mc_quality > 0);
135  }
136  m_qeResultsExtractor->extractVariables(m_estimator->estimateQualityAndProperties(sortedHits));
137 
138  m_recorder->record();
139  }
140 }
141 
143 {
144  m_recorder->write();
145  m_recorder.reset();
146 }
Base class for Modules.
Definition: Module.h:72
Storage for (VXD) SpacePoint-based track candidates.
@ c_isActive
bit 11: SPTC is active (i.e.
bool isRequired(const std::string &name="")
Ensure this array/object has been registered previously.
static const double T
[tesla]
Definition: Unit.h:120
VXD Quality Estimator Data Collector Module to collect data for a MVA training using VXDQE_teacher....
std::unique_ptr< SimpleVariableRecorder > m_recorder
pointer to the object that writes out the collected data into a root file
std::string m_SpacePointTrackCandsStoreArrayName
sets the name of the expected StoreArray containing SpacePointTrackCands
bool m_MCStrictQualityEstimator
Required for MCInfo method, activates its strict version.
void initialize() override
Initializes the Module.
void event() override
applies the selected quality estimation method for a given set of TCs
float m_truth
truth information collected with m_estimatorMC type is float to be consistend with m_variableSet (and...
std::string m_EstimationMethod
Identifier which estimation method to use.
float m_nSpacePoints
number of SpacePoints in SPTC as additional info to be collected, type is float to be consistend with...
void terminate() override
write out data from m_recorder
std::unique_ptr< QEResultsExtractor > m_qeResultsExtractor
pointer to object that extracts the results from the estimation mehtod (including QI,...
bool m_UseTimingInfo
whether to collect timing information
bool m_mva_target
Bool to indicate if mva target requiring complete agreement in SVD Clusters between MC and PR track t...
void beginRun() override
sets magnetic field strength
std::unique_ptr< ClusterInfoExtractor > m_clusterInfoExtractor
pointer to object that extracts info from the clusters of a SPTC
std::string m_MCRecoTracksStoreArrayName
sets the name of the expected StoreArray containing MCRecoTracks.
std::vector< Named< float * > > m_variableSet
set of named variables to be collected
bool m_MCInfo
whether to collect MC truth information
std::unique_ptr< QualityEstimatorBase > m_estimatorMC
QualityEstimatorMC as target for training.
std::unique_ptr< QualityEstimatorBase > m_estimator
pointer to the selected QualityEstimator
std::string m_TrainingDataOutputName
name of the output rootfile
StoreArray< SpacePointTrackCand > m_spacePointTrackCands
the storeArray for SpacePointTrackCands as member, is faster than recreating link for each event
std::string m_ClusterInformation
how to compile information from clusters ['Average']
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:650
static void getField(const double *pos, double *field)
return the magnetic field at a given position.
Abstract base class for different kinds of events.