Belle II Software  release-05-02-19
FBDTClassifier.cc
1 /*************************************************************************
2 * BASF2 (Belle Analysis Framework 2) *
3 * Copyright(C) 2015 - Belle II Collaboration *
4 * *
5 * Author: The Belle II Collaboration *
6 * Contributors: Thomas Madlener *
7 * *
8 * This software is provided "as is" without any warranty. *
9 **************************************************************************/
10 
11 #include <tracking/trackFindingVXD/filterTools/FBDTClassifier.h>
12 #include <framework/logging/Logger.h>
13 
14 #if FastBDT_VERSION_MAJOR >= 3
15 #include <FastBDT_IO.h>
16 #else
17 #include <IO.h>
18 #endif
19 
20 using namespace Belle2;
21 using namespace std;
22 
23 template<size_t Ndims>
25 {
26  m_featBins.clear(); // clear possibly present feature Binning
27  B2DEBUG(20, "Reading the FeatureBinnings");
28  is >> m_featBins;
29  B2DEBUG(20, "Reading the Forest");
30 #if FastBDT_VERSION_MAJOR >= 3
31  m_forest = FastBDT::readForestFromStream<unsigned int>(is);
32 #else
33  m_forest = FastBDT::readForestFromStream(is);
34 #endif
35  B2DEBUG(20, "Reading the DecorrelationMatrix");
36  if (!m_decorrMat.readFromStream(is)) { // for some reason this does not stop if there is no decor matrix
37  B2ERROR("Reading in the decorrelation matrix did not work! The decorrelation matrix of this classifier will be set to identity!");
38  m_decorrMat = DecorrelationMatrix<9>();
39  }
40 }
41 
42 template<size_t Ndims>
43 void FBDTClassifier<Ndims>::writeToStream(std::ostream& os) const
44 {
45  B2DEBUG(20, "Reading the FeatureBinnings");
46  os << m_featBins << std::endl;
47  B2DEBUG(20, "Reading the Forest");
48  os << m_forest << std::endl;
49  B2DEBUG(20, "Reading the DecorrelationMatrix");
50  os << m_decorrMat.print() << std::endl;
51 }
52 
53 template<size_t Ndims>
55  int nTrees, int depth, double shrinkage, double ratio)
56 {
57  if (samples.empty()) {
58  B2ERROR("No samples passed for training a FBDTClassifier.");
59  return;
60  }
61 
62  unsigned int nBinCuts = 8;
63  size_t nSamples = samples.size();
64  B2DEBUG(20, "Using for training: nBinCuts: " << nBinCuts << ", with " << Ndims << " features and " << nSamples << " samples.");
65 
66  B2DEBUG(20, "FBDTClassifier::train(): Starting to restructure the data into the format better suited for later use");
67  std::array<std::vector<double>, Ndims> data;
68  for (const auto& event : samples) {
69  for (size_t iSP = 0; iSP < Ndims; ++iSP) {
70  data[iSP].push_back(event.hits[iSP]);
71  }
72  }
73 
74  B2DEBUG(20, "FBDTClassifier::train(): Calculating the decorrelation transformation.");
75  m_decorrMat.calculateDecorrMatrix(data, false);
76  B2DEBUG(20, "FBDTClassifier::train(): Applying decorrelation transformation");
77  data = m_decorrMat.decorrelate(data);
78 
79  B2DEBUG(20, "FBDTClassifier::train(): Determining the FeatureBinnings");
80  std::vector<unsigned int> nBinningLevels;
81  m_featBins.clear(); // clear the feature binnings (if present)
82  for (auto featureVec : data) {
83 #if FastBDT_VERSION_MAJOR >= 3
84  m_featBins.push_back(FastBDT::FeatureBinning<double>(nBinCuts, featureVec));
85 #else
86  m_featBins.push_back(FastBDT::FeatureBinning<double>(nBinCuts, featureVec.begin(), featureVec.end()));
87 #endif
88  nBinningLevels.push_back(nBinCuts);
89  }
90 
91  // have to use the decorrelated data for training!!!
92  B2DEBUG(20, "FBDTClassifier::train(): Creating the EventSamples");
93 #if FastBDT_VERSION_MAJOR >= 5
94  FastBDT::EventSample eventSample(nSamples, Ndims, 0, nBinningLevels);
95 #else
96  FastBDT::EventSample eventSample(nSamples, Ndims, nBinningLevels);
97 #endif
98  for (size_t iS = 0; iS < nSamples; ++iS) {
99  std::vector<unsigned> bins(Ndims);
100  for (size_t iF = 0; iF < Ndims; ++iF) {
101  bins[iF] = m_featBins[iF].ValueToBin(data[iF][iS]);
102  }
103  eventSample.AddEvent(bins, 1.0, samples[iS].signal);
104  }
105 
106  B2DEBUG(20, "FBDTClassifier::train(): Training the FastBDT");
107  FastBDT::ForestBuilder fbdt(eventSample, nTrees, shrinkage, ratio, depth); // train FastBDT
108 
109  B2DEBUG(20, "FBDTClassifier::train(): getting FastBDT to internal member");
110 #if FastBDT_VERSION_MAJOR >= 3
111  FBDTForest forest(fbdt.GetF0(), fbdt.GetShrinkage(), true);
112 #else
113  FBDTForest forest(fbdt.GetF0(), fbdt.GetShrinkage());
114 #endif
115  for (const auto& tree : fbdt.GetForest()) {
116  forest.AddTree(tree);
117  }
118 
119  m_forest = forest; // check if this can be done better with move or something similar
120 }
121 
122 // explicit instantiation for SpacePoint in order to have .h and .cc file separated
123 template class Belle2::FBDTClassifier<9>;
Belle2::DecorrelationMatrix< 9 >
Belle2::FBDTTrainSample
bundle together the classifier input and the target value into one struct for easier passing around.
Definition: FBDTClassifierHelper.h:36
Belle2::FBDTClassifier::readFromStream
void readFromStream(std::istream &is)
read all the necessary data from stream and fill the Forest and the FeatureBinnings NOTE: uses FastBD...
Definition: FBDTClassifier.cc:24
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::FBDTClassifier::train
void train(const std::vector< Belle2::FBDTTrainSample< Ndims > > &samples, int nTree, int depth, double shrinkage=0.15, double ratio=0.5)
train the BDT NOTE overwrites a currently existing classifier internally TODO does not work at the mo...
Definition: FBDTClassifier.cc:54
Belle2::FBDTClassifier::writeToStream
void writeToStream(std::ostream &os) const
write out the data from the Forest and the FeatureBinnings to a stream NOTE: uses FastBDTs IO stuff.
Definition: FBDTClassifier.cc:43
Belle2::FBDTClassifier< 9 >