Belle II Software development
test_Weightfile.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/interface/Weightfile.h>
10#include <mva/interface/Options.h>
11#include <framework/utilities/TestHelpers.h>
12
13#include <framework/database/Configuration.h>
14#include <framework/database/Database.h>
15
16#include <TFile.h>
17
18#include <filesystem>
19#include <fstream>
20
21#include <gtest/gtest.h>
22
23using namespace Belle2;
24
25namespace {
26
27 class TestOptions : public MVA::Options {
28
29 public:
30 TestOptions(const std::string& _x, const std::string& _y) : x(_x), y(_y) { }
31 void load(const boost::property_tree::ptree& pt) override { y = pt.get<std::string>(x); }
32 void save(boost::property_tree::ptree& pt) const override { pt.put(x, y); }
33 po::options_description getDescription() override
34 {
35 po::options_description description("General options");
36 description.add_options()
37 ("help", "print this message");
38 return description;
39 }
40
41 std::string x;
42 std::string y;
43 };
44
45 TEST(WeightfileTest, Options)
46 {
47 TestOptions options1("Test1", "a");
48 TestOptions options2("Test2", "b");
49 MVA::Weightfile weightfile;
50 weightfile.addOptions(options1);
51 weightfile.addOptions(options2);
52
53 EXPECT_EQ(weightfile.getElement<std::string>("Test1"), "a");
54 EXPECT_EQ(weightfile.getElement<std::string>("Test2"), "b");
55
56 TestOptions options3("Test2", "c");
57 weightfile.getOptions(options3);
58 EXPECT_EQ(options3.y, "b");
59 }
60
61 TEST(WeightfileTest, FeatureImportance)
62 {
63 std::map<std::string, float> importance;
64 importance["a"] = 1.0;
65 importance["b"] = 2.0;
66 importance["c"] = 3.0;
67 MVA::Weightfile weightfile;
68 weightfile.addFeatureImportance(importance);
69
70 EXPECT_EQ(weightfile.getElement<unsigned int>("number_of_importance_vars"), 3);
71 EXPECT_EQ(weightfile.getElement<std::string>("importance_key0"), "a");
72 EXPECT_EQ(weightfile.getElement<float>("importance_value0"), 1.0);
73
74 auto importance2 = weightfile.getFeatureImportance();
75
76 EXPECT_EQ(importance2.size(), 3);
77 EXPECT_EQ(importance2["a"], 1.0);
78 EXPECT_EQ(importance2["b"], 2.0);
79 EXPECT_EQ(importance2["c"], 3.0);
80
81 }
82
83 TEST(WeightfileTest, SignalFraction)
84 {
85
86 MVA::Weightfile weightfile;
87 weightfile.addSignalFraction(0.7);
88 EXPECT_FLOAT_EQ(weightfile.getSignalFraction(), 0.7);
89
90 }
91
92 TEST(WeightfileTest, Element)
93 {
94
95 MVA::Weightfile weightfile;
96 weightfile.addElement("Test", 1);
97 EXPECT_EQ(weightfile.getElement<int>("Test"), 1);
98
99 }
100
101 TEST(WeightfileTest, Stream)
102 {
103
104 MVA::Weightfile weightfile;
105 std::stringstream sstream("MyStream");
106 weightfile.addStream("Test", sstream);
107 EXPECT_EQ(weightfile.getStream("Test"), "MyStream");
108
109 }
110
111 TEST(WeightfileTest, File)
112 {
113
115 std::ofstream ofile("file.txt");
116 ofile << "MyFile";
117 ofile.close();
118
119 MVA::Weightfile weightfile;
120 weightfile.addFile("Test", "file.txt");
121
122 weightfile.getFile("Test", "file2.txt");
123 std::string content;
124 std::ifstream ifile("file2.txt");
125 ifile >> content;
126 ifile.close();
127
128 EXPECT_EQ(content, "MyFile");
129
130 }
131
132 TEST(WeightfileTest, StaticSaveLoadDatabase)
133 {
134
136
138 conf.overrideGlobalTags();
139 conf.prependTestingPayloadLocation("localdb/database.txt");
140
141 MVA::Weightfile weightfile;
142 weightfile.addElement("Test", "a");
143
144 MVA::Weightfile::save(weightfile, "MVAInterfaceTest");
145 auto loaded = MVA::Weightfile::loadFromDatabase("MVAInterfaceTest");
146 EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
147
148 EXPECT_THROW(MVA::Weightfile::loadFromFile("MVAInterfaceTest"), std::runtime_error);
149
150 loaded = MVA::Weightfile::load("MVAInterfaceTest");
151 EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
152
153 EXPECT_THROW(MVA::Weightfile::loadFromDatabase("DOES_NOT_EXIST"), std::runtime_error);
154
155 std::filesystem::remove_all("testPayloads");
157
158 }
159
160 TEST(WeightfileTest, StaticSaveLoadXML)
161 {
162
164
165 MVA::Weightfile weightfile;
166 weightfile.addElement("Test", "a");
167
168 MVA::Weightfile::save(weightfile, "MVAInterfaceTest.xml");
169 auto loaded = MVA::Weightfile::loadFromXMLFile("MVAInterfaceTest.xml");
170 EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
171
172 loaded = MVA::Weightfile::loadFromFile("MVAInterfaceTest.xml");
173 EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
174
175 loaded = MVA::Weightfile::load("MVAInterfaceTest.xml");
176 EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
177
178 EXPECT_THROW(MVA::Weightfile::loadFromXMLFile("DOES_NOT_EXIST.xml"), std::runtime_error);
179
180 }
181
182 TEST(WeightfileTest, StaticSaveLoadROOT)
183 {
184
186
187 MVA::Weightfile weightfile;
188 weightfile.addElement("Test", "a");
189
190 MVA::Weightfile::save(weightfile, "MVAInterfaceTest.root");
191 auto loaded = MVA::Weightfile::loadFromROOTFile("MVAInterfaceTest.root");
192 EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
193
194 loaded = MVA::Weightfile::loadFromFile("MVAInterfaceTest.root");
195 EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
196
197 loaded = MVA::Weightfile::load("MVAInterfaceTest.root");
198 EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
199
200 EXPECT_THROW(MVA::Weightfile::loadFromROOTFile("DOES_NOT_EXIST.root"), std::runtime_error);
201
202 {
203 std::fstream file("INVALID.root");
204 }
205 EXPECT_THROW(MVA::Weightfile::loadFromROOTFile("INVALID.root"), std::runtime_error);
206
207 }
208
209 TEST(WeightfileTest, StaticDatabase)
210 {
211
214 conf.overrideGlobalTags();
215 conf.prependTestingPayloadLocation("localdb/database.txt");
216
217 MVA::Weightfile weightfile;
218 weightfile.addElement("Test", "a");
219
220 MVA::Weightfile::saveToDatabase(weightfile, "MVAInterfaceTest");
221
222 auto loaded = MVA::Weightfile::loadFromDatabase("MVAInterfaceTest");
223
224 EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
225
226 std::filesystem::remove_all("testPayloads");
228
229 }
230
231 TEST(WeightfileTest, StaticDatabaseBadSymbols)
232 {
233
236 conf.overrideGlobalTags();
237 conf.prependTestingPayloadLocation("localdb/database.txt");
238
239 MVA::Weightfile weightfile;
240 weightfile.addElement("Test", "a");
241
242 std::string evilIdentifier = "==> *+:";
243 MVA::Weightfile::saveToDatabase(weightfile, evilIdentifier);
244
245 auto loaded = MVA::Weightfile::loadFromDatabase(evilIdentifier);
246
247 EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
248
249 std::filesystem::remove_all("testPayloads");
251
252 }
253
254 TEST(WeightfileTest, StaticXMLFile)
255 {
256
258 MVA::Weightfile weightfile;
259 weightfile.addElement("Test", "a");
260
261 MVA::Weightfile::saveToXMLFile(weightfile, "MVAInterfaceTest.xml");
262
263 auto loaded = MVA::Weightfile::loadFromXMLFile("MVAInterfaceTest.xml");
264
265 EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
266
267 }
268
269 TEST(WeightfileTest, StaticROOTFile)
270 {
271
273 MVA::Weightfile weightfile;
274 weightfile.addElement("Test", "a");
275
276 MVA::Weightfile::saveToROOTFile(weightfile, "MVAInterfaceTest.root");
277
278 auto loaded = MVA::Weightfile::loadFromROOTFile("MVAInterfaceTest.root");
279
280 EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
281
282 TFile file("invalid_weightfile.root", "RECREATE");
283 file.Close();
284 EXPECT_THROW(MVA::Weightfile::loadFromROOTFile("invalid_weightfile.root"), std::runtime_error);
285
286 }
287
288 TEST(WeightfileTest, StaticStream)
289 {
290
292 MVA::Weightfile weightfile;
293 weightfile.addElement("Test", "a");
294
295 std::ofstream ofile("file.txt");
296 MVA::Weightfile::saveToStream(weightfile, ofile);
297 ofile.close();
298
299 std::ifstream ifile("file.txt");
300 auto loaded = MVA::Weightfile::loadFromStream(ifile);
301 EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
302
303 }
304
305 TEST(WeightfileTest, GetFileName)
306 {
307
308 MVA::Weightfile weightfile;
309 std::string filename = weightfile.generateFileName(".xml");
310 unsigned int length = filename.size();
311 EXPECT_TRUE(filename.substr(length - 4, length) == ".xml");
312
313 {
314 MVA::Weightfile weightfile2;
315 weightfile2.setRemoveTemporaryDirectories(true);
316 filename = weightfile2.generateFileName(".xml");
317 {
318 std::ofstream a(filename);
319 }
320 EXPECT_TRUE(std::filesystem::exists(filename));
321 }
322 EXPECT_FALSE(std::filesystem::exists(filename));
323
324 {
325 MVA::Weightfile weightfile2;
326 weightfile2.setRemoveTemporaryDirectories(false);
327 filename = weightfile2.generateFileName(".xml");
328 {
329 std::ofstream a(filename);
330 }
331 EXPECT_TRUE(std::filesystem::exists(filename));
332 }
333 EXPECT_TRUE(std::filesystem::exists(filename));
334 std::filesystem::remove_all(std::filesystem::path(filename).parent_path());
335 EXPECT_FALSE(std::filesystem::exists(std::filesystem::path(filename).parent_path()));
336
337 char* directory_template = strdup((std::filesystem::temp_directory_path() / "Basf2Sub.XXXXXX").c_str());
338 auto tempdir = std::string(mkdtemp(directory_template));
339 setenv("TMPDIR", tempdir.c_str(), 1);
340 {
341 MVA::Weightfile weightfile2;
342 filename = weightfile2.generateFileName(".xml");
343 EXPECT_EQ(filename.substr(0, tempdir.size()), tempdir);
344 }
345 free(directory_template);
346 std::filesystem::remove_all(tempdir);
347 EXPECT_FALSE(std::filesystem::exists(tempdir));
348 setenv("TMPDIR", "/tmp", 1);
349 }
350
351}
static Configuration & getInstance()
Get a reference to the instance which will be used when the Database is initialized.
Abstract base class of all Options given to the MVA interface.
Definition: Options.h:34
virtual po::options_description getDescription()=0
Returns a program options description for all available options.
virtual void load(const boost::property_tree::ptree &pt)=0
Load mechanism (used by Weightfile) to load Options from a xml tree.
virtual void save(boost::property_tree::ptree &pt) const =0
Save mechanism (used by Weightfile) to store Options in a xml tree.
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
void addStream(const std::string &identifier, std::istream &in)
Add a stream to our weightfile.
Definition: Weightfile.cc:123
void addElement(const std::string &identifier, const T &element)
Add an element to the xml tree.
Definition: Weightfile.h:114
T getElement(const std::string &identifier) const
Returns a stored element from the xml tree.
Definition: Weightfile.h:151
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
Definition: Weightfile.cc:115
std::map< std::string, float > getFeatureImportance() const
Get feature importance.
Definition: Weightfile.cc:83
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
Definition: Weightfile.cc:240
static void save(Weightfile &weightfile, const std::string &filename, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile to a file.
Definition: Weightfile.cc:154
void setRemoveTemporaryDirectories(bool remove_temporary_directories)
Set the deletion behaviour of the weightfile object for temporary directories For debugging it can be...
Definition: Weightfile.h:282
static void saveToXMLFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a XML file.
Definition: Weightfile.cc:175
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:251
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition: Weightfile.cc:62
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
Definition: Weightfile.cc:217
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:67
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
Definition: Weightfile.cc:195
static Weightfile loadFromDatabase(const std::string &identifier, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from the basf2 condition database.
Definition: Weightfile.cc:281
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
Definition: Weightfile.cc:185
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:206
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:95
void addFeatureImportance(const std::map< std::string, float > &importance)
Add variable importance.
Definition: Weightfile.cc:72
static void saveToROOTFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a ROOT file.
Definition: Weightfile.cc:165
float getSignalFraction() const
Loads the signal fraction frm the xml tree.
Definition: Weightfile.cc:100
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
Definition: Weightfile.cc:105
std::string getStream(const std::string &identifier) const
Returns the content of a stored stream as string.
Definition: Weightfile.cc:144
static void saveToDatabase(Weightfile &weightfile, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile in the basf2 condition database.
Definition: Weightfile.cc:258
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Definition: Weightfile.cc:138
changes working directory into a newly created directory, and removes it (and contents) on destructio...
Definition: TestHelpers.h:66
static void reset(bool keepConfig=false)
Reset the database instance.
Definition: Database.cc:50
Abstract base class for different kinds of events.