Belle II Software  release-08-01-10
MLSegmentNetworkProducerModule.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <tracking/modules/vxdtfRedesign/MLSegmentNetworkProducerModule.h>
10 
11 #include <framework/logging/Logger.h>
12 
13 #include <fstream>
14 
15 using namespace Belle2;
16 
17 REG_MODULE(MLSegmentNetworkProducer);
18 
20 {
21  setDescription("SegmentNetwork Producer Module with a machine learning classifier as three hit filter.");
22 
23  addParam("FBDTFileName", m_PARAMfbdtFileName, "file where the FastBDT classifier is stored");
24  // addParam("collectMode", m_PARAMcollectMode, "set to true for collecting training data, false for applying the filter", false);
25 
26  addParam("networkInputName",
28  "Name of the StoreObjPtr where the network container used in this module is stored", std::string(""));
29 
30  addParam("sectorMapName", m_PARAMsecMapName,
31  "The name of the SectorMap used for this instance.", std::string("testMap"));
32 
33  addParam("cutValue", m_PARAMcutVal,
34  "Cut value to be used for dividing the classifier output into signal (above) and background (below)",
35  0.5);
36 
37 }
38 
40 {
42 
43  if (m_PARAMcutVal < 0. || m_PARAMcutVal > 1.) {
44  B2ERROR("cutValue set to " << m_PARAMcutVal << " but has to be in [0,1]!");
45  }
46 
48  setupFilter();
49 }
50 
52 {
55  std::deque<Segment<TrackNode>>& segments = m_network->accessSegments();
56 
57  unsigned nAccepted{}, nRejected{}, nLinked{};
58 
59  for (const auto& outerHit : hitNetwork.getNodes()) {
60  for (const auto& centerHit : outerHit->getInnerNodes()) {
61  bool alreadyAdded = false; // skip adding Nodes twice into the network
62  for (const auto& innerHit : centerHit->getInnerNodes()) {
63  bool accepted = m_filter->accept(*(innerHit->getEntry().m_spacePoint),
64  *(centerHit->getEntry().m_spacePoint),
65  *(outerHit->getEntry().m_spacePoint));
66 
67  if (LogSystem::Instance().isLevelEnabled(LogConfig::c_Debug, 499, PACKAGENAME())) {
68  const auto& sp1 = innerHit->getEntry().m_spacePoint;
69  const auto& sp2 = centerHit->getEntry().m_spacePoint;
70  const auto& sp3 = outerHit->getEntry().m_spacePoint;
71  std::array<double, 9> coords{{ sp1->X(), sp1->Y(), sp1->Z(), sp2->X(), sp2->Y(), sp2->Z(), sp3->X(), sp3->Y(), sp3->Z() }};
72  double classOut = m_classifier->analyze(coords);
73  B2DEBUG(25, "Classifier output: " << classOut << ", cutValue: " << m_PARAMcutVal);
74  }
75 
76  if (!accepted) { nRejected++; continue; } // don't store combinations which have not been accepted
77  nAccepted++;
78 
79  segments.emplace_back(centerHit->getEntry().m_sector->getFullSecID(),
80  innerHit->getEntry().m_sector->getFullSecID(),
81  &centerHit->getEntry(),
82  &innerHit->getEntry());
83  Segment<TrackNode>* innerSegmentPointer = &segments.back();
84 
85  B2DEBUG(29, "buildSegmentNetwork: innerSegment: " << innerSegmentPointer->getName());
86  DirectedNode<Segment<TrackNode>, CACell>* tempInnerSegmentnode = segmentNetwork.getNode(innerSegmentPointer->getID());
87  if (tempInnerSegmentnode == nullptr) {
88  segmentNetwork.addNode(innerSegmentPointer->getID(), segments.back());
89  } else {
90  innerSegmentPointer = &(tempInnerSegmentnode->getEntry());
91  segments.pop_back();
92  }
93 
94  if (!alreadyAdded) {
95  // create outerSector
96  segments.emplace_back(outerHit->getEntry().m_sector->getFullSecID(),
97  centerHit->getEntry().m_sector->getFullSecID(),
98  &outerHit->getEntry(),
99  &centerHit->getEntry());
100  Segment<TrackNode>* outerSegmentPointer = &segments.back();
101  B2DEBUG(29, "buildSegmentNetwork: outerSegment(freshly created): " << outerSegmentPointer->getName() <<
102  " to be linked with inner segment: " << innerSegmentPointer->getName());
103 
104  DirectedNode<Segment<TrackNode>, CACell>* tempOuterSegmentnode = segmentNetwork.getNode(outerSegmentPointer->getID());
105  if (tempOuterSegmentnode == nullptr) {
106  segmentNetwork.addNode(outerSegmentPointer->getID(), segments.back());
107  } else {
108  outerSegmentPointer = &(tempOuterSegmentnode->getEntry());
109  segments.pop_back();
110  }
111 
112  B2DEBUG(29, "buildSegmentNetwork: outerSegment (after duplicate check): " << outerSegmentPointer->getName() <<
113  " to be linked with inner segment: " << innerSegmentPointer->getName());
114  segmentNetwork.linkNodes(outerSegmentPointer->getID(), innerSegmentPointer->getID());
115  nLinked++;
116  alreadyAdded = true;
117  continue;
118  }
119  if (segmentNetwork.addInnerToLastOuterNode(innerSegmentPointer->getID())) {
120  nLinked++;
121  }
122  } // end inner loop
123  } // end center loop
124  } // end outer loop
125 
126 
127  B2DEBUG(20, "MLSegmentNetworkProducerModule::buildSegmentNetwork(): nAccepted/nRejected: " << nAccepted << "/" << nRejected <<
128  ", size of nLinked/hitNetwork: " << nLinked << "/" << segmentNetwork.size());
129 
130 }
131 
133 {
134 
135 }
136 
137 
138 void MLSegmentNetworkProducerModule::setupClassifier(const std::string& filename)
139 {
140  std::ifstream filestr(filename);
141  if (!filestr.is_open()) {
142  B2FATAL("Could not open file: " << filename << " for reading in a FBDTClassifier");
143  }
144 
145  auto classifier = std::unique_ptr<FBDTClassifier<9> >(new FBDTClassifier<9>());
146  classifier->readFromStream(filestr);
147  filestr.close();
148 
149  m_classifier = std::move(classifier);
150 }
151 
153 {
154  using RangeT = MLRange<FBDTClassifier<9>, 9, double>;
155 
156  auto filter = std::unique_ptr<MLFilter>(new MLFilter(RangeT(m_classifier.get(), m_PARAMcutVal)));
157  m_filter = std::move(filter);
158 }
The CACell class This Class stores all relevant information one wants to have stored in a cell for a ...
Definition: CACell.h:20
DirectedNodeNetwork< Belle2::TrackNode, Belle2::VoidMetaInfo > & accessHitNetwork()
Returns reference to the HitNetwork stored in this container, intended for read and write access.
std::deque< Belle2::Segment< Belle2::TrackNode > > & accessSegments()
Returns reference to the actual segments stored in this container, intended for read and write access...
DirectedNodeNetwork< Belle2::Segment< Belle2::TrackNode >, Belle2::CACell > & accessSegmentNetwork()
Returns reference to the SegmentNetwork stored in this container, intended for read and write access.
Network of directed nodes of the type EntryType.
std::vector< Node * > & getNodes()
Returns all nodes of the network.
bool addInnerToLastOuterNode(NodeID innerNodeID)
to the last outerNode added, another innerNode will be attached
Node * getNode(NodeID toBeFound)
Returns pointer to the node carrying the entry which is equal to given parameter.
unsigned int size() const
Returns number of nodes to be found in the network.
bool linkNodes(NodeID outerNodeID, NodeID innerNodeID)
takes two entry IDs and weaves them into the network
bool addNode(NodeID nodeID, EntryType &newEntry)
************************* PUBLIC MEMBER FUNCTIONS *************************
The Node-Class.
Definition: DirectedNode.h:31
EntryType & getEntry()
Allows access to stored entry.
Definition: DirectedNode.h:92
@ c_Debug
Debug: for code development.
Definition: LogConfig.h:26
static LogSystem & Instance()
Static method to get a reference to the LogSystem instance.
Definition: LogSystem.cc:31
Range used for the Machine Learning assisted TrackFinding approach.
Definition: MLRange.h:29
std::unique_ptr< Belle2::FBDTClassifier< 9 > > m_classifier
classifier used throughout this module
Belle2::StoreObjPtr< Belle2::DirectedNodeNetworkContainer > m_network
StoreObjPtr to access the DNNs that are used in this module.
Belle2::Filter< Belle2::MLHandover< Belle2::SpacePoint, 9 >, Belle2::MLRange< Belle2::FBDTClassifier< 9 >, 9, double >, VoidObserver > MLFilter
typedef with complete definition
void setupClassifier(const std::string &filename)
construct the classifier from file
std::string m_PARAMsecMapName
the name of the used SectorMap.
double m_PARAMcutVal
cut value to be used with classifier
std::string m_PARAMnetworkInputName
name of the StoreObjPtr pointing to the network container used in this module.
std::string m_PARAMfbdtFileName
file where the FastBDT classifier is stored.
std::unique_ptr< MLFilter > m_filter
internal three hit filter
Base class for Modules.
Definition: Module.h:72
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:214
The Segment class This class represents segments of track candidates needed for TrackFinderVXD-Module...
Definition: Segment.h:25
std::int64_t getID() const
************************* PUBLIC MEMBER FUNCTIONS *************************
Definition: Segment.h:84
std::string getName() const
returns longer debugging name of this segment
Definition: Segment.h:87
bool isRequired(const std::string &name="")
Ensure this array/object has been registered previously.
REG_MODULE(arichBtest)
Register the Module.
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:560
std::map< ExpRun, std::pair< double, double > > filter(const std::map< ExpRun, std::pair< double, double >> &runs, double cut, std::map< ExpRun, std::pair< double, double >> &runsRemoved)
filter events to remove runs shorter than cut, it stores removed runs in runsRemoved
Definition: Splitter.cc:38
Abstract base class for different kinds of events.