9#include <tracking/modules/vxdtfRedesign/FastBDTClassifierTrainingModule.h>
10#include <tracking/trackFindingVXD/filterTools/FBDTClassifier.h>
11#include <tracking/trackFindingVXD/segmentNetwork/DirectedNodeNetwork.h>
12#include <tracking/trackFindingVXD/segmentNetwork/TrackNode.h>
14#include <tracking/spacePointCreation/PurityCalculatorTools.h>
15#include <tracking/spacePointCreation/MCVXDPurityInfo.h>
17#include <framework/logging/Logger.h>
32 "Output file name to which the trained FBDTClassifier will be stored",
33 std::string(
"FBDTClassifier.dat"));
36 "Name of the StoreObjPtr where the network container used in this module is stored", std::string(
""));
38 addParam(
"train",
m_PARAMdoTrain,
"Set if the module should train a classifier after collecting or not",
true);
45 std::string(
"FBDTClassifier_samples.dat"));
47 "use samples for training that have been collected previously and bypass the collection of samples",
false);
54 B2ERROR(
"storeSamples and useSamples are both set to true. However, only one option can be set at a time");
58 B2WARNING(
"nTrees was set to " <<
m_PARAMnTrees <<
". Has to be at least 1. Setting to 1.");
63 B2WARNING(
"Trees have to be at least a stump, but treeDepth was set to " <<
m_PARAMtreeDepth <<
". Setting to 3 (default).");
67 if (m_PARAMshrinkage < 0 || m_PARAMshrinkage > 1) {
68 B2WARNING(
"shrinkage has to be in [0,1] but was " <<
m_PARAMrandRatio <<
". Setting to 0.15 (default).");
72 if (m_PARAMrandRatio < 0 || m_PARAMrandRatio > 1) {
73 B2WARNING(
"randRatio has to be in [0,1] but was " <<
m_PARAMrandRatio <<
". Setting to 0.5 (default).");
79 if (!sampFile.is_open() || !sampFile.good()) {
97 size_t samplePriorEvent =
m_samples.size();
101 for (
const auto& centerHit :
outerHit->getInnerNodes()) {
102 for (
const auto& innerHit : centerHit->getInnerNodes()) {
104 centerHit->getEntry().m_spacePoint,
105 innerHit->getEntry().m_spacePoint));
110 B2DEBUG(21,
"collected " <<
m_samples.size() - samplePriorEvent <<
" training samples in this event");
119 sampStream.precision(16);
125 B2DEBUG(20,
"Training a FBDTClassifier with " <<
m_samples.size() <<
" input samples. Training Parameters: \n" <<
133 classifier.writeToStream(ofs);
143 auto mcId = purityInfos[0].getPurity();
144 bool signal = (mcId.first >= 0 && mcId.second == 1);
146 std::array<double, 9> coords {{
147 inner->
X(), inner->
Y(), inner->
Z(),
148 center->
X(), center->
Y(), center->
Z(),
149 outer->
X(), outer->
Y(), outer->
Z()
155 std::stringstream coordOutput;
156 for (
double d : sample.hits) coordOutput << d <<
" ";
158 B2DEBUG(29,
"Created TrainingsSample with coordinates: ( " << coordOutput.str() <<
" ) " << sample.signal);
DirectedNodeNetwork< Belle2::TrackNode, Belle2::VoidMetaInfo > & accessHitNetwork()
Returns reference to the HitNetwork stored in this container, intended for read and write access.
Network of directed nodes of the type EntryType.
std::vector< Node * > & getNodes()
Returns all nodes of the network.
FastBDT as RelationsObject to make it storable and accessible on/via the DataStore.
double m_PARAMrandRatio
ratio of samples to be used for training one tree in the FastBDT.
FastBDTClassifierTrainingModule()
module constructor.
std::string m_PARAMsamplesFileName
filename to be used to store / read collect samples
void initialize() override
initialize the module
Belle2::StoreObjPtr< Belle2::DirectedNodeNetworkContainer > m_network
StoreObjPtr to access the DNNs that are used in this module.
void event() override
collect all possible combinations and store them
int m_PARAMtreeDepth
tree depth in FastBDT.
std::vector< TrainSample > m_samples
vector in which all samples are collected on the fly in event.
const TrainSample makeTrainSample(const Belle2::SpacePoint *outerHit, const Belle2::SpacePoint *centerHit, const Belle2::SpacePoint *innerHit)
create a trainings sample from the three hit combination
void terminate() override
take the collected data and train a FBDTClassifier and store it in the given output file
bool m_PARAMdoTrain
actually train a classifier or only do collection
std::string m_PARAMfbdtOutFileName
output file name into which the FBDTClassifier is stored.
double m_PARAMshrinkage
shrinkage parameter of FastBDT.
bool m_PARAMstoreSamples
store the collected samples into a file
std::string m_PARAMnetworkInputName
name of the StoreObjPtr in which the network container is stored which contains the network that is u...
bool m_PARAMuseSamples
use pre-collected samples for training and bypass the collection step
int m_PARAMnTrees
number of trees in the FastBDT.
@ c_Debug
Debug: for code development.
static LogSystem & Instance()
Static method to get a reference to the LogSystem instance.
void setDescription(const std::string &description)
Sets the description of the module.
SpacePoint typically is build from 1 PXDCluster or 1-2 SVDClusters.
double Z() const
return the z-value of the global position of the SpacePoint
double X() const
return the x-value of the global position of the SpacePoint
double Y() const
return the y-value of the global position of the SpacePoint
bool isRequired(const std::string &name="")
Ensure this array/object has been registered previously.
void addParam(const std::string &name, T ¶mVariable, const std::string &description, const T &defaultValue)
Adds a new parameter to the module.
#define REG_MODULE(moduleName)
Register the given module (without 'Module' suffix) with the framework.
static void writeSamplesToStream(std::ostream &os, const std::vector< FBDTTrainSample< Ndims > > &samples)
write all samples to stream
static std::vector< Belle2::MCVXDPurityInfo > createPurityInfosVec(const std::vector< const Belle2::SpacePoint * > &spacePoints)
create a vector of MCVXDPurityInfos objects for a std::vector<Belle2::SpacePoints>.
B2Vector3D outerHit(0, 0, 0)
testing out of range behavior
static void readSamplesFromStream(std::istream &is, std::vector< FBDTTrainSample< Ndims > > &samples)
read samples from stream and append them to samples
Abstract base class for different kinds of events.
bundle together the classifier input and the target value into one struct for easier passing around.