Belle II Software development
test_FastBDT.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/methods/FastBDT.h>
10#include <mva/interface/Interface.h>
11#include <mva/interface/Dataset.h>
12#include <framework/utilities/FileSystem.h>
13#include <framework/utilities/TestHelpers.h>
14
15#include <gtest/gtest.h>
16
17using namespace Belle2;
18
19namespace {
20
21 TEST(FastBDTTest, FastBDTOptions)
22 {
23 MVA::FastBDTOptions specific_options;
24
25 EXPECT_EQ(specific_options.m_nTrees, 200);
26 EXPECT_EQ(specific_options.m_nCuts, 8);
27 EXPECT_EQ(specific_options.m_nLevels, 3);
28 EXPECT_FLOAT_EQ(specific_options.m_shrinkage, 0.1);
29 EXPECT_FLOAT_EQ(specific_options.m_randRatio, 0.5);
30 EXPECT_EQ(specific_options.m_sPlot, false);
31 EXPECT_EQ(specific_options.m_individual_nCuts.size(), 0);
32 EXPECT_EQ(specific_options.m_individualPurityTransformation.size(), 0);
33 EXPECT_EQ(specific_options.m_purityTransformation, false);
34 EXPECT_FLOAT_EQ(specific_options.m_flatnessLoss, -1.0);
35
36 specific_options.m_nTrees = 100;
37 specific_options.m_nCuts = 10;
38 specific_options.m_nLevels = 2;
39 specific_options.m_shrinkage = 0.2;
40 specific_options.m_randRatio = 0.8;
41 specific_options.m_individual_nCuts = {2, 3, 4};
42 specific_options.m_flatnessLoss = 0.3;
43 specific_options.m_sPlot = true;
44 specific_options.m_purityTransformation = true;
45 specific_options.m_individualPurityTransformation = {true, false, true};
46
47 boost::property_tree::ptree pt;
48 specific_options.save(pt);
49 EXPECT_EQ(pt.get<unsigned int>("FastBDT_nTrees"), 100);
50 EXPECT_EQ(pt.get<unsigned int>("FastBDT_nCuts"), 10);
51 EXPECT_EQ(pt.get<unsigned int>("FastBDT_nLevels"), 2);
52 EXPECT_FLOAT_EQ(pt.get<double>("FastBDT_shrinkage"), 0.2);
53 EXPECT_FLOAT_EQ(pt.get<double>("FastBDT_randRatio"), 0.8);
54 EXPECT_EQ(pt.get<unsigned int>("FastBDT_number_individual_nCuts"), 3);
55 EXPECT_EQ(pt.get<unsigned int>("FastBDT_individual_nCuts0"), 2);
56 EXPECT_EQ(pt.get<unsigned int>("FastBDT_individual_nCuts1"), 3);
57 EXPECT_EQ(pt.get<unsigned int>("FastBDT_individual_nCuts2"), 4);
58 EXPECT_EQ(pt.get<bool>("FastBDT_sPlot"), true);
59 EXPECT_FLOAT_EQ(pt.get<double>("FastBDT_flatnessLoss"), 0.3);
60 EXPECT_EQ(pt.get<bool>("FastBDT_purityTransformation"), true);
61 EXPECT_EQ(pt.get<unsigned int>("FastBDT_number_individualPurityTransformation"), 3);
62 EXPECT_EQ(pt.get<bool>("FastBDT_individualPurityTransformation0"), true);
63 EXPECT_EQ(pt.get<bool>("FastBDT_individualPurityTransformation1"), false);
64 EXPECT_EQ(pt.get<bool>("FastBDT_individualPurityTransformation2"), true);
65
66 MVA::FastBDTOptions specific_options2;
67 specific_options2.load(pt);
68
69 EXPECT_EQ(specific_options2.m_nTrees, 100);
70 EXPECT_EQ(specific_options2.m_nCuts, 10);
71 EXPECT_EQ(specific_options2.m_nLevels, 2);
72 EXPECT_FLOAT_EQ(specific_options2.m_shrinkage, 0.2);
73 EXPECT_FLOAT_EQ(specific_options2.m_randRatio, 0.8);
74 EXPECT_EQ(specific_options2.m_sPlot, true);
75 EXPECT_FLOAT_EQ(specific_options2.m_flatnessLoss, 0.3);
76 EXPECT_EQ(specific_options2.m_purityTransformation, true);
77 EXPECT_EQ(specific_options2.m_individualPurityTransformation.size(), 3);
78 EXPECT_EQ(specific_options2.m_individualPurityTransformation[0], true);
79 EXPECT_EQ(specific_options2.m_individualPurityTransformation[1], false);
80 EXPECT_EQ(specific_options2.m_individualPurityTransformation[2], true);
81 EXPECT_EQ(specific_options2.m_individual_nCuts.size(), 3);
82 EXPECT_EQ(specific_options2.m_individual_nCuts[0], 2);
83 EXPECT_EQ(specific_options2.m_individual_nCuts[1], 3);
84 EXPECT_EQ(specific_options2.m_individual_nCuts[2], 4);
85
86 EXPECT_EQ(specific_options.getMethod(), std::string("FastBDT"));
87
88 // Test if po::options_description is created without crashing
89 auto description = specific_options.getDescription();
90
91 EXPECT_EQ(description.options().size(), 10);
92
93 // Check for B2ERROR and throw if version is wrong
94 // we try with version 100, surely we will never reach this!
95 pt.put("FastBDT_version", 100);
96 try {
97 EXPECT_B2ERROR(specific_options2.load(pt));
98 } catch (...) {
99
100 }
101 EXPECT_THROW(specific_options2.load(pt), std::runtime_error);
102 }
103
104 class TestDataset : public MVA::Dataset {
105 public:
106 explicit TestDataset(const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
107 {
108 m_input = {0.0};
109 m_target = 0.0;
110 m_isSignal = false;
111 m_weight = 1.0;
112 }
113
114 [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 1; }
115 [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 0; }
116 [[nodiscard]] unsigned int getNumberOfEvents() const override { return m_data.size(); }
117 void loadEvent(unsigned int iEvent) override { m_input[0] = m_data[iEvent]; m_target = iEvent % 2; m_isSignal = m_target == 1; };
118 float getSignalFraction() override { return 0.1; };
119 std::vector<float> getFeature(unsigned int) override { return m_data; }
120
121 std::vector<float> m_data;
122
123 };
124
125
126 TEST(FastBDTTest, FastBDTInterface)
127 {
129
130 MVA::GeneralOptions general_options;
131 general_options.m_variables = {"A"};
132 MVA::FastBDTOptions specific_options;
133 specific_options.m_randRatio = 1.0;
134 TestDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 3.0});
135
136 auto teacher = interface.getTeacher(general_options, specific_options);
137 auto weightfile = teacher->train(dataset);
138
139 auto expert = interface.getExpert();
140 expert->load(weightfile);
141 auto probabilities = expert->apply(dataset);
142 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
143 for (unsigned int i = 0; i < 4; ++i) {
144 EXPECT_LE(probabilities[i], 0.6);
145 EXPECT_GE(probabilities[i], 0.4);
146 }
147 EXPECT_LE(probabilities[4], 0.2);
148 EXPECT_GE(probabilities[5], 0.8);
149 EXPECT_LE(probabilities[6], 0.2);
150 EXPECT_GE(probabilities[7], 0.8);
151
152 }
153
154 TEST(FastBDTTest, FastBDTInterfaceWithPurityTransformation)
155 {
157
158 MVA::GeneralOptions general_options;
159 general_options.m_variables = {"A"};
160 MVA::FastBDTOptions specific_options;
161 specific_options.m_randRatio = 1.0;
162 specific_options.m_purityTransformation = true;
163 TestDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 3.0});
164
165 auto teacher = interface.getTeacher(general_options, specific_options);
166 auto weightfile = teacher->train(dataset);
167
168 auto expert = interface.getExpert();
169 expert->load(weightfile);
170 auto probabilities = expert->apply(dataset);
171 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
172 for (unsigned int i = 0; i < 4; ++i) {
173 EXPECT_LE(probabilities[i], 0.6);
174 EXPECT_GE(probabilities[i], 0.4);
175 }
176 EXPECT_LE(probabilities[4], 0.2);
177 EXPECT_GE(probabilities[5], 0.8);
178 EXPECT_LE(probabilities[6], 0.2);
179 EXPECT_GE(probabilities[7], 0.8);
180
181 }
182
183 TEST(FastBDTTest, WeightfilesOfDifferentVersionsAreConsistent)
184 {
186
187 MVA::GeneralOptions general_options;
188 general_options.m_variables = {"M", "p", "pt"};
189 MVA::MultiDataset dataset(general_options, {{1.835127, 1.179507, 1.164944},
190 {1.873689, 1.881940, 1.843310},
191 {1.863657, 1.774831, 1.753773},
192 {1.858293, 1.605311, 0.631336},
193 {1.837129, 1.575739, 1.490166},
194 {1.811395, 1.524029, 0.565220}
195 },
196 {}, {0.0, 1.0, 0.0, 1.0, 0.0, 1.0});
197
198 auto expert = interface.getExpert();
199
200 auto weightfile_v3 = MVA::Weightfile::loadFromFile(FileSystem::findFile("mva/methods/tests/FastBDTv3.xml"));
201 expert->load(weightfile_v3);
202 auto probabilities_v3 = expert->apply(dataset);
203 EXPECT_NEAR(probabilities_v3[0], 0.0402499, 0.0001);
204 EXPECT_NEAR(probabilities_v3[1], 0.2189, 0.0001);
205 EXPECT_NEAR(probabilities_v3[2], 0.264094, 0.0001);
206 EXPECT_NEAR(probabilities_v3[3], 0.100049, 0.0001);
207 EXPECT_NEAR(probabilities_v3[4], 0.0664554, 0.0001);
208 EXPECT_NEAR(probabilities_v3[5], 0.00886221, 0.0001);
209
210 auto weightfile_v5 = MVA::Weightfile::loadFromFile(FileSystem::findFile("mva/methods/tests/FastBDTv5.xml"));
211 expert->load(weightfile_v5);
212 auto probabilities_v5 = expert->apply(dataset);
213 EXPECT_NEAR(probabilities_v5[0], 0.0402498, 0.0001);
214 EXPECT_NEAR(probabilities_v5[1], 0.218899, 0.0001);
215 EXPECT_NEAR(probabilities_v5[2], 0.264093, 0.0001);
216 EXPECT_NEAR(probabilities_v5[3], 0.100048, 0.0001);
217 EXPECT_NEAR(probabilities_v5[4], 0.0664551, 0.0001);
218 EXPECT_NEAR(probabilities_v5[5], 0.00886217, 0.0001);
219
220 // There are small differences due to floating point precision
221 // the FastBDT code is compiled with -O3 and changes slightly
222 // depending on the compiler and random things.
223 // Nevertheless the returned probabilities should be nearly the same!
224 EXPECT_NEAR(probabilities_v5[0], probabilities_v3[0], 0.001);
225 EXPECT_NEAR(probabilities_v5[1], probabilities_v3[1], 0.001);
226 EXPECT_NEAR(probabilities_v5[2], probabilities_v3[2], 0.001);
227 EXPECT_NEAR(probabilities_v5[3], probabilities_v3[3], 0.001);
228 EXPECT_NEAR(probabilities_v5[4], probabilities_v3[4], 0.001);
229 EXPECT_NEAR(probabilities_v5[5], probabilities_v3[5], 0.001);
230 }
231
232}
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
Definition: FileSystem.cc:151
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual unsigned int getNumberOfSpectators() const =0
Returns the number of spectators in this dataset.
virtual unsigned int getNumberOfFeatures() const =0
Returns the number of features in this dataset.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
virtual std::vector< float > getFeature(unsigned int iFeature)
Returns all values of one feature in a std::vector<float>
Definition: Dataset.cc:74
virtual float getSignalFraction()
Returns the signal fraction of the whole sample.
Definition: Dataset.cc:35
Options for the FANN MVA method.
Definition: FastBDT.h:37
std::vector< unsigned int > m_individual_nCuts
Number of cut Levels = log_2(Number of Cuts) for each provided feature.
Definition: FastBDT.h:68
bool m_sPlot
Activates sPlot sampling.
Definition: FastBDT.h:70
double m_randRatio
Fraction of data to use in the stochastic training.
Definition: FastBDT.h:66
double m_flatnessLoss
Flatness Loss constant.
Definition: FastBDT.h:69
double m_shrinkage
Shrinkage during the boosting step.
Definition: FastBDT.h:65
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: FastBDT.cc:31
bool m_purityTransformation
Activates purity transformation globally for all features.
Definition: FastBDT.h:71
unsigned int m_nLevels
Depth of tree.
Definition: FastBDT.h:64
std::vector< bool > m_individualPurityTransformation
Vector which decided for each feature individually if the purity transformation should be used.
Definition: FastBDT.h:73
unsigned int m_nCuts
Number of cut Levels = log_2(Number of Cuts)
Definition: FastBDT.h:63
unsigned int m_nTrees
Number of trees.
Definition: FastBDT.h:62
General options which are shared by all MVA trainings.
Definition: Options.h:62
Template class to easily construct a interface for an MVA library using a library-specific Options,...
Definition: Interface.h:99
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Definition: Interface.h:125
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
Definition: Interface.h:116
Wraps the data of a multiple event into a Dataset.
Definition: Dataset.h:186
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:206
Abstract base class for different kinds of events.