Belle II Software  release-08-01-10
test_Trivial.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <mva/methods/Trivial.h>
10 #include <mva/interface/Interface.h>
11 #include <framework/utilities/FileSystem.h>
12 #include <framework/utilities/TestHelpers.h>
13 
14 #include <gtest/gtest.h>
15 
16 using namespace Belle2;
17 
18 namespace {
19 
20  TEST(TrivialTest, TrivialOptions)
21  {
22 
23  MVA::TrivialOptions specific_options;
24 
25  EXPECT_EQ(specific_options.m_output, 0.5);
26 
27  specific_options.m_output = 0.1;
28  specific_options.m_multiple_output = {1.0, 2.0, 3.0};
29  specific_options.m_passthrough = true;
30 
31  boost::property_tree::ptree pt;
32  specific_options.save(pt);
33  EXPECT_FLOAT_EQ(pt.get<double>("Trivial_output"), 0.1);
34  EXPECT_FLOAT_EQ(pt.get<bool>("Trivial_passthrough"), true);
35 
36  EXPECT_FLOAT_EQ(pt.get<unsigned int>("Trivial_number_of_multiple_outputs"), 3);
37  EXPECT_FLOAT_EQ(pt.get<double>("Trivial_multiple_output0"), 1.0);
38  EXPECT_FLOAT_EQ(pt.get<double>("Trivial_multiple_output1"), 2.0);
39  EXPECT_FLOAT_EQ(pt.get<double>("Trivial_multiple_output2"), 3.0);
40 
41  MVA::TrivialOptions specific_options2;
42  specific_options2.load(pt);
43 
44  EXPECT_EQ(specific_options2.m_output, 0.1);
45  EXPECT_EQ(specific_options2.m_multiple_output.at(0), 1.0);
46  EXPECT_EQ(specific_options2.m_multiple_output.at(1), 2.0);
47  EXPECT_EQ(specific_options2.m_multiple_output.at(2), 3.0);
48 
49  EXPECT_EQ(specific_options2.getMethod(), std::string("Trivial"));
50  EXPECT_FLOAT_EQ(specific_options2.m_passthrough, true);
51 
52  // Test if po::options_description is created without crashing
53  auto description = specific_options.getDescription();
54  EXPECT_EQ(description.options().size(), 3);
55 
56  // Check for B2ERROR and throw if version is wrong
57  // we try with version 100, surely we will never reach this!
58  pt.put("Trivial_version", 100);
59  try {
60  EXPECT_B2ERROR(specific_options2.load(pt));
61  } catch (...) {
62 
63  }
64  EXPECT_THROW(specific_options2.load(pt), std::runtime_error);
65  }
66 
67  class TestDataset : public MVA::Dataset {
68  public:
69  explicit TestDataset(const std::vector<std::vector<float>>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data),
70  m_nFeatures(data[0].size())
71  {
72  m_input = {0.0};
73  m_target = 0.0;
74  m_isSignal = false;
75  m_weight = 1.0;
76  }
77 
78  [[nodiscard]] unsigned int getNumberOfFeatures() const override { return m_nFeatures;}
79  [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 0; }
80  [[nodiscard]] unsigned int getNumberOfEvents() const override { return m_data.size(); }
81  void loadEvent(unsigned int iEvent) override { m_input = m_data[iEvent]; m_target = iEvent % 2; m_isSignal = m_target == 1; };
82  float getSignalFraction() override { return 0.1; };
83  std::vector<float> getFeature(unsigned int iFeature) override
84  {
85  std::vector<float> feature(m_data.size(), 0.0);
86  for (unsigned int iEvent = 0; iEvent << m_data.size(); iEvent++) {
87  feature[iEvent] = m_data[iEvent][iFeature];
88  }
89  return feature;
90  }
91  std::vector<std::vector<float>> m_data;
92  unsigned int m_nFeatures;
93  };
94 
95  TEST(TrivialTest, TrivialInterface)
96  {
98 
99  MVA::GeneralOptions general_options;
100  MVA::TrivialOptions specific_options;
101  TestDataset dataset({{1.0,}, {1.0,}, {1.0,}, {1.0,}, {2.0,}, {3.0,}, {2.0,}, {3.0,}});
102 
103  auto teacher = interface.getTeacher(general_options, specific_options);
104  auto weightfile = teacher->train(dataset);
105 
106  auto expert = interface.getExpert();
107  expert->load(weightfile);
108  auto probabilities = expert->apply(dataset);
109  EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
110  for (unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i)
111  EXPECT_FLOAT_EQ(probabilities[i], 0.5);
112  }
113 
114  TEST(TrivialTest, TrivialPassThrough)
115  {
117 
118  MVA::GeneralOptions general_options;
119  MVA::TrivialOptions specific_options;
120  general_options.m_variables = {"p",};
121  specific_options.m_passthrough = true;
122  TestDataset dataset({{1.0,}, {1.0,}, {1.0,}, {1.0,}, {2.0,}, {3.0,}, {2.0,}, {3.0,}});
123 
124  auto teacher = interface.getTeacher(general_options, specific_options);
125  auto weightfile = teacher->train(dataset);
126 
127  auto expert = interface.getExpert();
128  expert->load(weightfile);
129  auto probabilities = expert->apply(dataset);
130  EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
131  for (unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i)
132  EXPECT_FLOAT_EQ(probabilities[i], dataset.m_data[i][0]);
133  }
134 
135  TEST(TrivialTest, TrivialPassThroughMulticlass)
136  {
138 
139  MVA::GeneralOptions general_options;
140  MVA::TrivialOptions specific_options;
141  general_options.m_variables = {"p",};
142  specific_options.m_passthrough = true;
143  general_options.m_nClasses = 3;
144  TestDataset dataset({{1.0,}, {1.0,}, {1.0,}, {1.0,}, {2.0,}, {3.0,}, {2.0,}, {3.0,}});
145 
146  auto teacher = interface.getTeacher(general_options, specific_options);
147  auto weightfile = teacher->train(dataset);
148 
149  auto expert = interface.getExpert();
150  expert->load(weightfile);
151  auto probabilities = expert->applyMulticlass(dataset);
152  EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
153  for (unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i) {
154  for (unsigned int j = 0; j < probabilities[i].size(); ++ j) {
155  EXPECT_FLOAT_EQ(probabilities[i][j], dataset.m_data[i][0]);
156  }
157  }
158  }
159 
160  TEST(TrivialTest, TrivialPassThroughMulticlassMultiVariable)
161  {
163 
164  MVA::GeneralOptions general_options;
165  MVA::TrivialOptions specific_options;
166  general_options.m_variables = {"px", "py", "pz"};
167  specific_options.m_passthrough = true;
168  general_options.m_nClasses = 3;
169  TestDataset dataset({{1.0, 2.0, 3.0}, {1.0, 3.0, 4.0}, {1.0, 7.0, 13.0}});
170 
171  auto teacher = interface.getTeacher(general_options, specific_options);
172  auto weightfile = teacher->train(dataset);
173 
174  auto expert = interface.getExpert();
175  expert->load(weightfile);
176  auto probabilities = expert->applyMulticlass(dataset);
177  EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
178  for (unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i) {
179  for (unsigned int j = 0; j < probabilities[i].size(); ++ j) {
180  EXPECT_FLOAT_EQ(probabilities[i][j], dataset.m_data[i][j]);
181  }
182  }
183  }
184 }
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
General options which are shared by all MVA trainings.
Definition: Options.h:62
Template class to easily construct a interface for an MVA library using a library-specific Options,...
Definition: Interface.h:99
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
Definition: Interface.h:117
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Definition: Interface.h:126
Options for the Trivial MVA method.
Definition: Trivial.h:28
double m_output
Output of the trivial method.
Definition: Trivial.h:53
virtual std::string getMethod() const override
Return method name.
Definition: Trivial.h:51
std::vector< double > m_multiple_output
Output of the trivial method.
Definition: Trivial.h:54
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: Trivial.cc:20
bool m_passthrough
Flag for passthrough setting.
Definition: Trivial.h:55
TEST(TestgetDetectorRegion, TestgetDetectorRegion)
Test Constructors.
Abstract base class for different kinds of events.