Belle II Software development
FBDTClassifier.h
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#pragma once
10
11#include <framework/datastore/RelationsObject.h>
12#include <FastBDT.h>
13
14#include <tracking/trackFindingVXD/filterTools/FBDTClassifierHelper.h>
15#include <tracking/trackFindingVXD/filterTools/DecorrelationMatrix.h>
16#include <array>
17#include <vector>
18#include <iostream>
19
20typedef FastBDT::Forest<unsigned int> FBDTForest;
21
22namespace Belle2 {
36 template<size_t Ndims = 9>
38
39 public:
40
41 FBDTClassifier() { ; }// = default; /**< default constructor */
42
44 FBDTClassifier(const FBDTForest& forest, const std::vector<FastBDT::FeatureBinning<double> >& fB,
45 const Belle2::DecorrelationMatrix<9>& dM) : m_forest(forest), m_featBins(fB), m_decorrMat(dM) { ; }
46
50 double analyze(const std::array<double, Ndims>& hits) const;
51
56 void train(const std::vector<Belle2::FBDTTrainSample<Ndims> >& samples,
57 int nTree, int depth, double shrinkage = 0.15, double ratio = 0.5);
58
62 void readFromStream(std::istream& is);
63
67 void writeToStream(std::ostream& os) const;
68
70 FBDTForest getForest() const { return m_forest; }
71
73 std::vector<FastBDT::FeatureBinning<double> > getFeatureBinnings() const { return m_featBins; }
74
77
78 private:
79
80 FBDTForest m_forest{};
82 std::vector<FastBDT::FeatureBinning<double> > m_featBins{};
86 // TODO: make this work with the externals -> tell Thomas Keck what is needed for this stuff to work in the externals
88 ClassDef(FBDTClassifier, 2); // first version: only Forest and FeatureBinnings present
89 };
90
91
92
93 // =================================== IMPLEMENTATION ==============================
94
95 template<size_t Ndims>
96 double FBDTClassifier<Ndims>::analyze(const std::array<double, Ndims>& hits) const
97 {
98 std::vector<double> positions = m_decorrMat.decorrelate(hits);
99
100 std::vector<unsigned> bins(Ndims);
101 for (size_t i = 0; i < Ndims; ++i) {
102 bins[i] = m_featBins[i].ValueToBin(positions[i]);
103 }
104
105 return m_forest.Analyse(bins);
106 }
107
109} // end namespace Belle2
Class holding a Matrix that can be used to decorrelate input data to Machine Learning classifiers.
FastBDT as RelationsObject to make it storable and accessible on/via the DataStore.
FBDTForest m_forest
the forest used for classification
void readFromStream(std::istream &is)
read all the necessary data from stream and fill the Forest and the FeatureBinnings NOTE: uses FastBD...
void writeToStream(std::ostream &os) const
write out the data from the Forest and the FeatureBinnings to a stream NOTE: uses FastBDTs IO stuff.
FBDTForest getForest() const
get the forest
Belle2::DecorrelationMatrix< 9 > getDecorrelationMatrix() const
get the decorrelation matrix
ClassDef(FBDTClassifier, 2)
Making this Class a ROOT class.
FBDTClassifier(const FBDTForest &forest, const std::vector< FastBDT::FeatureBinning< double > > &fB, const Belle2::DecorrelationMatrix< 9 > &dM)
constructor from three main parts.
std::vector< FastBDT::FeatureBinning< double > > m_featBins
the feature binnings corresponding to the BDT
std::vector< FastBDT::FeatureBinning< double > > getFeatureBinnings() const
get the feature binnings
~FBDTClassifier()
TODO destructor.
Belle2::DecorrelationMatrix< Ndims > m_decorrMat
the decorrelation matrix used in this classifier
Defines interface for accessing relations of objects in StoreArray.
double analyze(const std::array< double, Ndims > &hits) const
calculate the output of the FastBDT.
Abstract base class for different kinds of events.
bundle together the classifier input and the target value into one struct for easier passing around.