9 #include <mva/methods/Trivial.h>
10 #include <mva/interface/Interface.h>
11 #include <framework/utilities/FileSystem.h>
12 #include <framework/utilities/TestHelpers.h>
14 #include <gtest/gtest.h>
20 TEST(TrivialTest, TrivialOptions)
25 EXPECT_EQ(specific_options.m_output, 0.5);
27 specific_options.m_output = 0.1;
28 specific_options.m_multiple_output = {1.0, 2.0, 3.0};
29 specific_options.m_passthrough =
true;
31 boost::property_tree::ptree pt;
32 specific_options.save(pt);
33 EXPECT_FLOAT_EQ(pt.get<
double>(
"Trivial_output"), 0.1);
34 EXPECT_FLOAT_EQ(pt.get<
bool>(
"Trivial_passthrough"),
true);
36 EXPECT_FLOAT_EQ(pt.get<
unsigned int>(
"Trivial_number_of_multiple_outputs"), 3);
37 EXPECT_FLOAT_EQ(pt.get<
double>(
"Trivial_multiple_output0"), 1.0);
38 EXPECT_FLOAT_EQ(pt.get<
double>(
"Trivial_multiple_output1"), 2.0);
39 EXPECT_FLOAT_EQ(pt.get<
double>(
"Trivial_multiple_output2"), 3.0);
42 specific_options2.
load(pt);
44 EXPECT_EQ(specific_options2.
m_output, 0.1);
49 EXPECT_EQ(specific_options2.
getMethod(), std::string(
"Trivial"));
53 auto description = specific_options.getDescription();
54 EXPECT_EQ(description.options().size(), 3);
58 pt.put(
"Trivial_version", 100);
60 EXPECT_B2ERROR(specific_options2.
load(pt));
64 EXPECT_THROW(specific_options2.
load(pt), std::runtime_error);
69 explicit TestDataset(
const std::vector<std::vector<float>>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data),
70 m_nFeatures(data[0].size())
78 [[nodiscard]]
unsigned int getNumberOfFeatures()
const override {
return m_nFeatures;}
79 [[nodiscard]]
unsigned int getNumberOfSpectators()
const override {
return 0; }
80 [[nodiscard]]
unsigned int getNumberOfEvents()
const override {
return m_data.size(); }
81 void loadEvent(
unsigned int iEvent)
override { m_input = m_data[iEvent]; m_target = iEvent % 2; m_isSignal = m_target == 1; };
82 float getSignalFraction()
override {
return 0.1; };
83 std::vector<float> getFeature(
unsigned int iFeature)
override
85 std::vector<float> feature(m_data.size(), 0.0);
86 for (
unsigned int iEvent = 0; iEvent << m_data.size(); iEvent++) {
87 feature[iEvent] = m_data[iEvent][iFeature];
91 std::vector<std::vector<float>> m_data;
92 unsigned int m_nFeatures;
95 TEST(TrivialTest, TrivialInterface)
101 TestDataset dataset({{1.0,}, {1.0,}, {1.0,}, {1.0,}, {2.0,}, {3.0,}, {2.0,}, {3.0,}});
103 auto teacher = interface.
getTeacher(general_options, specific_options);
104 auto weightfile = teacher->train(dataset);
107 expert->load(weightfile);
108 auto probabilities = expert->apply(dataset);
109 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
110 for (
unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i)
111 EXPECT_FLOAT_EQ(probabilities[i], 0.5);
114 TEST(TrivialTest, TrivialPassThrough)
120 general_options.m_variables = {
"p",};
121 specific_options.m_passthrough =
true;
122 TestDataset dataset({{1.0,}, {1.0,}, {1.0,}, {1.0,}, {2.0,}, {3.0,}, {2.0,}, {3.0,}});
124 auto teacher = interface.
getTeacher(general_options, specific_options);
125 auto weightfile = teacher->train(dataset);
128 expert->load(weightfile);
129 auto probabilities = expert->apply(dataset);
130 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
131 for (
unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i)
132 EXPECT_FLOAT_EQ(probabilities[i], dataset.m_data[i][0]);
135 TEST(TrivialTest, TrivialPassThroughMulticlass)
141 general_options.m_variables = {
"p",};
142 specific_options.m_passthrough =
true;
143 general_options.m_nClasses = 3;
144 TestDataset dataset({{1.0,}, {1.0,}, {1.0,}, {1.0,}, {2.0,}, {3.0,}, {2.0,}, {3.0,}});
146 auto teacher = interface.
getTeacher(general_options, specific_options);
147 auto weightfile = teacher->train(dataset);
150 expert->load(weightfile);
151 auto probabilities = expert->applyMulticlass(dataset);
152 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
153 for (
unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i) {
154 for (
unsigned int j = 0; j < probabilities[i].size(); ++ j) {
155 EXPECT_FLOAT_EQ(probabilities[i][j], dataset.m_data[i][0]);
160 TEST(TrivialTest, TrivialPassThroughMulticlassMultiVariable)
166 general_options.m_variables = {
"px",
"py",
"pz"};
167 specific_options.m_passthrough =
true;
168 general_options.m_nClasses = 3;
169 TestDataset dataset({{1.0, 2.0, 3.0}, {1.0, 3.0, 4.0}, {1.0, 7.0, 13.0}});
171 auto teacher = interface.
getTeacher(general_options, specific_options);
172 auto weightfile = teacher->train(dataset);
175 expert->load(weightfile);
176 auto probabilities = expert->applyMulticlass(dataset);
177 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
178 for (
unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i) {
179 for (
unsigned int j = 0; j < probabilities[i].size(); ++ j) {
180 EXPECT_FLOAT_EQ(probabilities[i][j], dataset.m_data[i][j]);
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
General options which are shared by all MVA trainings.
Template class to easily construct a interface for an MVA library using a library-specific Options,...
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Options for the Trivial MVA method.
double m_output
Output of the trivial method.
virtual std::string getMethod() const override
Return method name.
std::vector< double > m_multiple_output
Output of the trivial method.
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
bool m_passthrough
Flag for passthrough setting.
TEST(TestgetDetectorRegion, TestgetDetectorRegion)
Test Constructors.
Abstract base class for different kinds of events.