9 #include <tracking/modules/vxdtfRedesign/MLSegmentNetworkProducerModule.h>
11 #include <framework/logging/Logger.h>
22 setDescription(
"SegmentNetwork Producer Module with a machine learning classifier as three hit filter.");
24 addParam(
"FBDTFileName", m_PARAMfbdtFileName,
"file where the FastBDT classifier is stored");
27 addParam(
"networkInputName",
28 m_PARAMnetworkInputName,
29 "Name of the StoreObjPtr where the network container used in this module is stored", std::string(
""));
31 addParam(
"sectorMapName", m_PARAMsecMapName,
32 "The name of the SectorMap used for this instance.",
string(
"testMap"));
34 addParam(
"cutValue", m_PARAMcutVal,
35 "Cut value to be used for dividing the classifier output into signal (above) and background (below)",
42 m_network.isRequired(m_PARAMnetworkInputName);
44 if (m_PARAMcutVal < 0. || m_PARAMcutVal > 1.) {
45 B2ERROR(
"cutValue set to " << m_PARAMcutVal <<
" but has to be in [0,1]!");
48 setupClassifier(m_PARAMfbdtFileName);
56 std::deque<Segment<TrackNode>>& segments = m_network->accessSegments();
58 unsigned nAccepted{}, nRejected{}, nLinked{};
60 for (
const auto& outerHit : hitNetwork.
getNodes()) {
61 for (
const auto& centerHit : outerHit->getInnerNodes()) {
62 bool alreadyAdded =
false;
63 for (
const auto& innerHit : centerHit->getInnerNodes()) {
64 bool accepted = m_filter->accept(*(innerHit->getEntry().m_spacePoint),
65 *(centerHit->getEntry().m_spacePoint),
66 *(outerHit->getEntry().m_spacePoint));
69 const auto& sp1 = innerHit->getEntry().m_spacePoint;
70 const auto& sp2 = centerHit->getEntry().m_spacePoint;
71 const auto& sp3 = outerHit->getEntry().m_spacePoint;
72 std::array<double, 9> coords{{ sp1->X(), sp1->Y(), sp1->Z(), sp2->X(), sp2->Y(), sp2->Z(), sp3->X(), sp3->Y(), sp3->Z() }};
73 double classOut = m_classifier->analyze(coords);
74 B2DEBUG(499,
"Classifier output: " << classOut <<
", cutValue: " << m_PARAMcutVal);
77 if (!accepted) { nRejected++;
continue; }
80 segments.emplace_back(centerHit->getEntry().m_sector->getFullSecID(),
81 innerHit->getEntry().m_sector->getFullSecID(),
82 ¢erHit->getEntry(),
83 &innerHit->getEntry());
86 B2DEBUG(999,
"buildSegmentNetwork: innerSegment: " << innerSegmentPointer->
getName());
88 if (tempInnerSegmentnode ==
nullptr) {
89 segmentNetwork.
addNode(innerSegmentPointer->
getID(), segments.back());
91 innerSegmentPointer = &(tempInnerSegmentnode->
getEntry());
97 segments.emplace_back(outerHit->getEntry().m_sector->getFullSecID(),
98 centerHit->getEntry().m_sector->getFullSecID(),
99 &outerHit->getEntry(),
100 ¢erHit->getEntry());
102 B2DEBUG(999,
"buildSegmentNetwork: outerSegment(freshly created): " << outerSegmentPointer->
getName() <<
103 " to be linked with inner segment: " << innerSegmentPointer->
getName());
106 if (tempOuterSegmentnode ==
nullptr) {
107 segmentNetwork.
addNode(outerSegmentPointer->
getID(), segments.back());
109 outerSegmentPointer = &(tempOuterSegmentnode->
getEntry());
113 B2DEBUG(999,
"buildSegmentNetwork: outerSegment (after duplicate check): " << outerSegmentPointer->
getName() <<
114 " to be linked with inner segment: " << innerSegmentPointer->
getName());
128 B2DEBUG(10,
"MLSegmentNetworkProducerModule::buildSegmentNetwork(): nAccepted/nRejected: " << nAccepted <<
"/" << nRejected <<
129 ", size of nLinked/hitNetwork: " << nLinked <<
"/" << segmentNetwork.
size());
141 ifstream filestr(filename);
142 if (!filestr.is_open()) {
143 B2FATAL(
"Could not open file: " << filename <<
" for reading in a FBDTClassifier");
147 classifier->readFromStream(filestr);
150 m_classifier = std::move(classifier);
157 auto filter = std::unique_ptr<MLFilter>(
new MLFilter(RangeT(m_classifier.get(), m_PARAMcutVal)));
158 m_filter = std::move(
filter);
The CACell class This Class stores all relevant information one wants to have stored in a cell for a ...
Network of directed nodes of the type EntryType.
std::vector< Node * > & getNodes()
Returns all nodes of the network.
bool addInnerToLastOuterNode(NodeID innerNodeID)
to the last outerNode added, another innerNode will be attached
Node * getNode(NodeID toBeFound)
Returns pointer to the node carrying the entry which is equal to given parameter.
unsigned int size() const
Returns number of nodes to be found in the network.
bool linkNodes(NodeID outerNodeID, NodeID innerNodeID)
takes two entry IDs and weaves them into the network
bool addNode(NodeID nodeID, EntryType &newEntry)
************************* PUBLIC MEMBER FUNCTIONS *************************
EntryType & getEntry()
Allows access to stored entry.
This class is used to select pairs, triplets...
@ c_Debug
Debug: for code development.
static LogSystem & Instance()
Static method to get a reference to the LogSystem instance.
Range used for the Machine Learning assisted TrackFinding approach.
Segment network producer module with a Machine Learning classifier.
void initialize() override
initialize module
void event() override
event
void terminate() override
terminate module
void setupClassifier(const std::string &filename)
construct the classifier from file
void setupFilter()
setup the filter
The Segment class This class represents segments of track candidates needed for TrackFinderVXD-Module...
std::int64_t getID() const
************************* PUBLIC MEMBER FUNCTIONS *************************
std::string getName() const
returns longer debugging name of this segment
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
std::map< ExpRun, std::pair< double, double > > filter(const std::map< ExpRun, std::pair< double, double >> &runs, double cut, std::map< ExpRun, std::pair< double, double >> &runsRemoved)
filter events to remove runs shorter than cut, it stores removed runs in runsRemoved
Abstract base class for different kinds of events.