Belle II Software  release-06-01-15
MLSegmentNetworkProducerModule.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <tracking/modules/vxdtfRedesign/MLSegmentNetworkProducerModule.h>
10 
11 #include <framework/logging/Logger.h>
12 
13 #include <fstream>
14 
15 using namespace Belle2;
16 using namespace std;
17 
18 REG_MODULE(MLSegmentNetworkProducer)
19 
21 {
22  setDescription("SegmentNetwork Producer Module with a machine learning classifier as three hit filter.");
23 
24  addParam("FBDTFileName", m_PARAMfbdtFileName, "file where the FastBDT classifier is stored");
25  // addParam("collectMode", m_PARAMcollectMode, "set to true for collecting training data, false for applying the filter", false);
26 
27  addParam("networkInputName",
28  m_PARAMnetworkInputName,
29  "Name of the StoreObjPtr where the network container used in this module is stored", std::string(""));
30 
31  addParam("sectorMapName", m_PARAMsecMapName,
32  "The name of the SectorMap used for this instance.", string("testMap"));
33 
34  addParam("cutValue", m_PARAMcutVal,
35  "Cut value to be used for dividing the classifier output into signal (above) and background (below)",
36  0.5);
37 
38 }
39 
41 {
42  m_network.isRequired(m_PARAMnetworkInputName);
43 
44  if (m_PARAMcutVal < 0. || m_PARAMcutVal > 1.) {
45  B2ERROR("cutValue set to " << m_PARAMcutVal << " but has to be in [0,1]!");
46  }
47 
48  setupClassifier(m_PARAMfbdtFileName);
49  setupFilter();
50 }
51 
53 {
54  DirectedNodeNetwork<TrackNode, VoidMetaInfo>& hitNetwork = m_network->accessHitNetwork();
55  DirectedNodeNetwork<Segment<TrackNode>, CACell>& segmentNetwork = m_network->accessSegmentNetwork();
56  std::deque<Segment<TrackNode>>& segments = m_network->accessSegments();
57 
58  unsigned nAccepted{}, nRejected{}, nLinked{};
59 
60  for (const auto& outerHit : hitNetwork.getNodes()) {
61  for (const auto& centerHit : outerHit->getInnerNodes()) {
62  bool alreadyAdded = false; // skip adding Nodes twice into the network
63  for (const auto& innerHit : centerHit->getInnerNodes()) {
64  bool accepted = m_filter->accept(*(innerHit->getEntry().m_spacePoint),
65  *(centerHit->getEntry().m_spacePoint),
66  *(outerHit->getEntry().m_spacePoint));
67 
68  if (LogSystem::Instance().isLevelEnabled(LogConfig::c_Debug, 499, PACKAGENAME())) {
69  const auto& sp1 = innerHit->getEntry().m_spacePoint;
70  const auto& sp2 = centerHit->getEntry().m_spacePoint;
71  const auto& sp3 = outerHit->getEntry().m_spacePoint;
72  std::array<double, 9> coords{{ sp1->X(), sp1->Y(), sp1->Z(), sp2->X(), sp2->Y(), sp2->Z(), sp3->X(), sp3->Y(), sp3->Z() }};
73  double classOut = m_classifier->analyze(coords);
74  B2DEBUG(499, "Classifier output: " << classOut << ", cutValue: " << m_PARAMcutVal);
75  }
76 
77  if (!accepted) { nRejected++; continue; } // don't store combinations which have not been accepted
78  nAccepted++;
79 
80  segments.emplace_back(centerHit->getEntry().m_sector->getFullSecID(),
81  innerHit->getEntry().m_sector->getFullSecID(),
82  &centerHit->getEntry(),
83  &innerHit->getEntry());
84  Segment<TrackNode>* innerSegmentPointer = &segments.back();
85 
86  B2DEBUG(999, "buildSegmentNetwork: innerSegment: " << innerSegmentPointer->getName());
87  DirectedNode<Segment<TrackNode>, CACell>* tempInnerSegmentnode = segmentNetwork.getNode(innerSegmentPointer->getID());
88  if (tempInnerSegmentnode == nullptr) {
89  segmentNetwork.addNode(innerSegmentPointer->getID(), segments.back());
90  } else {
91  innerSegmentPointer = &(tempInnerSegmentnode->getEntry());
92  segments.pop_back();
93  }
94 
95  if (!alreadyAdded) {
96  // create outerSector
97  segments.emplace_back(outerHit->getEntry().m_sector->getFullSecID(),
98  centerHit->getEntry().m_sector->getFullSecID(),
99  &outerHit->getEntry(),
100  &centerHit->getEntry());
101  Segment<TrackNode>* outerSegmentPointer = &segments.back();
102  B2DEBUG(999, "buildSegmentNetwork: outerSegment(freshly created): " << outerSegmentPointer->getName() <<
103  " to be linked with inner segment: " << innerSegmentPointer->getName());
104 
105  DirectedNode<Segment<TrackNode>, CACell>* tempOuterSegmentnode = segmentNetwork.getNode(outerSegmentPointer->getID());
106  if (tempOuterSegmentnode == nullptr) {
107  segmentNetwork.addNode(outerSegmentPointer->getID(), segments.back());
108  } else {
109  outerSegmentPointer = &(tempOuterSegmentnode->getEntry());
110  segments.pop_back();
111  }
112 
113  B2DEBUG(999, "buildSegmentNetwork: outerSegment (after duplicate check): " << outerSegmentPointer->getName() <<
114  " to be linked with inner segment: " << innerSegmentPointer->getName());
115  segmentNetwork.linkNodes(outerSegmentPointer->getID(), innerSegmentPointer->getID());
116  nLinked++;
117  alreadyAdded = true;
118  continue;
119  }
120  if (segmentNetwork.addInnerToLastOuterNode(innerSegmentPointer->getID())) {
121  nLinked++;
122  }
123  } // end inner loop
124  } // end center loop
125  } // end outer loop
126 
127 
128  B2DEBUG(10, "MLSegmentNetworkProducerModule::buildSegmentNetwork(): nAccepted/nRejected: " << nAccepted << "/" << nRejected <<
129  ", size of nLinked/hitNetwork: " << nLinked << "/" << segmentNetwork.size());
130 
131 }
132 
134 {
135 
136 }
137 
138 
139 void MLSegmentNetworkProducerModule::setupClassifier(const std::string& filename)
140 {
141  ifstream filestr(filename);
142  if (!filestr.is_open()) {
143  B2FATAL("Could not open file: " << filename << " for reading in a FBDTClassifier");
144  }
145 
146  auto classifier = std::unique_ptr<FBDTClassifier<9> >(new FBDTClassifier<9>());
147  classifier->readFromStream(filestr);
148  filestr.close();
149 
150  m_classifier = std::move(classifier);
151 }
152 
154 {
155  using RangeT = MLRange<FBDTClassifier<9>, 9, double>;
156 
157  auto filter = std::unique_ptr<MLFilter>(new MLFilter(RangeT(m_classifier.get(), m_PARAMcutVal)));
158  m_filter = std::move(filter);
159 }
The CACell class This Class stores all relevant information one wants to have stored in a cell for a ...
Definition: CACell.h:20
Network of directed nodes of the type EntryType.
std::vector< Node * > & getNodes()
Returns all nodes of the network.
bool addInnerToLastOuterNode(NodeID innerNodeID)
to the last outerNode added, another innerNode will be attached
Node * getNode(NodeID toBeFound)
Returns pointer to the node carrying the entry which is equal to given parameter.
unsigned int size() const
Returns number of nodes to be found in the network.
bool linkNodes(NodeID outerNodeID, NodeID innerNodeID)
takes two entry IDs and weaves them into the network
bool addNode(NodeID nodeID, EntryType &newEntry)
************************* PUBLIC MEMBER FUNCTIONS *************************
The Node-Class.
Definition: DirectedNode.h:31
EntryType & getEntry()
Allows access to stored entry.
Definition: DirectedNode.h:92
This class is used to select pairs, triplets...
Definition: Filter.h:34
@ c_Debug
Debug: for code development.
Definition: LogConfig.h:26
static LogSystem & Instance()
Static method to get a reference to the LogSystem instance.
Definition: LogSystem.cc:31
Range used for the Machine Learning assisted TrackFinding approach.
Definition: MLRange.h:29
Segment network producer module with a Machine Learning classifier.
void setupClassifier(const std::string &filename)
construct the classifier from file
Base class for Modules.
Definition: Module.h:72
The Segment class This class represents segments of track candidates needed for TrackFinderVXD-Module...
Definition: Segment.h:25
std::int64_t getID() const
************************* PUBLIC MEMBER FUNCTIONS *************************
Definition: Segment.h:84
std::string getName() const
returns longer debugging name of this segment
Definition: Segment.h:87
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:650
std::map< ExpRun, std::pair< double, double > > filter(const std::map< ExpRun, std::pair< double, double >> &runs, double cut, std::map< ExpRun, std::pair< double, double >> &runsRemoved)
filter events to remove runs shorter than cut, it stores removed runs in runsRemoved
Definition: Splitter.cc:40
Abstract base class for different kinds of events.