Belle II Software  release-08-01-10
test_Combination.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <mva/methods/Combination.h>
10 #include <mva/methods/Trivial.h>
11 #include <mva/interface/Interface.h>
12 #include <framework/utilities/TestHelpers.h>
13 
14 #include <gtest/gtest.h>
15 
16 using namespace Belle2;
17 
18 namespace {
19 
20  TEST(CombinationTest, CombinationOptions)
21  {
22 
23  MVA::CombinationOptions specific_options;
24 
25  EXPECT_EQ(specific_options.m_weightfiles.size(), 0);
26 
27  specific_options.m_weightfiles = {"A", "B"};
28 
29  boost::property_tree::ptree pt;
30  specific_options.save(pt);
31  EXPECT_FLOAT_EQ(pt.get<unsigned int>("Combination_number_of_weightfiles"), 2);
32  EXPECT_EQ(pt.get<std::string>("Combination_weightfile0"), "A");
33  EXPECT_EQ(pt.get<std::string>("Combination_weightfile1"), "B");
34 
35  MVA::CombinationOptions specific_options2;
36  specific_options2.load(pt);
37 
38  EXPECT_EQ(specific_options2.m_weightfiles.size(), 2);
39  EXPECT_EQ(specific_options2.m_weightfiles[0], "A");
40  EXPECT_EQ(specific_options2.m_weightfiles[1], "B");
41 
42  EXPECT_EQ(specific_options.getMethod(), std::string("Combination"));
43 
44  // Test if po::options_description is created without crashing
45  auto description = specific_options.getDescription();
46  EXPECT_EQ(description.options().size(), 1);
47 
48  // Check for B2ERROR and throw if version is wrong
49  // we try with version 100, surely we will never reach this!
50  pt.put("Combination_version", 100);
51  try {
52  EXPECT_B2ERROR(specific_options2.load(pt));
53  } catch (...) {
54 
55  }
56  EXPECT_THROW(specific_options2.load(pt), std::runtime_error);
57  }
58 
59  class TestDataset : public MVA::Dataset {
60  public:
61  explicit TestDataset(const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
62  {
63  m_input = {0.0};
64  m_target = 0.0;
65  m_isSignal = false;
66  m_weight = 1.0;
67  }
68 
69  [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 1; }
70  [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 0; }
71  [[nodiscard]] unsigned int getNumberOfEvents() const override { return m_data.size(); }
72  void loadEvent(unsigned int iEvent) override { m_input[0] = m_data[iEvent]; m_target = iEvent % 2; m_isSignal = m_target == 1; };
73  float getSignalFraction() override { return 0.1; };
74  std::vector<float> getFeature(unsigned int) override { return m_data; }
75 
76  std::vector<float> m_data;
77 
78  };
79 
80 
81  TEST(CombinationTest, CombinationInterface)
82  {
84 
85  MVA::GeneralOptions general_options;
86  general_options.m_method = "Trivial";
87  TestDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 3.0});
88 
90 
91  MVA::TrivialOptions trivial_options;
92  trivial_options.m_output = 0.1;
93  auto trivial_teacher1 = trivial.getTeacher(general_options, trivial_options);
94  auto trivial_weightfile1 = trivial_teacher1->train(dataset);
95  MVA::Weightfile::saveToXMLFile(trivial_weightfile1, "weightfile1.xml");
96 
97  trivial_options.m_output = 0.6;
98  auto trivial_teacher2 = trivial.getTeacher(general_options, trivial_options);
99  auto trivial_weightfile2 = trivial_teacher2->train(dataset);
100  MVA::Weightfile::saveToXMLFile(trivial_weightfile2, "weightfile2.xml");
101 
103  general_options.m_method = "Combination";
104  MVA::CombinationOptions specific_options;
105  specific_options.m_weightfiles = {"weightfile1.xml", "weightfile2.xml"};
106  auto teacher = combination.getTeacher(general_options, specific_options);
107  auto weightfile = teacher->train(dataset);
108 
109  auto expert = combination.getExpert();
110  expert->load(weightfile);
111  auto probabilities = expert->apply(dataset);
112  EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
113  for (unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i)
114  EXPECT_FLOAT_EQ(probabilities[i], (0.1 * 0.6) / (0.1 * 0.6 + (1 - 0.1) * (1 - 0.6)));
115 
116  // Error and throw runtime error if method does not exist
117  trivial_weightfile2.addElement("method", "DOESNOTEXIST");
118  MVA::Weightfile::saveToXMLFile(trivial_weightfile2, "weightfile2.xml");
119 
120  auto weightfile2 = teacher->train(dataset);
121  try {
122  EXPECT_B2ERROR(expert->load(weightfile2));
123  } catch (...) {
124 
125  }
126  EXPECT_THROW(expert->load(weightfile2), std::runtime_error);
127 
128 
129  }
130 
131 }
Options for the Combination MVA method.
Definition: Combination.h:28
std::vector< std::string > m_weightfiles
Weightfiles of all methods we want to combine.
Definition: Combination.h:53
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: Combination.cc:21
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
General options which are shared by all MVA trainings.
Definition: Options.h:62
Template class to easily construct a interface for an MVA library using a library-specific Options,...
Definition: Interface.h:99
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
Definition: Interface.h:117
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Definition: Interface.h:126
Options for the Trivial MVA method.
Definition: Trivial.h:28
static void saveToXMLFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a XML file.
Definition: Weightfile.cc:175
changes working directory into a newly created directory, and removes it (and contents) on destructio...
Definition: TestHelpers.h:66
TEST(TestgetDetectorRegion, TestgetDetectorRegion)
Test Constructors.
Abstract base class for different kinds of events.