Belle II Software development
FastBDTClassifierTrainingModule.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <tracking/modules/vxdtfRedesign/FastBDTClassifierTrainingModule.h>
10#include <tracking/trackFindingVXD/filterTools/FBDTClassifier.h>
11#include <tracking/trackFindingVXD/segmentNetwork/DirectedNodeNetwork.h>
12#include <tracking/trackFindingVXD/segmentNetwork/TrackNode.h>
13
14#include <tracking/spacePointCreation/PurityCalculatorTools.h>
15#include <tracking/spacePointCreation/MCVXDPurityInfo.h>
16
17#include <framework/logging/Logger.h>
18
19#include <fstream>
20#include <sstream>
21
22using namespace Belle2;
23
24REG_MODULE(FastBDTClassifierTraining);
25
27{
28 setDescription("TODO");
29
30 addParam("outputFileName",
32 "Output file name to which the trained FBDTClassifier will be stored",
33 std::string("FBDTClassifier.dat"));
34
35 addParam("networkInputName", m_PARAMnetworkInputName,
36 "Name of the StoreObjPtr where the network container used in this module is stored", std::string(""));
37
38 addParam("train", m_PARAMdoTrain, "Set if the module should train a classifier after collecting or not", true);
39 addParam("nTrees", m_PARAMnTrees, "Number of Trees used in the FastBDT", 100);
40 addParam("treeDepth", m_PARAMtreeDepth, "Tree depth of the trees used in the FastBDT", 3);
41 addParam("shrinkage", m_PARAMshrinkage, "Shrinkage parameter used in the FastBDT", 0.15);
42 addParam("randRatio", m_PARAMrandRatio, "ratio of randomly chosen samples for training of one tree", 0.5);
43 addParam("storeSamples", m_PARAMstoreSamples, "store the collected samples into a file", false);
44 addParam("samplesFileName", m_PARAMsamplesFileName, "the file name into which/from whicht the collected samples are stored/read",
45 std::string("FBDTClassifier_samples.dat"));
46 addParam("useSamples", m_PARAMuseSamples,
47 "use samples for training that have been collected previously and bypass the collection of samples", false);
48}
49
51{
52
54 B2ERROR("storeSamples and useSamples are both set to true. However, only one option can be set at a time");
55 }
56
57 if (m_PARAMnTrees < 1) {
58 B2WARNING("nTrees was set to " << m_PARAMnTrees << ". Has to be at least 1. Setting to 1.");
59 m_PARAMnTrees = 1;
60 }
61
62 if (m_PARAMtreeDepth < 0) {
63 B2WARNING("Trees have to be at least a stump, but treeDepth was set to " << m_PARAMtreeDepth << ". Setting to 3 (default).");
65 }
66
67 if (m_PARAMshrinkage < 0 || m_PARAMshrinkage > 1) { // TODO: check this
68 B2WARNING("shrinkage has to be in [0,1] but was " << m_PARAMrandRatio << ". Setting to 0.15 (default).");
69 m_PARAMshrinkage = .15;
70 }
71
72 if (m_PARAMrandRatio < 0 || m_PARAMrandRatio > 1) {
73 B2WARNING("randRatio has to be in [0,1] but was " << m_PARAMrandRatio << ". Setting to 0.5 (default).");
74 m_PARAMrandRatio = 0.5;
75 }
76
78 std::ifstream sampFile(m_PARAMsamplesFileName);
79 if (!sampFile.is_open() || !sampFile.good()) {
80 B2ERROR("Was not able to open the samples file: " << m_PARAMsamplesFileName);
81 }
82
84 } else { // only if no samples are provided the collection from the DNN is necessary
86 }
87}
88
90{
91 if (m_PARAMuseSamples) return; // don't collect anything during event if samples are provided
92
94
95 // B2DEBUG(20, "size of hitNetwork " << hitNetwork.getNodes().size());
96
97 size_t samplePriorEvent = m_samples.size();
98
99 // XXXHit is of type DirectedNode<TrackNode, VoidMetaInfo>
100 for (const auto& outerHit : hitNetwork.getNodes()) { // loop over all outer nodes
101 for (const auto& centerHit : outerHit->getInnerNodes()) { // loop over all center nodes attached to outer node
102 for (const auto& innerHit : centerHit->getInnerNodes()) { // loop over all inner nodes attached to center node
103 m_samples.push_back(makeTrainSample(outerHit->getEntry().m_spacePoint,
104 centerHit->getEntry().m_spacePoint,
105 innerHit->getEntry().m_spacePoint));
106 } // inner node loop
107 } // center node loop
108 } // outer node loop
109
110 B2DEBUG(21, "collected " << m_samples.size() - samplePriorEvent << " training samples in this event");
111
112}
113
115{
117 B2DEBUG(20, "Storing the collected samples to file: " << m_PARAMsamplesFileName);
118 std::ofstream sampStream(m_PARAMsamplesFileName);
119 sampStream.precision(16); // increase precision for sample writeout
120 writeSamplesToStream(sampStream, m_samples);
121 sampStream.close();
122 }
123 if (m_PARAMdoTrain) {
124 FBDTClassifier<9> classifier{};
125 B2DEBUG(20, "Training a FBDTClassifier with " << m_samples.size() << " input samples. Training Parameters: \n" <<
126 "nTrees: " << m_PARAMnTrees << "\n" <<
127 "treeDetph: " << m_PARAMtreeDepth << "\n" <<
128 "shrinkage: " << m_PARAMshrinkage << "\n" <<
129 "randRatio: " << m_PARAMrandRatio << "\n");
131
132 std::ofstream ofs(m_PARAMfbdtOutFileName);
133 classifier.writeToStream(ofs);
134 ofs.close();
135 }
136}
137
140 const Belle2::SpacePoint* inner)
141{
142 std::vector<MCVXDPurityInfo> purityInfos = createPurityInfosVec({outer, center, inner});
143 auto mcId = purityInfos[0].getPurity(); // there is at least one entry in this vector!
144 bool signal = (mcId.first >= 0 && mcId.second == 1); // only assign true for actual MCParticle and purity 1
145
146 std::array<double, 9> coords {{
147 inner->X(), inner->Y(), inner->Z(),
148 center->X(), center->Y(), center->Z(),
149 outer->X(), outer->Y(), outer->Z()
150 }};
151
152 TrainSample sample(coords, signal);
153
154 if (LogSystem::Instance().isLevelEnabled(LogConfig::c_Debug, 499, PACKAGENAME())) {
155 std::stringstream coordOutput;
156 for (double d : sample.hits) coordOutput << d << " ";
157
158 B2DEBUG(29, "Created TrainingsSample with coordinates: ( " << coordOutput.str() << " ) " << sample.signal);
159 }
160
161 return sample;
162}
163
164// void FastBDTClassifierTrainingModule::readSamplesFromStream(std::istream& is)
165// {
166// std::string line;
167// while(!is.eof()) {
168// getline(is, line);
169// if(line.empty()) break;
170// stringstream ss(line);
171// std::array<double, 9> coords;
172// for(double& c : coords) ss >> c;
173// bool sig; ss >> sig;
174
175// m_samples.push_back(FBDTTrainSample<9>(coords, sig));
176// }
177
178// B2INFO("Read in " << m_samples.size() << " training samples.");
179// }
180
181// void FastBDTClassifierTrainingModule::writeSamplesToStream(std::ostream& os) const
182// {
183// for (const auto& event : m_samples) {
184// for (const auto& val : event.hits) {
185// os << val << " ";
186// }
187// os << event.signal << std::endl;
188// }
189// B2INFO("Wrote out " << m_samples.size() << " training samples.");
190// }
DirectedNodeNetwork< Belle2::TrackNode, Belle2::VoidMetaInfo > & accessHitNetwork()
Returns reference to the HitNetwork stored in this container, intended for read and write access.
Network of directed nodes of the type EntryType.
std::vector< Node * > & getNodes()
Returns all nodes of the network.
FastBDT as RelationsObject to make it storeable and accesible on/via the DataStore.
double m_PARAMrandRatio
ratio of samples to be used for training one tree in the FastBDT.
std::string m_PARAMsamplesFileName
filename to be used to store / read collect samples
void initialize() override
initialize the module
Belle2::StoreObjPtr< Belle2::DirectedNodeNetworkContainer > m_network
StoreObjPtr to access the DNNs that are used in this module.
void event() override
collect all possible combinations and store them
std::vector< TrainSample > m_samples
vector in which all samples are collected on the fly in event.
const TrainSample makeTrainSample(const Belle2::SpacePoint *outerHit, const Belle2::SpacePoint *centerHit, const Belle2::SpacePoint *innerHit)
create a trainings sample from the three hit combination
void terminate() override
take the collected data and train a FBDTClassifier and store it in the given output file
bool m_PARAMdoTrain
actually train a classifier or only do collection
std::string m_PARAMfbdtOutFileName
output file name into which the FBDTClassifier is stored.
double m_PARAMshrinkage
shrinkage parameter of FastBDT.
bool m_PARAMstoreSamples
store the collected samples into a file
std::string m_PARAMnetworkInputName
name of the StoreObjPtr in which the network container is stored which contains the network that is u...
bool m_PARAMuseSamples
use pre-collected samples for training and bypass the collection step
int m_PARAMnTrees
number of trees in the FastBDT.
@ c_Debug
Debug: for code development.
Definition: LogConfig.h:26
static LogSystem & Instance()
Static method to get a reference to the LogSystem instance.
Definition: LogSystem.cc:31
Base class for Modules.
Definition: Module.h:72
void setDescription(const std::string &description)
Sets the description of the module.
Definition: Module.cc:214
SpacePoint typically is build from 1 PXDCluster or 1-2 SVDClusters.
Definition: SpacePoint.h:42
double Z() const
return the z-value of the global position of the SpacePoint
Definition: SpacePoint.h:129
double X() const
return the x-value of the global position of the SpacePoint
Definition: SpacePoint.h:123
double Y() const
return the y-value of the global position of the SpacePoint
Definition: SpacePoint.h:126
bool isRequired(const std::string &name="")
Ensure this array/object has been registered previously.
void addParam(const std::string &name, T &paramVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
Definition: Module.h:560
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
Definition: Module.h:650
static void writeSamplesToStream(std::ostream &os, const std::vector< FBDTTrainSample< Ndims > > &samples)
write all samples to stream
static std::vector< Belle2::MCVXDPurityInfo > createPurityInfosVec(const std::vector< const Belle2::SpacePoint * > &spacePoints)
create a vector of MCVXDPurityInfos objects for a std::vector<Belle2::SpacePoints>.
B2Vector3D outerHit(0, 0, 0)
testing out of range behavior
static void readSamplesFromStream(std::istream &is, std::vector< FBDTTrainSample< Ndims > > &samples)
read samples from stream and append them to samples
Abstract base class for different kinds of events.
bundle together the classifier input and the target value into one struct for easier passing around.