9#include <mva/interface/Weightfile.h>
10#include <mva/interface/Options.h>
11#include <framework/utilities/TestHelpers.h>
13#include <framework/database/Configuration.h>
14#include <framework/database/Database.h>
21#include <gtest/gtest.h>
30 TestOptions(
const std::string& _x,
const std::string& _y) : x(_x), y(_y) { }
31 void load(
const boost::property_tree::ptree& pt)
override { y = pt.get<std::string>(x); }
32 void save(boost::property_tree::ptree& pt)
const override { pt.put(x, y); }
35 po::options_description description(
"General options");
36 description.add_options()
37 (
"help",
"print this message");
45 TEST(WeightfileTest, Options)
47 TestOptions options1(
"Test1",
"a");
48 TestOptions options2(
"Test2",
"b");
53 EXPECT_EQ(weightfile.
getElement<std::string>(
"Test1"),
"a");
54 EXPECT_EQ(weightfile.
getElement<std::string>(
"Test2"),
"b");
56 TestOptions options3(
"Test2",
"c");
58 EXPECT_EQ(options3.y,
"b");
61 TEST(WeightfileTest, FeatureImportance)
63 std::map<std::string, float> importance;
64 importance[
"a"] = 1.0;
65 importance[
"b"] = 2.0;
66 importance[
"c"] = 3.0;
70 EXPECT_EQ(weightfile.
getElement<
unsigned int>(
"number_of_importance_vars"), 3);
71 EXPECT_EQ(weightfile.
getElement<std::string>(
"importance_key0"),
"a");
72 EXPECT_EQ(weightfile.
getElement<
float>(
"importance_value0"), 1.0);
76 EXPECT_EQ(importance2.size(), 3);
77 EXPECT_EQ(importance2[
"a"], 1.0);
78 EXPECT_EQ(importance2[
"b"], 2.0);
79 EXPECT_EQ(importance2[
"c"], 3.0);
83 TEST(WeightfileTest, SignalFraction)
92 TEST(WeightfileTest, Element)
97 EXPECT_EQ(weightfile.
getElement<
int>(
"Test"), 1);
101 TEST(WeightfileTest, Stream)
105 std::stringstream sstream(
"MyStream");
107 EXPECT_EQ(weightfile.
getStream(
"Test"),
"MyStream");
111 TEST(WeightfileTest,
File)
115 std::ofstream ofile(
"file.txt");
120 weightfile.
addFile(
"Test",
"file.txt");
122 weightfile.
getFile(
"Test",
"file2.txt");
124 std::ifstream ifile(
"file2.txt");
128 EXPECT_EQ(content,
"MyFile");
132 TEST(WeightfileTest, StaticSaveLoadDatabase)
138 conf.overrideGlobalTags();
139 conf.prependTestingPayloadLocation(
"localdb/database.txt");
146 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
151 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
155 std::filesystem::remove_all(
"testPayloads");
160 TEST(WeightfileTest, StaticSaveLoadXML)
170 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
173 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
176 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
182 TEST(WeightfileTest, StaticSaveLoadROOT)
192 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
195 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
198 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
203 std::fstream file(
"INVALID.root");
209 TEST(WeightfileTest, StaticDatabase)
214 conf.overrideGlobalTags();
215 conf.prependTestingPayloadLocation(
"localdb/database.txt");
224 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
226 std::filesystem::remove_all(
"testPayloads");
231 TEST(WeightfileTest, StaticDatabaseBadSymbols)
236 conf.overrideGlobalTags();
237 conf.prependTestingPayloadLocation(
"localdb/database.txt");
242 std::string evilIdentifier =
"==> *+:";
247 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
249 std::filesystem::remove_all(
"testPayloads");
254 TEST(WeightfileTest, StaticXMLFile)
265 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
269 TEST(WeightfileTest, StaticROOTFile)
280 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
282 TFile file(
"invalid_weightfile.root",
"RECREATE");
288 TEST(WeightfileTest, StaticStream)
295 std::ofstream ofile(
"file.txt");
299 std::ifstream ifile(
"file.txt");
301 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
305 TEST(WeightfileTest, GetFileName)
310 unsigned int length = filename.size();
311 EXPECT_TRUE(filename.substr(length - 4, length) ==
".xml");
318 std::ofstream a(filename);
320 EXPECT_TRUE(std::filesystem::exists(filename));
322 EXPECT_FALSE(std::filesystem::exists(filename));
329 std::ofstream a(filename);
331 EXPECT_TRUE(std::filesystem::exists(filename));
333 EXPECT_TRUE(std::filesystem::exists(filename));
334 std::filesystem::remove_all(std::filesystem::path(filename).parent_path());
335 EXPECT_FALSE(std::filesystem::exists(std::filesystem::path(filename).parent_path()));
337 char* directory_template = strdup((std::filesystem::temp_directory_path() /
"Basf2Sub.XXXXXX").c_str());
338 auto tempdir = std::string(mkdtemp(directory_template));
339 setenv(
"TMPDIR", tempdir.c_str(), 1);
343 EXPECT_EQ(filename.substr(0, tempdir.size()), tempdir);
345 free(directory_template);
346 std::filesystem::remove_all(tempdir);
347 EXPECT_FALSE(std::filesystem::exists(tempdir));
348 setenv(
"TMPDIR",
"/tmp", 1);
static Configuration & getInstance()
Get a reference to the instance which will be used when the Database is initialized.
Abstract base class of all Options given to the MVA interface.
virtual po::options_description getDescription()=0
Returns a program options description for all available options.
virtual void load(const boost::property_tree::ptree &pt)=0
Load mechanism (used by Weightfile) to load Options from a xml tree.
virtual void save(boost::property_tree::ptree &pt) const =0
Save mechanism (used by Weightfile) to store Options in a xml tree.
The Weightfile class serializes all information about a training into an xml tree.
void addStream(const std::string &identifier, std::istream &in)
Add a stream to our weightfile.
void addElement(const std::string &identifier, const T &element)
Add an element to the xml tree.
T getElement(const std::string &identifier) const
Returns a stored element from the xml tree.
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
std::map< std::string, float > getFeatureImportance() const
Get feature importance.
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
static void save(Weightfile &weightfile, const std::string &filename, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile to a file.
void setRemoveTemporaryDirectories(bool remove_temporary_directories)
Set the deletion behaviour of the weightfile object for temporary directories For debugging it can be...
static void saveToXMLFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a XML file.
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
void addOptions(const Options &options)
Add an Option object to the xml tree.
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
void getOptions(Options &options) const
Fills an Option object from the xml tree.
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
static Weightfile loadFromDatabase(const std::string &identifier, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from the basf2 condition database.
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
void addFeatureImportance(const std::map< std::string, float > &importance)
Add variable importance.
static void saveToROOTFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a ROOT file.
float getSignalFraction() const
Loads the signal fraction frm the xml tree.
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
std::string getStream(const std::string &identifier) const
Returns the content of a stored stream as string.
static void saveToDatabase(Weightfile &weightfile, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile in the basf2 condition database.
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
changes working directory into a newly created directory, and removes it (and contents) on destructio...
static void reset(bool keepConfig=false)
Reset the database instance.
Abstract base class for different kinds of events.