9#include <mva/methods/Combination.h> 
   10#include <mva/methods/Trivial.h> 
   11#include <mva/interface/Interface.h> 
   12#include <framework/utilities/TestHelpers.h> 
   14#include <gtest/gtest.h> 
   20  TEST(CombinationTest, CombinationOptions)
 
   25    EXPECT_EQ(specific_options.m_weightfiles.size(), 0);
 
   27    specific_options.m_weightfiles = {
"A", 
"B"};
 
   29    boost::property_tree::ptree pt;
 
   30    specific_options.save(pt);
 
   31    EXPECT_FLOAT_EQ(pt.get<
unsigned int>(
"Combination_number_of_weightfiles"), 2);
 
   32    EXPECT_EQ(pt.get<std::string>(
"Combination_weightfile0"), 
"A");
 
   33    EXPECT_EQ(pt.get<std::string>(
"Combination_weightfile1"), 
"B");
 
   36    specific_options2.
load(pt);
 
   42    EXPECT_EQ(specific_options.getMethod(), std::string(
"Combination"));
 
   45    auto description = specific_options.getDescription();
 
   46    EXPECT_EQ(description.options().size(), 1);
 
   50    pt.put(
"Combination_version", 100);
 
   52      EXPECT_B2ERROR(specific_options2.
load(pt));
 
   56    EXPECT_THROW(specific_options2.
load(pt), std::runtime_error);
 
   61    explicit TestDataset(
const std::vector<float>& data) : MVA::Dataset(MVA::GeneralOptions()), m_data(data)
 
   69    [[nodiscard]] 
unsigned int getNumberOfFeatures()
 const override { 
return 1; }
 
   70    [[nodiscard]] 
unsigned int getNumberOfSpectators()
 const override { 
return 0; }
 
   71    [[nodiscard]] 
unsigned int getNumberOfEvents()
 const override { 
return m_data.size(); }
 
   72    void loadEvent(
unsigned int iEvent)
 override { m_input[0] = m_data[iEvent]; m_target = iEvent % 2; m_isSignal = m_target == 1; };
 
   73    float getSignalFraction()
 override { 
return 0.1; };
 
   74    std::vector<float> getFeature(
unsigned int)
 override { 
return m_data; }
 
   76    std::vector<float> m_data;
 
   81  TEST(CombinationTest, CombinationInterface)
 
   86    general_options.m_method = 
"Trivial";
 
   87    TestDataset dataset({1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 3.0});
 
   92    trivial_options.m_output = 0.1;
 
   93    auto trivial_teacher1 = trivial.getTeacher(general_options, trivial_options);
 
   94    auto trivial_weightfile1 = trivial_teacher1->train(dataset);
 
   97    trivial_options.m_output = 0.6;
 
   98    auto trivial_teacher2 = trivial.getTeacher(general_options, trivial_options);
 
   99    auto trivial_weightfile2 = trivial_teacher2->train(dataset);
 
  103    general_options.m_method = 
"Combination";
 
  105    specific_options.m_weightfiles = {
"weightfile1.xml", 
"weightfile2.xml"};
 
  106    auto teacher = combination.
getTeacher(general_options, specific_options);
 
  107    auto weightfile = teacher->train(dataset);
 
  110    expert->load(weightfile);
 
  111    auto probabilities = expert->apply(dataset);
 
  112    EXPECT_EQ(probabilities.size(), dataset.getNumberOfEvents());
 
  113    for (
unsigned int i = 0; i < dataset.getNumberOfEvents(); ++i)
 
  114      EXPECT_FLOAT_EQ(probabilities[i], (0.1 * 0.6) / (0.1 * 0.6 + (1 - 0.1) * (1 - 0.6)));
 
  117    trivial_weightfile2.addElement(
"method", 
"DOESNOTEXIST");
 
  120    auto weightfile2 = teacher->train(dataset);
 
  122      EXPECT_B2ERROR(expert->load(weightfile2));
 
  126    EXPECT_THROW(expert->load(weightfile2), std::runtime_error);
 
Options for the Combination MVA method.
std::vector< std::string > m_weightfiles
Weightfiles of all methods we want to combine.
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
General options which are shared by all MVA trainings.
Template class to easily construct a interface for an MVA library using a library-specific Options,...
virtual std::unique_ptr< MVA::Expert > getExpert() const override
Get Exoert of this MVA library.
virtual std::unique_ptr< Teacher > getTeacher(const GeneralOptions &general_options, const SpecificOptions &specific_options) const override
Get Teacher of this MVA library.
Options for the Trivial MVA method.
static void saveToXMLFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a XML file.
changes working directory into a newly created directory, and removes it (and contents) on destructio...
Abstract base class for different kinds of events.