Belle II Software light-2405-quaxo
FastBDT.h
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#pragma once
10#ifndef INCLUDE_GUARD_BELLE2_MVA_FASTBDT_HEADER
11#define INCLUDE_GUARD_BELLE2_MVA_FASTBDT_HEADER
12
13#include <mva/interface/Options.h>
14#include <mva/interface/Teacher.h>
15#include <mva/interface/Expert.h>
16
17#include <FastBDT.h>
18
19#if FastBDT_VERSION_MAJOR >= 3
20#include <FastBDT_IO.h>
21#else
22#include <IO.h>
23#endif
24
25#if FastBDT_VERSION_MAJOR >= 5
26#include <Classifier.h>
27#endif
28
29// Template specialization to fix NAN sort bug of FastBDT in up to Version 3.2
30#if FastBDT_VERSION_MAJOR <= 3 && FastBDT_VERSION_MINOR <= 2
31namespace FastBDT {
32 template<>
33 bool compareIncludingNaN(float i, float j);
34}
35#endif
36
37namespace Belle2 {
42 namespace MVA {
43
44
48 bool isValidSignal(const std::vector<bool>& Signals);
49
54
55 public:
60 virtual void load(const boost::property_tree::ptree& pt) override;
61
66 virtual void save(boost::property_tree::ptree& pt) const override;
67
71 virtual po::options_description getDescription() override;
72
76 virtual std::string getMethod() const override { return "FastBDT"; }
77
78 unsigned int m_nTrees = 200;
79 unsigned int m_nCuts = 8;
80 unsigned int m_nLevels = 3;
81 double m_shrinkage = 0.1;
82 double m_randRatio = 0.5;
83#if FastBDT_VERSION_MAJOR >= 5
84 std::vector<unsigned int>
85 m_individual_nCuts;
86 double m_flatnessLoss = -1.0;
87 bool m_sPlot = false;
88 bool m_purityTransformation = false;
89 std::vector<bool>
90 m_individualPurityTransformation;
91#endif
92 };
93
94
98 class FastBDTTeacher : public Teacher {
99
100 public:
106 FastBDTTeacher(const GeneralOptions& general_options, const FastBDTOptions& specific_options);
107
112 virtual Weightfile train(Dataset& training_data) const override;
113
114 private:
116 };
117
118
122 class FastBDTExpert : public MVA::Expert {
123
124 public:
129 virtual void load(Weightfile& weightfile) override;
130
135 virtual std::vector<float> apply(Dataset& test_data) const override;
136
137 private:
139#if FastBDT_VERSION_MAJOR >= 3
140#if FastBDT_VERSION_MAJOR >= 5
141 bool m_use_simplified_interface = false;
142 FastBDT::Classifier m_classifier;
143#endif
144 FastBDT::Forest<float> m_expert_forest;
145#else
146 FastBDT::Forest m_expert_forest;
147 std::vector<FastBDT::FeatureBinning<float>> m_expert_feature_binning;
148#endif
149 };
150
151 }
153}
154#endif
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
Abstract base class of all Expert Each MVA library has its own implementation of this class,...
Definition: Expert.h:31
Expert for the FastBDT MVA method.
Definition: FastBDT.h:122
std::vector< FastBDT::FeatureBinning< float > > m_expert_feature_binning
Forest feature binning.
Definition: FastBDT.h:147
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
Definition: FastBDT.cc:415
FastBDT::Forest m_expert_forest
Forest Expert.
Definition: FastBDT.h:146
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: FastBDT.cc:322
FastBDTOptions m_specific_options
Method specific options.
Definition: FastBDT.h:138
Options for the FANN MVA method.
Definition: FastBDT.h:53
virtual std::string getMethod() const override
Return method name.
Definition: FastBDT.h:76
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: FastBDT.cc:126
double m_randRatio
Fraction of data to use in the stochastic training.
Definition: FastBDT.h:82
double m_shrinkage
Shrinkage during the boosting step.
Definition: FastBDT.h:81
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: FastBDT.cc:53
unsigned int m_nLevels
Depth of tree.
Definition: FastBDT.h:80
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: FastBDT.cc:99
unsigned int m_nCuts
Number of cut Levels = log_2(Number of Cuts)
Definition: FastBDT.h:79
unsigned int m_nTrees
Number of trees.
Definition: FastBDT.h:78
Teacher for the FastBDT MVA method.
Definition: FastBDT.h:98
FastBDTOptions m_specific_options
Method specific options.
Definition: FastBDT.h:115
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: FastBDT.cc:160
General options which are shared by all MVA trainings.
Definition: Options.h:62
Specific Options, all method Options have to inherit from this class.
Definition: Options.h:98
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
Definition: Teacher.h:29
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:24