10 #include <mva/methods/Combination.h>
11 #include <mva/methods/Trivial.h>
12 #include <mva/interface/Interface.h>
13 #include <framework/utilities/TestHelpers.h>
15 #include <gtest/gtest.h>
21 TEST(CombinationTest, CombinationOptions)
26 EXPECT_EQ(specific_options.m_weightfiles.size(), 0);
28 specific_options.m_weightfiles = {
"A",
"B"};
30 boost::property_tree::ptree pt;
31 specific_options.save(pt);
32 EXPECT_FLOAT_EQ(pt.get<
unsigned int>(
"Combination_number_of_weightfiles"), 2);
33 EXPECT_EQ(pt.get<std::string>(
"Combination_weightfile0"),
"A");
34 EXPECT_EQ(pt.get<std::string>(
"Combination_weightfile1"),
"B");
37 specific_options2.
load(pt);
43 EXPECT_EQ(specific_options.getMethod(), std::string(
"Combination"));
46 auto description = specific_options.getDescription();
47 EXPECT_EQ(description.options().size(), 1);
51 pt.put(
"Combination_version", 100);
53 EXPECT_B2ERROR(specific_options2.
load(pt));
57 EXPECT_THROW(specific_options2.
load(pt), std::runtime_error);
62 explicit TestDataset(
const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
70 [[nodiscard]]
unsigned int getNumberOfFeatures()
const override {
return 1; }
71 [[nodiscard]]
unsigned int getNumberOfSpectators()
const override {
return 0; }
72 [[nodiscard]]
unsigned int getNumberOfEvents()
const override {
return m_data.size(); }
73 void loadEvent(
unsigned int iEvent)
override { m_input[0] = m_data[iEvent]; m_target = iEvent % 2; m_isSignal = m_target == 1; };
74 float getSignalFraction()
override {
return 0.1; };
75 std::vector<float> getFeature(
unsigned int)
override {
return m_data; }
77 std::vector<float> m_data;
82 TEST(CombinationTest, CombinationInterface)
87 general_options.m_method =
"Trivial";
88 TestDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 3.0});
93 trivial_options.m_output = 0.1;
94 auto trivial_teacher1 = trivial.getTeacher(general_options, trivial_options);
95 auto trivial_weightfile1 = trivial_teacher1->train(dataset);
98 trivial_options.m_output = 0.6;
99 auto trivial_teacher2 = trivial.getTeacher(general_options, trivial_options);
100 auto trivial_weightfile2 = trivial_teacher2->train(dataset);
104 general_options.m_method =
"Combination";
106 specific_options.m_weightfiles = {
"weightfile1.xml",
"weightfile2.xml"};
107 auto teacher = combination.
getTeacher(general_options, specific_options);
108 auto weightfile = teacher->train(dataset);
111 expert->load(weightfile);
112 auto probabilities = expert->apply(dataset);
113 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
114 for (
unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i)
115 EXPECT_FLOAT_EQ(probabilities[i], (0.1 * 0.6) / (0.1 * 0.6 + (1 - 0.1) * (1 - 0.6)));
118 trivial_weightfile2.addElement(
"method",
"DOESNOTEXIST");
121 auto weightfile2 = teacher->train(dataset);
123 EXPECT_B2ERROR(expert->load(weightfile2));
127 EXPECT_THROW(expert->load(weightfile2), std::runtime_error);