10 #include <mva/methods/FastBDT.h>
11 #include <mva/interface/Interface.h>
12 #include <mva/interface/Dataset.h>
13 #include <framework/utilities/FileSystem.h>
14 #include <framework/utilities/TestHelpers.h>
16 #include <gtest/gtest.h>
22 TEST(FastBDTTest, FastBDTOptions)
26 EXPECT_EQ(specific_options.m_nTrees, 200);
27 EXPECT_EQ(specific_options.m_nCuts, 8);
28 EXPECT_EQ(specific_options.m_nLevels, 3);
29 EXPECT_FLOAT_EQ(specific_options.m_shrinkage, 0.1);
30 EXPECT_FLOAT_EQ(specific_options.m_randRatio, 0.5);
31 #if FastBDT_VERSION_MAJOR >= 5
32 EXPECT_EQ(specific_options.m_sPlot,
false);
33 EXPECT_EQ(specific_options.m_individual_nCuts.size(), 0);
34 EXPECT_EQ(specific_options.m_individualPurityTransformation.size(), 0);
35 EXPECT_EQ(specific_options.m_purityTransformation,
false);
36 EXPECT_FLOAT_EQ(specific_options.m_flatnessLoss, -1.0);
39 specific_options.m_nTrees = 100;
40 specific_options.m_nCuts = 10;
41 specific_options.m_nLevels = 2;
42 specific_options.m_shrinkage = 0.2;
43 specific_options.m_randRatio = 0.8;
44 #if FastBDT_VERSION_MAJOR >= 5
45 specific_options.m_individual_nCuts = {2, 3, 4};
46 specific_options.m_flatnessLoss = 0.3;
47 specific_options.m_sPlot =
true;
48 specific_options.m_purityTransformation =
true;
49 specific_options.m_individualPurityTransformation = {
true,
false,
true};
52 boost::property_tree::ptree pt;
53 specific_options.save(pt);
54 EXPECT_EQ(pt.get<
unsigned int>(
"FastBDT_nTrees"), 100);
55 EXPECT_EQ(pt.get<
unsigned int>(
"FastBDT_nCuts"), 10);
56 EXPECT_EQ(pt.get<
unsigned int>(
"FastBDT_nLevels"), 2);
57 EXPECT_FLOAT_EQ(pt.get<
double>(
"FastBDT_shrinkage"), 0.2);
58 EXPECT_FLOAT_EQ(pt.get<
double>(
"FastBDT_randRatio"), 0.8);
59 #if FastBDT_VERSION_MAJOR >= 5
60 EXPECT_EQ(pt.get<
unsigned int>(
"FastBDT_number_individual_nCuts"), 3);
61 EXPECT_EQ(pt.get<
unsigned int>(
"FastBDT_individual_nCuts0"), 2);
62 EXPECT_EQ(pt.get<
unsigned int>(
"FastBDT_individual_nCuts1"), 3);
63 EXPECT_EQ(pt.get<
unsigned int>(
"FastBDT_individual_nCuts2"), 4);
64 EXPECT_EQ(pt.get<
bool>(
"FastBDT_sPlot"),
true);
65 EXPECT_FLOAT_EQ(pt.get<
double>(
"FastBDT_flatnessLoss"), 0.3);
66 EXPECT_EQ(pt.get<
bool>(
"FastBDT_purityTransformation"),
true);
67 EXPECT_EQ(pt.get<
unsigned int>(
"FastBDT_number_individualPurityTransformation"), 3);
68 EXPECT_EQ(pt.get<
bool>(
"FastBDT_individualPurityTransformation0"),
true);
69 EXPECT_EQ(pt.get<
bool>(
"FastBDT_individualPurityTransformation1"),
false);
70 EXPECT_EQ(pt.get<
bool>(
"FastBDT_individualPurityTransformation2"),
true);
74 specific_options2.
load(pt);
76 EXPECT_EQ(specific_options2.
m_nTrees, 100);
77 EXPECT_EQ(specific_options2.
m_nCuts, 10);
78 EXPECT_EQ(specific_options2.
m_nLevels, 2);
79 EXPECT_FLOAT_EQ(specific_options2.
m_shrinkage, 0.2);
80 EXPECT_FLOAT_EQ(specific_options2.
m_randRatio, 0.8);
81 #if FastBDT_VERSION_MAJOR >= 5
82 EXPECT_EQ(specific_options2.m_sPlot,
true);
83 EXPECT_FLOAT_EQ(specific_options2.m_flatnessLoss, 0.3);
84 EXPECT_EQ(specific_options2.m_purityTransformation,
true);
85 EXPECT_EQ(specific_options2.m_individualPurityTransformation.size(), 3);
86 EXPECT_EQ(specific_options2.m_individualPurityTransformation[0],
true);
87 EXPECT_EQ(specific_options2.m_individualPurityTransformation[1],
false);
88 EXPECT_EQ(specific_options2.m_individualPurityTransformation[2],
true);
89 EXPECT_EQ(specific_options2.m_individual_nCuts.size(), 3);
90 EXPECT_EQ(specific_options2.m_individual_nCuts[0], 2);
91 EXPECT_EQ(specific_options2.m_individual_nCuts[1], 3);
92 EXPECT_EQ(specific_options2.m_individual_nCuts[2], 4);
95 EXPECT_EQ(specific_options.getMethod(), std::string(
"FastBDT"));
98 auto description = specific_options.getDescription();
101 #if FastBDT_VERSION_MAJOR >= 5
102 EXPECT_EQ(description.options().size(), 10);
104 EXPECT_EQ(description.options().size(), 5);
109 pt.put(
"FastBDT_version", 100);
111 EXPECT_B2ERROR(specific_options2.
load(pt));
115 EXPECT_THROW(specific_options2.
load(pt), std::runtime_error);
120 explicit TestDataset(
const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
128 [[nodiscard]]
unsigned int getNumberOfFeatures()
const override {
return 1; }
129 [[nodiscard]]
unsigned int getNumberOfSpectators()
const override {
return 0; }
130 [[nodiscard]]
unsigned int getNumberOfEvents()
const override {
return m_data.size(); }
131 void loadEvent(
unsigned int iEvent)
override { m_input[0] = m_data[iEvent]; m_target = iEvent % 2; m_isSignal = m_target == 1; };
132 float getSignalFraction()
override {
return 0.1; };
133 std::vector<float> getFeature(
unsigned int)
override {
return m_data; }
135 std::vector<float> m_data;
140 TEST(FastBDTTest, FastBDTInterface)
145 general_options.m_variables = {
"A"};
147 specific_options.m_randRatio = 1.0;
148 TestDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 3.0});
150 auto teacher = interface.
getTeacher(general_options, specific_options);
151 auto weightfile = teacher->train(dataset);
154 expert->load(weightfile);
155 auto probabilities = expert->apply(dataset);
156 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
157 for (
unsigned int i = 0; i < 4; ++i) {
158 EXPECT_LE(probabilities[i], 0.6);
159 EXPECT_GE(probabilities[i], 0.4);
161 EXPECT_LE(probabilities[4], 0.2);
162 EXPECT_GE(probabilities[5], 0.8);
163 EXPECT_LE(probabilities[6], 0.2);
164 EXPECT_GE(probabilities[7], 0.8);
168 #if FastBDT_VERSION_MAJOR >= 5
169 TEST(FastBDTTest, FastBDTInterfaceWithPurityTransformation)
174 general_options.m_variables = {
"A"};
176 specific_options.m_randRatio = 1.0;
177 specific_options.m_purityTransformation =
true;
178 TestDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 3.0});
180 auto teacher = interface.
getTeacher(general_options, specific_options);
181 auto weightfile = teacher->train(dataset);
184 expert->load(weightfile);
185 auto probabilities = expert->apply(dataset);
186 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
187 for (
unsigned int i = 0; i < 4; ++i) {
188 EXPECT_LE(probabilities[i], 0.6);
189 EXPECT_GE(probabilities[i], 0.4);
191 EXPECT_LE(probabilities[4], 0.2);
192 EXPECT_GE(probabilities[5], 0.8);
193 EXPECT_LE(probabilities[6], 0.2);
194 EXPECT_GE(probabilities[7], 0.8);
199 TEST(FastBDTTest, WeightfilesOfDifferentVersionsAreConsistent)
204 general_options.m_variables = {
"M",
"p",
"pt"};
206 {1.873689, 1.881940, 1.843310},
207 {1.863657, 1.774831, 1.753773},
208 {1.858293, 1.605311, 0.631336},
209 {1.837129, 1.575739, 1.490166},
210 {1.811395, 1.524029, 0.565220}
212 {}, {0.0, 1.0, 0.0, 1.0, 0.0, 1.0});
216 #if FastBDT_VERSION_MAJOR >= 3
218 expert->load(weightfile_v3);
219 auto probabilities_v3 = expert->apply(dataset);
220 EXPECT_NEAR(probabilities_v3[0], 0.0402499, 0.0001);
221 EXPECT_NEAR(probabilities_v3[1], 0.2189, 0.0001);
222 EXPECT_NEAR(probabilities_v3[2], 0.264094, 0.0001);
223 EXPECT_NEAR(probabilities_v3[3], 0.100049, 0.0001);
224 EXPECT_NEAR(probabilities_v3[4], 0.0664554, 0.0001);
225 EXPECT_NEAR(probabilities_v3[5], 0.00886221, 0.0001);
228 #if FastBDT_VERSION_MAJOR >= 5
230 expert->load(weightfile_v5);
231 auto probabilities_v5 = expert->apply(dataset);
232 EXPECT_NEAR(probabilities_v5[0], 0.0402498, 0.0001);
233 EXPECT_NEAR(probabilities_v5[1], 0.218899, 0.0001);
234 EXPECT_NEAR(probabilities_v5[2], 0.264093, 0.0001);
235 EXPECT_NEAR(probabilities_v5[3], 0.100048, 0.0001);
236 EXPECT_NEAR(probabilities_v5[4], 0.0664551, 0.0001);
237 EXPECT_NEAR(probabilities_v5[5], 0.00886217, 0.0001);
243 EXPECT_NEAR(probabilities_v5[0], probabilities_v3[0], 0.001);
244 EXPECT_NEAR(probabilities_v5[1], probabilities_v3[1], 0.001);
245 EXPECT_NEAR(probabilities_v5[2], probabilities_v3[2], 0.001);
246 EXPECT_NEAR(probabilities_v5[3], probabilities_v3[3], 0.001);
247 EXPECT_NEAR(probabilities_v5[4], probabilities_v3[4], 0.001);
248 EXPECT_NEAR(probabilities_v5[5], probabilities_v3[5], 0.001);