Belle II Software development
FastBDTClassifierTrainingModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <tracking/modules/vxdtfRedesign/FastBDTClassifierTrainingModule.h>
10#include <tracking/trackFindingVXD/filterTools/FBDTClassifier.h>
11#include <tracking/trackFindingVXD/segmentNetwork/DirectedNodeNetwork.h>
12#include <tracking/trackFindingVXD/segmentNetwork/TrackNode.h>
13
14#include <tracking/spacePointCreation/PurityCalculatorTools.h>
15#include <tracking/spacePointCreation/MCVXDPurityInfo.h>
16
17#include <framework/logging/Logger.h>
18
19#include <fstream>
20#include <sstream>
21
22using namespace Belle2;
23
24REG_MODULE(FastBDTClassifierTraining);
25
27{
28 setDescription("TODO");
29
30 addParam("outputFileName",
32 "Output file name to which the trained FBDTClassifier will be stored",
33 std::string("FBDTClassifier.dat"));
34
35 addParam("networkInputName", m_PARAMnetworkInputName,
36 "Name of the StoreObjPtr where the network container used in this module is stored", std::string(""));
37
38 addParam("train", m_PARAMdoTrain, "Set if the module should train a classifier after collecting or not", true);
39 addParam("nTrees", m_PARAMnTrees, "Number of Trees used in the FastBDT", 100);
40 addParam("treeDepth", m_PARAMtreeDepth, "Tree depth of the trees used in the FastBDT", 3);
41 addParam("shrinkage", m_PARAMshrinkage, "Shrinkage parameter used in the FastBDT", 0.15);
42 addParam("randRatio", m_PARAMrandRatio, "ratio of randomly chosen samples for training of one tree", 0.5);
43 addParam("storeSamples", m_PARAMstoreSamples, "store the collected samples into a file", false);
44 addParam("samplesFileName", m_PARAMsamplesFileName, "the file name into which (from which) the collected samples are stored (read)",
45 std::string("FBDTClassifier_samples.dat"));
46 addParam("useSamples", m_PARAMuseSamples,
47 "use samples for training that have been collected previously and bypass the collection of samples", false);
48}
49
51{
52
54 B2ERROR("storeSamples and useSamples are both set to true. However, only one option can be set at a time");
55 }
56
57 if (m_PARAMnTrees < 1) {
58 B2WARNING("nTrees was set to " << m_PARAMnTrees << ". Has to be at least 1. Setting to 1.");
59 m_PARAMnTrees = 1;
60 }
61
62 if (m_PARAMtreeDepth < 0) {
63 B2WARNING("Trees have to be at least a stump, but treeDepth was set to " << m_PARAMtreeDepth << ". Setting to 3 (default).");
65 }
66
67 if (m_PARAMshrinkage < 0 || m_PARAMshrinkage > 1) { // TODO: check this
68 B2WARNING("shrinkage has to be in [0,1] but was " << m_PARAMrandRatio << ". Setting to 0.15 (default).");
69 m_PARAMshrinkage = .15;
70 }
71
73 B2WARNING("randRatio has to be in [0,1] but was " << m_PARAMrandRatio << ". Setting to 0.5 (default).");
74 m_PARAMrandRatio = 0.5;
75 }
76
78 std::ifstream sampFile(m_PARAMsamplesFileName);
79 if (!sampFile.is_open() || !sampFile.good()) {
80 B2ERROR("Was not able to open the samples file: " << m_PARAMsamplesFileName);
81 }
82
84 } else { // only if no samples are provided the collection from the DNN is necessary
86 }
87}
88
90{
91 if (m_PARAMuseSamples) return; // don't collect anything during event if samples are provided
92
93 DirectedNodeNetwork<TrackNode, VoidMetaInfo>& hitNetwork = m_network->accessHitNetwork();
94
95 // B2DEBUG(20, "size of hitNetwork " << hitNetwork.getNodes().size());
96
97 size_t samplePriorEvent = m_samples.size();
98
99 // XXXHit is of type DirectedNode<TrackNode, VoidMetaInfo>
100 for (const auto& outerHit : hitNetwork.getNodes()) { // loop over all outer nodes
101 for (const auto& centerHit : outerHit->getInnerNodes()) { // loop over all center nodes attached to outer node
102 for (const auto& innerHit : centerHit->getInnerNodes()) { // loop over all inner nodes attached to center node
103 m_samples.push_back(makeTrainSample(outerHit->getEntry().m_spacePoint,
104 centerHit->getEntry().m_spacePoint,
105 innerHit->getEntry().m_spacePoint));
106 } // inner node loop
107 } // center node loop
108 } // outer node loop
109
110 B2DEBUG(21, "collected " << m_samples.size() - samplePriorEvent << " training samples in this event");
111
112}
113
115{
117 B2DEBUG(20, "Storing the collected samples to file: " << m_PARAMsamplesFileName);
118 std::ofstream sampStream(m_PARAMsamplesFileName);
119 sampStream.precision(16); // increase precision for sample writeout
120 writeSamplesToStream(sampStream, m_samples);
121 sampStream.close();
122 }
123 if (m_PARAMdoTrain) {
124 FBDTClassifier<9> classifier{};
125 B2DEBUG(20, "Training a FBDTClassifier with " << m_samples.size() << " input samples. Training Parameters: \n" <<
126 "nTrees: " << m_PARAMnTrees << "\n" <<
127 "treeDetph: " << m_PARAMtreeDepth << "\n" <<
128 "shrinkage: " << m_PARAMshrinkage << "\n" <<
129 "randRatio: " << m_PARAMrandRatio << "\n");
131
132 std::ofstream ofs(m_PARAMfbdtOutFileName);
133 classifier.writeToStream(ofs);
134 ofs.close();
135 }
136}
137
140 const Belle2::SpacePoint* inner)
141{
142 std::vector<MCVXDPurityInfo> purityInfos = createPurityInfosVec({outer, center, inner});
143 auto mcId = purityInfos[0].getPurity(); // there is at least one entry in this vector!
144 bool signal = (mcId.first >= 0 && mcId.second == 1); // only assign true for actual MCParticle and purity 1
145
146 std::array<double, 9> coords {{
147 inner->X(), inner->Y(), inner->Z(),
148 center->X(), center->Y(), center->Z(),
149 outer->X(), outer->Y(), outer->Z()
150 }};
151
152 TrainSample sample(coords, signal);
153
154 if (LogSystem::Instance().isLevelEnabled(LogConfig::c_Debug, 499, PACKAGENAME())) {
155 std::stringstream coordOutput;
156 for (double d : sample.hits) coordOutput << d << " ";
157
158 B2DEBUG(29, "Created TrainingsSample with coordinates: ( " << coordOutput.str() << " ) " << sample.signal);
159 }
160
161 return sample;
162}
163
164// void FastBDTClassifierTrainingModule::readSamplesFromStream(std::istream& is)
165// {
166// std::string line;
167// while(!is.eof()) {
168// getline(is, line);
169// if(line.empty()) break;
170// stringstream ss(line);
171// std::array<double, 9> coords;
172// for(double& c : coords) ss >> c;
173// bool sig; ss >> sig;
174
175// m_samples.push_back(FBDTTrainSample<9>(coords, sig));
176// }
177
178// B2INFO("Read in " << m_samples.size() << " training samples.");
179// }
180
181// void FastBDTClassifierTrainingModule::writeSamplesToStream(std::ostream& os) const
182// {
183// for (const auto& event : m_samples) {
184// for (const auto& val : event.hits) {
185// os << val << " ";
186// }
187// os << event.signal << std::endl;
188// }
189// B2INFO("Wrote out " << m_samples.size() << " training samples.");
190// }
Network of directed nodes of the type EntryType.
std::vector< Node * > & getNodes()
Returns all nodes of the network.
FastBDT as RelationsObject to make it storable and accessible on/via the DataStore.
double m_PARAMrandRatio
ratio of samples to be used for training one tree in the FastBDT.
std::string m_PARAMsamplesFileName
filename to be used to store / read collect samples
Belle2::StoreObjPtr< Belle2::DirectedNodeNetworkContainer > m_network
StoreObjPtr to access the DNNs that are used in this module.
void event() override
collect all possible combinations and store them
std::vector< TrainSample > m_samples
vector in which all samples are collected on the fly in event.
const TrainSample makeTrainSample(const Belle2::SpacePoint *outerHit, const Belle2::SpacePoint *centerHit, const Belle2::SpacePoint *innerHit)
create a trainings sample from the three hit combination
void terminate() override
take the collected data and train a FBDTClassifier and store it in the given output file
Belle2::FBDTTrainSample< 9 > TrainSample
< private typedef for shorter notation
bool m_PARAMdoTrain
actually train a classifier or only do collection
std::string m_PARAMfbdtOutFileName
output file name into which the FBDTClassifier is stored.
double m_PARAMshrinkage
shrinkage parameter of FastBDT.
bool m_PARAMstoreSamples
store the collected samples into a file
std::string m_PARAMnetworkInputName
name of the StoreObjPtr in which the network container is stored which contains the network that is u...
bool m_PARAMuseSamples
use pre-collected samples for training and bypass the collection step
@ c_Debug
Debug: for code development.
Definition LogConfig.h:26
static LogSystem & Instance()
Static method to get a reference to the LogSystem instance.
Definition LogSystem.cc:28
void setDescription(const std::string &description)
Sets the description of the module.
Definition Module.cc:214
Module()
Constructor.
Definition Module.cc:30
SpacePoint typically is build from 1 PXDCluster or 1-2 SVDClusters.
Definition SpacePoint.h:42
double Z() const
return the z-value of the global position of the SpacePoint
Definition SpacePoint.h:129
double X() const
return the x-value of the global position of the SpacePoint
Definition SpacePoint.h:123
double Y() const
return the y-value of the global position of the SpacePoint
Definition SpacePoint.h:126
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition Module.h:559
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition Module.h:649
static void writeSamplesToStream(std::ostream &os, const std::vector< FBDTTrainSample< Ndims > > &samples)
write all samples to stream
static std::vector< Belle2::MCVXDPurityInfo > createPurityInfosVec(const std::vector< const Belle2::SpacePoint * > &spacePoints)
create a vector of MCVXDPurityInfos objects for a std::vector<Belle2::SpacePoints>.
B2Vector3D outerHit(0, 0, 0)
testing out of range behavior
static void readSamplesFromStream(std::istream &is, std::vector< FBDTTrainSample< Ndims > > &samples)
read samples from stream and append them to samples
Abstract base class for different kinds of events.