9#include "trg/cdc/modules/neurotrigger/CDCTrigger3DHNeuroDataModule.h"
18#include <boost/iostreams/filter/gzip.hpp>
19#include <boost/iostreams/filtering_stream.hpp>
21#include "trg/cdc/dataobjects/CDCTriggerSegmentHit.h"
22#include "trg/cdc/dataobjects/CDCTrigger3DHTrack.h"
23#include "tracking/dataobjects/RecoTrack.h"
24#include "framework/datastore/StoreArray.h"
25#include "framework/core/ModuleParam.templateDetails.h"
26#include "framework/geometry/B2Vector3.h"
28#define BOOST_MULTI_ARRAY_NO_GENERATORS
34CDCTrigger3DHNeuroDataModule::CDCTrigger3DHNeuroDataModule() :
Module()
37 "This module takes 3dtracks, track segments, and target tracks (recotracks)\n"
38 "as input and generates training data for the neurotrigger in a tab separated, gzip compressed file."
40 addParam(
"hitCollectionName", m_hitCollectionName,
41 "Name of the input StoreArray of CDCTriggerSegmentHits (relations to input tracks required).",
43 addParam(
"inputCollectionName", m_inputCollectionName,
44 "Name of the StoreArray holding the 3DHough Finder input tracks.",
45 std::string(
"CDCTriggerNNInput2DTracks"));
46 addParam(
"targetCollectionName", m_targetCollectionName,
47 "Name of the RecoTrack collection used as target values.",
48 std::string(
"RecoTracks"));
49 addParam(
"configFileName", m_configFileName,
50 "Name of the configuration file.",
52 addParam(
"gzipFilename", m_filename,
53 "Name of the gzip file, where the training samples will be saved.",
54 std::string(
"out.gz"));
55 addParam(
"saveFakeEventTracks", m_saveFakeEventTracks,
56 "Flag to save the 3DFinder tracks from fake events (no reconstructed track present).",
58 addParam(
"saveFakeUnrelatedTracks", m_saveFakeUnrelatedTracks,
59 "Flag to save the 3DFinder tracks that have no relation to a reconstructed track.",
65 m_ndFinderTracks.isRequired(m_inputCollectionName);
66 m_recoTracks.isRequired(m_targetCollectionName);
67 m_neuroParameters3DH = CDCTrigger3DHMLP::loadConfigFromFile(m_configFileName);
68 m_neuroTrigger3DH.initialize();
69 m_neuroTrigger3DH.setNeuroParameters(m_neuroParameters3DH);
70 m_neuroTrigger3DH.initializeCollections(m_hitCollectionName);
75void CDCTrigger3DHNeuroDataModule::writeHeadline()
const
77 std::ostringstream oss;
78 const size_t inputPerSL = m_neuroParameters3DH.nInput / m_nSL;
79 for (
size_t superLayerIdx = 0; superLayerIdx < m_nSL; ++superLayerIdx) {
80 oss <<
"SL" << superLayerIdx <<
"-relID\t";
81 oss <<
"SL" << superLayerIdx <<
"-driftT\t";
82 oss <<
"SL" << superLayerIdx <<
"-alpha\t";
83 for (
size_t i = 0; i < inputPerSL - 3; ++i) {
84 oss <<
"SL" << superLayerIdx <<
"-extra_input" << i <<
"\t";
87 oss <<
"RecoZ\tRecoTheta\tRecoNNTClassification\tRecoSTTClassification\tTrackType\n";
89 std::ofstream gzipfile(m_filename, std::ios_base::trunc | std::ios_base::binary);
90 boost::iostreams::filtering_ostream outStream;
91 outStream.push(boost::iostreams::gzip_compressor());
92 outStream.push(gzipfile);
93 outStream << oss.str() << std::endl;
98 bool isFakeEvent = (m_recoTracks.getEntries() == 0);
100 for (
int trackIdx = 0; trackIdx < m_ndFinderTracks.getEntries(); ++trackIdx) {
103 TargetResult targetResult = computeTargetVector(*ndFinderTrack, isFakeEvent);
104 std::vector<float> target = targetResult.targetVector;
105 unsigned short trackType = targetResult.trackType;
106 if (target.empty()) {
110 m_neuroTrigger3DH.calculateTrackParameters(*ndFinderTrack);
111 m_neuroTrigger3DH.setEventTime(*m_ndFinderTracks[trackIdx]);
112 std::vector<size_t> hitIds = m_neuroTrigger3DH.load3DHits(*ndFinderTrack);
113 std::vector<float> inputVector = m_neuroTrigger3DH.getInputVector(hitIds);
115 std::ostringstream oss;
116 for (
size_t i = 0; i < inputVector.size(); ++i)
117 oss << inputVector[i] <<
"\t";
118 for (
size_t i = 0; i < target.size(); ++i)
119 oss << target[i] <<
"\t";
120 oss << trackType <<
"\n";
122 std::ofstream gzipfile(m_filename, std::ios_base::app | std::ios_base::binary);
123 boost::iostreams::filtering_ostream outStream;
124 outStream.push(boost::iostreams::gzip_compressor());
125 outStream.push(gzipfile);
126 outStream << oss.str();
135 float z = 0., theta = 0., classificationNNT = -1., classificationSTT = -1.;
137 const bool isUnrelatedFake = !isFakeEvent && (recoTrack ==
nullptr);
138 const bool isFakeTrack = isFakeEvent || isUnrelatedFake;
140 if (isFakeEvent && !m_saveFakeEventTracks) {
return {}; }
141 if (isUnrelatedFake && !m_saveFakeUnrelatedTracks) {
return {}; }
145 for (
auto* rep : reps) {
149 rep->extrapolateToLine(state, TVector3(0, 0, -1000), TVector3(0, 0, 2000));
150 if (state.getMom().Dot(
B2Vector3D(ndFinderTrack.getDirection())) < 0) {
151 state.setPosMom(state.getPos(), -state.getMom());
152 state.setChargeSign(-state.getCharge());
154 z = state.getPos().Z();
155 theta = state.getMom().Theta();
156 const auto& pRaw = state.getMom();
157 float totalMomentum = std::sqrt(pRaw.Px() * pRaw.Px() + pRaw.Py() * pRaw.Py() + pRaw.Pz() * pRaw.Pz());
158 bool fromIP = (std::abs(z) <= 1);
159 if (fromIP) classificationNNT = 1.0f;
160 if (fromIP && totalMomentum >= 0.7) classificationSTT = 1.0f;
168 std::vector<float> rawTarget = {
169 isFakeTrack ? 0.0f : z,
170 isFakeTrack ?
static_cast<float>(std::atan2(1., ndFinderTrack.getCotTheta())) : theta,
171 isFakeTrack ? -1.0f : classificationNNT,
172 isFakeTrack ? -1.0f : classificationSTT
175 auto scaled = m_neuroTrigger3DH.scaleTarget(rawTarget);
176 for (
float& v : scaled) {
177 v = std::clamp(v, -1.0f, 1.0f);
180 TrackType trackType = determineTrackType(classificationNNT, isFakeEvent, isUnrelatedFake);
182 targetResult.targetVector = scaled;
183 targetResult.trackType =
static_cast<unsigned short>(trackType);
188CDCTrigger3DHNeuroDataModule::TrackType CDCTrigger3DHNeuroDataModule::determineTrackType(
189 const float classificationNNT,
const bool isFakeEvent,
const bool isUnrelatedFake)
const
191 if (classificationNNT == 1.0f)
return TrackType::Real;
192 if (isFakeEvent)
return TrackType::Fake;
193 if (isUnrelatedFake)
return TrackType::UnrelatedFake;
194 if (classificationNNT == -1.0f)
return TrackType::Background;
195 return TrackType::Unknown;
virtual void initialize() override
Initialize the Module.
virtual void event() override
This method is the core of the module.
This is the Reconstruction Event-Data Model Track.
const std::vector< genfit::AbsTrackRep * > & getRepresentations() const
Return a list of track representations. You are not allowed to modify or delete them!
bool wasFitSuccessful(const genfit::AbsTrackRep *representation=nullptr) const
Returns true if the last fit with the given representation was successful.
const genfit::MeasuredStateOnPlane & getMeasuredStateOnPlaneClosestTo(const ROOT::Math::XYZVector &closestPoint, const genfit::AbsTrackRep *representation=nullptr)
Return genfit's MasuredStateOnPlane, that is closest to the given point useful for extrapolation of m...
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
B2Vector3< double > B2Vector3D
typedef for common usage with double
Abstract base class for different kinds of events.