Belle II Software development
FBDTClassifier.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <tracking/trackFindingVXD/filterTools/FBDTClassifier.h>
10#include <framework/logging/Logger.h>
11
12#include <FastBDT_IO.h>
13
14using namespace Belle2;
15
16template<size_t Ndims>
18{
19 m_featBins.clear(); // clear possibly present feature Binning
20 B2DEBUG(20, "Reading the FeatureBinnings");
21 is >> m_featBins;
22 B2DEBUG(20, "Reading the Forest");
23 m_forest = FastBDT::readForestFromStream<unsigned int>(is);
24 B2DEBUG(20, "Reading the DecorrelationMatrix");
25 if (!m_decorrMat.readFromStream(is)) { // for some reason this does not stop if there is no decor matrix
26 B2ERROR("Reading in the decorrelation matrix did not work! The decorrelation matrix of this classifier will be set to identity!");
27 m_decorrMat = DecorrelationMatrix<9>();
28 }
29}
30
31template<size_t Ndims>
32void FBDTClassifier<Ndims>::writeToStream(std::ostream& os) const
33{
34 B2DEBUG(20, "Reading the FeatureBinnings");
35 os << m_featBins << std::endl;
36 B2DEBUG(20, "Reading the Forest");
37 os << m_forest << std::endl;
38 B2DEBUG(20, "Reading the DecorrelationMatrix");
39 os << m_decorrMat.print() << std::endl;
40}
41
42template<size_t Ndims>
44 int nTrees, int depth, double shrinkage, double ratio)
45{
46 if (samples.empty()) {
47 B2ERROR("No samples passed for training a FBDTClassifier.");
48 return;
49 }
50
51 unsigned int nBinCuts = 8;
52 size_t nSamples = samples.size();
53 B2DEBUG(20, "Using for training: nBinCuts: " << nBinCuts << ", with " << Ndims << " features and " << nSamples << " samples.");
54
55 B2DEBUG(20, "FBDTClassifier::train(): Starting to restructure the data into the format better suited for later use");
56 std::array<std::vector<double>, Ndims> data;
57 for (const auto& event : samples) {
58 for (size_t iSP = 0; iSP < Ndims; ++iSP) {
59 data[iSP].push_back(event.hits[iSP]);
60 }
61 }
63 B2DEBUG(20, "FBDTClassifier::train(): Calculating the decorrelation transformation.");
64 m_decorrMat.calculateDecorrMatrix(data, false);
65 B2DEBUG(20, "FBDTClassifier::train(): Applying decorrelation transformation");
66 data = m_decorrMat.decorrelate(data);
68 B2DEBUG(20, "FBDTClassifier::train(): Determining the FeatureBinnings");
69 std::vector<unsigned int> nBinningLevels;
70 m_featBins.clear(); // clear the feature binnings (if present)
71 for (auto featureVec : data) {
72 m_featBins.push_back(FastBDT::FeatureBinning<double>(nBinCuts, featureVec));
73 nBinningLevels.push_back(nBinCuts);
74 }
75
76 // have to use the decorrelated data for training!!!
77 B2DEBUG(20, "FBDTClassifier::train(): Creating the EventSamples");
78 FastBDT::EventSample eventSample(nSamples, Ndims, 0, nBinningLevels);
79 for (size_t iS = 0; iS < nSamples; ++iS) {
80 std::vector<unsigned> bins(Ndims);
81 for (size_t iF = 0; iF < Ndims; ++iF) {
82 bins[iF] = m_featBins[iF].ValueToBin(data[iF][iS]);
83 }
84 eventSample.AddEvent(bins, 1.0, samples[iS].signal);
85 }
86
87 B2DEBUG(20, "FBDTClassifier::train(): Training the FastBDT");
88 FastBDT::ForestBuilder fbdt(eventSample, nTrees, shrinkage, ratio, depth); // train FastBDT
89
90 B2DEBUG(20, "FBDTClassifier::train(): getting FastBDT to internal member");
91 FBDTForest forest(fbdt.GetF0(), fbdt.GetShrinkage(), true);
92 for (const auto& tree : fbdt.GetForest()) {
93 forest.AddTree(tree);
94 }
95
96 m_forest = forest; // check if this can be done better with move or something similar
97}
98
99// explicit instantiation for SpacePoint in order to have .h and .cc file separated
100template class Belle2::FBDTClassifier<9>;
Class holding a Matrix that can be used to decorrelate input data to Machine Learning classifiers.
FastBDT as RelationsObject to make it storable and accessible on/via the DataStore.
void readFromStream(std::istream &is)
read all the necessary data from stream and fill the Forest and the FeatureBinnings NOTE: uses FastBD...
void writeToStream(std::ostream &os) const
write out the data from the Forest and the FeatureBinnings to a stream NOTE: uses FastBDTs IO stuff.
void train(const std::vector< Belle2::FBDTTrainSample< Ndims > > &samples, int nTree, int depth, double shrinkage=0.15, double ratio=0.5)
train the BDT NOTE overwrites a currently existing classifier internally TODO does not work at the mo...
Abstract base class for different kinds of events.
bundle together the classifier input and the target value into one struct for easier passing around.