Belle II Software  release-08-01-10
test_FastBDT.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <mva/methods/FastBDT.h>
10 #include <mva/interface/Interface.h>
11 #include <mva/interface/Dataset.h>
12 #include <framework/utilities/FileSystem.h>
13 #include <framework/utilities/TestHelpers.h>
14 
15 #include <gtest/gtest.h>
16 
17 using namespace Belle2;
18 
19 namespace {
20 
21  TEST(FastBDTTest, FastBDTOptions)
22  {
23  MVA::FastBDTOptions specific_options;
24 
25  EXPECT_EQ(specific_options.m_nTrees, 200);
26  EXPECT_EQ(specific_options.m_nCuts, 8);
27  EXPECT_EQ(specific_options.m_nLevels, 3);
28  EXPECT_FLOAT_EQ(specific_options.m_shrinkage, 0.1);
29  EXPECT_FLOAT_EQ(specific_options.m_randRatio, 0.5);
30 #if FastBDT_VERSION_MAJOR >= 5
31  EXPECT_EQ(specific_options.m_sPlot, false);
32  EXPECT_EQ(specific_options.m_individual_nCuts.size(), 0);
33  EXPECT_EQ(specific_options.m_individualPurityTransformation.size(), 0);
34  EXPECT_EQ(specific_options.m_purityTransformation, false);
35  EXPECT_FLOAT_EQ(specific_options.m_flatnessLoss, -1.0);
36 #endif
37 
38  specific_options.m_nTrees = 100;
39  specific_options.m_nCuts = 10;
40  specific_options.m_nLevels = 2;
41  specific_options.m_shrinkage = 0.2;
42  specific_options.m_randRatio = 0.8;
43 #if FastBDT_VERSION_MAJOR >= 5
44  specific_options.m_individual_nCuts = {2, 3, 4};
45  specific_options.m_flatnessLoss = 0.3;
46  specific_options.m_sPlot = true;
47  specific_options.m_purityTransformation = true;
48  specific_options.m_individualPurityTransformation = {true, false, true};
49 #endif
50 
51  boost::property_tree::ptree pt;
52  specific_options.save(pt);
53  EXPECT_EQ(pt.get<unsigned int>("FastBDT_nTrees"), 100);
54  EXPECT_EQ(pt.get<unsigned int>("FastBDT_nCuts"), 10);
55  EXPECT_EQ(pt.get<unsigned int>("FastBDT_nLevels"), 2);
56  EXPECT_FLOAT_EQ(pt.get<double>("FastBDT_shrinkage"), 0.2);
57  EXPECT_FLOAT_EQ(pt.get<double>("FastBDT_randRatio"), 0.8);
58 #if FastBDT_VERSION_MAJOR >= 5
59  EXPECT_EQ(pt.get<unsigned int>("FastBDT_number_individual_nCuts"), 3);
60  EXPECT_EQ(pt.get<unsigned int>("FastBDT_individual_nCuts0"), 2);
61  EXPECT_EQ(pt.get<unsigned int>("FastBDT_individual_nCuts1"), 3);
62  EXPECT_EQ(pt.get<unsigned int>("FastBDT_individual_nCuts2"), 4);
63  EXPECT_EQ(pt.get<bool>("FastBDT_sPlot"), true);
64  EXPECT_FLOAT_EQ(pt.get<double>("FastBDT_flatnessLoss"), 0.3);
65  EXPECT_EQ(pt.get<bool>("FastBDT_purityTransformation"), true);
66  EXPECT_EQ(pt.get<unsigned int>("FastBDT_number_individualPurityTransformation"), 3);
67  EXPECT_EQ(pt.get<bool>("FastBDT_individualPurityTransformation0"), true);
68  EXPECT_EQ(pt.get<bool>("FastBDT_individualPurityTransformation1"), false);
69  EXPECT_EQ(pt.get<bool>("FastBDT_individualPurityTransformation2"), true);
70 #endif
71 
72  MVA::FastBDTOptions specific_options2;
73  specific_options2.load(pt);
74 
75  EXPECT_EQ(specific_options2.m_nTrees, 100);
76  EXPECT_EQ(specific_options2.m_nCuts, 10);
77  EXPECT_EQ(specific_options2.m_nLevels, 2);
78  EXPECT_FLOAT_EQ(specific_options2.m_shrinkage, 0.2);
79  EXPECT_FLOAT_EQ(specific_options2.m_randRatio, 0.8);
80 #if FastBDT_VERSION_MAJOR >= 5
81  EXPECT_EQ(specific_options2.m_sPlot, true);
82  EXPECT_FLOAT_EQ(specific_options2.m_flatnessLoss, 0.3);
83  EXPECT_EQ(specific_options2.m_purityTransformation, true);
84  EXPECT_EQ(specific_options2.m_individualPurityTransformation.size(), 3);
85  EXPECT_EQ(specific_options2.m_individualPurityTransformation[0], true);
86  EXPECT_EQ(specific_options2.m_individualPurityTransformation[1], false);
87  EXPECT_EQ(specific_options2.m_individualPurityTransformation[2], true);
88  EXPECT_EQ(specific_options2.m_individual_nCuts.size(), 3);
89  EXPECT_EQ(specific_options2.m_individual_nCuts[0], 2);
90  EXPECT_EQ(specific_options2.m_individual_nCuts[1], 3);
91  EXPECT_EQ(specific_options2.m_individual_nCuts[2], 4);
92 #endif
93 
94  EXPECT_EQ(specific_options.getMethod(), std::string("FastBDT"));
95 
96  // Test if po::options_description is created without crashing
97  auto description = specific_options.getDescription();
98 
99  // flatnessLoss, sPlot and individualNCuts are only activated in FastBDT version 4
100 #if FastBDT_VERSION_MAJOR >= 5
101  EXPECT_EQ(description.options().size(), 10);
102 #else
103  EXPECT_EQ(description.options().size(), 5);
104 #endif
105 
106  // Check for B2ERROR and throw if version is wrong
107  // we try with version 100, surely we will never reach this!
108  pt.put("FastBDT_version", 100);
109  try {
110  EXPECT_B2ERROR(specific_options2.load(pt));
111  } catch (...) {
112 
113  }
114  EXPECT_THROW(specific_options2.load(pt), std::runtime_error);
115  }
116 
117  class TestDataset : public MVA::Dataset {
118  public:
119  explicit TestDataset(const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
120  {
121  m_input = {0.0};
122  m_target = 0.0;
123  m_isSignal = false;
124  m_weight = 1.0;
125  }
126 
127  [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 1; }
128  [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 0; }
129  [[nodiscard]] unsigned int getNumberOfEvents() const override { return m_data.size(); }
130  void loadEvent(unsigned int iEvent) override { m_input[0] = m_data[iEvent]; m_target = iEvent % 2; m_isSignal = m_target == 1; };
131  float getSignalFraction() override { return 0.1; };
132  std::vector<float> getFeature(unsigned int) override { return m_data; }
133 
134  std::vector<float> m_data;
135 
136  };
137 
138 
139  TEST(FastBDTTest, FastBDTInterface)
140  {
142 
143  MVA::GeneralOptions general_options;
144  general_options.m_variables = {"A"};
145  MVA::FastBDTOptions specific_options;
146  specific_options.m_randRatio = 1.0;
147  TestDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 3.0});
148 
149  auto teacher = interface.getTeacher(general_options, specific_options);
150  auto weightfile = teacher->train(dataset);
151 
152  auto expert = interface.getExpert();
153  expert->load(weightfile);
154  auto probabilities = expert->apply(dataset);
155  EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
156  for (unsigned int i = 0; i < 4; ++i) {
157  EXPECT_LE(probabilities[i], 0.6);
158  EXPECT_GE(probabilities[i], 0.4);
159  }
160  EXPECT_LE(probabilities[4], 0.2);
161  EXPECT_GE(probabilities[5], 0.8);
162  EXPECT_LE(probabilities[6], 0.2);
163  EXPECT_GE(probabilities[7], 0.8);
164 
165  }
166 
167 #if FastBDT_VERSION_MAJOR >= 5
168  TEST(FastBDTTest, FastBDTInterfaceWithPurityTransformation)
169  {
171 
172  MVA::GeneralOptions general_options;
173  general_options.m_variables = {"A"};
174  MVA::FastBDTOptions specific_options;
175  specific_options.m_randRatio = 1.0;
176  specific_options.m_purityTransformation = true;
177  TestDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 3.0});
178 
179  auto teacher = interface.getTeacher(general_options, specific_options);
180  auto weightfile = teacher->train(dataset);
181 
182  auto expert = interface.getExpert();
183  expert->load(weightfile);
184  auto probabilities = expert->apply(dataset);
185  EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
186  for (unsigned int i = 0; i < 4; ++i) {
187  EXPECT_LE(probabilities[i], 0.6);
188  EXPECT_GE(probabilities[i], 0.4);
189  }
190  EXPECT_LE(probabilities[4], 0.2);
191  EXPECT_GE(probabilities[5], 0.8);
192  EXPECT_LE(probabilities[6], 0.2);
193  EXPECT_GE(probabilities[7], 0.8);
194 
195  }
196 #endif
197 
198  TEST(FastBDTTest, WeightfilesOfDifferentVersionsAreConsistent)
199  {
201 
202  MVA::GeneralOptions general_options;
203  general_options.m_variables = {"M", "p", "pt"};
204  MVA::MultiDataset dataset(general_options, {{1.835127, 1.179507, 1.164944},
205  {1.873689, 1.881940, 1.843310},
206  {1.863657, 1.774831, 1.753773},
207  {1.858293, 1.605311, 0.631336},
208  {1.837129, 1.575739, 1.490166},
209  {1.811395, 1.524029, 0.565220}
210  },
211  {}, {0.0, 1.0, 0.0, 1.0, 0.0, 1.0});
212 
213  // cppcheck-suppress unreadVariable
214  auto expert = interface.getExpert();
215 
216 #if FastBDT_VERSION_MAJOR >= 3
217  auto weightfile_v3 = MVA::Weightfile::loadFromFile(FileSystem::findFile("mva/methods/tests/FastBDTv3.xml"));
218  expert->load(weightfile_v3);
219  auto probabilities_v3 = expert->apply(dataset);
220  EXPECT_NEAR(probabilities_v3[0], 0.0402499, 0.0001);
221  EXPECT_NEAR(probabilities_v3[1], 0.2189, 0.0001);
222  EXPECT_NEAR(probabilities_v3[2], 0.264094, 0.0001);
223  EXPECT_NEAR(probabilities_v3[3], 0.100049, 0.0001);
224  EXPECT_NEAR(probabilities_v3[4], 0.0664554, 0.0001);
225  EXPECT_NEAR(probabilities_v3[5], 0.00886221, 0.0001);
226 #endif
227 
228 #if FastBDT_VERSION_MAJOR >= 5
229  auto weightfile_v5 = MVA::Weightfile::loadFromFile(FileSystem::findFile("mva/methods/tests/FastBDTv5.xml"));
230  expert->load(weightfile_v5);
231  auto probabilities_v5 = expert->apply(dataset);
232  EXPECT_NEAR(probabilities_v5[0], 0.0402498, 0.0001);
233  EXPECT_NEAR(probabilities_v5[1], 0.218899, 0.0001);
234  EXPECT_NEAR(probabilities_v5[2], 0.264093, 0.0001);
235  EXPECT_NEAR(probabilities_v5[3], 0.100048, 0.0001);
236  EXPECT_NEAR(probabilities_v5[4], 0.0664551, 0.0001);
237  EXPECT_NEAR(probabilities_v5[5], 0.00886217, 0.0001);
238 
239  // There are small differences due to floating point precision
240  // the FastBDT code is compiled with -O3 and changes slightly
241  // depending on the compiler and random things.
242  // Nevertheless the returned probabilities should be nearly the same!
243  EXPECT_NEAR(probabilities_v5[0], probabilities_v3[0], 0.001);
244  EXPECT_NEAR(probabilities_v5[1], probabilities_v3[1], 0.001);
245  EXPECT_NEAR(probabilities_v5[2], probabilities_v3[2], 0.001);
246  EXPECT_NEAR(probabilities_v5[3], probabilities_v3[3], 0.001);
247  EXPECT_NEAR(probabilities_v5[4], probabilities_v3[4], 0.001);
248  EXPECT_NEAR(probabilities_v5[5], probabilities_v3[5], 0.001);
249 #endif
250  }
251 
252 }
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
Definition: FileSystem.cc:148
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
Options for the FANN MVA method.
Definition: FastBDT.h:53
double m_randRatio
Fraction of data to use in the stochastic training.
Definition: FastBDT.h:82
double m_shrinkage
Shrinkage during the boosting step.
Definition: FastBDT.h:81
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: FastBDT.cc:53
unsigned int m_nLevels
Depth of tree.
Definition: FastBDT.h:80
unsigned int m_nCuts
Number of cut Levels = log_2(Number of Cuts)
Definition: FastBDT.h:79
unsigned int m_nTrees
Number of trees.
Definition: FastBDT.h:78
General options which are shared by all MVA trainings.
Definition: Options.h:62
Template class to easily construct a interface for an MVA library using a library-specific Options,...
Definition: Interface.h:99
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
Definition: Interface.h:117
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Definition: Interface.h:126
Wraps the data of a multiple event into a Dataset.
Definition: Dataset.h:186
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:206
TEST(TestgetDetectorRegion, TestgetDetectorRegion)
Test Constructors.
Abstract base class for different kinds of events.