Belle II Software light-2406-ragdoll
test_Combination.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/methods/Combination.h>
10#include <mva/methods/Trivial.h>
11#include <mva/interface/Interface.h>
12#include <framework/utilities/TestHelpers.h>
13
14#include <gtest/gtest.h>
15
16using namespace Belle2;
17
18namespace {
19
20 TEST(CombinationTest, CombinationOptions)
21 {
22
23 MVA::CombinationOptions specific_options;
24
25 EXPECT_EQ(specific_options.m_weightfiles.size(), 0);
26
27 specific_options.m_weightfiles = {"A", "B"};
28
29 boost::property_tree::ptree pt;
30 specific_options.save(pt);
31 EXPECT_FLOAT_EQ(pt.get<unsigned int>("Combination_number_of_weightfiles"), 2);
32 EXPECT_EQ(pt.get<std::string>("Combination_weightfile0"), "A");
33 EXPECT_EQ(pt.get<std::string>("Combination_weightfile1"), "B");
34
35 MVA::CombinationOptions specific_options2;
36 specific_options2.load(pt);
37
38 EXPECT_EQ(specific_options2.m_weightfiles.size(), 2);
39 EXPECT_EQ(specific_options2.m_weightfiles[0], "A");
40 EXPECT_EQ(specific_options2.m_weightfiles[1], "B");
41
42 EXPECT_EQ(specific_options.getMethod(), std::string("Combination"));
43
44 // Test if po::options_description is created without crashing
45 auto description = specific_options.getDescription();
46 EXPECT_EQ(description.options().size(), 1);
47
48 // Check for B2ERROR and throw if version is wrong
49 // we try with version 100, surely we will never reach this!
50 pt.put("Combination_version", 100);
51 try {
52 EXPECT_B2ERROR(specific_options2.load(pt));
53 } catch (...) {
54
55 }
56 EXPECT_THROW(specific_options2.load(pt), std::runtime_error);
57 }
58
59 class TestDataset : public MVA::Dataset {
60 public:
61 explicit TestDataset(const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
62 {
63 m_input = {0.0};
64 m_target = 0.0;
65 m_isSignal = false;
66 m_weight = 1.0;
67 }
68
69 [[nodiscard]] unsigned int getNumberOfFeatures() const override { return 1; }
70 [[nodiscard]] unsigned int getNumberOfSpectators() const override { return 0; }
71 [[nodiscard]] unsigned int getNumberOfEvents() const override { return m_data.size(); }
72 void loadEvent(unsigned int iEvent) override { m_input[0] = m_data[iEvent]; m_target = iEvent % 2; m_isSignal = m_target == 1; };
73 float getSignalFraction() override { return 0.1; };
74 std::vector<float> getFeature(unsigned int) override { return m_data; }
75
76 std::vector<float> m_data;
77
78 };
79
80
81 TEST(CombinationTest, CombinationInterface)
82 {
84
85 MVA::GeneralOptions general_options;
86 general_options.m_method = "Trivial";
87 TestDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 3.0});
88
90
91 MVA::TrivialOptions trivial_options;
92 trivial_options.m_output = 0.1;
93 auto trivial_teacher1 = trivial.getTeacher(general_options, trivial_options);
94 auto trivial_weightfile1 = trivial_teacher1->train(dataset);
95 MVA::Weightfile::saveToXMLFile(trivial_weightfile1, "weightfile1.xml");
96
97 trivial_options.m_output = 0.6;
98 auto trivial_teacher2 = trivial.getTeacher(general_options, trivial_options);
99 auto trivial_weightfile2 = trivial_teacher2->train(dataset);
100 MVA::Weightfile::saveToXMLFile(trivial_weightfile2, "weightfile2.xml");
101
103 general_options.m_method = "Combination";
104 MVA::CombinationOptions specific_options;
105 specific_options.m_weightfiles = {"weightfile1.xml", "weightfile2.xml"};
106 auto teacher = combination.getTeacher(general_options, specific_options);
107 auto weightfile = teacher->train(dataset);
108
109 auto expert = combination.getExpert();
110 expert->load(weightfile);
111 auto probabilities = expert->apply(dataset);
112 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
113 for (unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i)
114 EXPECT_FLOAT_EQ(probabilities[i], (0.1 * 0.6) / (0.1 * 0.6 + (1 - 0.1) * (1 - 0.6)));
115
116 // Error and throw runtime error if method does not exist
117 trivial_weightfile2.addElement("method", "DOESNOTEXIST");
118 MVA::Weightfile::saveToXMLFile(trivial_weightfile2, "weightfile2.xml");
119
120 auto weightfile2 = teacher->train(dataset);
121 try {
122 EXPECT_B2ERROR(expert->load(weightfile2));
123 } catch (...) {
124
125 }
126 EXPECT_THROW(expert->load(weightfile2), std::runtime_error);
127
128
129 }
130
131}
Options for the Combination MVA method.
Definition: Combination.h:28
std::vector< std::string > m_weightfiles
Weightfiles of all methods we want to combine.
Definition: Combination.h:53
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: Combination.cc:21
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual unsigned int getNumberOfSpectators() const =0
Returns the number of spectators in this dataset.
virtual unsigned int getNumberOfFeatures() const =0
Returns the number of features in this dataset.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
virtual std::vector< float > getFeature(unsigned int iFeature)
Returns all values of one feature in a std::vector<float>
Definition: Dataset.cc:74
virtual float getSignalFraction()
Returns the signal fraction of the whole sample.
Definition: Dataset.cc:35
General options which are shared by all MVA trainings.
Definition: Options.h:62
Template class to easily construct a interface for an MVA library using a library-specific Options,...
Definition: Interface.h:99
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
Definition: Interface.h:125
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
Definition: Interface.h:116
Options for the Trivial MVA method.
Definition: Trivial.h:28
static void saveToXMLFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a XML file.
Definition: Weightfile.cc:175
changes working directory into a newly created directory, and removes it (and contents) on destructio...
Definition: TestHelpers.h:66
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:24