Belle II Software  release-06-00-14
test_Trivial.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <mva/methods/Trivial.h>
10 #include <mva/interface/Interface.h>
11 #include <framework/utilities/FileSystem.h>
12 #include <framework/utilities/TestHelpers.h>
13 
14 #include <gtest/gtest.h>
15 
16 using namespace Belle2;
17 
18 namespace {
19 
20  TEST(TrivialTest, TrivialOptions)
21  {
22 
23  MVA::TrivialOptions specific_options;
24 
25  EXPECT_EQ(specific_options.m_output, 0.5);
26 
27  specific_options.m_output = 0.1;
28 
29  boost::property_tree::ptree pt;
30  specific_options.save(pt);
31  EXPECT_FLOAT_EQ(pt.get<double>("Trivial_output"), 0.1);
32 
33  MVA::TrivialOptions specific_options2;
34  specific_options2.load(pt);
35 
36  EXPECT_EQ(specific_options2.m_output, 0.1);
37 
38  EXPECT_EQ(specific_options.getMethod(), std::string("Trivial"));
39 
40  // Test if po::options_description is created without crashing
41  auto description = specific_options.getDescription();
42  EXPECT_EQ(description.options().size(), 1);
43 
44  // Check for B2ERROR and throw if version is wrong
45  // we try with version 100, surely we will never reach this!
46  pt.put("Trivial_version", 100);
47  try {
48  EXPECT_B2ERROR(specific_options2.load(pt));
49  } catch (...) {
50 
51  }
52  EXPECT_THROW(specific_options2.load(pt), std::runtime_error);
53  }
54 
55  class TestDataset : public MVA::Dataset {
56  public:
57  explicit TestDataset(const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
58  {
59  m_input = {0.0};
60  m_target = 0.0;
61  m_isSignal = false;
62  m_weight = 1.0;
63  }
64 
65  [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 1; }
66  [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 0; }
67  [[nodiscard]] unsigned int getNumberOfEvents() const override { return m_data.size(); }
68  void loadEvent(unsigned int iEvent) override { m_input[0] = m_data[iEvent]; m_target = iEvent % 2; m_isSignal = m_target == 1; };
69  float getSignalFraction() override { return 0.1; };
70  std::vector<float> getFeature(unsigned int) override { return m_data; }
71 
72  std::vector<float> m_data;
73 
74  };
75 
76 
77  TEST(TrivialTest, TrivialInterface)
78  {
80 
81  MVA::GeneralOptions general_options;
82  MVA::TrivialOptions specific_options;
83  TestDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 3.0});
84 
85  auto teacher = interface.getTeacher(general_options, specific_options);
86  auto weightfile = teacher->train(dataset);
87 
88  auto expert = interface.getExpert();
89  expert->load(weightfile);
90  auto probabilities = expert->apply(dataset);
91  EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
92  for (unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i)
93  EXPECT_FLOAT_EQ(probabilities[i], 0.5);
94 
95  }
96 
97 }
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:31
General options which are shared by all MVA trainings.
Definition: Options.h:62
Template class to easily construct a interface for an MVA library using a library-specific Options,...
Definition: Interface.h:99
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
Definition: Interface.h:117
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Definition: Interface.h:126
Options for the Trivial MVA method.
Definition: Trivial.h:28
double m_output
Output of the trivial method.
Definition: Trivial.h:53
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: Trivial.cc:20
TEST(TestgetDetectorRegion, TestgetDetectorRegion)
Test Constructors.
Abstract base class for different kinds of events.