Belle II Software  release-05-02-19
FastBDTClassifierTrainingModule.cc
1 /*************************************************************************
2 * BASF2 (Belle Analysis Framework 2) *
3 * Copyright(C) 2016 - Belle II Collaboration *
4 * *
5 * Author: The Belle II Collaboration *
6 * Contributors: Thomas Madlener *
7 * *
8 * This software is provided "as is" without any warranty. *
9 **************************************************************************/
10 
11 #include <tracking/modules/vxdtfRedesign/FastBDTClassifierTrainingModule.h>
12 #include <tracking/trackFindingVXD/filterTools/FBDTClassifier.h>
13 #include <tracking/trackFindingVXD/segmentNetwork/DirectedNodeNetwork.h>
14 #include <tracking/trackFindingVXD/segmentNetwork/TrackNode.h>
15 
16 #include <tracking/spacePointCreation/PurityCalculatorTools.h>
17 #include <tracking/spacePointCreation/MCVXDPurityInfo.h>
18 
19 #include <framework/logging/Logger.h>
20 
21 #include <fstream>
22 #include <sstream>
23 
24 using namespace Belle2;
25 using namespace std;
26 
27 REG_MODULE(FastBDTClassifierTraining);
28 
30 {
31  setDescription("TODO");
32 
33  addParam("outputFileName",
35  "Output file name to which the trained FBDTClassifier will be stored",
36  std::string("FBDTClassifier.dat"));
37 
38  addParam("networkInputName", m_PARAMnetworkInputName,
39  "Name of the StoreObjPtr where the network container used in this module is stored", std::string(""));
40 
41  addParam("train", m_PARAMdoTrain, "Set if the module should train a classifier after collecting or not", true);
42  addParam("nTrees", m_PARAMnTrees, "Number of Trees used in the FastBDT", 100);
43  addParam("treeDepth", m_PARAMtreeDepth, "Tree depth of the trees used in the FastBDT", 3);
44  addParam("shrinkage", m_PARAMshrinkage, "Shrinkage parameter used in the FastBDT", 0.15);
45  addParam("randRatio", m_PARAMrandRatio, "ratio of randomly chosen samples for training of one tree", 0.5);
46  addParam("storeSamples", m_PARAMstoreSamples, "store the collected samples into a file", false);
47  addParam("samplesFileName", m_PARAMsamplesFileName, "the file name into which/from whicht the collected samples are stored/read",
48  std::string("FBDTClassifier_samples.dat"));
49  addParam("useSamples", m_PARAMuseSamples,
50  "use samples for training that have been collected previously and bypass the collection of samples", false);
51 }
52 
54 {
55 
57  B2ERROR("storeSamples and useSamples are both set to true. However, only one option can be set at a time");
58  }
59 
60  if (m_PARAMnTrees < 1) {
61  B2WARNING("nTrees was set to " << m_PARAMnTrees << ". Has to be at least 1. Setting to 1.");
62  m_PARAMnTrees = 1;
63  }
64 
65  if (m_PARAMtreeDepth < 0) {
66  B2WARNING("Trees have to be at least a stump, but treeDepth was set to " << m_PARAMtreeDepth << ". Setting to 3 (default).");
67  m_PARAMtreeDepth = 3;
68  }
69 
70  if (m_PARAMshrinkage < 0 || m_PARAMshrinkage > 1) { // TODO: check this
71  B2WARNING("shrinkage has to be in [0,1] but was " << m_PARAMrandRatio << ". Setting to 0.15 (default).");
72  m_PARAMshrinkage = .15;
73  }
74 
75  if (m_PARAMrandRatio < 0 || m_PARAMrandRatio > 1) {
76  B2WARNING("randRatio has to be in [0,1] but was " << m_PARAMrandRatio << ". Setting to 0.5 (default).");
77  m_PARAMrandRatio = 0.5;
78  }
79 
80  if (m_PARAMuseSamples) {
81  ifstream sampFile(m_PARAMsamplesFileName);
82  if (!sampFile.is_open() || !sampFile.good()) {
83  B2ERROR("Was not able to open the samples file: " << m_PARAMsamplesFileName);
84  }
85 
87  } else { // only if no samples are provided the collection from the DNN is necessary
89  }
90 }
91 
93 {
94  if (m_PARAMuseSamples) return; // don't collect anything during event if samples are provided
95 
96  DirectedNodeNetwork<TrackNode, VoidMetaInfo>& hitNetwork = m_network->accessHitNetwork();
97 
98  // B2DEBUG(1, "size of hitNetwork " << hitNetwork.getNodes().size());
99 
100  size_t samplePriorEvent = m_samples.size();
101 
102  // XXXHit is of type DirectedNode<TrackNode, VoidMetaInfo>
103  for (const auto& outerHit : hitNetwork.getNodes()) { // loop over all outer nodes
104  for (const auto& centerHit : outerHit->getInnerNodes()) { // loop over all center nodes attached to outer node
105  for (const auto& innerHit : centerHit->getInnerNodes()) { // loop over all inner nodes attached to center node
106  // cppcheck-suppress useStlAlgorithm
107  m_samples.push_back(makeTrainSample(outerHit->getEntry().m_spacePoint,
108  centerHit->getEntry().m_spacePoint,
109  innerHit->getEntry().m_spacePoint));
110  } // inner node loop
111  } // center node loop
112  } // outer node loop
113 
114  B2DEBUG(10, "collected " << m_samples.size() - samplePriorEvent << " training samples in this event");
115 
116 }
117 
119 {
120  if (m_PARAMstoreSamples) {
121  B2DEBUG(1, "Storing the collected samples to file: " << m_PARAMsamplesFileName);
122  ofstream sampStream(m_PARAMsamplesFileName);
123  sampStream.precision(16); // increase precision for sample writeout
124  writeSamplesToStream(sampStream, m_samples);
125  sampStream.close();
126  }
127  if (m_PARAMdoTrain) {
128  FBDTClassifier<9> classifier{};
129  B2DEBUG(1, "Training a FBDTClassifier with " << m_samples.size() << " input samples. Training Parameters: \n" <<
130  "nTrees: " << m_PARAMnTrees << "\n" <<
131  "treeDetph: " << m_PARAMtreeDepth << "\n" <<
132  "shrinkage: " << m_PARAMshrinkage << "\n" <<
133  "randRatio: " << m_PARAMrandRatio << "\n");
135 
136  ofstream ofs(m_PARAMfbdtOutFileName);
137  classifier.writeToStream(ofs);
138  ofs.close();
139  }
140 }
141 
144  const Belle2::SpacePoint* inner)
145 {
146  vector<MCVXDPurityInfo> purityInfos = createPurityInfosVec({outer, center, inner});
147  auto mcId = purityInfos[0].getPurity(); // there is at least one entry in this vector!
148  bool signal = (mcId.first >= 0 && mcId.second == 1); // only assign true for actual MCParticle and purity 1
149 
150  std::array<double, 9> coords {{
151  inner->X(), inner->Y(), inner->Z(),
152  center->X(), center->Y(), center->Z(),
153  outer->X(), outer->Y(), outer->Z()
154  }};
155 
156  TrainSample sample(coords, signal);
157 
158  if (LogSystem::Instance().isLevelEnabled(LogConfig::c_Debug, 499, PACKAGENAME())) {
159  std::stringstream coordOutput;
160  for (double d : sample.hits) coordOutput << d << " ";
161 
162  B2DEBUG(499, "Created TrainingsSample with coordinates: ( " << coordOutput.str() << " ) " << sample.signal);
163  }
164 
165  return sample;
166 }
167 
168 // void FastBDTClassifierTrainingModule::readSamplesFromStream(std::istream& is)
169 // {
170 // std::string line;
171 // while(!is.eof()) {
172 // getline(is, line);
173 // if(line.empty()) break;
174 // stringstream ss(line);
175 // std::array<double, 9> coords;
176 // for(double& c : coords) ss >> c;
177 // bool sig; ss >> sig;
178 
179 // m_samples.push_back(FBDTTrainSample<9>(coords, sig));
180 // }
181 
182 // B2INFO("Read in " << m_samples.size() << " training samples.");
183 // }
184 
185 // void FastBDTClassifierTrainingModule::writeSamplesToStream(std::ostream& os) const
186 // {
187 // for (const auto& event : m_samples) {
188 // for (const auto& val : event.hits) {
189 // os << val << " ";
190 // }
191 // os << event.signal << std::endl;
192 // }
193 // B2INFO("Wrote out " << m_samples.size() << " training samples.");
194 // }
Belle2::Module::setDescription
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:216
Belle2::FastBDTClassifierTrainingModule::m_PARAMfbdtOutFileName
std::string m_PARAMfbdtOutFileName
output file name into which the FBDTClassifier is stored.
Definition: FastBDTClassifierTrainingModule.h:70
Belle2::FastBDTClassifierTrainingModule::m_PARAMdoTrain
bool m_PARAMdoTrain
actually train a classifier or only do collection
Definition: FastBDTClassifierTrainingModule.h:79
Belle2::FBDTTrainSample
bundle together the classifier input and the target value into one struct for easier passing around.
Definition: FBDTClassifierHelper.h:36
Belle2::SpacePoint::X
double X() const
return the x-value of the global position of the SpacePoint
Definition: SpacePoint.h:133
Belle2::FastBDTClassifierTrainingModule::m_PARAMstoreSamples
bool m_PARAMstoreSamples
store the collected samples into a file
Definition: FastBDTClassifierTrainingModule.h:82
Belle2::readSamplesFromStream
static void readSamplesFromStream(std::istream &is, std::vector< FBDTTrainSample< Ndims > > &samples)
read samples from stream and append them to samples
Definition: FBDTClassifierHelper.h:52
REG_MODULE
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:652
Belle2::SpacePoint::Z
double Z() const
return the z-value of the global position of the SpacePoint
Definition: SpacePoint.h:139
Belle2::FastBDTClassifierTrainingModule::m_samples
std::vector< TrainSample > m_samples
vector in which all samples are collected on the fly in event.
Definition: FastBDTClassifierTrainingModule.h:100
Belle2::FastBDTClassifierTrainingModule::m_PARAMshrinkage
double m_PARAMshrinkage
shrinkage parameter of FastBDT.
Definition: FastBDTClassifierTrainingModule.h:94
Belle2::DirectedNodeNetwork
Network of directed nodes of the type EntryType.
Definition: DirectedNodeNetwork.h:38
Belle2::FastBDTClassifierTrainingModule::FastBDTClassifierTrainingModule
FastBDTClassifierTrainingModule()
module constructor.
Definition: FastBDTClassifierTrainingModule.cc:29
Belle2::FastBDTClassifierTrainingModule::m_PARAMsamplesFileName
std::string m_PARAMsamplesFileName
filename to be used to store / read collect samples
Definition: FastBDTClassifierTrainingModule.h:76
Belle2::FastBDTClassifierTrainingModule::m_PARAMtreeDepth
int m_PARAMtreeDepth
tree depth in FastBDT.
Definition: FastBDTClassifierTrainingModule.h:91
Belle2::SpacePoint
SpacePoint typically is build from 1 PXDCluster or 1-2 SVDClusters.
Definition: SpacePoint.h:52
Belle2::FastBDTClassifierTrainingModule::event
void event() override
collect all possible combinations and store them
Definition: FastBDTClassifierTrainingModule.cc:92
Belle2::Module
Base class for Modules.
Definition: Module.h:74
Belle2::SpacePoint::Y
double Y() const
return the y-value of the global position of the SpacePoint
Definition: SpacePoint.h:136
Belle2::FastBDTClassifierTrainingModule::terminate
void terminate() override
take the collected data and train a FBDTClassifier and store it in the given output file
Definition: FastBDTClassifierTrainingModule.cc:118
Belle2::FastBDTClassifierTrainingModule::m_network
Belle2::StoreObjPtr< Belle2::DirectedNodeNetworkContainer > m_network
StoreObjPtr to access the DNNs that are used in this module.
Definition: FastBDTClassifierTrainingModule.h:103
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::DirectedNodeNetwork::getNodes
std::vector< Node * > & getNodes()
Returns all nodes of the network.
Definition: DirectedNodeNetwork.h:201
Belle2::writeSamplesToStream
static void writeSamplesToStream(std::ostream &os, const std::vector< FBDTTrainSample< Ndims > > &samples)
write all samples to stream
Definition: FBDTClassifierHelper.h:72
Belle2::LogSystem::Instance
static LogSystem & Instance()
Static method to get a reference to the LogSystem instance.
Definition: LogSystem.cc:33
Belle2::createPurityInfosVec
static std::vector< Belle2::MCVXDPurityInfo > createPurityInfosVec(const std::vector< const Belle2::SpacePoint * > &spacePoints)
create a vector of MCVXDPurityInfos objects for a std::vector<Belle2::SpacePoints>.
Definition: PurityCalculatorTools.h:159
Belle2::Module::addParam
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:562
Belle2::FastBDTClassifierTrainingModule::m_PARAMnetworkInputName
std::string m_PARAMnetworkInputName
name of the StoreObjPtr in which the network container is stored which contains the network that is u...
Definition: FastBDTClassifierTrainingModule.h:73
Belle2::FastBDTClassifierTrainingModule::makeTrainSample
const TrainSample makeTrainSample(const Belle2::SpacePoint *outerHit, const Belle2::SpacePoint *centerHit, const Belle2::SpacePoint *innerHit)
create a trainings sample from the three hit combination
Definition: FastBDTClassifierTrainingModule.cc:143
Belle2::LogConfig::c_Debug
@ c_Debug
Debug: for code development.
Definition: LogConfig.h:36
Belle2::FastBDTClassifierTrainingModule::m_PARAMrandRatio
double m_PARAMrandRatio
ratio of samples to be used for training one tree in the FastBDT.
Definition: FastBDTClassifierTrainingModule.h:97
Belle2::FBDTClassifier< 9 >
Belle2::FastBDTClassifierTrainingModule::initialize
void initialize() override
initialize the module
Definition: FastBDTClassifierTrainingModule.cc:53
Belle2::FastBDTClassifierTrainingModule::m_PARAMnTrees
int m_PARAMnTrees
number of trees in the FastBDT.
Definition: FastBDTClassifierTrainingModule.h:88
Belle2::FastBDTClassifierTrainingModule::m_PARAMuseSamples
bool m_PARAMuseSamples
use pre-collected samples for training and bypass the collection step
Definition: FastBDTClassifierTrainingModule.h:85