11 #include <tracking/modules/vxdtfRedesign/MLSegmentNetworkProducerModule.h>
13 #include <framework/logging/Logger.h>
24 setDescription(
"SegmentNetwork Producer Module with a machine learning classifier as three hit filter.");
26 addParam(
"FBDTFileName", m_PARAMfbdtFileName,
"file where the FastBDT classifier is stored");
29 addParam(
"networkInputName",
30 m_PARAMnetworkInputName,
31 "Name of the StoreObjPtr where the network container used in this module is stored", std::string(
""));
33 addParam(
"sectorMapName", m_PARAMsecMapName,
34 "The name of the SectorMap used for this instance.",
string(
"testMap"));
36 addParam(
"cutValue", m_PARAMcutVal,
37 "Cut value to be used for dividing the classifier output into signal (above) and background (below)",
44 m_network.isRequired(m_PARAMnetworkInputName);
46 if (m_PARAMcutVal < 0. || m_PARAMcutVal > 1.) {
47 B2ERROR(
"cutValue set to " << m_PARAMcutVal <<
" but has to be in [0,1]!");
50 setupClassifier(m_PARAMfbdtFileName);
58 std::deque<Segment<TrackNode>>& segments = m_network->accessSegments();
60 unsigned nAccepted{}, nRejected{}, nLinked{};
62 for (
const auto& outerHit : hitNetwork.
getNodes()) {
63 for (
const auto& centerHit : outerHit->getInnerNodes()) {
64 bool alreadyAdded =
false;
65 for (
const auto& innerHit : centerHit->getInnerNodes()) {
66 bool accepted = m_filter->accept(*(innerHit->getEntry().m_spacePoint),
67 *(centerHit->getEntry().m_spacePoint),
68 *(outerHit->getEntry().m_spacePoint));
71 const auto& sp1 = innerHit->getEntry().m_spacePoint;
72 const auto& sp2 = centerHit->getEntry().m_spacePoint;
73 const auto& sp3 = outerHit->getEntry().m_spacePoint;
74 std::array<double, 9> coords{{ sp1->X(), sp1->Y(), sp1->Z(), sp2->X(), sp2->Y(), sp2->Z(), sp3->X(), sp3->Y(), sp3->Z() }};
75 double classOut = m_classifier->analyze(coords);
76 B2DEBUG(499,
"Classifier output: " << classOut <<
", cutValue: " << m_PARAMcutVal);
79 if (!accepted) { nRejected++;
continue; }
82 segments.emplace_back(centerHit->getEntry().m_sector->getFullSecID(),
83 innerHit->getEntry().m_sector->getFullSecID(),
84 ¢erHit->getEntry(),
85 &innerHit->getEntry());
88 B2DEBUG(999,
"buildSegmentNetwork: innerSegment: " << innerSegmentPointer->
getName());
90 if (tempInnerSegmentnode ==
nullptr) {
91 segmentNetwork.
addNode(innerSegmentPointer->
getID(), segments.back());
93 innerSegmentPointer = &(tempInnerSegmentnode->
getEntry());
99 segments.emplace_back(outerHit->getEntry().m_sector->getFullSecID(),
100 centerHit->getEntry().m_sector->getFullSecID(),
101 &outerHit->getEntry(),
102 ¢erHit->getEntry());
104 B2DEBUG(999,
"buildSegmentNetwork: outerSegment(freshly created): " << outerSegmentPointer->
getName() <<
105 " to be linked with inner segment: " << innerSegmentPointer->
getName());
108 if (tempOuterSegmentnode ==
nullptr) {
109 segmentNetwork.
addNode(outerSegmentPointer->
getID(), segments.back());
111 outerSegmentPointer = &(tempOuterSegmentnode->
getEntry());
115 B2DEBUG(999,
"buildSegmentNetwork: outerSegment (after duplicate check): " << outerSegmentPointer->
getName() <<
116 " to be linked with inner segment: " << innerSegmentPointer->
getName());
130 B2DEBUG(10,
"MLSegmentNetworkProducerModule::buildSegmentNetwork(): nAccepted/nRejected: " << nAccepted <<
"/" << nRejected <<
131 ", size of nLinked/hitNetwork: " << nLinked <<
"/" << segmentNetwork.
size());
143 ifstream filestr(filename);
144 if (!filestr.is_open()) {
145 B2FATAL(
"Could not open file: " << filename <<
" for reading in a FBDTClassifier");
149 classifier->readFromStream(filestr);
152 m_classifier = std::move(classifier);
159 auto filter = std::unique_ptr<MLFilter>(
new MLFilter(RangeT(m_classifier.get(), m_PARAMcutVal)));
160 m_filter = std::move(
filter);