11#include <framework/datastore/RelationsObject.h>
14#include <tracking/trackFindingVXD/filterTools/FBDTClassifierHelper.h>
15#include <tracking/trackFindingVXD/filterTools/DecorrelationMatrix.h>
20typedef FastBDT::Forest<unsigned int> FBDTForest;
36 template<
size_t Ndims = 9>
41 FBDTClassifier() { ; }
44 FBDTClassifier(
const FBDTForest& forest,
const std::vector<FastBDT::FeatureBinning<double> >& fB,
50 double analyze(
const std::array<double, Ndims>& hits)
const;
57 int nTree,
int depth,
double shrinkage = 0.15,
double ratio = 0.5);
82 std::vector<FastBDT::FeatureBinning<double> >
m_featBins{};
95 template<
size_t Ndims>
98 std::vector<double> positions =
m_decorrMat.decorrelate(hits);
100 std::vector<unsigned> bins(Ndims);
101 for (
size_t i = 0; i < Ndims; ++i) {
102 bins[i] =
m_featBins[i].ValueToBin(positions[i]);
Class holding a Matrix that can be used to decorrelate input data to Machine Learning classifiers.
FBDTForest m_forest
the forest used for classification
void readFromStream(std::istream &is)
read all the necessary data from stream and fill the Forest and the FeatureBinnings NOTE: uses FastBD...
void writeToStream(std::ostream &os) const
write out the data from the Forest and the FeatureBinnings to a stream NOTE: uses FastBDTs IO stuff.
FBDTForest getForest() const
get the forest
void train(const std::vector< Belle2::FBDTTrainSample< Ndims > > &samples, int nTree, int depth, double shrinkage=0.15, double ratio=0.5)
Belle2::DecorrelationMatrix< 9 > getDecorrelationMatrix() const
get the decorrelation matrix
ClassDef(FBDTClassifier, 2)
Making this Class a ROOT class.
FBDTClassifier(const FBDTForest &forest, const std::vector< FastBDT::FeatureBinning< double > > &fB, const Belle2::DecorrelationMatrix< 9 > &dM)
constructor from three main parts.
std::vector< FastBDT::FeatureBinning< double > > m_featBins
the feature binnings corresponding to the BDT
std::vector< FastBDT::FeatureBinning< double > > getFeatureBinnings() const
get the feature binnings
~FBDTClassifier()
TODO destructor.
Belle2::DecorrelationMatrix< Ndims > m_decorrMat
the decorrelation matrix used in this classifier
RelationsInterface< TObject > RelationsObject
Provides interface for getting/adding relations to objects in StoreArrays.
double analyze(const std::array< double, Ndims > &hits) const
calculate the output of the FastBDT.
Abstract base class for different kinds of events.
bundle together the classifier input and the target value into one struct for easier passing around.