Belle II Software  release-08-01-10
FBDTClassifier.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <tracking/trackFindingVXD/filterTools/FBDTClassifier.h>
10 #include <framework/logging/Logger.h>
11 
12 #if FastBDT_VERSION_MAJOR >= 3
13 #include <FastBDT_IO.h>
14 #else
15 #include <IO.h>
16 #endif
17 
18 using namespace Belle2;
19 
20 template<size_t Ndims>
22 {
23  m_featBins.clear(); // clear possibly present feature Binning
24  B2DEBUG(20, "Reading the FeatureBinnings");
25  is >> m_featBins;
26  B2DEBUG(20, "Reading the Forest");
27 #if FastBDT_VERSION_MAJOR >= 3
28  m_forest = FastBDT::readForestFromStream<unsigned int>(is);
29 #else
30  m_forest = FastBDT::readForestFromStream(is);
31 #endif
32  B2DEBUG(20, "Reading the DecorrelationMatrix");
33  if (!m_decorrMat.readFromStream(is)) { // for some reason this does not stop if there is no decor matrix
34  B2ERROR("Reading in the decorrelation matrix did not work! The decorrelation matrix of this classifier will be set to identity!");
35  m_decorrMat = DecorrelationMatrix<9>();
36  }
37 }
38 
39 template<size_t Ndims>
40 void FBDTClassifier<Ndims>::writeToStream(std::ostream& os) const
41 {
42  B2DEBUG(20, "Reading the FeatureBinnings");
43  os << m_featBins << std::endl;
44  B2DEBUG(20, "Reading the Forest");
45  os << m_forest << std::endl;
46  B2DEBUG(20, "Reading the DecorrelationMatrix");
47  os << m_decorrMat.print() << std::endl;
48 }
49 
50 template<size_t Ndims>
52  int nTrees, int depth, double shrinkage, double ratio)
53 {
54  if (samples.empty()) {
55  B2ERROR("No samples passed for training a FBDTClassifier.");
56  return;
57  }
58 
59  unsigned int nBinCuts = 8;
60  size_t nSamples = samples.size();
61  B2DEBUG(20, "Using for training: nBinCuts: " << nBinCuts << ", with " << Ndims << " features and " << nSamples << " samples.");
62 
63  B2DEBUG(20, "FBDTClassifier::train(): Starting to restructure the data into the format better suited for later use");
64  std::array<std::vector<double>, Ndims> data;
65  for (const auto& event : samples) {
66  for (size_t iSP = 0; iSP < Ndims; ++iSP) {
67  data[iSP].push_back(event.hits[iSP]);
68  }
69  }
70 
71  B2DEBUG(20, "FBDTClassifier::train(): Calculating the decorrelation transformation.");
72  m_decorrMat.calculateDecorrMatrix(data, false);
73  B2DEBUG(20, "FBDTClassifier::train(): Applying decorrelation transformation");
74  data = m_decorrMat.decorrelate(data);
75 
76  B2DEBUG(20, "FBDTClassifier::train(): Determining the FeatureBinnings");
77  std::vector<unsigned int> nBinningLevels;
78  m_featBins.clear(); // clear the feature binnings (if present)
79  for (auto featureVec : data) {
80 #if FastBDT_VERSION_MAJOR >= 3
81  m_featBins.push_back(FastBDT::FeatureBinning<double>(nBinCuts, featureVec));
82 #else
83  m_featBins.push_back(FastBDT::FeatureBinning<double>(nBinCuts, featureVec.begin(), featureVec.end()));
84 #endif
85  nBinningLevels.push_back(nBinCuts);
86  }
87 
88  // have to use the decorrelated data for training!!!
89  B2DEBUG(20, "FBDTClassifier::train(): Creating the EventSamples");
90 #if FastBDT_VERSION_MAJOR >= 5
91  FastBDT::EventSample eventSample(nSamples, Ndims, 0, nBinningLevels);
92 #else
93  FastBDT::EventSample eventSample(nSamples, Ndims, nBinningLevels);
94 #endif
95  for (size_t iS = 0; iS < nSamples; ++iS) {
96  std::vector<unsigned> bins(Ndims);
97  for (size_t iF = 0; iF < Ndims; ++iF) {
98  bins[iF] = m_featBins[iF].ValueToBin(data[iF][iS]);
99  }
100  eventSample.AddEvent(bins, 1.0, samples[iS].signal);
101  }
102 
103  B2DEBUG(20, "FBDTClassifier::train(): Training the FastBDT");
104  FastBDT::ForestBuilder fbdt(eventSample, nTrees, shrinkage, ratio, depth); // train FastBDT
105 
106  B2DEBUG(20, "FBDTClassifier::train(): getting FastBDT to internal member");
107 #if FastBDT_VERSION_MAJOR >= 3
108  FBDTForest forest(fbdt.GetF0(), fbdt.GetShrinkage(), true);
109 #else
110  FBDTForest forest(fbdt.GetF0(), fbdt.GetShrinkage());
111 #endif
112  for (const auto& tree : fbdt.GetForest()) {
113  forest.AddTree(tree);
114  }
115 
116  m_forest = forest; // check if this can be done better with move or something similar
117 }
118 
119 // explicit instantiation for SpacePoint in order to have .h and .cc file separated
120 template class Belle2::FBDTClassifier<9>;
void readFromStream(std::istream &is)
read all the necessary data from stream and fill the Forest and the FeatureBinnings NOTE: uses FastBD...
void writeToStream(std::ostream &os) const
write out the data from the Forest and the FeatureBinnings to a stream NOTE: uses FastBDTs IO stuff.
void train(const std::vector< Belle2::FBDTTrainSample< Ndims > > &samples, int nTree, int depth, double shrinkage=0.15, double ratio=0.5)
train the BDT NOTE overwrites a currently existing classifier internally TODO does not work at the mo...
Abstract base class for different kinds of events.
bundle together the classifier input and the target value into one struct for easier passing around.