Belle II Software  release-05-01-25
test_FastBDT.cc
1 /* BASF2 (Belle Analysis Framework 2) *
2  * Copyright(C) 2016 - Belle II Collaboration *
3  * *
4  * Author: The Belle II Collaboration *
5  * Contributors: Thomas Keck *
6  * *
7  * This software is provided "as is" without any warranty. *
8  **************************************************************************/
9 
10 #include <mva/methods/FastBDT.h>
11 #include <mva/interface/Interface.h>
12 #include <mva/interface/Dataset.h>
13 #include <framework/utilities/FileSystem.h>
14 #include <framework/utilities/TestHelpers.h>
15 
16 #include <gtest/gtest.h>
17 
18 using namespace Belle2;
19 
20 namespace {
21 
22  TEST(FastBDTTest, FastBDTOptions)
23  {
24  MVA::FastBDTOptions specific_options;
25 
26  EXPECT_EQ(specific_options.m_nTrees, 200);
27  EXPECT_EQ(specific_options.m_nCuts, 8);
28  EXPECT_EQ(specific_options.m_nLevels, 3);
29  EXPECT_FLOAT_EQ(specific_options.m_shrinkage, 0.1);
30  EXPECT_FLOAT_EQ(specific_options.m_randRatio, 0.5);
31 #if FastBDT_VERSION_MAJOR >= 5
32  EXPECT_EQ(specific_options.m_sPlot, false);
33  EXPECT_EQ(specific_options.m_individual_nCuts.size(), 0);
34  EXPECT_EQ(specific_options.m_individualPurityTransformation.size(), 0);
35  EXPECT_EQ(specific_options.m_purityTransformation, false);
36  EXPECT_FLOAT_EQ(specific_options.m_flatnessLoss, -1.0);
37 #endif
38 
39  specific_options.m_nTrees = 100;
40  specific_options.m_nCuts = 10;
41  specific_options.m_nLevels = 2;
42  specific_options.m_shrinkage = 0.2;
43  specific_options.m_randRatio = 0.8;
44 #if FastBDT_VERSION_MAJOR >= 5
45  specific_options.m_individual_nCuts = {2, 3, 4};
46  specific_options.m_flatnessLoss = 0.3;
47  specific_options.m_sPlot = true;
48  specific_options.m_purityTransformation = true;
49  specific_options.m_individualPurityTransformation = {true, false, true};
50 #endif
51 
52  boost::property_tree::ptree pt;
53  specific_options.save(pt);
54  EXPECT_EQ(pt.get<unsigned int>("FastBDT_nTrees"), 100);
55  EXPECT_EQ(pt.get<unsigned int>("FastBDT_nCuts"), 10);
56  EXPECT_EQ(pt.get<unsigned int>("FastBDT_nLevels"), 2);
57  EXPECT_FLOAT_EQ(pt.get<double>("FastBDT_shrinkage"), 0.2);
58  EXPECT_FLOAT_EQ(pt.get<double>("FastBDT_randRatio"), 0.8);
59 #if FastBDT_VERSION_MAJOR >= 5
60  EXPECT_EQ(pt.get<unsigned int>("FastBDT_number_individual_nCuts"), 3);
61  EXPECT_EQ(pt.get<unsigned int>("FastBDT_individual_nCuts0"), 2);
62  EXPECT_EQ(pt.get<unsigned int>("FastBDT_individual_nCuts1"), 3);
63  EXPECT_EQ(pt.get<unsigned int>("FastBDT_individual_nCuts2"), 4);
64  EXPECT_EQ(pt.get<bool>("FastBDT_sPlot"), true);
65  EXPECT_FLOAT_EQ(pt.get<double>("FastBDT_flatnessLoss"), 0.3);
66  EXPECT_EQ(pt.get<bool>("FastBDT_purityTransformation"), true);
67  EXPECT_EQ(pt.get<unsigned int>("FastBDT_number_individualPurityTransformation"), 3);
68  EXPECT_EQ(pt.get<bool>("FastBDT_individualPurityTransformation0"), true);
69  EXPECT_EQ(pt.get<bool>("FastBDT_individualPurityTransformation1"), false);
70  EXPECT_EQ(pt.get<bool>("FastBDT_individualPurityTransformation2"), true);
71 #endif
72 
73  MVA::FastBDTOptions specific_options2;
74  specific_options2.load(pt);
75 
76  EXPECT_EQ(specific_options2.m_nTrees, 100);
77  EXPECT_EQ(specific_options2.m_nCuts, 10);
78  EXPECT_EQ(specific_options2.m_nLevels, 2);
79  EXPECT_FLOAT_EQ(specific_options2.m_shrinkage, 0.2);
80  EXPECT_FLOAT_EQ(specific_options2.m_randRatio, 0.8);
81 #if FastBDT_VERSION_MAJOR >= 5
82  EXPECT_EQ(specific_options2.m_sPlot, true);
83  EXPECT_FLOAT_EQ(specific_options2.m_flatnessLoss, 0.3);
84  EXPECT_EQ(specific_options2.m_purityTransformation, true);
85  EXPECT_EQ(specific_options2.m_individualPurityTransformation.size(), 3);
86  EXPECT_EQ(specific_options2.m_individualPurityTransformation[0], true);
87  EXPECT_EQ(specific_options2.m_individualPurityTransformation[1], false);
88  EXPECT_EQ(specific_options2.m_individualPurityTransformation[2], true);
89  EXPECT_EQ(specific_options2.m_individual_nCuts.size(), 3);
90  EXPECT_EQ(specific_options2.m_individual_nCuts[0], 2);
91  EXPECT_EQ(specific_options2.m_individual_nCuts[1], 3);
92  EXPECT_EQ(specific_options2.m_individual_nCuts[2], 4);
93 #endif
94 
95  EXPECT_EQ(specific_options.getMethod(), std::string("FastBDT"));
96 
97  // Test if po::options_description is created without crashing
98  auto description = specific_options.getDescription();
99 
100  // flatnessLoss, sPlot and individualNCuts are only activated in FastBDT version 4
101 #if FastBDT_VERSION_MAJOR >= 5
102  EXPECT_EQ(description.options().size(), 10);
103 #else
104  EXPECT_EQ(description.options().size(), 5);
105 #endif
106 
107  // Check for B2ERROR and throw if version is wrong
108  // we try with version 100, surely we will never reach this!
109  pt.put("FastBDT_version", 100);
110  try {
111  EXPECT_B2ERROR(specific_options2.load(pt));
112  } catch (...) {
113 
114  }
115  EXPECT_THROW(specific_options2.load(pt), std::runtime_error);
116  }
117 
118  class TestDataset : public MVA::Dataset {
119  public:
120  explicit TestDataset(const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
121  {
122  m_input = {0.0};
123  m_target = 0.0;
124  m_isSignal = false;
125  m_weight = 1.0;
126  }
127 
128  [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 1; }
129  [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 0; }
130  [[nodiscard]] unsigned int getNumberOfEvents() const override { return m_data.size(); }
131  void loadEvent(unsigned int iEvent) override { m_input[0] = m_data[iEvent]; m_target = iEvent % 2; m_isSignal = m_target == 1; };
132  float getSignalFraction() override { return 0.1; };
133  std::vector<float> getFeature(unsigned int) override { return m_data; }
134 
135  std::vector<float> m_data;
136 
137  };
138 
139 
140  TEST(FastBDTTest, FastBDTInterface)
141  {
143 
144  MVA::GeneralOptions general_options;
145  general_options.m_variables = {"A"};
146  MVA::FastBDTOptions specific_options;
147  specific_options.m_randRatio = 1.0;
148  TestDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 3.0});
149 
150  auto teacher = interface.getTeacher(general_options, specific_options);
151  auto weightfile = teacher->train(dataset);
152 
153  auto expert = interface.getExpert();
154  expert->load(weightfile);
155  auto probabilities = expert->apply(dataset);
156  EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
157  for (unsigned int i = 0; i < 4; ++i) {
158  EXPECT_LE(probabilities[i], 0.6);
159  EXPECT_GE(probabilities[i], 0.4);
160  }
161  EXPECT_LE(probabilities[4], 0.2);
162  EXPECT_GE(probabilities[5], 0.8);
163  EXPECT_LE(probabilities[6], 0.2);
164  EXPECT_GE(probabilities[7], 0.8);
165 
166  }
167 
168 #if FastBDT_VERSION_MAJOR >= 5
169  TEST(FastBDTTest, FastBDTInterfaceWithPurityTransformation)
170  {
172 
173  MVA::GeneralOptions general_options;
174  general_options.m_variables = {"A"};
175  MVA::FastBDTOptions specific_options;
176  specific_options.m_randRatio = 1.0;
177  specific_options.m_purityTransformation = true;
178  TestDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 3.0});
179 
180  auto teacher = interface.getTeacher(general_options, specific_options);
181  auto weightfile = teacher->train(dataset);
182 
183  auto expert = interface.getExpert();
184  expert->load(weightfile);
185  auto probabilities = expert->apply(dataset);
186  EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
187  for (unsigned int i = 0; i < 4; ++i) {
188  EXPECT_LE(probabilities[i], 0.6);
189  EXPECT_GE(probabilities[i], 0.4);
190  }
191  EXPECT_LE(probabilities[4], 0.2);
192  EXPECT_GE(probabilities[5], 0.8);
193  EXPECT_LE(probabilities[6], 0.2);
194  EXPECT_GE(probabilities[7], 0.8);
195 
196  }
197 #endif
198 
199  TEST(FastBDTTest, WeightfilesOfDifferentVersionsAreConsistent)
200  {
202 
203  MVA::GeneralOptions general_options;
204  general_options.m_variables = {"M", "p", "pt"};
205  MVA::MultiDataset dataset(general_options, {{1.835127, 1.179507, 1.164944},
206  {1.873689, 1.881940, 1.843310},
207  {1.863657, 1.774831, 1.753773},
208  {1.858293, 1.605311, 0.631336},
209  {1.837129, 1.575739, 1.490166},
210  {1.811395, 1.524029, 0.565220}
211  },
212  {}, {0.0, 1.0, 0.0, 1.0, 0.0, 1.0});
213 
214  auto expert = interface.getExpert();
215 
216 #if FastBDT_VERSION_MAJOR >= 3
217  auto weightfile_v3 = MVA::Weightfile::loadFromFile(FileSystem::findFile("mva/methods/tests/FastBDTv3.xml"));
218  expert->load(weightfile_v3);
219  auto probabilities_v3 = expert->apply(dataset);
220  EXPECT_NEAR(probabilities_v3[0], 0.0402499, 0.0001);
221  EXPECT_NEAR(probabilities_v3[1], 0.2189, 0.0001);
222  EXPECT_NEAR(probabilities_v3[2], 0.264094, 0.0001);
223  EXPECT_NEAR(probabilities_v3[3], 0.100049, 0.0001);
224  EXPECT_NEAR(probabilities_v3[4], 0.0664554, 0.0001);
225  EXPECT_NEAR(probabilities_v3[5], 0.00886221, 0.0001);
226 #endif
227 
228 #if FastBDT_VERSION_MAJOR >= 5
229  auto weightfile_v5 = MVA::Weightfile::loadFromFile(FileSystem::findFile("mva/methods/tests/FastBDTv5.xml"));
230  expert->load(weightfile_v5);
231  auto probabilities_v5 = expert->apply(dataset);
232  EXPECT_NEAR(probabilities_v5[0], 0.0402498, 0.0001);
233  EXPECT_NEAR(probabilities_v5[1], 0.218899, 0.0001);
234  EXPECT_NEAR(probabilities_v5[2], 0.264093, 0.0001);
235  EXPECT_NEAR(probabilities_v5[3], 0.100048, 0.0001);
236  EXPECT_NEAR(probabilities_v5[4], 0.0664551, 0.0001);
237  EXPECT_NEAR(probabilities_v5[5], 0.00886217, 0.0001);
238 
239  // There are small differences due to floating point precision
240  // the FastBDT code is compiled with -O3 and changes slightly
241  // depending on the compiler and random things.
242  // Nevertheless the returned probabilities should be nearly the same!
243  EXPECT_NEAR(probabilities_v5[0], probabilities_v3[0], 0.001);
244  EXPECT_NEAR(probabilities_v5[1], probabilities_v3[1], 0.001);
245  EXPECT_NEAR(probabilities_v5[2], probabilities_v3[2], 0.001);
246  EXPECT_NEAR(probabilities_v5[3], probabilities_v3[3], 0.001);
247  EXPECT_NEAR(probabilities_v5[4], probabilities_v3[4], 0.001);
248  EXPECT_NEAR(probabilities_v5[5], probabilities_v3[5], 0.001);
249 #endif
250  }
251 
252 }
Belle2::MVA::FastBDTOptions
Options for the FANN MVA method.
Definition: FastBDT.h:55
Belle2::MVA::MultiDataset
Wraps the data of a multiple event into a Dataset.
Definition: Dataset.h:187
Belle2::MVA::FastBDTOptions::load
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: FastBDT.cc:55
Belle2::MVA::Dataset
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:34
Belle2::MVA::FastBDTOptions::m_randRatio
double m_randRatio
Fraction of data to use in the stochastic training.
Definition: FastBDT.h:84
Belle2::MVA::Interface::getTeacher
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
Definition: Interface.h:119
Belle2::MVA::FastBDTOptions::m_nCuts
unsigned int m_nCuts
Number of cut Levels = log_2(Number of Cuts)
Definition: FastBDT.h:81
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::MVA::Weightfile::loadFromFile
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:215
Belle2::MVA::FastBDTOptions::m_nTrees
unsigned int m_nTrees
Number of trees.
Definition: FastBDT.h:80
Belle2::MVA::GeneralOptions
General options which are shared by all MVA trainings.
Definition: Options.h:64
Belle2::TEST
TEST(TestgetDetectorRegion, TestgetDetectorRegion)
Test Constructors.
Definition: utilityFunctions.cc:18
Belle2::MVA::FastBDTOptions::m_nLevels
unsigned int m_nLevels
Depth of tree.
Definition: FastBDT.h:82
Belle2::MVA::FastBDTOptions::m_shrinkage
double m_shrinkage
Shrinkage during the boosting step.
Definition: FastBDT.h:83
Belle2::FileSystem::findFile
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
Definition: FileSystem.cc:147
Belle2::MVA::Interface::getExpert
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Definition: Interface.h:128
Belle2::MVA::Interface
Template class to easily construct a interface for an MVA library using a library-specific Options,...
Definition: Interface.h:101