9#include <mva/methods/Trivial.h>
10#include <mva/interface/Interface.h>
11#include <framework/utilities/FileSystem.h>
12#include <framework/utilities/TestHelpers.h>
14#include <gtest/gtest.h>
20 TEST(TrivialTest, TrivialOptions)
25 EXPECT_EQ(specific_options.m_output, 0.5);
27 specific_options.m_output = 0.1;
28 specific_options.m_multiple_output = {1.0, 2.0, 3.0};
29 specific_options.m_passthrough =
true;
31 boost::property_tree::ptree pt;
32 specific_options.save(pt);
33 EXPECT_FLOAT_EQ(pt.get<
double>(
"Trivial_output"), 0.1);
34 EXPECT_FLOAT_EQ(pt.get<
bool>(
"Trivial_passthrough"),
true);
36 EXPECT_FLOAT_EQ(pt.get<
unsigned int>(
"Trivial_number_of_multiple_outputs"), 3);
37 EXPECT_FLOAT_EQ(pt.get<
double>(
"Trivial_multiple_output0"), 1.0);
38 EXPECT_FLOAT_EQ(pt.get<
double>(
"Trivial_multiple_output1"), 2.0);
39 EXPECT_FLOAT_EQ(pt.get<
double>(
"Trivial_multiple_output2"), 3.0);
42 specific_options2.
load(pt);
44 EXPECT_EQ(specific_options2.
m_output, 0.1);
49 EXPECT_EQ(specific_options2.
getMethod(), std::string(
"Trivial"));
53 auto description = specific_options.getDescription();
54 EXPECT_EQ(description.options().size(), 3);
58 pt.put(
"Trivial_version", 100);
60 EXPECT_B2ERROR(specific_options2.
load(pt));
64 EXPECT_THROW(specific_options2.
load(pt), std::runtime_error);
69 explicit TestDataset(
const std::vector<std::vector<float>>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data),
70 m_nFeatures(data[0].size())
80 [[nodiscard]]
unsigned int getNumberOfEvents()
const override {
return m_data.size(); }
81 void loadEvent(
unsigned int iEvent)
override { m_input = m_data[iEvent]; m_target = iEvent % 2; m_isSignal = m_target == 1; };
83 std::vector<float>
getFeature(
unsigned int iFeature)
override
85 std::vector<float> feature(m_data.size(), 0.0);
86 for (
unsigned int iEvent = 0; iEvent << m_data.size(); iEvent++) {
87 feature[iEvent] = m_data[iEvent][iFeature];
91 std::vector<std::vector<float>> m_data;
92 unsigned int m_nFeatures;
95 TEST(TrivialTest, TrivialInterface)
101 TestDataset dataset({{1.0,}, {1.0,}, {1.0,}, {1.0,}, {2.0,}, {3.0,}, {2.0,}, {3.0,}});
103 auto teacher = interface.
getTeacher(general_options, specific_options);
104 auto weightfile = teacher->train(dataset);
107 expert->load(weightfile);
108 auto probabilities = expert->apply(dataset);
109 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
110 for (
unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i)
111 EXPECT_FLOAT_EQ(probabilities[i], 0.5);
114 TEST(TrivialTest, TrivialPassThrough)
120 general_options.m_variables = {
"p",};
121 specific_options.m_passthrough =
true;
122 TestDataset dataset({{1.0,}, {1.0,}, {1.0,}, {1.0,}, {2.0,}, {3.0,}, {2.0,}, {3.0,}});
124 auto teacher = interface.
getTeacher(general_options, specific_options);
125 auto weightfile = teacher->train(dataset);
128 expert->load(weightfile);
129 auto probabilities = expert->apply(dataset);
130 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
131 for (
unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i)
132 EXPECT_FLOAT_EQ(probabilities[i], dataset.m_data[i][0]);
135 TEST(TrivialTest, TrivialPassThroughMulticlass)
141 general_options.m_variables = {
"p",};
142 specific_options.m_passthrough =
true;
143 general_options.m_nClasses = 3;
144 TestDataset dataset({{1.0,}, {1.0,}, {1.0,}, {1.0,}, {2.0,}, {3.0,}, {2.0,}, {3.0,}});
146 auto teacher = interface.
getTeacher(general_options, specific_options);
147 auto weightfile = teacher->train(dataset);
150 expert->load(weightfile);
151 auto probabilities = expert->applyMulticlass(dataset);
152 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
153 for (
unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i) {
154 for (
unsigned int j = 0; j < probabilities[i].size(); ++ j) {
155 EXPECT_FLOAT_EQ(probabilities[i][j], dataset.m_data[i][0]);
160 TEST(TrivialTest, TrivialPassThroughMulticlassMultiVariable)
166 general_options.m_variables = {
"px",
"py",
"pz"};
167 specific_options.m_passthrough =
true;
168 general_options.m_nClasses = 3;
169 TestDataset dataset({{1.0, 2.0, 3.0}, {1.0, 3.0, 4.0}, {1.0, 7.0, 13.0}});
171 auto teacher = interface.
getTeacher(general_options, specific_options);
172 auto weightfile = teacher->train(dataset);
175 expert->load(weightfile);
176 auto probabilities = expert->applyMulticlass(dataset);
177 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
178 for (
unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i) {
179 for (
unsigned int j = 0; j < probabilities[i].size(); ++ j) {
180 EXPECT_FLOAT_EQ(probabilities[i][j], dataset.m_data[i][j]);
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual unsigned int getNumberOfSpectators() const =0
Returns the number of spectators in this dataset.
virtual unsigned int getNumberOfFeatures() const =0
Returns the number of features in this dataset.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
virtual std::vector< float > getFeature(unsigned int iFeature)
Returns all values of one feature in a std::vector<float>
virtual float getSignalFraction()
Returns the signal fraction of the whole sample.
General options which are shared by all MVA trainings.
Template class to easily construct a interface for an MVA library using a library-specific Options,...
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
Options for the Trivial MVA method.
double m_output
Output of the trivial method.
virtual std::string getMethod() const override
Return method name.
std::vector< double > m_multiple_output
Output of the trivial method.
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
bool m_passthrough
Flag for passthrough setting.
Abstract base class for different kinds of events.