Belle II Software  release-05-02-19
FBDTClassifier.h
1 /*************************************************************************
2 * BASF2 (Belle Analysis Framework 2) *
3 * Copyright(C) 2015 - Belle II Collaboration *
4 * *
5 * Author: The Belle II Collaboration *
6 * Contributors: Thomas Madlener *
7 * *
8 * This software is provided "as is" without any warranty. *
9 **************************************************************************/
10 
11 #pragma once
12 
13 #include <framework/datastore/RelationsObject.h>
14 #include <FastBDT.h>
15 
16 #include <tracking/trackFindingVXD/filterTools/FBDTClassifierHelper.h>
17 #include <tracking/trackFindingVXD/filterTools/DecorrelationMatrix.h>
18 #include <array>
19 #include <vector>
20 #include <iostream>
21 
22 #if FastBDT_VERSION_MAJOR >= 3
23 typedef FastBDT::Forest<unsigned int> FBDTForest;
24 #else
25 typedef FastBDT::Forest FBDTForest;
26 #endif
27 
28 namespace Belle2 {
42  template<size_t Ndims = 9>
44 
45  public:
46 
47  FBDTClassifier() { ; }// = default; /**< default constructor */
48 
50  FBDTClassifier(const FBDTForest& forest, const std::vector<FastBDT::FeatureBinning<double> >& fB,
51  const Belle2::DecorrelationMatrix<9>& dM) : m_forest(forest), m_featBins(fB), m_decorrMat(dM) { ; }
52 
53  ~FBDTClassifier() { ; }
56  double analyze(const std::array<double, Ndims>& hits) const;
57 
62  void train(const std::vector<Belle2::FBDTTrainSample<Ndims> >& samples,
63  int nTree, int depth, double shrinkage = 0.15, double ratio = 0.5);
64 
68  void readFromStream(std::istream& is);
69 
73  void writeToStream(std::ostream& os) const;
74 
76  FBDTForest getForest() const { return m_forest; }
77 
79  std::vector<FastBDT::FeatureBinning<double> > getFeatureBinnings() const { return m_featBins; }
80 
83 
84  private:
85 
86  FBDTForest m_forest{};
88  std::vector<FastBDT::FeatureBinning<double> > m_featBins{};
92  // TODO: make this work with the externals -> tell Thomas Keck what is needed for this stuff to work in the externals
94  ClassDef(FBDTClassifier, 2); // first version: only Forest and FeatureBinnings present
95  };
96 
97 
98 
99  // =================================== IMPLEMENTATION ==============================
100 
101  template<size_t Ndims>
102  double FBDTClassifier<Ndims>::analyze(const std::array<double, Ndims>& hits) const
103  {
104  std::vector<double> positions = m_decorrMat.decorrelate(hits);
105 
106  std::vector<unsigned> bins(Ndims);
107  for (size_t i = 0; i < Ndims; ++i) {
108  bins[i] = m_featBins[i].ValueToBin(positions[i]);
109  }
110 
111  return m_forest.Analyse(bins);
112  }
113 
115 } // end namespace Belle2
Belle2::DecorrelationMatrix< 9 >
Belle2::FBDTClassifier::ClassDef
ClassDef(FBDTClassifier, 2)
Making this Class a ROOT class.
Belle2::FBDTClassifier::m_forest
FBDTForest m_forest
the forest used for classification
Definition: FBDTClassifier.h:86
Belle2::FBDTClassifier::analyze
double analyze(const std::array< double, Ndims > &hits) const
calculate the output of the FastBDT.
Definition: FBDTClassifier.h:102
Belle2::FBDTClassifier::m_decorrMat
Belle2::DecorrelationMatrix< Ndims > m_decorrMat
the decorrelation matrix used in this classifier
Definition: FBDTClassifier.h:90
Belle2::FBDTTrainSample
bundle together the classifier input and the target value into one struct for easier passing around.
Definition: FBDTClassifierHelper.h:36
Belle2::FBDTClassifier::getFeatureBinnings
std::vector< FastBDT::FeatureBinning< double > > getFeatureBinnings() const
get the feature binnings
Definition: FBDTClassifier.h:79
Belle2::FBDTClassifier::FBDTClassifier
FBDTClassifier(const FBDTForest &forest, const std::vector< FastBDT::FeatureBinning< double > > &fB, const Belle2::DecorrelationMatrix< 9 > &dM)
constructor from three main parts.
Definition: FBDTClassifier.h:50
Belle2::FBDTClassifier::readFromStream
void readFromStream(std::istream &is)
read all the necessary data from stream and fill the Forest and the FeatureBinnings NOTE: uses FastBD...
Definition: FBDTClassifier.cc:24
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::FBDTClassifier::getForest
FBDTForest getForest() const
get the forest
Definition: FBDTClassifier.h:76
Belle2::FBDTClassifier::train
void train(const std::vector< Belle2::FBDTTrainSample< Ndims > > &samples, int nTree, int depth, double shrinkage=0.15, double ratio=0.5)
train the BDT NOTE overwrites a currently existing classifier internally TODO does not work at the mo...
Definition: FBDTClassifier.cc:54
Belle2::FBDTClassifier::getDecorrelationMatrix
Belle2::DecorrelationMatrix< 9 > getDecorrelationMatrix() const
get the decorrelation matrix
Definition: FBDTClassifier.h:82
Belle2::FBDTClassifier::writeToStream
void writeToStream(std::ostream &os) const
write out the data from the Forest and the FeatureBinnings to a stream NOTE: uses FastBDTs IO stuff.
Definition: FBDTClassifier.cc:43
Belle2::FBDTClassifier
FastBDT as RelationsObject to make it storeable and accesible on/via the DataStore.
Definition: FBDTClassifier.h:43
Belle2::FBDTClassifier::m_featBins
std::vector< FastBDT::FeatureBinning< double > > m_featBins
the feature binnings corresponding to the BDT
Definition: FBDTClassifier.h:88
Belle2::RelationsInterface
Defines interface for accessing relations of objects in StoreArray.
Definition: RelationsObject.h:102