10 #include <mva/methods/PDF.h>
11 #include <mva/interface/Interface.h>
12 #include <framework/utilities/TestHelpers.h>
14 #include <gtest/gtest.h>
20 TEST(PDFTest, PDFOptions)
25 EXPECT_EQ(specific_options.m_mode,
"probability");
26 EXPECT_EQ(specific_options.m_binning,
"frequency");
27 EXPECT_EQ(specific_options.m_nBins, 100);
29 specific_options.m_mode =
"mode";
30 specific_options.m_binning =
"binning";
31 specific_options.m_nBins = 3;
33 boost::property_tree::ptree pt;
34 specific_options.save(pt);
35 EXPECT_EQ(pt.get<std::string>(
"PDF_mode"),
"mode");
36 EXPECT_EQ(pt.get<std::string>(
"PDF_binning"),
"binning");
37 EXPECT_EQ(pt.get<
unsigned int>(
"PDF_nBins"), 3);
41 specific_options2.
load(pt);
43 EXPECT_EQ(specific_options2.
m_mode,
"mode");
44 EXPECT_EQ(specific_options2.
m_binning,
"binning");
45 EXPECT_EQ(specific_options2.
m_nBins, 3);
47 EXPECT_EQ(specific_options.getMethod(), std::string(
"PDF"));
50 auto description = specific_options.getDescription();
51 EXPECT_EQ(description.options().size(), 3);
55 pt.put(
"PDF_version", 100);
57 EXPECT_B2ERROR(specific_options2.
load(pt));
61 EXPECT_THROW(specific_options2.
load(pt), std::runtime_error);
66 explicit TestDataset(
const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
74 [[nodiscard]]
unsigned int getNumberOfFeatures()
const override {
return 1; }
75 [[nodiscard]]
unsigned int getNumberOfSpectators()
const override {
return 0; }
76 [[nodiscard]]
unsigned int getNumberOfEvents()
const override {
return m_data.size(); }
77 void loadEvent(
unsigned int iEvent)
override { m_input[0] = m_data[iEvent]; m_target = iEvent % 2; m_isSignal = m_target == 1; };
78 float getSignalFraction()
override {
return 0.1; };
79 std::vector<float> getFeature(
unsigned int)
override {
return m_data; }
81 std::vector<float> m_data;
86 TEST(PDFTest, PDFInterface)
92 specific_options.m_nBins = 4;
93 TestDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 3.0});
95 auto teacher = interface.
getTeacher(general_options, specific_options);
96 auto weightfile = teacher->train(dataset);
99 expert->load(weightfile);
100 auto probabilities = expert->apply(dataset);
101 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
102 EXPECT_FLOAT_EQ(probabilities[0], 0.5);
103 EXPECT_FLOAT_EQ(probabilities[1], 0.5);
104 EXPECT_FLOAT_EQ(probabilities[2], 0.5);
105 EXPECT_FLOAT_EQ(probabilities[3], 0.5);
106 EXPECT_FLOAT_EQ(probabilities[4], 0.0);
107 EXPECT_FLOAT_EQ(probabilities[5], 1.0);
108 EXPECT_FLOAT_EQ(probabilities[6], 0.0);
109 EXPECT_FLOAT_EQ(probabilities[7], 1.0);