Belle II Software  release-06-02-00
FBDTClassifier.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <tracking/trackFindingVXD/filterTools/FBDTClassifier.h>
10 #include <framework/logging/Logger.h>
11 
12 #if FastBDT_VERSION_MAJOR >= 3
13 #include <FastBDT_IO.h>
14 #else
15 #include <IO.h>
16 #endif
17 
18 using namespace Belle2;
19 using namespace std;
20 
21 template<size_t Ndims>
23 {
24  m_featBins.clear(); // clear possibly present feature Binning
25  B2DEBUG(20, "Reading the FeatureBinnings");
26  is >> m_featBins;
27  B2DEBUG(20, "Reading the Forest");
28 #if FastBDT_VERSION_MAJOR >= 3
29  m_forest = FastBDT::readForestFromStream<unsigned int>(is);
30 #else
31  m_forest = FastBDT::readForestFromStream(is);
32 #endif
33  B2DEBUG(20, "Reading the DecorrelationMatrix");
34  if (!m_decorrMat.readFromStream(is)) { // for some reason this does not stop if there is no decor matrix
35  B2ERROR("Reading in the decorrelation matrix did not work! The decorrelation matrix of this classifier will be set to identity!");
36  m_decorrMat = DecorrelationMatrix<9>();
37  }
38 }
39 
40 template<size_t Ndims>
41 void FBDTClassifier<Ndims>::writeToStream(std::ostream& os) const
42 {
43  B2DEBUG(20, "Reading the FeatureBinnings");
44  os << m_featBins << std::endl;
45  B2DEBUG(20, "Reading the Forest");
46  os << m_forest << std::endl;
47  B2DEBUG(20, "Reading the DecorrelationMatrix");
48  os << m_decorrMat.print() << std::endl;
49 }
50 
51 template<size_t Ndims>
53  int nTrees, int depth, double shrinkage, double ratio)
54 {
55  if (samples.empty()) {
56  B2ERROR("No samples passed for training a FBDTClassifier.");
57  return;
58  }
59 
60  unsigned int nBinCuts = 8;
61  size_t nSamples = samples.size();
62  B2DEBUG(20, "Using for training: nBinCuts: " << nBinCuts << ", with " << Ndims << " features and " << nSamples << " samples.");
63 
64  B2DEBUG(20, "FBDTClassifier::train(): Starting to restructure the data into the format better suited for later use");
65  std::array<std::vector<double>, Ndims> data;
66  for (const auto& event : samples) {
67  for (size_t iSP = 0; iSP < Ndims; ++iSP) {
68  data[iSP].push_back(event.hits[iSP]);
69  }
70  }
71 
72  B2DEBUG(20, "FBDTClassifier::train(): Calculating the decorrelation transformation.");
73  m_decorrMat.calculateDecorrMatrix(data, false);
74  B2DEBUG(20, "FBDTClassifier::train(): Applying decorrelation transformation");
75  data = m_decorrMat.decorrelate(data);
76 
77  B2DEBUG(20, "FBDTClassifier::train(): Determining the FeatureBinnings");
78  std::vector<unsigned int> nBinningLevels;
79  m_featBins.clear(); // clear the feature binnings (if present)
80  for (auto featureVec : data) {
81 #if FastBDT_VERSION_MAJOR >= 3
82  m_featBins.push_back(FastBDT::FeatureBinning<double>(nBinCuts, featureVec));
83 #else
84  m_featBins.push_back(FastBDT::FeatureBinning<double>(nBinCuts, featureVec.begin(), featureVec.end()));
85 #endif
86  nBinningLevels.push_back(nBinCuts);
87  }
88 
89  // have to use the decorrelated data for training!!!
90  B2DEBUG(20, "FBDTClassifier::train(): Creating the EventSamples");
91 #if FastBDT_VERSION_MAJOR >= 5
92  FastBDT::EventSample eventSample(nSamples, Ndims, 0, nBinningLevels);
93 #else
94  FastBDT::EventSample eventSample(nSamples, Ndims, nBinningLevels);
95 #endif
96  for (size_t iS = 0; iS < nSamples; ++iS) {
97  std::vector<unsigned> bins(Ndims);
98  for (size_t iF = 0; iF < Ndims; ++iF) {
99  bins[iF] = m_featBins[iF].ValueToBin(data[iF][iS]);
100  }
101  eventSample.AddEvent(bins, 1.0, samples[iS].signal);
102  }
103 
104  B2DEBUG(20, "FBDTClassifier::train(): Training the FastBDT");
105  FastBDT::ForestBuilder fbdt(eventSample, nTrees, shrinkage, ratio, depth); // train FastBDT
106 
107  B2DEBUG(20, "FBDTClassifier::train(): getting FastBDT to internal member");
108 #if FastBDT_VERSION_MAJOR >= 3
109  FBDTForest forest(fbdt.GetF0(), fbdt.GetShrinkage(), true);
110 #else
111  FBDTForest forest(fbdt.GetF0(), fbdt.GetShrinkage());
112 #endif
113  for (const auto& tree : fbdt.GetForest()) {
114  forest.AddTree(tree);
115  }
116 
117  m_forest = forest; // check if this can be done better with move or something similar
118 }
119 
120 // explicit instantiation for SpacePoint in order to have .h and .cc file separated
121 template class Belle2::FBDTClassifier<9>;
void readFromStream(std::istream &is)
read all the necessary data from stream and fill the Forest and the FeatureBinnings NOTE: uses FastBD...
void writeToStream(std::ostream &os) const
write out the data from the Forest and the FeatureBinnings to a stream NOTE: uses FastBDTs IO stuff.
void train(const std::vector< Belle2::FBDTTrainSample< Ndims > > &samples, int nTree, int depth, double shrinkage=0.15, double ratio=0.5)
train the BDT NOTE overwrites a currently existing classifier internally TODO does not work at the mo...
Abstract base class for different kinds of events.
bundle together the classifier input and the target value into one struct for easier passing around.