Belle II Software  release-08-01-10
FastBDT.h
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #pragma once
10 #ifndef INCLUDE_GUARD_BELLE2_MVA_FASTBDT_HEADER
11 #define INCLUDE_GUARD_BELLE2_MVA_FASTBDT_HEADER
12 
13 #include <mva/interface/Options.h>
14 #include <mva/interface/Teacher.h>
15 #include <mva/interface/Expert.h>
16 
17 #include <FastBDT.h>
18 
19 #if FastBDT_VERSION_MAJOR >= 3
20 #include <FastBDT_IO.h>
21 #else
22 #include <IO.h>
23 #endif
24 
25 #if FastBDT_VERSION_MAJOR >= 5
26 #include <Classifier.h>
27 #endif
28 
29 // Template specialization to fix NAN sort bug of FastBDT in up to Version 3.2
30 #if FastBDT_VERSION_MAJOR <= 3 && FastBDT_VERSION_MINOR <= 2
31 namespace FastBDT {
32  template<>
33  bool compareIncludingNaN(float i, float j);
34 }
35 #endif
36 
37 namespace Belle2 {
42  namespace MVA {
43 
44 
48  bool isValidSignal(const std::vector<bool>& Signals);
49 
54 
55  public:
60  virtual void load(const boost::property_tree::ptree& pt) override;
61 
66  virtual void save(boost::property_tree::ptree& pt) const override;
67 
71  virtual po::options_description getDescription() override;
72 
76  virtual std::string getMethod() const override { return "FastBDT"; }
77 
78  unsigned int m_nTrees = 200;
79  unsigned int m_nCuts = 8;
80  unsigned int m_nLevels = 3;
81  double m_shrinkage = 0.1;
82  double m_randRatio = 0.5;
83 #if FastBDT_VERSION_MAJOR >= 5
84  std::vector<unsigned int>
85  m_individual_nCuts;
86  double m_flatnessLoss = -1.0;
87  bool m_sPlot = false;
88  bool m_purityTransformation = false;
89  std::vector<bool>
90  m_individualPurityTransformation;
91 #endif
92  };
93 
94 
98  class FastBDTTeacher : public Teacher {
99 
100  public:
106  FastBDTTeacher(const GeneralOptions& general_options, const FastBDTOptions& specific_options);
107 
112  virtual Weightfile train(Dataset& training_data) const override;
113 
114  private:
116  };
117 
118 
122  class FastBDTExpert : public MVA::Expert {
123 
124  public:
129  virtual void load(Weightfile& weightfile) override;
130 
135  virtual std::vector<float> apply(Dataset& test_data) const override;
136 
137  private:
139 #if FastBDT_VERSION_MAJOR >= 3
140 #if FastBDT_VERSION_MAJOR >= 5
141  bool m_use_simplified_interface = false;
142  FastBDT::Classifier m_classifier;
143 #endif
144  FastBDT::Forest<float> m_expert_forest;
145 #else
146  FastBDT::Forest m_expert_forest;
147  std::vector<FastBDT::FeatureBinning<float>> m_expert_feature_binning;
148 #endif
149  };
150 
151  }
153 }
154 #endif
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
Abstract base class of all Expert Each MVA library has its own implementation of this class,...
Definition: Expert.h:31
Expert for the FastBDT MVA method.
Definition: FastBDT.h:122
std::vector< FastBDT::FeatureBinning< float > > m_expert_feature_binning
Forest feature binning.
Definition: FastBDT.h:147
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
Definition: FastBDT.cc:415
FastBDT::Forest m_expert_forest
Forest Expert.
Definition: FastBDT.h:146
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: FastBDT.cc:322
FastBDTOptions m_specific_options
Method specific options.
Definition: FastBDT.h:138
Options for the FANN MVA method.
Definition: FastBDT.h:53
virtual std::string getMethod() const override
Return method name.
Definition: FastBDT.h:76
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: FastBDT.cc:126
double m_randRatio
Fraction of data to use in the stochastic training.
Definition: FastBDT.h:82
double m_shrinkage
Shrinkage during the boosting step.
Definition: FastBDT.h:81
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: FastBDT.cc:53
unsigned int m_nLevels
Depth of tree.
Definition: FastBDT.h:80
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: FastBDT.cc:99
unsigned int m_nCuts
Number of cut Levels = log_2(Number of Cuts)
Definition: FastBDT.h:79
unsigned int m_nTrees
Number of trees.
Definition: FastBDT.h:78
Teacher for the FastBDT MVA method.
Definition: FastBDT.h:98
FastBDTTeacher(const GeneralOptions &general_options, const FastBDTOptions &specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: FastBDT.cc:156
FastBDTOptions m_specific_options
Method specific options.
Definition: FastBDT.h:115
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: FastBDT.cc:160
General options which are shared by all MVA trainings.
Definition: Options.h:62
Specific Options, all method Options have to inherit from this class.
Definition: Options.h:98
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
Definition: Teacher.h:29
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
Abstract base class for different kinds of events.