Belle II Software light-2405-quaxo
test_FastBDT.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/methods/FastBDT.h>
10#include <mva/interface/Interface.h>
11#include <mva/interface/Dataset.h>
12#include <framework/utilities/FileSystem.h>
13#include <framework/utilities/TestHelpers.h>
14
15#include <gtest/gtest.h>
16
17using namespace Belle2;
18
19namespace {
20
21 TEST(FastBDTTest, FastBDTOptions)
22 {
23 MVA::FastBDTOptions specific_options;
24
25 EXPECT_EQ(specific_options.m_nTrees, 200);
26 EXPECT_EQ(specific_options.m_nCuts, 8);
27 EXPECT_EQ(specific_options.m_nLevels, 3);
28 EXPECT_FLOAT_EQ(specific_options.m_shrinkage, 0.1);
29 EXPECT_FLOAT_EQ(specific_options.m_randRatio, 0.5);
30#if FastBDT_VERSION_MAJOR >= 5
31 EXPECT_EQ(specific_options.m_sPlot, false);
32 EXPECT_EQ(specific_options.m_individual_nCuts.size(), 0);
33 EXPECT_EQ(specific_options.m_individualPurityTransformation.size(), 0);
34 EXPECT_EQ(specific_options.m_purityTransformation, false);
35 EXPECT_FLOAT_EQ(specific_options.m_flatnessLoss, -1.0);
36#endif
37
38 specific_options.m_nTrees = 100;
39 specific_options.m_nCuts = 10;
40 specific_options.m_nLevels = 2;
41 specific_options.m_shrinkage = 0.2;
42 specific_options.m_randRatio = 0.8;
43#if FastBDT_VERSION_MAJOR >= 5
44 specific_options.m_individual_nCuts = {2, 3, 4};
45 specific_options.m_flatnessLoss = 0.3;
46 specific_options.m_sPlot = true;
47 specific_options.m_purityTransformation = true;
48 specific_options.m_individualPurityTransformation = {true, false, true};
49#endif
50
51 boost::property_tree::ptree pt;
52 specific_options.save(pt);
53 EXPECT_EQ(pt.get<unsigned int>("FastBDT_nTrees"), 100);
54 EXPECT_EQ(pt.get<unsigned int>("FastBDT_nCuts"), 10);
55 EXPECT_EQ(pt.get<unsigned int>("FastBDT_nLevels"), 2);
56 EXPECT_FLOAT_EQ(pt.get<double>("FastBDT_shrinkage"), 0.2);
57 EXPECT_FLOAT_EQ(pt.get<double>("FastBDT_randRatio"), 0.8);
58#if FastBDT_VERSION_MAJOR >= 5
59 EXPECT_EQ(pt.get<unsigned int>("FastBDT_number_individual_nCuts"), 3);
60 EXPECT_EQ(pt.get<unsigned int>("FastBDT_individual_nCuts0"), 2);
61 EXPECT_EQ(pt.get<unsigned int>("FastBDT_individual_nCuts1"), 3);
62 EXPECT_EQ(pt.get<unsigned int>("FastBDT_individual_nCuts2"), 4);
63 EXPECT_EQ(pt.get<bool>("FastBDT_sPlot"), true);
64 EXPECT_FLOAT_EQ(pt.get<double>("FastBDT_flatnessLoss"), 0.3);
65 EXPECT_EQ(pt.get<bool>("FastBDT_purityTransformation"), true);
66 EXPECT_EQ(pt.get<unsigned int>("FastBDT_number_individualPurityTransformation"), 3);
67 EXPECT_EQ(pt.get<bool>("FastBDT_individualPurityTransformation0"), true);
68 EXPECT_EQ(pt.get<bool>("FastBDT_individualPurityTransformation1"), false);
69 EXPECT_EQ(pt.get<bool>("FastBDT_individualPurityTransformation2"), true);
70#endif
71
72 MVA::FastBDTOptions specific_options2;
73 specific_options2.load(pt);
74
75 EXPECT_EQ(specific_options2.m_nTrees, 100);
76 EXPECT_EQ(specific_options2.m_nCuts, 10);
77 EXPECT_EQ(specific_options2.m_nLevels, 2);
78 EXPECT_FLOAT_EQ(specific_options2.m_shrinkage, 0.2);
79 EXPECT_FLOAT_EQ(specific_options2.m_randRatio, 0.8);
80#if FastBDT_VERSION_MAJOR >= 5
81 EXPECT_EQ(specific_options2.m_sPlot, true);
82 EXPECT_FLOAT_EQ(specific_options2.m_flatnessLoss, 0.3);
83 EXPECT_EQ(specific_options2.m_purityTransformation, true);
84 EXPECT_EQ(specific_options2.m_individualPurityTransformation.size(), 3);
85 EXPECT_EQ(specific_options2.m_individualPurityTransformation[0], true);
86 EXPECT_EQ(specific_options2.m_individualPurityTransformation[1], false);
87 EXPECT_EQ(specific_options2.m_individualPurityTransformation[2], true);
88 EXPECT_EQ(specific_options2.m_individual_nCuts.size(), 3);
89 EXPECT_EQ(specific_options2.m_individual_nCuts[0], 2);
90 EXPECT_EQ(specific_options2.m_individual_nCuts[1], 3);
91 EXPECT_EQ(specific_options2.m_individual_nCuts[2], 4);
92#endif
93
94 EXPECT_EQ(specific_options.getMethod(), std::string("FastBDT"));
95
96 // Test if po::options_description is created without crashing
97 auto description = specific_options.getDescription();
98
99 // flatnessLoss, sPlot and individualNCuts are only activated in FastBDT version 4
100#if FastBDT_VERSION_MAJOR >= 5
101 EXPECT_EQ(description.options().size(), 10);
102#else
103 EXPECT_EQ(description.options().size(), 5);
104#endif
105
106 // Check for B2ERROR and throw if version is wrong
107 // we try with version 100, surely we will never reach this!
108 pt.put("FastBDT_version", 100);
109 try {
110 EXPECT_B2ERROR(specific_options2.load(pt));
111 } catch (...) {
112
113 }
114 EXPECT_THROW(specific_options2.load(pt), std::runtime_error);
115 }
116
117 class TestDataset : public MVA::Dataset {
118 public:
119 explicit TestDataset(const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
120 {
121 m_input = {0.0};
122 m_target = 0.0;
123 m_isSignal = false;
124 m_weight = 1.0;
125 }
126
127 [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 1; }
128 [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 0; }
129 [[nodiscard]] unsigned int getNumberOfEvents() const override { return m_data.size(); }
130 void loadEvent(unsigned int iEvent) override { m_input[0] = m_data[iEvent]; m_target = iEvent % 2; m_isSignal = m_target == 1; };
131 float getSignalFraction() override { return 0.1; };
132 std::vector<float> getFeature(unsigned int) override { return m_data; }
133
134 std::vector<float> m_data;
135
136 };
137
138
139 TEST(FastBDTTest, FastBDTInterface)
140 {
142
143 MVA::GeneralOptions general_options;
144 general_options.m_variables = {"A"};
145 MVA::FastBDTOptions specific_options;
146 specific_options.m_randRatio = 1.0;
147 TestDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 3.0});
148
149 auto teacher = interface.getTeacher(general_options, specific_options);
150 auto weightfile = teacher->train(dataset);
151
152 auto expert = interface.getExpert();
153 expert->load(weightfile);
154 auto probabilities = expert->apply(dataset);
155 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
156 for (unsigned int i = 0; i < 4; ++i) {
157 EXPECT_LE(probabilities[i], 0.6);
158 EXPECT_GE(probabilities[i], 0.4);
159 }
160 EXPECT_LE(probabilities[4], 0.2);
161 EXPECT_GE(probabilities[5], 0.8);
162 EXPECT_LE(probabilities[6], 0.2);
163 EXPECT_GE(probabilities[7], 0.8);
164
165 }
166
167#if FastBDT_VERSION_MAJOR >= 5
168 TEST(FastBDTTest, FastBDTInterfaceWithPurityTransformation)
169 {
171
172 MVA::GeneralOptions general_options;
173 general_options.m_variables = {"A"};
174 MVA::FastBDTOptions specific_options;
175 specific_options.m_randRatio = 1.0;
176 specific_options.m_purityTransformation = true;
177 TestDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 3.0});
178
179 auto teacher = interface.getTeacher(general_options, specific_options);
180 auto weightfile = teacher->train(dataset);
181
182 auto expert = interface.getExpert();
183 expert->load(weightfile);
184 auto probabilities = expert->apply(dataset);
185 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
186 for (unsigned int i = 0; i < 4; ++i) {
187 EXPECT_LE(probabilities[i], 0.6);
188 EXPECT_GE(probabilities[i], 0.4);
189 }
190 EXPECT_LE(probabilities[4], 0.2);
191 EXPECT_GE(probabilities[5], 0.8);
192 EXPECT_LE(probabilities[6], 0.2);
193 EXPECT_GE(probabilities[7], 0.8);
194
195 }
196#endif
197
198 TEST(FastBDTTest, WeightfilesOfDifferentVersionsAreConsistent)
199 {
201
202 MVA::GeneralOptions general_options;
203 general_options.m_variables = {"M", "p", "pt"};
204 MVA::MultiDataset dataset(general_options, {{1.835127, 1.179507, 1.164944},
205 {1.873689, 1.881940, 1.843310},
206 {1.863657, 1.774831, 1.753773},
207 {1.858293, 1.605311, 0.631336},
208 {1.837129, 1.575739, 1.490166},
209 {1.811395, 1.524029, 0.565220}
210 },
211 {}, {0.0, 1.0, 0.0, 1.0, 0.0, 1.0});
212
213 // cppcheck-suppress unreadVariable
214 auto expert = interface.getExpert();
215
216#if FastBDT_VERSION_MAJOR >= 3
217 auto weightfile_v3 = MVA::Weightfile::loadFromFile(FileSystem::findFile("mva/methods/tests/FastBDTv3.xml"));
218 expert->load(weightfile_v3);
219 auto probabilities_v3 = expert->apply(dataset);
220 EXPECT_NEAR(probabilities_v3[0], 0.0402499, 0.0001);
221 EXPECT_NEAR(probabilities_v3[1], 0.2189, 0.0001);
222 EXPECT_NEAR(probabilities_v3[2], 0.264094, 0.0001);
223 EXPECT_NEAR(probabilities_v3[3], 0.100049, 0.0001);
224 EXPECT_NEAR(probabilities_v3[4], 0.0664554, 0.0001);
225 EXPECT_NEAR(probabilities_v3[5], 0.00886221, 0.0001);
226#endif
227
228#if FastBDT_VERSION_MAJOR >= 5
229 auto weightfile_v5 = MVA::Weightfile::loadFromFile(FileSystem::findFile("mva/methods/tests/FastBDTv5.xml"));
230 expert->load(weightfile_v5);
231 auto probabilities_v5 = expert->apply(dataset);
232 EXPECT_NEAR(probabilities_v5[0], 0.0402498, 0.0001);
233 EXPECT_NEAR(probabilities_v5[1], 0.218899, 0.0001);
234 EXPECT_NEAR(probabilities_v5[2], 0.264093, 0.0001);
235 EXPECT_NEAR(probabilities_v5[3], 0.100048, 0.0001);
236 EXPECT_NEAR(probabilities_v5[4], 0.0664551, 0.0001);
237 EXPECT_NEAR(probabilities_v5[5], 0.00886217, 0.0001);
238
239 // There are small differences due to floating point precision
240 // the FastBDT code is compiled with -O3 and changes slightly
241 // depending on the compiler and random things.
242 // Nevertheless the returned probabilities should be nearly the same!
243 EXPECT_NEAR(probabilities_v5[0], probabilities_v3[0], 0.001);
244 EXPECT_NEAR(probabilities_v5[1], probabilities_v3[1], 0.001);
245 EXPECT_NEAR(probabilities_v5[2], probabilities_v3[2], 0.001);
246 EXPECT_NEAR(probabilities_v5[3], probabilities_v3[3], 0.001);
247 EXPECT_NEAR(probabilities_v5[4], probabilities_v3[4], 0.001);
248 EXPECT_NEAR(probabilities_v5[5], probabilities_v3[5], 0.001);
249#endif
250 }
251
252}
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
Definition: FileSystem.cc:151
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual unsigned int getNumberOfSpectators() const =0
Returns the number of spectators in this dataset.
virtual unsigned int getNumberOfFeatures() const =0
Returns the number of features in this dataset.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
virtual std::vector< float > getFeature(unsigned int iFeature)
Returns all values of one feature in a std::vector<float>
Definition: Dataset.cc:74
virtual float getSignalFraction()
Returns the signal fraction of the whole sample.
Definition: Dataset.cc:35
Options for the FANN MVA method.
Definition: FastBDT.h:53
double m_randRatio
Fraction of data to use in the stochastic training.
Definition: FastBDT.h:82
double m_shrinkage
Shrinkage during the boosting step.
Definition: FastBDT.h:81
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: FastBDT.cc:53
unsigned int m_nLevels
Depth of tree.
Definition: FastBDT.h:80
unsigned int m_nCuts
Number of cut Levels = log_2(Number of Cuts)
Definition: FastBDT.h:79
unsigned int m_nTrees
Number of trees.
Definition: FastBDT.h:78
General options which are shared by all MVA trainings.
Definition: Options.h:62
Template class to easily construct a interface for an MVA library using a library-specific Options,...
Definition: Interface.h:99
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Definition: Interface.h:125
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
Definition: Interface.h:116
Wraps the data of a multiple event into a Dataset.
Definition: Dataset.h:186
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:206
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:24