Belle II Software development
CDCTrigger3DHNeuroModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include "trg/cdc/modules/neurotrigger/CDCTrigger3DHNeuroModule.h"
10
11#include <cmath>
12#include <vector>
13#include <array>
14
15#include "trg/cdc/dataobjects/CDCTriggerHoughMLP.h"
16#include "framework/database/DBObjPtr.h"
17
18using namespace Belle2;
19
20REG_MODULE(CDCTrigger3DHNeuro);
21
22CDCTrigger3DHNeuroModule::CDCTrigger3DHNeuroModule() : Module()
23{
24 setDescription(
25 "The NeuroTrigger3DH module of the CDC trigger.\n"
26 "Takes track segments and 3D track estimates as input and estimates\n"
27 "the z-vertex and classification for each track using a neural network.\n"
28 "Requires a trained network stored in a file.\n"
29 );
30 setPropertyFlags(c_ParallelProcessingCertified);
31 addParam("fileName", m_fileName,
32 "Name of the network to load. Contains the configuration parameters."
33 "When left blank, the network is loaded from the ConditionsDB.",
34 std::string(""));
35 addParam("arrayName", m_arrayName,
36 "Name of the TObjArray holding the NeuroTrigger3DH parameters.",
37 std::string("MLP"));
38 addParam("hitCollectionName", m_hitCollectionName,
39 "Name of the input StoreArray of CDCTriggerSegmentHits.",
40 std::string(""));
41 addParam("inputCollectionName", m_inputCollectionName,
42 "Name of the StoreArray holding the 3DHough Finder input tracks.",
43 std::string("TRGCDCNDFinderTracks"));
44 addParam("outputCollectionName", m_outputCollectionName,
45 "Name of the StoreArray holding the output Neuro tracks.",
46 std::string("TRGCDC3DHNeuroTracks"));
47 addParam("fixedPoint", m_fixedPoint,
48 "Switch to turn on fixed point arithmetic for FPGA simulation.",
49 false);
50 addParam("classificationCutNNT", m_classificationCutNNT,
51 "The across the board classification cut for the nnt (y)-bit (between -1 and 1).",
52 static_cast<double>(0.0));
53 addParam("classificationCutSTT", m_classificationCutSTT,
54 "The across the board classification cut for the stt-bit (between -1 and 1, for p > 0.7GeV).",
55 static_cast<double>(0.0));
56}
57
59{
60 m_neuroTrigger3DH.initialize();
61 const CDCTrigger3DHMLP& mlp = m_fileName.empty()
62 ? *m_CDCTrigger3DHMLPConditionsDB // Load from conditions database
63 : CDCTrigger3DHMLP::loadMLPFromFile<CDCTrigger3DHMLP>(m_fileName, m_arrayName); // Load from custom file
64 m_neuroTrigger3DH.setMLP(mlp);
65
66 if (m_fixedPoint) {
67 m_neuroTrigger3DH.createIntWeights();
68 }
69 m_neuroTrigger3DH.initializeCollections(m_hitCollectionName);
70
71 m_ndFinderTracks.isRequired(m_inputCollectionName);
72 m_trackSegmentHits.isRequired(m_hitCollectionName);
73 m_neuro3DHTracks.registerInDataStore(m_outputCollectionName);
74
75 m_ndFinderTracks.registerRelationTo(m_neuro3DHTracks);
76 m_ndFinderTracks.requireRelationTo(m_trackSegmentHits);
77 m_neuro3DHTracks.registerRelationTo(m_trackSegmentHits);
78}
79
81{
82 for (int trackIdx = 0; trackIdx < m_ndFinderTracks.getEntries(); ++trackIdx) {
83 // Setup of the MLP input
84 if (m_fixedPoint) {
85 m_neuroTrigger3DH.calculateTrackParametersFixedPrecision(*m_ndFinderTracks[trackIdx]);
86 } else {
87 m_neuroTrigger3DH.calculateTrackParameters(*m_ndFinderTracks[trackIdx]);
88 }
89 m_neuroTrigger3DH.setEventTime(*m_ndFinderTracks[trackIdx]);
90 std::vector<size_t> hitIds = m_neuroTrigger3DH.load3DHits(*m_ndFinderTracks[trackIdx]);
91 std::vector<float> networkInput = m_neuroTrigger3DH.getInputVector(hitIds);
92 // Run the MLP
93 std::vector<float> networkPrediction;
94 if (m_fixedPoint) {
95 networkPrediction = m_neuroTrigger3DH.runMLPFixedPrecision(networkInput);
96 } else {
97 networkPrediction = m_neuroTrigger3DH.runMLP(networkInput);
98 }
99 // Create a new track with the MLP output values
100 double trackPhi = m_ndFinderTracks[trackIdx]->getPhi0();
101 double trackOmega = m_ndFinderTracks[trackIdx]->getOmega();
102 double zPrediction = networkPrediction[0];
103 double cotPrediction = std::cos(networkPrediction[1]) / std::sin(networkPrediction[1]);
104 double classificationNNT = networkPrediction[2];
105 double classificationSTT = networkPrediction[3];
106
107
108 double trackTheta = std::atan2(1.0, cotPrediction);
109 if (trackTheta < 0) trackTheta += M_PI;
110 double totalMomentumPrediction = m_ndFinderTracks[trackIdx]->getPt() / std::sin(trackTheta);
111
112 // Set the Helix
113 CDCTrigger3DHTrack* neuroTrack =
114 m_neuro3DHTracks.appendNew(trackPhi, trackOmega, zPrediction, cotPrediction);
115
116 // Set further info
117 neuroTrack->setTime(m_neuroTrigger3DH.getEventTime());
118 neuroTrack->setFloatInput(networkInput);
119 neuroTrack->setTotalMomentum(totalMomentumPrediction);
120 neuroTrack->setNNTClassification(classificationNNT);
121 neuroTrack->setSTTClassification(classificationSTT);
122 neuroTrack->setQuadrant(m_ndFinderTracks[trackIdx]->getQuadrant());
123 neuroTrack->setValidTrackBit(m_ndFinderTracks[trackIdx]->getValidTrackBit());
124
125 // Get the trigger bits
126 bool nntBit = (classificationNNT > m_classificationCutNNT);
127 bool sttBit = (classificationSTT > m_classificationCutSTT);
128 neuroTrack->setNNTBit(nntBit);
129 neuroTrack->setSTTBit(sttBit);
130
131 // Set relation to 3DFinder track
132 m_ndFinderTracks[trackIdx]->addRelationTo(neuroTrack);
133
134 // Set relation to Track Segments and create TSVector
135 std::array<unsigned short, 9> tsVector{0};
136 for (unsigned int hitIdx = 0; hitIdx < hitIds.size(); ++hitIdx) {
137 neuroTrack->addRelationTo(m_trackSegmentHits[hitIds[hitIdx]]);
138 unsigned short superLayer = m_trackSegmentHits[hitIds[hitIdx]]->getISuperLayer();
139 tsVector[superLayer] = m_trackSegmentHits[hitIds[hitIdx]]->getLeftRight();
140 }
141 neuroTrack->setTSVector(tsVector);
142 }
143}
virtual void initialize() override
Initialize the Module.
virtual void event() override
This method is the core of the module.
Base class for Modules.
Definition Module.h:72
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition Module.h:649
Abstract base class for different kinds of events.