9#include <mva/methods/Combination.h>
10#include <mva/methods/Trivial.h>
11#include <mva/interface/Interface.h>
12#include <framework/utilities/TestHelpers.h>
14#include <gtest/gtest.h>
20 TEST(CombinationTest, CombinationOptions)
25 EXPECT_EQ(specific_options.m_weightfiles.size(), 0);
27 specific_options.m_weightfiles = {
"A",
"B"};
29 boost::property_tree::ptree pt;
30 specific_options.save(pt);
31 EXPECT_FLOAT_EQ(pt.get<
unsigned int>(
"Combination_number_of_weightfiles"), 2);
32 EXPECT_EQ(pt.get<std::string>(
"Combination_weightfile0"),
"A");
33 EXPECT_EQ(pt.get<std::string>(
"Combination_weightfile1"),
"B");
36 specific_options2.
load(pt);
42 EXPECT_EQ(specific_options.getMethod(), std::string(
"Combination"));
45 auto description = specific_options.getDescription();
46 EXPECT_EQ(description.options().size(), 1);
50 pt.put(
"Combination_version", 100);
52 EXPECT_B2ERROR(specific_options2.
load(pt));
56 EXPECT_THROW(specific_options2.
load(pt), std::runtime_error);
61 explicit TestDataset(
const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
71 [[nodiscard]]
unsigned int getNumberOfEvents()
const override {
return m_data.size(); }
72 void loadEvent(
unsigned int iEvent)
override { m_input[0] = m_data[iEvent]; m_target = iEvent % 2; m_isSignal = m_target == 1; };
74 std::vector<float>
getFeature(
unsigned int)
override {
return m_data; }
76 std::vector<float> m_data;
81 TEST(CombinationTest, CombinationInterface)
86 general_options.m_method =
"Trivial";
87 TestDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 3.0});
92 trivial_options.m_output = 0.1;
93 auto trivial_teacher1 = trivial.getTeacher(general_options, trivial_options);
94 auto trivial_weightfile1 = trivial_teacher1->train(dataset);
97 trivial_options.m_output = 0.6;
98 auto trivial_teacher2 = trivial.getTeacher(general_options, trivial_options);
99 auto trivial_weightfile2 = trivial_teacher2->train(dataset);
103 general_options.m_method =
"Combination";
105 specific_options.m_weightfiles = {
"weightfile1.xml",
"weightfile2.xml"};
106 auto teacher = combination.
getTeacher(general_options, specific_options);
107 auto weightfile = teacher->train(dataset);
110 expert->load(weightfile);
111 auto probabilities = expert->apply(dataset);
112 EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
113 for (
unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i)
114 EXPECT_FLOAT_EQ(probabilities[i], (0.1 * 0.6) / (0.1 * 0.6 + (1 - 0.1) * (1 - 0.6)));
117 trivial_weightfile2.addElement(
"method",
"DOESNOTEXIST");
120 auto weightfile2 = teacher->train(dataset);
122 EXPECT_B2ERROR(expert->load(weightfile2));
126 EXPECT_THROW(expert->load(weightfile2), std::runtime_error);
Options for the Combination MVA method.
std::vector< std::string > m_weightfiles
Weightfiles of all methods we want to combine.
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
virtual unsigned int getNumberOfEvents() const =0
Returns the number of events in this dataset.
virtual unsigned int getNumberOfSpectators() const =0
Returns the number of spectators in this dataset.
virtual unsigned int getNumberOfFeatures() const =0
Returns the number of features in this dataset.
virtual void loadEvent(unsigned int iEvent)=0
Load the event number iEvent.
virtual std::vector< float > getFeature(unsigned int iFeature)
Returns all values of one feature in a std::vector<float>
virtual float getSignalFraction()
Returns the signal fraction of the whole sample.
General options which are shared by all MVA trainings.
Template class to easily construct a interface for an MVA library using a library-specific Options,...
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
Options for the Trivial MVA method.
static void saveToXMLFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a XML file.
changes working directory into a newly created directory, and removes it (and contents) on destructio...
Abstract base class for different kinds of events.