11 #include <tracking/trackFindingVXD/filterTools/FBDTClassifier.h>
12 #include <framework/logging/Logger.h>
14 #if FastBDT_VERSION_MAJOR >= 3
15 #include <FastBDT_IO.h>
23 template<
size_t Ndims>
27 B2DEBUG(20,
"Reading the FeatureBinnings");
29 B2DEBUG(20,
"Reading the Forest");
30 #if FastBDT_VERSION_MAJOR >= 3
31 m_forest = FastBDT::readForestFromStream<unsigned int>(is);
33 m_forest = FastBDT::readForestFromStream(is);
35 B2DEBUG(20,
"Reading the DecorrelationMatrix");
36 if (!m_decorrMat.readFromStream(is)) {
37 B2ERROR(
"Reading in the decorrelation matrix did not work! The decorrelation matrix of this classifier will be set to identity!");
42 template<
size_t Ndims>
45 B2DEBUG(20,
"Reading the FeatureBinnings");
46 os << m_featBins << std::endl;
47 B2DEBUG(20,
"Reading the Forest");
48 os << m_forest << std::endl;
49 B2DEBUG(20,
"Reading the DecorrelationMatrix");
50 os << m_decorrMat.print() << std::endl;
53 template<
size_t Ndims>
55 int nTrees,
int depth,
double shrinkage,
double ratio)
57 if (samples.empty()) {
58 B2ERROR(
"No samples passed for training a FBDTClassifier.");
62 unsigned int nBinCuts = 8;
63 size_t nSamples = samples.size();
64 B2DEBUG(20,
"Using for training: nBinCuts: " << nBinCuts <<
", with " << Ndims <<
" features and " << nSamples <<
" samples.");
66 B2DEBUG(20,
"FBDTClassifier::train(): Starting to restructure the data into the format better suited for later use");
67 std::array<std::vector<double>, Ndims> data;
68 for (
const auto& event : samples) {
69 for (
size_t iSP = 0; iSP < Ndims; ++iSP) {
70 data[iSP].push_back(event.hits[iSP]);
74 B2DEBUG(20,
"FBDTClassifier::train(): Calculating the decorrelation transformation.");
75 m_decorrMat.calculateDecorrMatrix(data,
false);
76 B2DEBUG(20,
"FBDTClassifier::train(): Applying decorrelation transformation");
77 data = m_decorrMat.decorrelate(data);
79 B2DEBUG(20,
"FBDTClassifier::train(): Determining the FeatureBinnings");
80 std::vector<unsigned int> nBinningLevels;
82 for (
auto featureVec : data) {
83 #if FastBDT_VERSION_MAJOR >= 3
84 m_featBins.push_back(FastBDT::FeatureBinning<double>(nBinCuts, featureVec));
86 m_featBins.push_back(FastBDT::FeatureBinning<double>(nBinCuts, featureVec.begin(), featureVec.end()));
88 nBinningLevels.push_back(nBinCuts);
92 B2DEBUG(20,
"FBDTClassifier::train(): Creating the EventSamples");
93 #if FastBDT_VERSION_MAJOR >= 5
94 FastBDT::EventSample eventSample(nSamples, Ndims, 0, nBinningLevels);
96 FastBDT::EventSample eventSample(nSamples, Ndims, nBinningLevels);
98 for (
size_t iS = 0; iS < nSamples; ++iS) {
99 std::vector<unsigned> bins(Ndims);
100 for (
size_t iF = 0; iF < Ndims; ++iF) {
101 bins[iF] = m_featBins[iF].ValueToBin(data[iF][iS]);
103 eventSample.AddEvent(bins, 1.0, samples[iS].signal);
106 B2DEBUG(20,
"FBDTClassifier::train(): Training the FastBDT");
107 FastBDT::ForestBuilder fbdt(eventSample, nTrees, shrinkage, ratio, depth);
109 B2DEBUG(20,
"FBDTClassifier::train(): getting FastBDT to internal member");
110 #if FastBDT_VERSION_MAJOR >= 3
111 FBDTForest forest(fbdt.GetF0(), fbdt.GetShrinkage(),
true);
113 FBDTForest forest(fbdt.GetF0(), fbdt.GetShrinkage());
115 for (
const auto& tree : fbdt.GetForest()) {
116 forest.AddTree(tree);