Belle II Software light-2406-ragdoll
test_Trivial.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/methods/Trivial.h>
10#include <mva/interface/Interface.h>
11#include <framework/utilities/FileSystem.h>
12#include <framework/utilities/TestHelpers.h>
13
14#include <gtest/gtest.h>
15
16using namespace Belle2;
17
18namespace {
19
20 TEST(TrivialTest, TrivialOptions)
21 {
22
23 MVA::TrivialOptions specific_options;
24
25 EXPECT_EQ(specific_options.m_output, 0.5);
26
27 specific_options.m_output = 0.1;
28 specific_options.m_multiple_output = {1.0, 2.0, 3.0};
29 specific_options.m_passthrough = true;
30
31 boost::property_tree::ptree pt;
32 specific_options.save(pt);
33 EXPECT_FLOAT_EQ(pt.get<double>("Trivial_output"), 0.1);
34 EXPECT_FLOAT_EQ(pt.get<bool>("Trivial_passthrough"), true);
35
36 EXPECT_FLOAT_EQ(pt.get<unsigned int>("Trivial_number_of_multiple_outputs"), 3);
37 EXPECT_FLOAT_EQ(pt.get<double>("Trivial_multiple_output0"), 1.0);
38 EXPECT_FLOAT_EQ(pt.get<double>("Trivial_multiple_output1"), 2.0);
39 EXPECT_FLOAT_EQ(pt.get<double>("Trivial_multiple_output2"), 3.0);
40
41 MVA::TrivialOptions specific_options2;
42 specific_options2.load(pt);
43
44 EXPECT_EQ(specific_options2.m_output, 0.1);
45 EXPECT_EQ(specific_options2.m_multiple_output.at(0), 1.0);
46 EXPECT_EQ(specific_options2.m_multiple_output.at(1), 2.0);
47 EXPECT_EQ(specific_options2.m_multiple_output.at(2), 3.0);
48
49 EXPECT_EQ(specific_options2.getMethod(), std::string("Trivial"));
50 EXPECT_FLOAT_EQ(specific_options2.m_passthrough, true);
51
52 // Test if po::options_description is created without crashing
53 auto description = specific_options.getDescription();
54 EXPECT_EQ(description.options().size(), 3);
55
56 // Check for B2ERROR and throw if version is wrong
57 // we try with version 100, surely we will never reach this!
58 pt.put("Trivial_version", 100);
59 try {
60 EXPECT_B2ERROR(specific_options2.load(pt));
61 } catch (...) {
62
63 }
64 EXPECT_THROW(specific_options2.load(pt), std::runtime_error);
65 }
66
67 class TestDataset : public MVA::Dataset {
68 public:
69 explicit TestDataset(const std::vector<std::vector<float>>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data),
70 m_nFeatures(data[0].size())
71 {
72 m_input = {0.0};
73 m_target = 0.0;
74 m_isSignal = false;
75 m_weight = 1.0;
76 }
77
78 [[nodiscard]] unsigned int getNumberOfFeatures() const override { return m_nFeatures;}
79 [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 0; }
80 [[nodiscard]] unsigned int getNumberOfEvents() const override { return m_data.size(); }
81 void loadEvent(unsigned int iEvent) override { m_input = m_data[iEvent]; m_target = iEvent % 2; m_isSignal = m_target == 1; };
82 float getSignalFraction() override { return 0.1; };
83 std::vector<float> getFeature(unsigned int iFeature) override
84 {
85 std::vector<float> feature(m_data.size(), 0.0);
86 for (unsigned int iEvent = 0; iEvent << m_data.size(); iEvent++) {
87 feature[iEvent] = m_data[iEvent][iFeature];
88 }
89 return feature;
90 }
91 std::vector<std::vector<float>> m_data;
92 unsigned int m_nFeatures;
93 };
94
95 TEST(TrivialTest, TrivialInterface)
96 {
98
99 MVA::GeneralOptions general_options;
100 MVA::TrivialOptions specific_options;
101 TestDataset dataset({{1.0,}, {1.0,}, {1.0,}, {1.0,}, {2.0,}, {3.0,}, {2.0,}, {3.0,}});
102
103 auto teacher = interface.getTeacher(general_options, specific_options);
104 auto weightfile = teacher->train(dataset);
105
106 auto expert = interface.getExpert();
107 expert->load(weightfile);
108 auto probabilities = expert->apply(dataset);
109 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
110 for (unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i)
111 EXPECT_FLOAT_EQ(probabilities[i], 0.5);
112 }
113
114 TEST(TrivialTest, TrivialPassThrough)
115 {
117
118 MVA::GeneralOptions general_options;
119 MVA::TrivialOptions specific_options;
120 general_options.m_variables = {"p",};
121 specific_options.m_passthrough = true;
122 TestDataset dataset({{1.0,}, {1.0,}, {1.0,}, {1.0,}, {2.0,}, {3.0,}, {2.0,}, {3.0,}});
123
124 auto teacher = interface.getTeacher(general_options, specific_options);
125 auto weightfile = teacher->train(dataset);
126
127 auto expert = interface.getExpert();
128 expert->load(weightfile);
129 auto probabilities = expert->apply(dataset);
130 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
131 for (unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i)
132 EXPECT_FLOAT_EQ(probabilities[i], dataset.m_data[i][0]);
133 }
134
135 TEST(TrivialTest, TrivialPassThroughMulticlass)
136 {
138
139 MVA::GeneralOptions general_options;
140 MVA::TrivialOptions specific_options;
141 general_options.m_variables = {"p",};
142 specific_options.m_passthrough = true;
143 general_options.m_nClasses = 3;
144 TestDataset dataset({{1.0,}, {1.0,}, {1.0,}, {1.0,}, {2.0,}, {3.0,}, {2.0,}, {3.0,}});
145
146 auto teacher = interface.getTeacher(general_options, specific_options);
147 auto weightfile = teacher->train(dataset);
148
149 auto expert = interface.getExpert();
150 expert->load(weightfile);
151 auto probabilities = expert->applyMulticlass(dataset);
152 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
153 for (unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i) {
154 for (unsigned int j = 0; j < probabilities[i].size(); ++ j) {
155 EXPECT_FLOAT_EQ(probabilities[i][j], dataset.m_data[i][0]);
156 }
157 }
158 }
159
160 TEST(TrivialTest, TrivialPassThroughMulticlassMultiVariable)
161 {
163
164 MVA::GeneralOptions general_options;
165 MVA::TrivialOptions specific_options;
166 general_options.m_variables = {"px", "py", "pz"};
167 specific_options.m_passthrough = true;
168 general_options.m_nClasses = 3;
169 TestDataset dataset({{1.0, 2.0, 3.0}, {1.0, 3.0, 4.0}, {1.0, 7.0, 13.0}});
170
171 auto teacher = interface.getTeacher(general_options, specific_options);
172 auto weightfile = teacher->train(dataset);
173
174 auto expert = interface.getExpert();
175 expert->load(weightfile);
176 auto probabilities = expert->applyMulticlass(dataset);
177 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
178 for (unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i) {
179 for (unsigned int j = 0; j < probabilities[i].size(); ++ j) {
180 EXPECT_FLOAT_EQ(probabilities[i][j], dataset.m_data[i][j]);
181 }
182 }
183 }
184}
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual unsigned int getNumberOfSpectators() const =0
Returns the number of spectators in this dataset.
virtual unsigned int getNumberOfFeatures() const =0
Returns the number of features in this dataset.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
virtual std::vector< float > getFeature(unsigned int iFeature)
Returns all values of one feature in a std::vector<float>
Definition: Dataset.cc:74
virtual float getSignalFraction()
Returns the signal fraction of the whole sample.
Definition: Dataset.cc:35
General options which are shared by all MVA trainings.
Definition: Options.h:62
Template class to easily construct a interface for an MVA library using a library-specific Options,...
Definition: Interface.h:99
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Definition: Interface.h:125
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
Definition: Interface.h:116
Options for the Trivial MVA method.
Definition: Trivial.h:28
double m_output
Output of the trivial method.
Definition: Trivial.h:53
virtual std::string getMethod() const override
Return method name.
Definition: Trivial.h:51
std::vector< double > m_multiple_output
Output of the trivial method.
Definition: Trivial.h:54
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: Trivial.cc:20
bool m_passthrough
Flag for passthrough setting.
Definition: Trivial.h:55
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:24