9 #include <mva/methods/FastBDT.h>
10 #include <mva/interface/Interface.h>
11 #include <mva/interface/Dataset.h>
12 #include <framework/utilities/FileSystem.h>
13 #include <framework/utilities/TestHelpers.h>
15 #include <gtest/gtest.h>
21 TEST(FastBDTTest, FastBDTOptions)
25 EXPECT_EQ(specific_options.m_nTrees, 200);
26 EXPECT_EQ(specific_options.m_nCuts, 8);
27 EXPECT_EQ(specific_options.m_nLevels, 3);
28 EXPECT_FLOAT_EQ(specific_options.m_shrinkage, 0.1);
29 EXPECT_FLOAT_EQ(specific_options.m_randRatio, 0.5);
30 #if FastBDT_VERSION_MAJOR >= 5
31 EXPECT_EQ(specific_options.m_sPlot,
false);
32 EXPECT_EQ(specific_options.m_individual_nCuts.size(), 0);
33 EXPECT_EQ(specific_options.m_individualPurityTransformation.size(), 0);
34 EXPECT_EQ(specific_options.m_purityTransformation,
false);
35 EXPECT_FLOAT_EQ(specific_options.m_flatnessLoss, -1.0);
38 specific_options.m_nTrees = 100;
39 specific_options.m_nCuts = 10;
40 specific_options.m_nLevels = 2;
41 specific_options.m_shrinkage = 0.2;
42 specific_options.m_randRatio = 0.8;
43 #if FastBDT_VERSION_MAJOR >= 5
44 specific_options.m_individual_nCuts = {2, 3, 4};
45 specific_options.m_flatnessLoss = 0.3;
46 specific_options.m_sPlot =
true;
47 specific_options.m_purityTransformation =
true;
48 specific_options.m_individualPurityTransformation = {
true,
false,
true};
51 boost::property_tree::ptree pt;
52 specific_options.save(pt);
53 EXPECT_EQ(pt.get<
unsigned int>(
"FastBDT_nTrees"), 100);
54 EXPECT_EQ(pt.get<
unsigned int>(
"FastBDT_nCuts"), 10);
55 EXPECT_EQ(pt.get<
unsigned int>(
"FastBDT_nLevels"), 2);
56 EXPECT_FLOAT_EQ(pt.get<
double>(
"FastBDT_shrinkage"), 0.2);
57 EXPECT_FLOAT_EQ(pt.get<
double>(
"FastBDT_randRatio"), 0.8);
58 #if FastBDT_VERSION_MAJOR >= 5
59 EXPECT_EQ(pt.get<
unsigned int>(
"FastBDT_number_individual_nCuts"), 3);
60 EXPECT_EQ(pt.get<
unsigned int>(
"FastBDT_individual_nCuts0"), 2);
61 EXPECT_EQ(pt.get<
unsigned int>(
"FastBDT_individual_nCuts1"), 3);
62 EXPECT_EQ(pt.get<
unsigned int>(
"FastBDT_individual_nCuts2"), 4);
63 EXPECT_EQ(pt.get<
bool>(
"FastBDT_sPlot"),
true);
64 EXPECT_FLOAT_EQ(pt.get<
double>(
"FastBDT_flatnessLoss"), 0.3);
65 EXPECT_EQ(pt.get<
bool>(
"FastBDT_purityTransformation"),
true);
66 EXPECT_EQ(pt.get<
unsigned int>(
"FastBDT_number_individualPurityTransformation"), 3);
67 EXPECT_EQ(pt.get<
bool>(
"FastBDT_individualPurityTransformation0"),
true);
68 EXPECT_EQ(pt.get<
bool>(
"FastBDT_individualPurityTransformation1"),
false);
69 EXPECT_EQ(pt.get<
bool>(
"FastBDT_individualPurityTransformation2"),
true);
73 specific_options2.
load(pt);
75 EXPECT_EQ(specific_options2.
m_nTrees, 100);
76 EXPECT_EQ(specific_options2.
m_nCuts, 10);
77 EXPECT_EQ(specific_options2.
m_nLevels, 2);
78 EXPECT_FLOAT_EQ(specific_options2.
m_shrinkage, 0.2);
79 EXPECT_FLOAT_EQ(specific_options2.
m_randRatio, 0.8);
80 #if FastBDT_VERSION_MAJOR >= 5
81 EXPECT_EQ(specific_options2.m_sPlot,
true);
82 EXPECT_FLOAT_EQ(specific_options2.m_flatnessLoss, 0.3);
83 EXPECT_EQ(specific_options2.m_purityTransformation,
true);
84 EXPECT_EQ(specific_options2.m_individualPurityTransformation.size(), 3);
85 EXPECT_EQ(specific_options2.m_individualPurityTransformation[0],
true);
86 EXPECT_EQ(specific_options2.m_individualPurityTransformation[1],
false);
87 EXPECT_EQ(specific_options2.m_individualPurityTransformation[2],
true);
88 EXPECT_EQ(specific_options2.m_individual_nCuts.size(), 3);
89 EXPECT_EQ(specific_options2.m_individual_nCuts[0], 2);
90 EXPECT_EQ(specific_options2.m_individual_nCuts[1], 3);
91 EXPECT_EQ(specific_options2.m_individual_nCuts[2], 4);
94 EXPECT_EQ(specific_options.getMethod(), std::string(
"FastBDT"));
97 auto description = specific_options.getDescription();
100 #if FastBDT_VERSION_MAJOR >= 5
101 EXPECT_EQ(description.options().size(), 10);
103 EXPECT_EQ(description.options().size(), 5);
108 pt.put(
"FastBDT_version", 100);
110 EXPECT_B2ERROR(specific_options2.
load(pt));
114 EXPECT_THROW(specific_options2.
load(pt), std::runtime_error);
119 explicit TestDataset(
const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
127 [[nodiscard]]
unsigned int getNumberOfFeatures()
const override {
return 1; }
128 [[nodiscard]]
unsigned int getNumberOfSpectators()
const override {
return 0; }
129 [[nodiscard]]
unsigned int getNumberOfEvents()
const override {
return m_data.size(); }
130 void loadEvent(
unsigned int iEvent)
override { m_input[0] = m_data[iEvent]; m_target = iEvent % 2; m_isSignal = m_target == 1; };
131 float getSignalFraction()
override {
return 0.1; };
132 std::vector<float> getFeature(
unsigned int)
override {
return m_data; }
134 std::vector<float> m_data;
139 TEST(FastBDTTest, FastBDTInterface)
144 general_options.m_variables = {
"A"};
146 specific_options.m_randRatio = 1.0;
147 TestDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 3.0});
149 auto teacher = interface.
getTeacher(general_options, specific_options);
150 auto weightfile = teacher->train(dataset);
153 expert->load(weightfile);
154 auto probabilities = expert->apply(dataset);
155 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
156 for (
unsigned int i = 0; i < 4; ++i) {
157 EXPECT_LE(probabilities[i], 0.6);
158 EXPECT_GE(probabilities[i], 0.4);
160 EXPECT_LE(probabilities[4], 0.2);
161 EXPECT_GE(probabilities[5], 0.8);
162 EXPECT_LE(probabilities[6], 0.2);
163 EXPECT_GE(probabilities[7], 0.8);
167 #if FastBDT_VERSION_MAJOR >= 5
168 TEST(FastBDTTest, FastBDTInterfaceWithPurityTransformation)
173 general_options.m_variables = {
"A"};
175 specific_options.m_randRatio = 1.0;
176 specific_options.m_purityTransformation =
true;
177 TestDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 3.0});
179 auto teacher = interface.
getTeacher(general_options, specific_options);
180 auto weightfile = teacher->train(dataset);
183 expert->load(weightfile);
184 auto probabilities = expert->apply(dataset);
185 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
186 for (
unsigned int i = 0; i < 4; ++i) {
187 EXPECT_LE(probabilities[i], 0.6);
188 EXPECT_GE(probabilities[i], 0.4);
190 EXPECT_LE(probabilities[4], 0.2);
191 EXPECT_GE(probabilities[5], 0.8);
192 EXPECT_LE(probabilities[6], 0.2);
193 EXPECT_GE(probabilities[7], 0.8);
198 TEST(FastBDTTest, WeightfilesOfDifferentVersionsAreConsistent)
203 general_options.m_variables = {
"M",
"p",
"pt"};
205 {1.873689, 1.881940, 1.843310},
206 {1.863657, 1.774831, 1.753773},
207 {1.858293, 1.605311, 0.631336},
208 {1.837129, 1.575739, 1.490166},
209 {1.811395, 1.524029, 0.565220}
211 {}, {0.0, 1.0, 0.0, 1.0, 0.0, 1.0});
216 #if FastBDT_VERSION_MAJOR >= 3
218 expert->load(weightfile_v3);
219 auto probabilities_v3 = expert->apply(dataset);
220 EXPECT_NEAR(probabilities_v3[0], 0.0402499, 0.0001);
221 EXPECT_NEAR(probabilities_v3[1], 0.2189, 0.0001);
222 EXPECT_NEAR(probabilities_v3[2], 0.264094, 0.0001);
223 EXPECT_NEAR(probabilities_v3[3], 0.100049, 0.0001);
224 EXPECT_NEAR(probabilities_v3[4], 0.0664554, 0.0001);
225 EXPECT_NEAR(probabilities_v3[5], 0.00886221, 0.0001);
228 #if FastBDT_VERSION_MAJOR >= 5
230 expert->load(weightfile_v5);
231 auto probabilities_v5 = expert->apply(dataset);
232 EXPECT_NEAR(probabilities_v5[0], 0.0402498, 0.0001);
233 EXPECT_NEAR(probabilities_v5[1], 0.218899, 0.0001);
234 EXPECT_NEAR(probabilities_v5[2], 0.264093, 0.0001);
235 EXPECT_NEAR(probabilities_v5[3], 0.100048, 0.0001);
236 EXPECT_NEAR(probabilities_v5[4], 0.0664551, 0.0001);
237 EXPECT_NEAR(probabilities_v5[5], 0.00886217, 0.0001);
243 EXPECT_NEAR(probabilities_v5[0], probabilities_v3[0], 0.001);
244 EXPECT_NEAR(probabilities_v5[1], probabilities_v3[1], 0.001);
245 EXPECT_NEAR(probabilities_v5[2], probabilities_v3[2], 0.001);
246 EXPECT_NEAR(probabilities_v5[3], probabilities_v3[3], 0.001);
247 EXPECT_NEAR(probabilities_v5[4], probabilities_v3[4], 0.001);
248 EXPECT_NEAR(probabilities_v5[5], probabilities_v3[5], 0.001);
static std::string findFile(const std::string &path, bool silent=false)
Search for given file or directory in local or central release directory, and return absolute path if...
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Options for the FANN MVA method.
double m_randRatio
Fraction of data to use in the stochastic training.
double m_shrinkage
Shrinkage during the boosting step.
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
unsigned int m_nLevels
Depth of tree.
unsigned int m_nCuts
Number of cut Levels = log_2(Number of Cuts)
unsigned int m_nTrees
Number of trees.
General options which are shared by all MVA trainings.
Template class to easily construct a interface for an MVA library using a library-specific Options,...
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Wraps the data of a multiple event into a Dataset.
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
TEST(TestgetDetectorRegion, TestgetDetectorRegion)
Test Constructors.
Abstract base class for different kinds of events.