9#include <mva/interface/Weightfile.h> 
   10#include <mva/interface/Options.h> 
   11#include <framework/utilities/TestHelpers.h> 
   13#include <framework/database/Configuration.h> 
   14#include <framework/database/Database.h> 
   21#include <gtest/gtest.h> 
   30    TestOptions(
const std::string& _x, 
const std::string& _y) : x(_x), y(_y) { }
 
   31    void load(
const boost::property_tree::ptree& pt)
 override { y = pt.get<std::string>(x); }
 
   32    void save(boost::property_tree::ptree& pt)
 const override { pt.put(x, y); }
 
   33    po::options_description getDescription()
 override 
   35      po::options_description description(
"General options");
 
   36      description.add_options()
 
   37      (
"help", 
"print this message");
 
   45  TEST(WeightfileTest, Options)
 
   47    TestOptions options1(
"Test1", 
"a");
 
   48    TestOptions options2(
"Test2", 
"b");
 
   53    EXPECT_EQ(weightfile.
getElement<std::string>(
"Test1"), 
"a");
 
   54    EXPECT_EQ(weightfile.
getElement<std::string>(
"Test2"), 
"b");
 
   56    TestOptions options3(
"Test2", 
"c");
 
   58    EXPECT_EQ(options3.y, 
"b");
 
   61  TEST(WeightfileTest, FeatureImportance)
 
   63    std::map<std::string, float> importance;
 
   64    importance[
"a"] = 1.0;
 
   65    importance[
"b"] = 2.0;
 
   66    importance[
"c"] = 3.0;
 
   70    EXPECT_EQ(weightfile.
getElement<
unsigned int>(
"number_of_importance_vars"), 3);
 
   71    EXPECT_EQ(weightfile.
getElement<std::string>(
"importance_key0"), 
"a");
 
   72    EXPECT_EQ(weightfile.
getElement<
float>(
"importance_value0"), 1.0);
 
   76    EXPECT_EQ(importance2.size(), 3);
 
   77    EXPECT_EQ(importance2[
"a"], 1.0);
 
   78    EXPECT_EQ(importance2[
"b"], 2.0);
 
   79    EXPECT_EQ(importance2[
"c"], 3.0);
 
   83  TEST(WeightfileTest, SignalFraction)
 
   97    EXPECT_EQ(weightfile.
getElement<
int>(
"Test"), 1);
 
  101  TEST(WeightfileTest, 
Stream)
 
  105    std::stringstream sstream(
"MyStream");
 
  107    EXPECT_EQ(weightfile.
getStream(
"Test"), 
"MyStream");
 
  111  TEST(WeightfileTest, File)
 
  115    std::ofstream ofile(
"file.txt");
 
  120    weightfile.
addFile(
"Test", 
"file.txt");
 
  122    weightfile.
getFile(
"Test", 
"file2.txt");
 
  124    std::ifstream ifile(
"file2.txt");
 
  128    EXPECT_EQ(content, 
"MyFile");
 
  132  TEST(WeightfileTest, StaticSaveLoadDatabase)
 
  138    conf.overrideGlobalTags();
 
  139    conf.prependTestingPayloadLocation(
"localdb/database.txt");
 
  146    EXPECT_EQ(loaded.getElement<std::string>(
"Test"), 
"a");
 
  151    EXPECT_EQ(loaded.getElement<std::string>(
"Test"), 
"a");
 
  155    std::filesystem::remove_all(
"testPayloads");
 
  160  TEST(WeightfileTest, StaticSaveLoadXML)
 
  170    EXPECT_EQ(loaded.getElement<std::string>(
"Test"), 
"a");
 
  173    EXPECT_EQ(loaded.getElement<std::string>(
"Test"), 
"a");
 
  176    EXPECT_EQ(loaded.getElement<std::string>(
"Test"), 
"a");
 
  182  TEST(WeightfileTest, StaticSaveLoadROOT)
 
  192    EXPECT_EQ(loaded.getElement<std::string>(
"Test"), 
"a");
 
  195    EXPECT_EQ(loaded.getElement<std::string>(
"Test"), 
"a");
 
  198    EXPECT_EQ(loaded.getElement<std::string>(
"Test"), 
"a");
 
  203      std::fstream file(
"INVALID.root");
 
  209  TEST(WeightfileTest, StaticDatabase)
 
  214    conf.overrideGlobalTags();
 
  215    conf.prependTestingPayloadLocation(
"localdb/database.txt");
 
  224    EXPECT_EQ(loaded.getElement<std::string>(
"Test"), 
"a");
 
  226    std::filesystem::remove_all(
"testPayloads");
 
  231  TEST(WeightfileTest, StaticDatabaseBadSymbols)
 
  236    conf.overrideGlobalTags();
 
  237    conf.prependTestingPayloadLocation(
"localdb/database.txt");
 
  242    std::string evilIdentifier = 
"==> *+:";
 
  247    EXPECT_EQ(loaded.getElement<std::string>(
"Test"), 
"a");
 
  249    std::filesystem::remove_all(
"testPayloads");
 
  254  TEST(WeightfileTest, StaticXMLFile)
 
  265    EXPECT_EQ(loaded.getElement<std::string>(
"Test"), 
"a");
 
  269  TEST(WeightfileTest, StaticROOTFile)
 
  280    EXPECT_EQ(loaded.getElement<std::string>(
"Test"), 
"a");
 
  282    TFile file(
"invalid_weightfile.root", 
"RECREATE");
 
  288  TEST(WeightfileTest, StaticStream)
 
  295    std::ofstream ofile(
"file.txt");
 
  299    std::ifstream ifile(
"file.txt");
 
  301    EXPECT_EQ(loaded.getElement<std::string>(
"Test"), 
"a");
 
  305  TEST(WeightfileTest, GetFileName)
 
  310    unsigned int length = filename.size();
 
  311    EXPECT_TRUE(filename.substr(length - 4, length) == 
".xml");
 
  318        std::ofstream a(filename);
 
  320      EXPECT_TRUE(std::filesystem::exists(filename));
 
  322    EXPECT_FALSE(std::filesystem::exists(filename));
 
  329        std::ofstream a(filename);
 
  331      EXPECT_TRUE(std::filesystem::exists(filename));
 
  333    EXPECT_TRUE(std::filesystem::exists(filename));
 
  334    std::filesystem::remove_all(std::filesystem::path(filename).parent_path());
 
  335    EXPECT_FALSE(std::filesystem::exists(std::filesystem::path(filename).parent_path()));
 
  337    char* directory_template = strdup((std::filesystem::temp_directory_path() / 
"Basf2Sub.XXXXXX").c_str());
 
  338    auto tempdir = std::string(mkdtemp(directory_template));
 
  339    setenv(
"TMPDIR", tempdir.c_str(), 1);
 
  343      EXPECT_EQ(filename.substr(0, tempdir.size()), tempdir);
 
  345    free(directory_template);
 
  346    std::filesystem::remove_all(tempdir);
 
  347    EXPECT_FALSE(std::filesystem::exists(tempdir));
 
  348    setenv(
"TMPDIR", 
"/tmp", 1);
 
static Configuration & getInstance()
Get a reference to the instance which will be used when the Database is initialized.
Abstract base class of all Options given to the MVA interface.
The Weightfile class serializes all information about a training into an xml tree.
void addStream(const std::string &identifier, std::istream &in)
Add a stream to our weightfile.
void addElement(const std::string &identifier, const T &element)
Add an element to the xml tree.
T getElement(const std::string &identifier) const
Returns a stored element from the xml tree.
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
std::map< std::string, float > getFeatureImportance() const
Get feature importance.
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
static void save(Weightfile &weightfile, const std::string &filename, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile to a file.
void setRemoveTemporaryDirectories(bool remove_temporary_directories)
Set the deletion behaviour of the weightfile object for temporary directories For debugging it can be...
static void saveToXMLFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a XML file.
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
void addOptions(const Options &options)
Add an Option object to the xml tree.
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
void getOptions(Options &options) const
Fills an Option object from the xml tree.
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
static Weightfile loadFromDatabase(const std::string &identifier, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from the basf2 condition database.
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
void addFeatureImportance(const std::map< std::string, float > &importance)
Add variable importance.
static void saveToROOTFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a ROOT file.
float getSignalFraction() const
Loads the signal fraction frm the xml tree.
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
std::string getStream(const std::string &identifier) const
Returns the content of a stored stream as string.
static void saveToDatabase(Weightfile &weightfile, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile in the basf2 condition database.
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
changes working directory into a newly created directory, and removes it (and contents) on destructio...
static void reset(bool keepConfig=false)
Reset the database instance.
Define (de)serialization methods for TObject.
Abstract base class for different kinds of events.