9#include <mva/interface/Weightfile.h>
10#include <mva/interface/Options.h>
11#include <framework/utilities/TestHelpers.h>
13#include <framework/database/Configuration.h>
14#include <framework/database/Database.h>
21#include <gtest/gtest.h>
30 TestOptions(
const std::string& _x,
const std::string& _y) : x(_x), y(_y) { }
31 void load(
const boost::property_tree::ptree& pt)
override { y = pt.get<std::string>(x); }
32 void save(boost::property_tree::ptree& pt)
const override { pt.put(x, y); }
33 po::options_description getDescription()
override
35 po::options_description description(
"General options");
36 description.add_options()
37 (
"help",
"print this message");
45 TEST(WeightfileTest, Options)
47 TestOptions options1(
"Test1",
"a");
48 TestOptions options2(
"Test2",
"b");
50 weightfile.addOptions(options1);
51 weightfile.addOptions(options2);
53 EXPECT_EQ(weightfile.getElement<std::string>(
"Test1"),
"a");
54 EXPECT_EQ(weightfile.getElement<std::string>(
"Test2"),
"b");
56 TestOptions options3(
"Test2",
"c");
57 weightfile.getOptions(options3);
58 EXPECT_EQ(options3.y,
"b");
61 TEST(WeightfileTest, FeatureImportance)
63 std::map<std::string, float> importance;
64 importance[
"a"] = 1.0;
65 importance[
"b"] = 2.0;
66 importance[
"c"] = 3.0;
68 weightfile.addFeatureImportance(importance);
70 EXPECT_EQ(weightfile.getElement<
unsigned int>(
"number_of_importance_vars"), 3);
71 EXPECT_EQ(weightfile.getElement<std::string>(
"importance_key0"),
"a");
72 EXPECT_EQ(weightfile.getElement<
float>(
"importance_value0"), 1.0);
74 auto importance2 = weightfile.getFeatureImportance();
76 EXPECT_EQ(importance2.size(), 3);
77 EXPECT_EQ(importance2[
"a"], 1.0);
78 EXPECT_EQ(importance2[
"b"], 2.0);
79 EXPECT_EQ(importance2[
"c"], 3.0);
83 TEST(WeightfileTest, SignalFraction)
87 weightfile.addSignalFraction(0.7);
88 EXPECT_FLOAT_EQ(weightfile.getSignalFraction(), 0.7);
96 weightfile.addElement(
"Test", 1);
97 EXPECT_EQ(weightfile.getElement<
int>(
"Test"), 1);
101 TEST(WeightfileTest,
Stream)
105 std::stringstream sstream(
"MyStream");
106 weightfile.addStream(
"Test", sstream);
107 EXPECT_EQ(weightfile.getStream(
"Test"),
"MyStream");
111 TEST(WeightfileTest, File)
115 std::ofstream ofile(
"file.txt");
120 weightfile.addFile(
"Test",
"file.txt");
122 weightfile.getFile(
"Test",
"file2.txt");
124 std::ifstream ifile(
"file2.txt");
128 EXPECT_EQ(content,
"MyFile");
132 TEST(WeightfileTest, StaticSaveLoadDatabase)
138 conf.overrideGlobalTags();
139 conf.prependTestingPayloadLocation(
"localdb/database.txt");
142 weightfile.addElement(
"Test",
"a");
146 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
151 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
155 std::filesystem::remove_all(
"testPayloads");
160 TEST(WeightfileTest, StaticSaveLoadXML)
166 weightfile.addElement(
"Test",
"a");
170 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
173 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
176 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
182 TEST(WeightfileTest, StaticSaveLoadROOT)
188 weightfile.addElement(
"Test",
"a");
192 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
195 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
198 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
203 std::fstream file(
"INVALID.root");
209 TEST(WeightfileTest, StaticDatabase)
214 conf.overrideGlobalTags();
215 conf.prependTestingPayloadLocation(
"localdb/database.txt");
218 weightfile.addElement(
"Test",
"a");
224 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
226 std::filesystem::remove_all(
"testPayloads");
231 TEST(WeightfileTest, StaticDatabaseBadSymbols)
236 conf.overrideGlobalTags();
237 conf.prependTestingPayloadLocation(
"localdb/database.txt");
240 weightfile.addElement(
"Test",
"a");
242 std::string evilIdentifier =
"==> *+:";
247 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
249 std::filesystem::remove_all(
"testPayloads");
254 TEST(WeightfileTest, StaticXMLFile)
259 weightfile.addElement(
"Test",
"a");
265 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
269 TEST(WeightfileTest, StaticROOTFile)
274 weightfile.addElement(
"Test",
"a");
280 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
282 TFile file(
"invalid_weightfile.root",
"RECREATE");
288 TEST(WeightfileTest, StaticStream)
293 weightfile.addElement(
"Test",
"a");
295 std::ofstream ofile(
"file.txt");
299 std::ifstream ifile(
"file.txt");
301 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
305 TEST(WeightfileTest, GetFileName)
309 std::string filename = weightfile.generateFileName(
".xml");
310 unsigned int length = filename.size();
311 EXPECT_TRUE(filename.substr(length - 4, length) ==
".xml");
318 std::ofstream a(filename);
320 EXPECT_TRUE(std::filesystem::exists(filename));
322 EXPECT_FALSE(std::filesystem::exists(filename));
329 std::ofstream a(filename);
331 EXPECT_TRUE(std::filesystem::exists(filename));
333 EXPECT_TRUE(std::filesystem::exists(filename));
334 std::filesystem::remove_all(std::filesystem::path(filename).parent_path());
335 EXPECT_FALSE(std::filesystem::exists(std::filesystem::path(filename).parent_path()));
337 char* directory_template = strdup((std::filesystem::temp_directory_path() /
"Basf2Sub.XXXXXX").c_str());
338 auto tempdir = std::string(mkdtemp(directory_template));
339 setenv(
"TMPDIR", tempdir.c_str(), 1);
343 EXPECT_EQ(filename.substr(0, tempdir.size()), tempdir);
345 free(directory_template);
346 std::filesystem::remove_all(tempdir);
347 EXPECT_FALSE(std::filesystem::exists(tempdir));
348 setenv(
"TMPDIR",
"/tmp", 1);
static Configuration & getInstance()
Get a reference to the instance which will be used when the Database is initialized.
Abstract base class of all Options given to the MVA interface.
The Weightfile class serializes all information about a training into an xml tree.
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
static void save(Weightfile &weightfile, const std::string &filename, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile to a file.
void setRemoveTemporaryDirectories(bool remove_temporary_directories)
Set the deletion behaviour of the weightfile object for temporary directories For debugging it can be...
static void saveToXMLFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a XML file.
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
static Weightfile loadFromDatabase(const std::string &identifier, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from the basf2 condition database.
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
static void saveToROOTFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a ROOT file.
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
static void saveToDatabase(Weightfile &weightfile, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile in the basf2 condition database.
changes working directory into a newly created directory, and removes it (and contents) on destructio...
static void reset(bool keepConfig=false)
Reset the database instance.
Define (de)serialization methods for TObject.
Abstract base class for different kinds of events.