10 #include <mva/methods/Trivial.h>
11 #include <mva/interface/Interface.h>
12 #include <framework/utilities/FileSystem.h>
13 #include <framework/utilities/TestHelpers.h>
15 #include <gtest/gtest.h>
21 TEST(TrivialTest, TrivialOptions)
26 EXPECT_EQ(specific_options.m_output, 0.5);
28 specific_options.m_output = 0.1;
30 boost::property_tree::ptree pt;
31 specific_options.save(pt);
32 EXPECT_FLOAT_EQ(pt.get<
double>(
"Trivial_output"), 0.1);
35 specific_options2.
load(pt);
37 EXPECT_EQ(specific_options2.
m_output, 0.1);
39 EXPECT_EQ(specific_options.getMethod(), std::string(
"Trivial"));
42 auto description = specific_options.getDescription();
43 EXPECT_EQ(description.options().size(), 1);
47 pt.put(
"Trivial_version", 100);
49 EXPECT_B2ERROR(specific_options2.
load(pt));
53 EXPECT_THROW(specific_options2.
load(pt), std::runtime_error);
58 explicit TestDataset(
const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
66 [[nodiscard]]
unsigned int getNumberOfFeatures()
const override {
return 1; }
67 [[nodiscard]]
unsigned int getNumberOfSpectators()
const override {
return 0; }
68 [[nodiscard]]
unsigned int getNumberOfEvents()
const override {
return m_data.size(); }
69 void loadEvent(
unsigned int iEvent)
override { m_input[0] = m_data[iEvent]; m_target = iEvent % 2; m_isSignal = m_target == 1; };
70 float getSignalFraction()
override {
return 0.1; };
71 std::vector<float> getFeature(
unsigned int)
override {
return m_data; }
73 std::vector<float> m_data;
78 TEST(TrivialTest, TrivialInterface)
84 TestDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 3.0});
86 auto teacher = interface.
getTeacher(general_options, specific_options);
87 auto weightfile = teacher->train(dataset);
90 expert->load(weightfile);
91 auto probabilities = expert->apply(dataset);
92 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
93 for (
unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i)
94 EXPECT_FLOAT_EQ(probabilities[i], 0.5);