9#include <tracking/modules/vxdtfRedesign/MLSegmentNetworkProducerModule.h>
11#include <framework/logging/Logger.h>
21 setDescription(
"SegmentNetwork Producer Module with a machine learning classifier as three hit filter.");
28 "Name of the StoreObjPtr where the network container used in this module is stored", std::string(
""));
31 "The name of the SectorMap used for this instance.", std::string(
"testMap"));
34 "Cut value to be used for dividing the classifier output into signal (above) and background (below)",
43 if (m_PARAMcutVal < 0. || m_PARAMcutVal > 1.) {
44 B2ERROR(
"cutValue set to " <<
m_PARAMcutVal <<
" but has to be in [0,1]!");
57 unsigned nAccepted{}, nRejected{}, nLinked{};
60 for (
const auto& centerHit :
outerHit->getInnerNodes()) {
61 bool alreadyAdded =
false;
62 for (
const auto& innerHit : centerHit->getInnerNodes()) {
63 bool accepted =
m_filter->accept(*(innerHit->getEntry().m_spacePoint),
64 *(centerHit->getEntry().m_spacePoint),
65 *(
outerHit->getEntry().m_spacePoint));
68 const auto& sp1 = innerHit->getEntry().m_spacePoint;
69 const auto& sp2 = centerHit->getEntry().m_spacePoint;
70 const auto& sp3 =
outerHit->getEntry().m_spacePoint;
71 std::array<double, 9> coords{{ sp1->
X(), sp1->Y(), sp1->Z(), sp2->X(), sp2->Y(), sp2->Z(), sp3->X(), sp3->Y(), sp3->Z() }};
73 B2DEBUG(25,
"Classifier output: " << classOut <<
", cutValue: " <<
m_PARAMcutVal);
76 if (!accepted) { nRejected++;
continue; }
79 segments.emplace_back(centerHit->getEntry().m_sector->getFullSecID(),
80 innerHit->getEntry().m_sector->getFullSecID(),
81 ¢erHit->getEntry(),
82 &innerHit->getEntry());
85 B2DEBUG(29,
"buildSegmentNetwork: innerSegment: " << innerSegmentPointer->
getName());
87 if (tempInnerSegmentnode ==
nullptr) {
88 segmentNetwork.
addNode(innerSegmentPointer->
getID(), segments.back());
90 innerSegmentPointer = &(tempInnerSegmentnode->
getEntry());
96 segments.emplace_back(
outerHit->getEntry().m_sector->getFullSecID(),
97 centerHit->getEntry().m_sector->getFullSecID(),
99 ¢erHit->getEntry());
101 B2DEBUG(29,
"buildSegmentNetwork: outerSegment(freshly created): " << outerSegmentPointer->
getName() <<
102 " to be linked with inner segment: " << innerSegmentPointer->
getName());
105 if (tempOuterSegmentnode ==
nullptr) {
106 segmentNetwork.
addNode(outerSegmentPointer->
getID(), segments.back());
108 outerSegmentPointer = &(tempOuterSegmentnode->
getEntry());
112 B2DEBUG(29,
"buildSegmentNetwork: outerSegment (after duplicate check): " << outerSegmentPointer->
getName() <<
113 " to be linked with inner segment: " << innerSegmentPointer->
getName());
127 B2DEBUG(20,
"MLSegmentNetworkProducerModule::buildSegmentNetwork(): nAccepted/nRejected: " << nAccepted <<
"/" << nRejected <<
128 ", size of nLinked/hitNetwork: " << nLinked <<
"/" << segmentNetwork.
size());
140 std::ifstream filestr(filename);
141 if (!filestr.is_open()) {
142 B2FATAL(
"Could not open file: " << filename <<
" for reading in a FBDTClassifier");
146 classifier->readFromStream(filestr);
DataType X() const
access variable X (= .at(0) without boundary check)
The CACell class This Class stores all relevant information one wants to have stored in a cell for a ...
DirectedNodeNetwork< Belle2::TrackNode, Belle2::VoidMetaInfo > & accessHitNetwork()
Returns reference to the HitNetwork stored in this container, intended for read and write access.
std::deque< Belle2::Segment< Belle2::TrackNode > > & accessSegments()
Returns reference to the actual segments stored in this container, intended for read and write access...
DirectedNodeNetwork< Belle2::Segment< Belle2::TrackNode >, Belle2::CACell > & accessSegmentNetwork()
Returns reference to the SegmentNetwork stored in this container, intended for read and write access.
Network of directed nodes of the type EntryType.
Node * getNode(NodeID toBeFound)
Returns pointer to the node carrying the entry which is equal to given parameter.
std::vector< Node * > & getNodes()
Returns all nodes of the network.
bool addInnerToLastOuterNode(NodeID innerNodeID)
to the last outerNode added, another innerNode will be attached
unsigned int size() const
Returns number of nodes to be found in the network.
bool linkNodes(NodeID outerNodeID, NodeID innerNodeID)
takes two entry IDs and weaves them into the network
bool addNode(NodeID nodeID, EntryType &newEntry)
************************* PUBLIC MEMBER FUNCTIONS *************************
EntryType & getEntry()
Allows access to stored entry.
FastBDT as RelationsObject to make it storable and accessible on/via the DataStore.
@ c_Debug
Debug: for code development.
static LogSystem & Instance()
Static method to get a reference to the LogSystem instance.
Range used for the Machine Learning assisted TrackFinding approach.
std::unique_ptr< Belle2::FBDTClassifier< 9 > > m_classifier
classifier used throughout this module
void initialize() override
initialize module
Belle2::StoreObjPtr< Belle2::DirectedNodeNetworkContainer > m_network
StoreObjPtr to access the DNNs that are used in this module.
void event() override
event
Belle2::Filter< Belle2::MLHandover< Belle2::SpacePoint, 9 >, Belle2::MLRange< Belle2::FBDTClassifier< 9 >, 9, double >, VoidObserver > MLFilter
typedef with complete definition
void terminate() override
terminate module
void setupClassifier(const std::string &filename)
construct the classifier from file
std::string m_PARAMsecMapName
the name of the used SectorMap.
double m_PARAMcutVal
cut value to be used with classifier
std::string m_PARAMnetworkInputName
name of the StoreObjPtr pointing to the network container used in this module.
std::string m_PARAMfbdtFileName
file where the FastBDT classifier is stored.
void setupFilter()
setup the filter
std::unique_ptr< MLFilter > m_filter
internal three hit filter
MLSegmentNetworkProducerModule()
module constructor
void setDescription(const std::string &description)
Sets the description of the module.
The Segment class This class represents segments of track candidates needed for TrackFinderVXD-Module...
std::int64_t getID() const
************************* PUBLIC MEMBER FUNCTIONS *************************
std::string getName() const
returns longer debugging name of this segment
bool isRequired(const std::string &name="")
Ensure this array/object has been registered previously.
void addParam(const std::string &name, T ¶mVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
std::map< ExpRun, std::pair< double, double > > filter(const std::map< ExpRun, std::pair< double, double > > &runs, double cut, std::map< ExpRun, std::pair< double, double > > &runsRemoved)
filter events to remove runs shorter than cut, it stores removed runs in runsRemoved
B2Vector3D outerHit(0, 0, 0)
testing out of range behavior
Abstract base class for different kinds of events.