9 #include <mva/interface/Weightfile.h>
10 #include <mva/interface/Options.h>
11 #include <framework/utilities/TestHelpers.h>
13 #include <framework/database/Configuration.h>
14 #include <framework/database/Database.h>
21 #include <gtest/gtest.h>
30 TestOptions(
const std::string& _x,
const std::string& _y) : x(_x), y(_y) { }
31 void load(
const boost::property_tree::ptree& pt)
override { y = pt.get<std::string>(x); }
32 void save(boost::property_tree::ptree& pt)
const override { pt.put(x, y); }
33 po::options_description getDescription()
override
35 po::options_description description(
"General options");
36 description.add_options()
37 (
"help",
"print this message");
45 TEST(WeightfileTest, Options)
47 TestOptions options1(
"Test1",
"a");
48 TestOptions options2(
"Test2",
"b");
53 EXPECT_EQ(weightfile.
getElement<std::string>(
"Test1"),
"a");
54 EXPECT_EQ(weightfile.
getElement<std::string>(
"Test2"),
"b");
56 TestOptions options3(
"Test2",
"c");
58 EXPECT_EQ(options3.y,
"b");
61 TEST(WeightfileTest, FeatureImportance)
63 std::map<std::string, float> importance;
64 importance[
"a"] = 1.0;
65 importance[
"b"] = 2.0;
66 importance[
"c"] = 3.0;
70 EXPECT_EQ(weightfile.
getElement<
unsigned int>(
"number_of_importance_vars"), 3);
71 EXPECT_EQ(weightfile.
getElement<std::string>(
"importance_key0"),
"a");
72 EXPECT_EQ(weightfile.
getElement<
float>(
"importance_value0"), 1.0);
76 EXPECT_EQ(importance2.size(), 3);
77 EXPECT_EQ(importance2[
"a"], 1.0);
78 EXPECT_EQ(importance2[
"b"], 2.0);
79 EXPECT_EQ(importance2[
"c"], 3.0);
83 TEST(WeightfileTest, SignalFraction)
92 TEST(WeightfileTest, Element)
97 EXPECT_EQ(weightfile.
getElement<
int>(
"Test"), 1);
101 TEST(WeightfileTest, Stream)
105 std::stringstream sstream(
"MyStream");
107 EXPECT_EQ(weightfile.
getStream(
"Test"),
"MyStream");
115 std::ofstream ofile(
"file.txt");
120 weightfile.
addFile(
"Test",
"file.txt");
122 weightfile.
getFile(
"Test",
"file2.txt");
124 std::ifstream ifile(
"file2.txt");
128 EXPECT_EQ(content,
"MyFile");
132 TEST(WeightfileTest, StaticSaveLoadDatabase)
138 conf.overrideGlobalTags();
139 conf.prependTestingPayloadLocation(
"localdb/database.txt");
146 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
151 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
155 std::filesystem::remove_all(
"testPayloads");
160 TEST(WeightfileTest, StaticSaveLoadXML)
170 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
173 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
176 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
182 TEST(WeightfileTest, StaticSaveLoadROOT)
192 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
195 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
198 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
203 std::fstream file(
"INVALID.root");
209 TEST(WeightfileTest, StaticDatabase)
214 conf.overrideGlobalTags();
215 conf.prependTestingPayloadLocation(
"localdb/database.txt");
224 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
226 std::filesystem::remove_all(
"testPayloads");
231 TEST(WeightfileTest, StaticDatabaseBadSymbols)
236 conf.overrideGlobalTags();
237 conf.prependTestingPayloadLocation(
"localdb/database.txt");
242 std::string evilIdentifier =
"==> *+:";
247 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
249 std::filesystem::remove_all(
"testPayloads");
254 TEST(WeightfileTest, StaticXMLFile)
265 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
269 TEST(WeightfileTest, StaticROOTFile)
280 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
282 TFile file(
"invalid_weightfile.root",
"RECREATE");
288 TEST(WeightfileTest, StaticStream)
295 std::ofstream ofile(
"file.txt");
299 std::ifstream ifile(
"file.txt");
301 EXPECT_EQ(loaded.getElement<std::string>(
"Test"),
"a");
305 TEST(WeightfileTest, GetFileName)
310 unsigned int length = filename.size();
311 EXPECT_TRUE(filename.substr(length - 4, length) ==
".xml");
318 std::ofstream a(filename);
320 EXPECT_TRUE(std::filesystem::exists(filename));
322 EXPECT_FALSE(std::filesystem::exists(filename));
329 std::ofstream a(filename);
331 EXPECT_TRUE(std::filesystem::exists(filename));
333 EXPECT_TRUE(std::filesystem::exists(filename));
334 std::filesystem::remove_all(std::filesystem::path(filename).parent_path());
335 EXPECT_FALSE(std::filesystem::exists(std::filesystem::path(filename).parent_path()));
337 char* directory_template = strdup((std::filesystem::temp_directory_path() /
"Basf2Sub.XXXXXX").c_str());
338 auto tempdir = std::string(mkdtemp(directory_template));
339 setenv(
"TMPDIR", tempdir.c_str(), 1);
343 EXPECT_EQ(filename.substr(0, tempdir.size()), tempdir);
345 free(directory_template);
346 std::filesystem::remove_all(tempdir);
347 EXPECT_FALSE(std::filesystem::exists(tempdir));
348 setenv(
"TMPDIR",
"/tmp", 1);
static Configuration & getInstance()
Get a reference to the instance which will be used when the Database is initialized.
Abstract base class of all Options given to the MVA interface.
The Weightfile class serializes all information about a training into an xml tree.
void addStream(const std::string &identifier, std::istream &in)
Add a stream to our weightfile.
void addElement(const std::string &identifier, const T &element)
Add an element to the xml tree.
T getElement(const std::string &identifier) const
Returns a stored element from the xml tree.
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
std::map< std::string, float > getFeatureImportance() const
Get feature importance.
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
static void save(Weightfile &weightfile, const std::string &filename, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile to a file.
void setRemoveTemporaryDirectories(bool remove_temporary_directories)
Set the deletion behaviour of the weightfile object for temporary directories For debugging it can be...
static void saveToXMLFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a XML file.
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
void addOptions(const Options &options)
Add an Option object to the xml tree.
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
void getOptions(Options &options) const
Fills an Option object from the xml tree.
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
static Weightfile loadFromDatabase(const std::string &identifier, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from the basf2 condition database.
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
void addFeatureImportance(const std::map< std::string, float > &importance)
Add variable importance.
static void saveToROOTFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a ROOT file.
float getSignalFraction() const
Loads the signal fraction frm the xml tree.
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
std::string getStream(const std::string &identifier) const
Returns the content of a stored stream as string.
static void saveToDatabase(Weightfile &weightfile, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile in the basf2 condition database.
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
changes working directory into a newly created directory, and removes it (and contents) on destructio...
TEST(TestgetDetectorRegion, TestgetDetectorRegion)
Test Constructors.
static void reset(bool keepConfig=false)
Reset the database instance.
Abstract base class for different kinds of events.