Belle II Software development
MLSegmentNetworkProducerModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <tracking/modules/vxdtfRedesign/MLSegmentNetworkProducerModule.h>
10
11#include <framework/logging/Logger.h>
12
13#include <fstream>
14
15using namespace Belle2;
16
17REG_MODULE(MLSegmentNetworkProducer);
18
20{
21 setDescription("SegmentNetwork Producer Module with a machine learning classifier as three hit filter.");
22
23 addParam("FBDTFileName", m_PARAMfbdtFileName, "file where the FastBDT classifier is stored");
24 // addParam("collectMode", m_PARAMcollectMode, "set to true for collecting training data, false for applying the filter", false);
25
26 addParam("networkInputName",
28 "Name of the StoreObjPtr where the network container used in this module is stored", std::string(""));
29
30 addParam("sectorMapName", m_PARAMsecMapName,
31 "The name of the SectorMap used for this instance.", std::string("testMap"));
32
33 addParam("cutValue", m_PARAMcutVal,
34 "Cut value to be used for dividing the classifier output into signal (above) and background (below)",
35 0.5);
36
37}
38
40{
42
43 if (m_PARAMcutVal < 0. || m_PARAMcutVal > 1.) {
44 B2ERROR("cutValue set to " << m_PARAMcutVal << " but has to be in [0,1]!");
45 }
46
49}
50
52{
55 std::deque<Segment<TrackNode>>& segments = m_network->accessSegments();
56
57 unsigned nAccepted{}, nRejected{}, nLinked{};
58
59 for (const auto& outerHit : hitNetwork.getNodes()) {
60 for (const auto& centerHit : outerHit->getInnerNodes()) {
61 bool alreadyAdded = false; // skip adding Nodes twice into the network
62 for (const auto& innerHit : centerHit->getInnerNodes()) {
63 bool accepted = m_filter->accept(*(innerHit->getEntry().m_spacePoint),
64 *(centerHit->getEntry().m_spacePoint),
65 *(outerHit->getEntry().m_spacePoint));
66
67 if (LogSystem::Instance().isLevelEnabled(LogConfig::c_Debug, 499, PACKAGENAME())) {
68 const auto& sp1 = innerHit->getEntry().m_spacePoint;
69 const auto& sp2 = centerHit->getEntry().m_spacePoint;
70 const auto& sp3 = outerHit->getEntry().m_spacePoint;
71 std::array<double, 9> coords{{ sp1->X(), sp1->Y(), sp1->Z(), sp2->X(), sp2->Y(), sp2->Z(), sp3->X(), sp3->Y(), sp3->Z() }};
72 double classOut = m_classifier->analyze(coords);
73 B2DEBUG(25, "Classifier output: " << classOut << ", cutValue: " << m_PARAMcutVal);
74 }
75
76 if (!accepted) { nRejected++; continue; } // don't store combinations which have not been accepted
77 nAccepted++;
78
79 segments.emplace_back(centerHit->getEntry().m_sector->getFullSecID(),
80 innerHit->getEntry().m_sector->getFullSecID(),
81 &centerHit->getEntry(),
82 &innerHit->getEntry());
83 Segment<TrackNode>* innerSegmentPointer = &segments.back();
84
85 B2DEBUG(29, "buildSegmentNetwork: innerSegment: " << innerSegmentPointer->getName());
86 DirectedNode<Segment<TrackNode>, CACell>* tempInnerSegmentnode = segmentNetwork.getNode(innerSegmentPointer->getID());
87 if (tempInnerSegmentnode == nullptr) {
88 segmentNetwork.addNode(innerSegmentPointer->getID(), segments.back());
89 } else {
90 innerSegmentPointer = &(tempInnerSegmentnode->getEntry());
91 segments.pop_back();
92 }
93
94 if (!alreadyAdded) {
95 // create outerSector
96 segments.emplace_back(outerHit->getEntry().m_sector->getFullSecID(),
97 centerHit->getEntry().m_sector->getFullSecID(),
98 &outerHit->getEntry(),
99 &centerHit->getEntry());
100 Segment<TrackNode>* outerSegmentPointer = &segments.back();
101 B2DEBUG(29, "buildSegmentNetwork: outerSegment(freshly created): " << outerSegmentPointer->getName() <<
102 " to be linked with inner segment: " << innerSegmentPointer->getName());
103
104 DirectedNode<Segment<TrackNode>, CACell>* tempOuterSegmentnode = segmentNetwork.getNode(outerSegmentPointer->getID());
105 if (tempOuterSegmentnode == nullptr) {
106 segmentNetwork.addNode(outerSegmentPointer->getID(), segments.back());
107 } else {
108 outerSegmentPointer = &(tempOuterSegmentnode->getEntry());
109 segments.pop_back();
110 }
111
112 B2DEBUG(29, "buildSegmentNetwork: outerSegment (after duplicate check): " << outerSegmentPointer->getName() <<
113 " to be linked with inner segment: " << innerSegmentPointer->getName());
114 segmentNetwork.linkNodes(outerSegmentPointer->getID(), innerSegmentPointer->getID());
115 nLinked++;
116 alreadyAdded = true;
117 continue;
118 }
119 if (segmentNetwork.addInnerToLastOuterNode(innerSegmentPointer->getID())) {
120 nLinked++;
121 }
122 } // end inner loop
123 } // end center loop
124 } // end outer loop
125
126
127 B2DEBUG(20, "MLSegmentNetworkProducerModule::buildSegmentNetwork(): nAccepted/nRejected: " << nAccepted << "/" << nRejected <<
128 ", size of nLinked/hitNetwork: " << nLinked << "/" << segmentNetwork.size());
129
130}
131
133{
134
135}
136
137
138void MLSegmentNetworkProducerModule::setupClassifier(const std::string& filename)
139{
140 std::ifstream filestr(filename);
141 if (!filestr.is_open()) {
142 B2FATAL("Could not open file: " << filename << " for reading in a FBDTClassifier");
143 }
144
145 auto classifier = std::unique_ptr<FBDTClassifier<9> >(new FBDTClassifier<9>());
146 classifier->readFromStream(filestr);
147 filestr.close();
148
149 m_classifier = std::move(classifier);
150}
151
153{
154 using RangeT = MLRange<FBDTClassifier<9>, 9, double>;
155
156 auto filter = std::unique_ptr<MLFilter>(new MLFilter(RangeT(m_classifier.get(), m_PARAMcutVal)));
157 m_filter = std::move(filter);
158}
DataType X() const
access variable X (= .at(0) without boundary check)
Definition: B2Vector3.h:431
The CACell class This Class stores all relevant information one wants to have stored in a cell for a ...
Definition: CACell.h:20
DirectedNodeNetwork< Belle2::TrackNode, Belle2::VoidMetaInfo > & accessHitNetwork()
Returns reference to the HitNetwork stored in this container, intended for read and write access.
std::deque< Belle2::Segment< Belle2::TrackNode > > & accessSegments()
Returns reference to the actual segments stored in this container, intended for read and write access...
DirectedNodeNetwork< Belle2::Segment< Belle2::TrackNode >, Belle2::CACell > & accessSegmentNetwork()
Returns reference to the SegmentNetwork stored in this container, intended for read and write access.
Network of directed nodes of the type EntryType.
Node * getNode(NodeID toBeFound)
Returns pointer to the node carrying the entry which is equal to given parameter.
std::vector< Node * > & getNodes()
Returns all nodes of the network.
bool addInnerToLastOuterNode(NodeID innerNodeID)
to the last outerNode added, another innerNode will be attached
unsigned int size() const
Returns number of nodes to be found in the network.
bool linkNodes(NodeID outerNodeID, NodeID innerNodeID)
takes two entry IDs and weaves them into the network
bool addNode(NodeID nodeID, EntryType &newEntry)
************************* PUBLIC MEMBER FUNCTIONS *************************
The Node-Class.
Definition: DirectedNode.h:31
EntryType & getEntry()
Allows access to stored entry.
Definition: DirectedNode.h:92
FastBDT as RelationsObject to make it storable and accessible on/via the DataStore.
@ c_Debug
Debug: for code development.
Definition: LogConfig.h:26
static LogSystem & Instance()
Static method to get a reference to the LogSystem instance.
Definition: LogSystem.cc:31
Range used for the Machine Learning assisted TrackFinding approach.
Definition: MLRange.h:29
std::unique_ptr< Belle2::FBDTClassifier< 9 > > m_classifier
classifier used throughout this module
Belle2::StoreObjPtr< Belle2::DirectedNodeNetworkContainer > m_network
StoreObjPtr to access the DNNs that are used in this module.
Belle2::Filter< Belle2::MLHandover< Belle2::SpacePoint, 9 >, Belle2::MLRange< Belle2::FBDTClassifier< 9 >, 9, double >, VoidObserver > MLFilter
typedef with complete definition
void setupClassifier(const std::string &filename)
construct the classifier from file
std::string m_PARAMsecMapName
the name of the used SectorMap.
double m_PARAMcutVal
cut value to be used with classifier
std::string m_PARAMnetworkInputName
name of the StoreObjPtr pointing to the network container used in this module.
std::string m_PARAMfbdtFileName
file where the FastBDT classifier is stored.
std::unique_ptr< MLFilter > m_filter
internal three hit filter
Base class for Modules.
Definition: Module.h:72
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:214
The Segment class This class represents segments of track candidates needed for TrackFinderVXD-Module...
Definition: Segment.h:25
std::int64_t getID() const
************************* PUBLIC MEMBER FUNCTIONS *************************
Definition: Segment.h:84
std::string getName() const
returns longer debugging name of this segment
Definition: Segment.h:87
bool isRequired(const std::string &name="")
Ensure this array/object has been registered previously.
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:560
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:650
std::map< ExpRun, std::pair< double, double > > filter(const std::map< ExpRun, std::pair< double, double > > &runs, double cut, std::map< ExpRun, std::pair< double, double > > &runsRemoved)
filter events to remove runs shorter than cut, it stores removed runs in runsRemoved
Definition: Splitter.cc:38
B2Vector3D outerHit(0, 0, 0)
testing out of range behavior
Abstract base class for different kinds of events.