9 #include <tracking/modules/vxdtfRedesign/FastBDTClassifierTrainingModule.h>
10 #include <tracking/trackFindingVXD/filterTools/FBDTClassifier.h>
11 #include <tracking/trackFindingVXD/segmentNetwork/DirectedNodeNetwork.h>
12 #include <tracking/trackFindingVXD/segmentNetwork/TrackNode.h>
14 #include <tracking/spacePointCreation/PurityCalculatorTools.h>
15 #include <tracking/spacePointCreation/MCVXDPurityInfo.h>
17 #include <framework/logging/Logger.h>
33 "Output file name to which the trained FBDTClassifier will be stored",
34 std::string(
"FBDTClassifier.dat"));
37 "Name of the StoreObjPtr where the network container used in this module is stored", std::string(
""));
39 addParam(
"train",
m_PARAMdoTrain,
"Set if the module should train a classifier after collecting or not",
true);
46 std::string(
"FBDTClassifier_samples.dat"));
48 "use samples for training that have been collected previously and bypass the collection of samples",
false);
55 B2ERROR(
"storeSamples and useSamples are both set to true. However, only one option can be set at a time");
59 B2WARNING(
"nTrees was set to " <<
m_PARAMnTrees <<
". Has to be at least 1. Setting to 1.");
64 B2WARNING(
"Trees have to be at least a stump, but treeDepth was set to " <<
m_PARAMtreeDepth <<
". Setting to 3 (default).");
68 if (m_PARAMshrinkage < 0 || m_PARAMshrinkage > 1) {
69 B2WARNING(
"shrinkage has to be in [0,1] but was " <<
m_PARAMrandRatio <<
". Setting to 0.15 (default).");
73 if (m_PARAMrandRatio < 0 || m_PARAMrandRatio > 1) {
74 B2WARNING(
"randRatio has to be in [0,1] but was " <<
m_PARAMrandRatio <<
". Setting to 0.5 (default).");
80 if (!sampFile.is_open() || !sampFile.good()) {
98 size_t samplePriorEvent =
m_samples.size();
101 for (
const auto& outerHit : hitNetwork.
getNodes()) {
102 for (
const auto& centerHit : outerHit->getInnerNodes()) {
103 for (
const auto& innerHit : centerHit->getInnerNodes()) {
105 centerHit->getEntry().m_spacePoint,
106 innerHit->getEntry().m_spacePoint));
111 B2DEBUG(10,
"collected " <<
m_samples.size() - samplePriorEvent <<
" training samples in this event");
120 sampStream.precision(16);
126 B2DEBUG(1,
"Training a FBDTClassifier with " <<
m_samples.size() <<
" input samples. Training Parameters: \n" <<
134 classifier.writeToStream(ofs);
144 auto mcId = purityInfos[0].getPurity();
145 bool signal = (mcId.first >= 0 && mcId.second == 1);
147 std::array<double, 9> coords {{
148 inner->
X(), inner->
Y(), inner->
Z(),
149 center->
X(), center->
Y(), center->
Z(),
150 outer->
X(), outer->
Y(), outer->
Z()
156 std::stringstream coordOutput;
157 for (
double d : sample.hits) coordOutput << d <<
" ";
159 B2DEBUG(499,
"Created TrainingsSample with coordinates: ( " << coordOutput.str() <<
" ) " << sample.signal);
DirectedNodeNetwork< Belle2::TrackNode, Belle2::VoidMetaInfo > & accessHitNetwork()
Returns reference to the HitNetwork stored in this container, intended for read and write access.
Network of directed nodes of the type EntryType.
std::vector< Node * > & getNodes()
Returns all nodes of the network.
double m_PARAMrandRatio
ratio of samples to be used for training one tree in the FastBDT.
FastBDTClassifierTrainingModule()
module constructor.
std::string m_PARAMsamplesFileName
filename to be used to store / read collect samples
void initialize() override
initialize the module
Belle2::StoreObjPtr< Belle2::DirectedNodeNetworkContainer > m_network
StoreObjPtr to access the DNNs that are used in this module.
void event() override
collect all possible combinations and store them
int m_PARAMtreeDepth
tree depth in FastBDT.
std::vector< TrainSample > m_samples
vector in which all samples are collected on the fly in event.
const TrainSample makeTrainSample(const Belle2::SpacePoint *outerHit, const Belle2::SpacePoint *centerHit, const Belle2::SpacePoint *innerHit)
create a trainings sample from the three hit combination
void terminate() override
take the collected data and train a FBDTClassifier and store it in the given output file
bool m_PARAMdoTrain
actually train a classifier or only do collection
std::string m_PARAMfbdtOutFileName
output file name into which the FBDTClassifier is stored.
double m_PARAMshrinkage
shrinkage parameter of FastBDT.
bool m_PARAMstoreSamples
store the collected samples into a file
std::string m_PARAMnetworkInputName
name of the StoreObjPtr in which the network container is stored which contains the network that is u...
bool m_PARAMuseSamples
use pre-collected samples for training and bypass the collection step
int m_PARAMnTrees
number of trees in the FastBDT.
@ c_Debug
Debug: for code development.
static LogSystem & Instance()
Static method to get a reference to the LogSystem instance.
void setDescription(const std::string &description)
Sets the description of the module.
SpacePoint typically is build from 1 PXDCluster or 1-2 SVDClusters.
double Z() const
return the z-value of the global position of the SpacePoint
double X() const
return the x-value of the global position of the SpacePoint
double Y() const
return the y-value of the global position of the SpacePoint
bool isRequired(const std::string &name="")
Ensure this array/object has been registered previously.
void addParam(const std::string &name, T ¶mVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
static void readSamplesFromStream(std::istream &is, std::vector< FBDTTrainSample< Ndims > > &samples)
read samples from stream and append them to samples
static std::vector< Belle2::MCVXDPurityInfo > createPurityInfosVec(const std::vector< const Belle2::SpacePoint * > &spacePoints)
create a vector of MCVXDPurityInfos objects for a std::vector<Belle2::SpacePoints>.
static void writeSamplesToStream(std::ostream &os, const std::vector< FBDTTrainSample< Ndims > > &samples)
write all samples to stream
Abstract base class for different kinds of events.
bundle together the classifier input and the target value into one struct for easier passing around.