Belle II Software  release-05-02-19
MLSegmentNetworkProducerModule.cc
1 /*************************************************************************
2 * BASF2 (Belle Analysis Framework 2) *
3 * Copyright(C) 2016 - Belle II Collaboration *
4 * *
5 * Author: The Belle II Collaboration *
6 * Contributors: Thomas Madlener *
7 * *
8 * This software is provided "as is" without any warranty. *
9 **************************************************************************/
10 
11 #include <tracking/modules/vxdtfRedesign/MLSegmentNetworkProducerModule.h>
12 
13 #include <framework/logging/Logger.h>
14 
15 #include <fstream>
16 
17 using namespace Belle2;
18 using namespace std;
19 
20 REG_MODULE(MLSegmentNetworkProducer)
21 
23 {
24  setDescription("SegmentNetwork Producer Module with a machine learning classifier as three hit filter.");
25 
26  addParam("FBDTFileName", m_PARAMfbdtFileName, "file where the FastBDT classifier is stored");
27  // addParam("collectMode", m_PARAMcollectMode, "set to true for collecting training data, false for applying the filter", false);
28 
29  addParam("networkInputName",
30  m_PARAMnetworkInputName,
31  "Name of the StoreObjPtr where the network container used in this module is stored", std::string(""));
32 
33  addParam("sectorMapName", m_PARAMsecMapName,
34  "The name of the SectorMap used for this instance.", string("testMap"));
35 
36  addParam("cutValue", m_PARAMcutVal,
37  "Cut value to be used for dividing the classifier output into signal (above) and background (below)",
38  0.5);
39 
40 }
41 
43 {
44  m_network.isRequired(m_PARAMnetworkInputName);
45 
46  if (m_PARAMcutVal < 0. || m_PARAMcutVal > 1.) {
47  B2ERROR("cutValue set to " << m_PARAMcutVal << " but has to be in [0,1]!");
48  }
49 
50  setupClassifier(m_PARAMfbdtFileName);
51  setupFilter();
52 }
53 
55 {
56  DirectedNodeNetwork<TrackNode, VoidMetaInfo>& hitNetwork = m_network->accessHitNetwork();
57  DirectedNodeNetwork<Segment<TrackNode>, CACell>& segmentNetwork = m_network->accessSegmentNetwork();
58  std::deque<Segment<TrackNode>>& segments = m_network->accessSegments();
59 
60  unsigned nAccepted{}, nRejected{}, nLinked{};
61 
62  for (const auto& outerHit : hitNetwork.getNodes()) {
63  for (const auto& centerHit : outerHit->getInnerNodes()) {
64  bool alreadyAdded = false; // skip adding Nodes twice into the network
65  for (const auto& innerHit : centerHit->getInnerNodes()) {
66  bool accepted = m_filter->accept(*(innerHit->getEntry().m_spacePoint),
67  *(centerHit->getEntry().m_spacePoint),
68  *(outerHit->getEntry().m_spacePoint));
69 
70  if (LogSystem::Instance().isLevelEnabled(LogConfig::c_Debug, 499, PACKAGENAME())) {
71  const auto& sp1 = innerHit->getEntry().m_spacePoint;
72  const auto& sp2 = centerHit->getEntry().m_spacePoint;
73  const auto& sp3 = outerHit->getEntry().m_spacePoint;
74  std::array<double, 9> coords{{ sp1->X(), sp1->Y(), sp1->Z(), sp2->X(), sp2->Y(), sp2->Z(), sp3->X(), sp3->Y(), sp3->Z() }};
75  double classOut = m_classifier->analyze(coords);
76  B2DEBUG(499, "Classifier output: " << classOut << ", cutValue: " << m_PARAMcutVal);
77  }
78 
79  if (!accepted) { nRejected++; continue; } // don't store combinations which have not been accepted
80  nAccepted++;
81 
82  segments.emplace_back(centerHit->getEntry().m_sector->getFullSecID(),
83  innerHit->getEntry().m_sector->getFullSecID(),
84  &centerHit->getEntry(),
85  &innerHit->getEntry());
86  Segment<TrackNode>* innerSegmentPointer = &segments.back();
87 
88  B2DEBUG(999, "buildSegmentNetwork: innerSegment: " << innerSegmentPointer->getName());
89  DirectedNode<Segment<TrackNode>, CACell>* tempInnerSegmentnode = segmentNetwork.getNode(innerSegmentPointer->getID());
90  if (tempInnerSegmentnode == nullptr) {
91  segmentNetwork.addNode(innerSegmentPointer->getID(), segments.back());
92  } else {
93  innerSegmentPointer = &(tempInnerSegmentnode->getEntry());
94  segments.pop_back();
95  }
96 
97  if (!alreadyAdded) {
98  // create outerSector
99  segments.emplace_back(outerHit->getEntry().m_sector->getFullSecID(),
100  centerHit->getEntry().m_sector->getFullSecID(),
101  &outerHit->getEntry(),
102  &centerHit->getEntry());
103  Segment<TrackNode>* outerSegmentPointer = &segments.back();
104  B2DEBUG(999, "buildSegmentNetwork: outerSegment(freshly created): " << outerSegmentPointer->getName() <<
105  " to be linked with inner segment: " << innerSegmentPointer->getName());
106 
107  DirectedNode<Segment<TrackNode>, CACell>* tempOuterSegmentnode = segmentNetwork.getNode(outerSegmentPointer->getID());
108  if (tempOuterSegmentnode == nullptr) {
109  segmentNetwork.addNode(outerSegmentPointer->getID(), segments.back());
110  } else {
111  outerSegmentPointer = &(tempOuterSegmentnode->getEntry());
112  segments.pop_back();
113  }
114 
115  B2DEBUG(999, "buildSegmentNetwork: outerSegment (after duplicate check): " << outerSegmentPointer->getName() <<
116  " to be linked with inner segment: " << innerSegmentPointer->getName());
117  segmentNetwork.linkNodes(outerSegmentPointer->getID(), innerSegmentPointer->getID());
118  nLinked++;
119  alreadyAdded = true;
120  continue;
121  }
122  if (segmentNetwork.addInnerToLastOuterNode(innerSegmentPointer->getID())) {
123  nLinked++;
124  }
125  } // end inner loop
126  } // end center loop
127  } // end outer loop
128 
129 
130  B2DEBUG(10, "MLSegmentNetworkProducerModule::buildSegmentNetwork(): nAccepted/nRejected: " << nAccepted << "/" << nRejected <<
131  ", size of nLinked/hitNetwork: " << nLinked << "/" << segmentNetwork.size());
132 
133 }
134 
136 {
137 
138 }
139 
140 
141 void MLSegmentNetworkProducerModule::setupClassifier(const std::string& filename)
142 {
143  ifstream filestr(filename);
144  if (!filestr.is_open()) {
145  B2FATAL("Could not open file: " << filename << " for reading in a FBDTClassifier");
146  }
147 
148  auto classifier = std::unique_ptr<FBDTClassifier<9> >(new FBDTClassifier<9>());
149  classifier->readFromStream(filestr);
150  filestr.close();
151 
152  m_classifier = std::move(classifier);
153 }
154 
156 {
157  using RangeT = MLRange<FBDTClassifier<9>, 9, double>;
158 
159  auto filter = std::unique_ptr<MLFilter>(new MLFilter(RangeT(m_classifier.get(), m_PARAMcutVal)));
160  m_filter = std::move(filter);
161 }
Belle2::MLSegmentNetworkProducerModule::setupClassifier
void setupClassifier(const std::string &filename)
construct the classifier from file
Definition: MLSegmentNetworkProducerModule.cc:141
REG_MODULE
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:652
Belle2::filter
std::map< ExpRun, std::pair< double, double > > filter(const std::map< ExpRun, std::pair< double, double >> &runs, double cut, std::map< ExpRun, std::pair< double, double >> &runsRemoved)
filter events to remove runs shorter than cut, it stores removed runs in runsRemoved
Definition: Splitter.cc:43
Belle2::DirectedNodeNetwork::size
unsigned int size() const
Returns number of nodes to be found in the network.
Definition: DirectedNodeNetwork.h:225
Belle2::Segment
The Segment class This class represents segments of track candidates needed for TrackFinderVXD-Module...
Definition: Segment.h:35
Belle2::Segment::getName
std::string getName() const
returns longer debugging name of this segment
Definition: Segment.h:97
Belle2::DirectedNodeNetwork
Network of directed nodes of the type EntryType.
Definition: DirectedNodeNetwork.h:38
Belle2::MLSegmentNetworkProducerModule::event
void event() override
event
Definition: MLSegmentNetworkProducerModule.cc:54
Belle2::MLSegmentNetworkProducerModule::terminate
void terminate() override
terminate module
Definition: MLSegmentNetworkProducerModule.cc:135
Belle2::Module
Base class for Modules.
Definition: Module.h:74
Belle2::Segment::getID
std::int64_t getID() const
************************* PUBLIC MEMBER FUNCTIONS *************************
Definition: Segment.h:94
Belle2::MLSegmentNetworkProducerModule
Segment network producer module with a Machine Learning classifier.
Definition: MLSegmentNetworkProducerModule.h:51
Belle2::DirectedNodeNetwork::linkNodes
bool linkNodes(NodeID outerNodeID, NodeID innerNodeID)
takes two entry IDs and weaves them into the network
Definition: DirectedNodeNetwork.h:132
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::DirectedNodeNetwork::getNode
Node * getNode(NodeID toBeFound)
Returns pointer to the node carrying the entry which is equal to given parameter.
Definition: DirectedNodeNetwork.h:194
Belle2::CACell
The CACell class This Class stores all relevant information one wants to have stored in a cell for a ...
Definition: CACell.h:30
Belle2::DirectedNodeNetwork::getNodes
std::vector< Node * > & getNodes()
Returns all nodes of the network.
Definition: DirectedNodeNetwork.h:201
Belle2::MLSegmentNetworkProducerModule::initialize
void initialize() override
initialize module
Definition: MLSegmentNetworkProducerModule.cc:42
Belle2::LogSystem::Instance
static LogSystem & Instance()
Static method to get a reference to the LogSystem instance.
Definition: LogSystem.cc:33
Belle2::LogConfig::c_Debug
@ c_Debug
Debug: for code development.
Definition: LogConfig.h:36
Belle2::FBDTClassifier< 9 >
Belle2::DirectedNode
The Node-Class.
Definition: DirectedNode.h:41
Belle2::MLRange
Range used for the Machine Learning assisted TrackFinding approach.
Definition: MLRange.h:39
Belle2::MLSegmentNetworkProducerModule::setupFilter
void setupFilter()
setup the filter
Definition: MLSegmentNetworkProducerModule.cc:155
Belle2::DirectedNodeNetwork::addNode
bool addNode(NodeID nodeID, EntryType &newEntry)
************************* PUBLIC MEMBER FUNCTIONS *************************
Definition: DirectedNodeNetwork.h:72
Belle2::Filter
This class is used to select pairs, triplets...
Definition: Filter.h:44
Belle2::DirectedNode::getEntry
EntryType & getEntry()
Allows access to stored entry.
Definition: DirectedNode.h:102
Belle2::DirectedNodeNetwork::addInnerToLastOuterNode
bool addInnerToLastOuterNode(NodeID innerNodeID)
to the last outerNode added, another innerNode will be attached
Definition: DirectedNodeNetwork.h:85