Belle II Software development
CDCTrigger3DHNeuroDataModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include "trg/cdc/modules/neurotrigger/CDCTrigger3DHNeuroDataModule.h"
10
11#include <fstream>
12#include <sstream>
13#include <cmath>
14#include <algorithm>
15#include <vector>
16#include <string>
17
18#include <boost/iostreams/filter/gzip.hpp>
19#include <boost/iostreams/filtering_stream.hpp>
20
21#include "trg/cdc/dataobjects/CDCTriggerSegmentHit.h"
22#include "trg/cdc/dataobjects/CDCTrigger3DHTrack.h"
23#include "tracking/dataobjects/RecoTrack.h"
24#include "framework/datastore/StoreArray.h"
25#include "framework/core/ModuleParam.templateDetails.h"
26#include "framework/geometry/B2Vector3.h"
27
28#define BOOST_MULTI_ARRAY_NO_GENERATORS
29
30using namespace Belle2;
31
32REG_MODULE(CDCTrigger3DHNeuroData);
33
34CDCTrigger3DHNeuroDataModule::CDCTrigger3DHNeuroDataModule() : Module()
35{
36 setDescription(
37 "This module takes 3dtracks, track segments, and target tracks (recotracks)\n"
38 "as input and generates training data for the neurotrigger in a tab separated, gzip compressed file."
39 );
40 addParam("hitCollectionName", m_hitCollectionName,
41 "Name of the input StoreArray of CDCTriggerSegmentHits (relations to input tracks required).",
42 std::string(""));
43 addParam("inputCollectionName", m_inputCollectionName,
44 "Name of the StoreArray holding the 3DHough Finder input tracks.",
45 std::string("CDCTriggerNNInput2DTracks"));
46 addParam("targetCollectionName", m_targetCollectionName,
47 "Name of the RecoTrack collection used as target values.",
48 std::string("RecoTracks"));
49 addParam("configFileName", m_configFileName,
50 "Name of the configuration file.",
51 std::string(""));
52 addParam("gzipFilename", m_filename,
53 "Name of the gzip file, where the training samples will be saved.",
54 std::string("out.gz"));
55 addParam("saveFakeEventTracks", m_saveFakeEventTracks,
56 "Flag to save the 3DFinder tracks from fake events (no reconstructed track present).",
57 false);
58 addParam("saveFakeUnrelatedTracks", m_saveFakeUnrelatedTracks,
59 "Flag to save the 3DFinder tracks that have no relation to a reconstructed track.",
60 false);
61}
62
64{
65 m_ndFinderTracks.isRequired(m_inputCollectionName);
66 m_recoTracks.isRequired(m_targetCollectionName);
67 m_neuroParameters3DH = CDCTrigger3DHMLP::loadConfigFromFile(m_configFileName);
68 m_neuroTrigger3DH.initialize();
69 m_neuroTrigger3DH.setNeuroParameters(m_neuroParameters3DH);
70 m_neuroTrigger3DH.initializeCollections(m_hitCollectionName);
71 writeHeadline();
72}
73
74// Write the headline to the .gz file
75void CDCTrigger3DHNeuroDataModule::writeHeadline() const
76{
77 std::ostringstream oss;
78 const size_t inputPerSL = m_neuroParameters3DH.nInput / m_nSL;
79 for (size_t superLayerIdx = 0; superLayerIdx < m_nSL; ++superLayerIdx) {
80 oss << "SL" << superLayerIdx << "-relID\t";
81 oss << "SL" << superLayerIdx << "-driftT\t";
82 oss << "SL" << superLayerIdx << "-alpha\t";
83 for (size_t i = 0; i < inputPerSL - 3; ++i) {
84 oss << "SL" << superLayerIdx << "-extra_input" << i << "\t";
85 }
86 }
87 oss << "RecoZ\tRecoTheta\tRecoNNTClassification\tRecoSTTClassification\tTrackType\n";
88 // Write the headline (oss) to the .gz file
89 std::ofstream gzipfile(m_filename, std::ios_base::trunc | std::ios_base::binary);
90 boost::iostreams::filtering_ostream outStream;
91 outStream.push(boost::iostreams::gzip_compressor());
92 outStream.push(gzipfile);
93 outStream << oss.str() << std::endl;
94}
95
97{
98 bool isFakeEvent = (m_recoTracks.getEntries() == 0);
99
100 for (int trackIdx = 0; trackIdx < m_ndFinderTracks.getEntries(); ++trackIdx) {
101 const CDCTrigger3DHTrack* ndFinderTrack = m_ndFinderTracks[trackIdx];
102
103 TargetResult targetResult = computeTargetVector(*ndFinderTrack, isFakeEvent);
104 std::vector<float> target = targetResult.targetVector;
105 unsigned short trackType = targetResult.trackType;
106 if (target.empty()) {
107 continue;
108 }
109
110 m_neuroTrigger3DH.calculateTrackParameters(*ndFinderTrack);
111 m_neuroTrigger3DH.setEventTime(*m_ndFinderTracks[trackIdx]);
112 std::vector<size_t> hitIds = m_neuroTrigger3DH.load3DHits(*ndFinderTrack);
113 std::vector<float> inputVector = m_neuroTrigger3DH.getInputVector(hitIds);
114
115 std::ostringstream oss;
116 for (size_t i = 0; i < inputVector.size(); ++i)
117 oss << inputVector[i] << "\t";
118 for (size_t i = 0; i < target.size(); ++i)
119 oss << target[i] << "\t";
120 oss << trackType << "\n";
121
122 std::ofstream gzipfile(m_filename, std::ios_base::app | std::ios_base::binary);
123 boost::iostreams::filtering_ostream outStream;
124 outStream.push(boost::iostreams::gzip_compressor());
125 outStream.push(gzipfile);
126 outStream << oss.str();
127 }
128}
129
130// Compute scaled target vector from reco track
131CDCTrigger3DHNeuroDataModule::TargetResult CDCTrigger3DHNeuroDataModule::computeTargetVector(
132 const CDCTrigger3DHTrack& ndFinderTrack, const bool isFakeEvent) const
133{
134 RecoTrack* recoTrack = ndFinderTrack.getRelatedTo<RecoTrack>(m_targetCollectionName);
135 float z = 0., theta = 0., classificationNNT = -1., classificationSTT = -1.;
136
137 const bool isUnrelatedFake = !isFakeEvent && (recoTrack == nullptr);
138 const bool isFakeTrack = isFakeEvent || isUnrelatedFake;
139
140 if (isFakeEvent && !m_saveFakeEventTracks) { return {}; }
141 if (isUnrelatedFake && !m_saveFakeUnrelatedTracks) { return {}; }
142
143 if (!isFakeTrack) {
144 const auto& reps = recoTrack->getRepresentations();
145 for (auto* rep : reps) {
146 if (!recoTrack->wasFitSuccessful(rep)) continue;
147 try {
148 auto state = recoTrack->getMeasuredStateOnPlaneClosestTo(ROOT::Math::XYZVector(0, 0, 0), rep);
149 rep->extrapolateToLine(state, TVector3(0, 0, -1000), TVector3(0, 0, 2000));
150 if (state.getMom().Dot(B2Vector3D(ndFinderTrack.getDirection())) < 0) {
151 state.setPosMom(state.getPos(), -state.getMom());
152 state.setChargeSign(-state.getCharge());
153 }
154 z = state.getPos().Z();
155 theta = state.getMom().Theta();
156 const auto& pRaw = state.getMom();
157 float totalMomentum = std::sqrt(pRaw.Px() * pRaw.Px() + pRaw.Py() * pRaw.Py() + pRaw.Pz() * pRaw.Pz());
158 bool fromIP = (std::abs(z) <= 1);
159 if (fromIP) classificationNNT = 1.0f;
160 if (fromIP && totalMomentum >= 0.7) classificationSTT = 1.0f;
161 break;
162 } catch (...) {
163 continue;
164 }
165 }
166 }
167
168 std::vector<float> rawTarget = {
169 isFakeTrack ? 0.0f : z,
170 isFakeTrack ? static_cast<float>(std::atan2(1., ndFinderTrack.getCotTheta())) : theta,
171 isFakeTrack ? -1.0f : classificationNNT,
172 isFakeTrack ? -1.0f : classificationSTT
173 };
174
175 auto scaled = m_neuroTrigger3DH.scaleTarget(rawTarget);
176 for (float& v : scaled) {
177 v = std::clamp(v, -1.0f, 1.0f);
178 }
179
180 TrackType trackType = determineTrackType(classificationNNT, isFakeEvent, isUnrelatedFake);
181 TargetResult targetResult;
182 targetResult.targetVector = scaled;
183 targetResult.trackType = static_cast<unsigned short>(trackType);
184 return targetResult;
185}
186
187// Get the (target) track type
188CDCTrigger3DHNeuroDataModule::TrackType CDCTrigger3DHNeuroDataModule::determineTrackType(
189 const float classificationNNT, const bool isFakeEvent, const bool isUnrelatedFake) const
190{
191 if (classificationNNT == 1.0f) return TrackType::Real;
192 if (isFakeEvent) return TrackType::Fake;
193 if (isUnrelatedFake) return TrackType::UnrelatedFake;
194 if (classificationNNT == -1.0f) return TrackType::Background;
195 return TrackType::Unknown;
196}
virtual void initialize() override
Initialize the Module.
virtual void event() override
This method is the core of the module.
Base class for Modules.
Definition Module.h:72
This is the Reconstruction Event-Data Model Track.
Definition RecoTrack.h:79
const std::vector< genfit::AbsTrackRep * > & getRepresentations() const
Return a list of track representations. You are not allowed to modify or delete them!
Definition RecoTrack.h:638
bool wasFitSuccessful(const genfit::AbsTrackRep *representation=nullptr) const
Returns true if the last fit with the given representation was successful.
Definition RecoTrack.cc:336
const genfit::MeasuredStateOnPlane & getMeasuredStateOnPlaneClosestTo(const ROOT::Math::XYZVector &closestPoint, const genfit::AbsTrackRep *representation=nullptr)
Return genfit's MasuredStateOnPlane, that is closest to the given point useful for extrapolation of m...
Definition RecoTrack.cc:426
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition Module.h:649
B2Vector3< double > B2Vector3D
typedef for common usage with double
Definition B2Vector3.h:516
Abstract base class for different kinds of events.