Belle II Software development
test_FastBDT.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/methods/FastBDT.h>
10#include <mva/interface/Interface.h>
11#include <mva/interface/Dataset.h>
12#include <framework/utilities/FileSystem.h>
13#include <framework/utilities/TestHelpers.h>
14
15#include <gtest/gtest.h>
16
17using namespace Belle2;
18
19namespace {
20
21 TEST(FastBDTTest, FastBDTOptions)
22 {
23 MVA::FastBDTOptions specific_options;
24
25 EXPECT_EQ(specific_options.m_nTrees, 200);
26 EXPECT_EQ(specific_options.m_nCuts, 8);
27 EXPECT_EQ(specific_options.m_nLevels, 3);
28 EXPECT_FLOAT_EQ(specific_options.m_shrinkage, 0.1);
29 EXPECT_FLOAT_EQ(specific_options.m_randRatio, 0.5);
30 EXPECT_EQ(specific_options.m_sPlot, false);
31 EXPECT_EQ(specific_options.m_individual_nCuts.size(), 0);
32 EXPECT_EQ(specific_options.m_individualPurityTransformation.size(), 0);
33 EXPECT_EQ(specific_options.m_purityTransformation, false);
34 EXPECT_FLOAT_EQ(specific_options.m_flatnessLoss, -1.0);
35
36 specific_options.m_nTrees = 100;
37 specific_options.m_nCuts = 10;
38 specific_options.m_nLevels = 2;
39 specific_options.m_shrinkage = 0.2;
40 specific_options.m_randRatio = 0.8;
41 specific_options.m_individual_nCuts = {2, 3, 4};
42 specific_options.m_flatnessLoss = 0.3;
43 specific_options.m_sPlot = true;
44 specific_options.m_purityTransformation = true;
45 specific_options.m_individualPurityTransformation = {true, false, true};
46
47 boost::property_tree::ptree pt;
48 specific_options.save(pt);
49 EXPECT_EQ(pt.get<unsigned int>("FastBDT_nTrees"), 100);
50 EXPECT_EQ(pt.get<unsigned int>("FastBDT_nCuts"), 10);
51 EXPECT_EQ(pt.get<unsigned int>("FastBDT_nLevels"), 2);
52 EXPECT_FLOAT_EQ(pt.get<double>("FastBDT_shrinkage"), 0.2);
53 EXPECT_FLOAT_EQ(pt.get<double>("FastBDT_randRatio"), 0.8);
54 EXPECT_EQ(pt.get<unsigned int>("FastBDT_number_individual_nCuts"), 3);
55 EXPECT_EQ(pt.get<unsigned int>("FastBDT_individual_nCuts0"), 2);
56 EXPECT_EQ(pt.get<unsigned int>("FastBDT_individual_nCuts1"), 3);
57 EXPECT_EQ(pt.get<unsigned int>("FastBDT_individual_nCuts2"), 4);
58 EXPECT_EQ(pt.get<bool>("FastBDT_sPlot"), true);
59 EXPECT_FLOAT_EQ(pt.get<double>("FastBDT_flatnessLoss"), 0.3);
60 EXPECT_EQ(pt.get<bool>("FastBDT_purityTransformation"), true);
61 EXPECT_EQ(pt.get<unsigned int>("FastBDT_number_individualPurityTransformation"), 3);
62 EXPECT_EQ(pt.get<bool>("FastBDT_individualPurityTransformation0"), true);
63 EXPECT_EQ(pt.get<bool>("FastBDT_individualPurityTransformation1"), false);
64 EXPECT_EQ(pt.get<bool>("FastBDT_individualPurityTransformation2"), true);
65
66 MVA::FastBDTOptions specific_options2;
67 specific_options2.load(pt);
68
69 EXPECT_EQ(specific_options2.m_nTrees, 100);
70 EXPECT_EQ(specific_options2.m_nCuts, 10);
71 EXPECT_EQ(specific_options2.m_nLevels, 2);
72 EXPECT_FLOAT_EQ(specific_options2.m_shrinkage, 0.2);
73 EXPECT_FLOAT_EQ(specific_options2.m_randRatio, 0.8);
74 EXPECT_EQ(specific_options2.m_sPlot, true);
75 EXPECT_FLOAT_EQ(specific_options2.m_flatnessLoss, 0.3);
76 EXPECT_EQ(specific_options2.m_purityTransformation, true);
77 EXPECT_EQ(specific_options2.m_individualPurityTransformation.size(), 3);
78 EXPECT_EQ(specific_options2.m_individualPurityTransformation[0], true);
79 EXPECT_EQ(specific_options2.m_individualPurityTransformation[1], false);
80 EXPECT_EQ(specific_options2.m_individualPurityTransformation[2], true);
81 EXPECT_EQ(specific_options2.m_individual_nCuts.size(), 3);
82 EXPECT_EQ(specific_options2.m_individual_nCuts[0], 2);
83 EXPECT_EQ(specific_options2.m_individual_nCuts[1], 3);
84 EXPECT_EQ(specific_options2.m_individual_nCuts[2], 4);
85
86 EXPECT_EQ(specific_options.getMethod(), std::string("FastBDT"));
87
88 // Test if po::options_description is created without crashing
89 auto description = specific_options.getDescription();
90
91 EXPECT_EQ(description.options().size(), 10);
92
93 // Check for B2ERROR and throw if version is wrong
94 // we try with version 100, surely we will never reach this!
95 pt.put("FastBDT_version", 100);
96 try {
97 EXPECT_B2ERROR(specific_options2.load(pt));
98 } catch (...) {
99
100 }
101 EXPECT_THROW(specific_options2.load(pt), std::runtime_error);
102 }
103
104 class TestDataset : public MVA::Dataset {
105 public:
106 explicit TestDataset(const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
107 {
108 m_input = {0.0};
109 m_target = 0.0;
110 m_isSignal = false;
111 m_weight = 1.0;
112 }
113
114 [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 1; }
115 [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 0; }
116 [[nodiscard]] unsigned int getNumberOfEvents() const override { return m_data.size(); }
117 void loadEvent(unsigned int iEvent) override { m_input[0] = m_data[iEvent]; m_target = iEvent % 2; m_isSignal = m_target == 1; };
118 float getSignalFraction() override { return 0.1; };
119 std::vector<float> getFeature(unsigned int) override { return m_data; }
120
121 std::vector<float> m_data;
122
123 };
124
125
126 TEST(FastBDTTest, FastBDTInterface)
127 {
129
130 MVA::GeneralOptions general_options;
131 general_options.m_variables = {"A"};
132 MVA::FastBDTOptions specific_options;
133 specific_options.m_randRatio = 1.0;
134 TestDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 3.0});
135
136 auto teacher = interface.getTeacher(general_options, specific_options);
137 auto weightfile = teacher->train(dataset);
138
139 auto expert = interface.getExpert();
140 expert->load(weightfile);
141 auto probabilities = expert->apply(dataset);
142 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
143 for (unsigned int i = 0; i < 4; ++i) {
144 EXPECT_LE(probabilities[i], 0.6);
145 EXPECT_GE(probabilities[i], 0.4);
146 }
147 EXPECT_LE(probabilities[4], 0.2);
148 EXPECT_GE(probabilities[5], 0.8);
149 EXPECT_LE(probabilities[6], 0.2);
150 EXPECT_GE(probabilities[7], 0.8);
151
152 }
153
154 TEST(FastBDTTest, FastBDTInterfaceWithPurityTransformation)
155 {
157
158 MVA::GeneralOptions general_options;
159 general_options.m_variables = {"A"};
160 MVA::FastBDTOptions specific_options;
161 specific_options.m_randRatio = 1.0;
162 specific_options.m_purityTransformation = true;
163 TestDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 3.0});
164
165 auto teacher = interface.getTeacher(general_options, specific_options);
166 auto weightfile = teacher->train(dataset);
167
168 auto expert = interface.getExpert();
169 expert->load(weightfile);
170 auto probabilities = expert->apply(dataset);
171 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
172 for (unsigned int i = 0; i < 4; ++i) {
173 EXPECT_LE(probabilities[i], 0.6);
174 EXPECT_GE(probabilities[i], 0.4);
175 }
176 EXPECT_LE(probabilities[4], 0.2);
177 EXPECT_GE(probabilities[5], 0.8);
178 EXPECT_LE(probabilities[6], 0.2);
179 EXPECT_GE(probabilities[7], 0.8);
180
181 }
182
183 TEST(FastBDTTest, WeightfilesOfDifferentVersionsAreConsistent)
184 {
186
187 MVA::GeneralOptions general_options;
188 general_options.m_variables = {"M", "p", "pt"};
189 MVA::MultiDataset dataset(general_options, {{1.835127, 1.179507, 1.164944},
190 {1.873689, 1.881940, 1.843310},
191 {1.863657, 1.774831, 1.753773},
192 {1.858293, 1.605311, 0.631336},
193 {1.837129, 1.575739, 1.490166},
194 {1.811395, 1.524029, 0.565220}
195 },
196 {}, {0.0, 1.0, 0.0, 1.0, 0.0, 1.0});
197
198 auto expert = interface.getExpert();
199
200 auto weightfile_v3 = MVA::Weightfile::loadFromFile(FileSystem::findFile("mva/methods/tests/FastBDTv3.xml"));
201 expert->load(weightfile_v3);
202 auto probabilities_v3 = expert->apply(dataset);
203 EXPECT_NEAR(probabilities_v3[0], 0.0402499, 0.0001);
204 EXPECT_NEAR(probabilities_v3[1], 0.2189, 0.0001);
205 EXPECT_NEAR(probabilities_v3[2], 0.264094, 0.0001);
206 EXPECT_NEAR(probabilities_v3[3], 0.100049, 0.0001);
207 EXPECT_NEAR(probabilities_v3[4], 0.0664554, 0.0001);
208 EXPECT_NEAR(probabilities_v3[5], 0.00886221, 0.0001);
209
210 auto weightfile_v5 = MVA::Weightfile::loadFromFile(FileSystem::findFile("mva/methods/tests/FastBDTv5.xml"));
211 expert->load(weightfile_v5);
212 auto probabilities_v5 = expert->apply(dataset);
213 EXPECT_NEAR(probabilities_v5[0], 0.0402498, 0.0001);
214 EXPECT_NEAR(probabilities_v5[1], 0.218899, 0.0001);
215 EXPECT_NEAR(probabilities_v5[2], 0.264093, 0.0001);
216 EXPECT_NEAR(probabilities_v5[3], 0.100048, 0.0001);
217 EXPECT_NEAR(probabilities_v5[4], 0.0664551, 0.0001);
218 EXPECT_NEAR(probabilities_v5[5], 0.00886217, 0.0001);
219
220 // There are small differences due to floating point precision
221 // the FastBDT code is compiled with -O3 and changes slightly
222 // depending on the compiler and random things.
223 // Nevertheless the returned probabilities should be nearly the same!
224 EXPECT_NEAR(probabilities_v5[0], probabilities_v3[0], 0.001);
225 EXPECT_NEAR(probabilities_v5[1], probabilities_v3[1], 0.001);
226 EXPECT_NEAR(probabilities_v5[2], probabilities_v3[2], 0.001);
227 EXPECT_NEAR(probabilities_v5[3], probabilities_v3[3], 0.001);
228 EXPECT_NEAR(probabilities_v5[4], probabilities_v3[4], 0.001);
229 EXPECT_NEAR(probabilities_v5[5], probabilities_v3[5], 0.001);
230 }
231
232}
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition Dataset.h:33
Options for the FANN MVA method.
Definition FastBDT.h:37
std::vector< unsigned int > m_individual_nCuts
Number of cut Levels = log_2(Number of Cuts) for each provided feature.
Definition FastBDT.h:68
bool m_sPlot
Activates sPlot sampling.
Definition FastBDT.h:70
double m_randRatio
Fraction of data to use in the stochastic training.
Definition FastBDT.h:66
double m_flatnessLoss
Flatness Loss constant.
Definition FastBDT.h:69
double m_shrinkage
Shrinkage during the boosting step.
Definition FastBDT.h:65
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition FastBDT.cc:31
bool m_purityTransformation
Activates purity transformation globally for all features.
Definition FastBDT.h:71
unsigned int m_nLevels
Depth of tree.
Definition FastBDT.h:64
std::vector< bool > m_individualPurityTransformation
Vector which decided for each feature individually if the purity transformation should be used.
Definition FastBDT.h:73
unsigned int m_nCuts
Number of cut Levels = log_2(Number of Cuts)
Definition FastBDT.h:63
unsigned int m_nTrees
Number of trees.
Definition FastBDT.h:62
General options which are shared by all MVA trainings.
Definition Options.h:62
Template class to easily construct a interface for an MVA library using a library-specific Options,...
Definition Interface.h:99
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Definition Interface.h:125
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
Definition Interface.h:116
Wraps the data of a multiple event into a Dataset.
Definition Dataset.h:186
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Abstract base class for different kinds of events.