Belle II Software  release-08-01-10
FBDTClassifier.h
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #pragma once
10 
11 #include <framework/datastore/RelationsObject.h>
12 #include <FastBDT.h>
13 
14 #include <tracking/trackFindingVXD/filterTools/FBDTClassifierHelper.h>
15 #include <tracking/trackFindingVXD/filterTools/DecorrelationMatrix.h>
16 #include <array>
17 #include <vector>
18 #include <iostream>
19 
20 #if FastBDT_VERSION_MAJOR >= 3
21 typedef FastBDT::Forest<unsigned int> FBDTForest;
22 #else
23 typedef FastBDT::Forest FBDTForest;
24 #endif
25 
26 namespace Belle2 {
40  template<size_t Ndims = 9>
42 
43  public:
44 
45  FBDTClassifier() { ; }// = default; /**< default constructor */
46 
48  FBDTClassifier(const FBDTForest& forest, const std::vector<FastBDT::FeatureBinning<double> >& fB,
49  const Belle2::DecorrelationMatrix<9>& dM) : m_forest(forest), m_featBins(fB), m_decorrMat(dM) { ; }
50 
51  ~FBDTClassifier() { ; }
54  double analyze(const std::array<double, Ndims>& hits) const;
55 
60  void train(const std::vector<Belle2::FBDTTrainSample<Ndims> >& samples,
61  int nTree, int depth, double shrinkage = 0.15, double ratio = 0.5);
62 
66  void readFromStream(std::istream& is);
67 
71  void writeToStream(std::ostream& os) const;
72 
74  FBDTForest getForest() const { return m_forest; }
75 
77  std::vector<FastBDT::FeatureBinning<double> > getFeatureBinnings() const { return m_featBins; }
78 
81 
82  private:
83 
84  FBDTForest m_forest{};
86  std::vector<FastBDT::FeatureBinning<double> > m_featBins{};
90  // TODO: make this work with the externals -> tell Thomas Keck what is needed for this stuff to work in the externals
92  ClassDef(FBDTClassifier, 2); // first version: only Forest and FeatureBinnings present
93  };
94 
95 
96 
97  // =================================== IMPLEMENTATION ==============================
98 
99  template<size_t Ndims>
100  double FBDTClassifier<Ndims>::analyze(const std::array<double, Ndims>& hits) const
101  {
102  std::vector<double> positions = m_decorrMat.decorrelate(hits);
103 
104  std::vector<unsigned> bins(Ndims);
105  for (size_t i = 0; i < Ndims; ++i) {
106  bins[i] = m_featBins[i].ValueToBin(positions[i]);
107  }
108 
109  return m_forest.Analyse(bins);
110  }
111 
113 } // end namespace Belle2
FastBDT as RelationsObject to make it storeable and accesible on/via the DataStore.
FBDTForest m_forest
the forest used for classification
void readFromStream(std::istream &is)
read all the necessary data from stream and fill the Forest and the FeatureBinnings NOTE: uses FastBD...
Belle2::DecorrelationMatrix< 9 > getDecorrelationMatrix() const
get the decorrelation matrix
void writeToStream(std::ostream &os) const
write out the data from the Forest and the FeatureBinnings to a stream NOTE: uses FastBDTs IO stuff.
FBDTForest getForest() const
get the forest
void train(const std::vector< Belle2::FBDTTrainSample< Ndims > > &samples, int nTree, int depth, double shrinkage=0.15, double ratio=0.5)
train the BDT NOTE overwrites a currently existing classifier internally TODO does not work at the mo...
std::vector< FastBDT::FeatureBinning< double > > getFeatureBinnings() const
get the feature binnings
ClassDef(FBDTClassifier, 2)
Making this Class a ROOT class.
FBDTClassifier(const FBDTForest &forest, const std::vector< FastBDT::FeatureBinning< double > > &fB, const Belle2::DecorrelationMatrix< 9 > &dM)
constructor from three main parts.
std::vector< FastBDT::FeatureBinning< double > > m_featBins
the feature binnings corresponding to the BDT
~FBDTClassifier()
TODO destructor.
Belle2::DecorrelationMatrix< Ndims > m_decorrMat
the decorrelation matrix used in this classifier
Defines interface for accessing relations of objects in StoreArray.
double analyze(const std::array< double, Ndims > &hits) const
calculate the output of the FastBDT.
Abstract base class for different kinds of events.
bundle together the classifier input and the target value into one struct for easier passing around.