Belle II Software  release-08-01-10
test_Weightfile.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <mva/interface/Weightfile.h>
10 #include <mva/interface/Options.h>
11 #include <framework/utilities/TestHelpers.h>
12 
13 #include <framework/database/Configuration.h>
14 #include <framework/database/Database.h>
15 
16 #include <TFile.h>
17 
18 #include <filesystem>
19 #include <fstream>
20 
21 #include <gtest/gtest.h>
22 
23 using namespace Belle2;
24 
25 namespace {
26 
27  class TestOptions : public MVA::Options {
28 
29  public:
30  TestOptions(const std::string& _x, const std::string& _y) : x(_x), y(_y) { }
31  void load(const boost::property_tree::ptree& pt) override { y = pt.get<std::string>(x); }
32  void save(boost::property_tree::ptree& pt) const override { pt.put(x, y); }
33  po::options_description getDescription() override
34  {
35  po::options_description description("General options");
36  description.add_options()
37  ("help", "print this message");
38  return description;
39  }
40 
41  std::string x;
42  std::string y;
43  };
44 
45  TEST(WeightfileTest, Options)
46  {
47  TestOptions options1("Test1", "a");
48  TestOptions options2("Test2", "b");
49  MVA::Weightfile weightfile;
50  weightfile.addOptions(options1);
51  weightfile.addOptions(options2);
52 
53  EXPECT_EQ(weightfile.getElement<std::string>("Test1"), "a");
54  EXPECT_EQ(weightfile.getElement<std::string>("Test2"), "b");
55 
56  TestOptions options3("Test2", "c");
57  weightfile.getOptions(options3);
58  EXPECT_EQ(options3.y, "b");
59  }
60 
61  TEST(WeightfileTest, FeatureImportance)
62  {
63  std::map<std::string, float> importance;
64  importance["a"] = 1.0;
65  importance["b"] = 2.0;
66  importance["c"] = 3.0;
67  MVA::Weightfile weightfile;
68  weightfile.addFeatureImportance(importance);
69 
70  EXPECT_EQ(weightfile.getElement<unsigned int>("number_of_importance_vars"), 3);
71  EXPECT_EQ(weightfile.getElement<std::string>("importance_key0"), "a");
72  EXPECT_EQ(weightfile.getElement<float>("importance_value0"), 1.0);
73 
74  auto importance2 = weightfile.getFeatureImportance();
75 
76  EXPECT_EQ(importance2.size(), 3);
77  EXPECT_EQ(importance2["a"], 1.0);
78  EXPECT_EQ(importance2["b"], 2.0);
79  EXPECT_EQ(importance2["c"], 3.0);
80 
81  }
82 
83  TEST(WeightfileTest, SignalFraction)
84  {
85 
86  MVA::Weightfile weightfile;
87  weightfile.addSignalFraction(0.7);
88  EXPECT_FLOAT_EQ(weightfile.getSignalFraction(), 0.7);
89 
90  }
91 
92  TEST(WeightfileTest, Element)
93  {
94 
95  MVA::Weightfile weightfile;
96  weightfile.addElement("Test", 1);
97  EXPECT_EQ(weightfile.getElement<int>("Test"), 1);
98 
99  }
100 
101  TEST(WeightfileTest, Stream)
102  {
103 
104  MVA::Weightfile weightfile;
105  std::stringstream sstream("MyStream");
106  weightfile.addStream("Test", sstream);
107  EXPECT_EQ(weightfile.getStream("Test"), "MyStream");
108 
109  }
110 
111  TEST(WeightfileTest, File)
112  {
113 
115  std::ofstream ofile("file.txt");
116  ofile << "MyFile";
117  ofile.close();
118 
119  MVA::Weightfile weightfile;
120  weightfile.addFile("Test", "file.txt");
121 
122  weightfile.getFile("Test", "file2.txt");
123  std::string content;
124  std::ifstream ifile("file2.txt");
125  ifile >> content;
126  ifile.close();
127 
128  EXPECT_EQ(content, "MyFile");
129 
130  }
131 
132  TEST(WeightfileTest, StaticSaveLoadDatabase)
133  {
134 
136 
138  conf.overrideGlobalTags();
139  conf.prependTestingPayloadLocation("localdb/database.txt");
140 
141  MVA::Weightfile weightfile;
142  weightfile.addElement("Test", "a");
143 
144  MVA::Weightfile::save(weightfile, "MVAInterfaceTest");
145  auto loaded = MVA::Weightfile::loadFromDatabase("MVAInterfaceTest");
146  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
147 
148  EXPECT_THROW(MVA::Weightfile::loadFromFile("MVAInterfaceTest"), std::runtime_error);
149 
150  loaded = MVA::Weightfile::load("MVAInterfaceTest");
151  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
152 
153  EXPECT_THROW(MVA::Weightfile::loadFromDatabase("DOES_NOT_EXIST"), std::runtime_error);
154 
155  std::filesystem::remove_all("testPayloads");
156  Database::reset();
157 
158  }
159 
160  TEST(WeightfileTest, StaticSaveLoadXML)
161  {
162 
164 
165  MVA::Weightfile weightfile;
166  weightfile.addElement("Test", "a");
167 
168  MVA::Weightfile::save(weightfile, "MVAInterfaceTest.xml");
169  auto loaded = MVA::Weightfile::loadFromXMLFile("MVAInterfaceTest.xml");
170  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
171 
172  loaded = MVA::Weightfile::loadFromFile("MVAInterfaceTest.xml");
173  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
174 
175  loaded = MVA::Weightfile::load("MVAInterfaceTest.xml");
176  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
177 
178  EXPECT_THROW(MVA::Weightfile::loadFromXMLFile("DOES_NOT_EXIST.xml"), std::runtime_error);
179 
180  }
181 
182  TEST(WeightfileTest, StaticSaveLoadROOT)
183  {
184 
186 
187  MVA::Weightfile weightfile;
188  weightfile.addElement("Test", "a");
189 
190  MVA::Weightfile::save(weightfile, "MVAInterfaceTest.root");
191  auto loaded = MVA::Weightfile::loadFromROOTFile("MVAInterfaceTest.root");
192  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
193 
194  loaded = MVA::Weightfile::loadFromFile("MVAInterfaceTest.root");
195  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
196 
197  loaded = MVA::Weightfile::load("MVAInterfaceTest.root");
198  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
199 
200  EXPECT_THROW(MVA::Weightfile::loadFromROOTFile("DOES_NOT_EXIST.root"), std::runtime_error);
201 
202  {
203  std::fstream file("INVALID.root");
204  }
205  EXPECT_THROW(MVA::Weightfile::loadFromROOTFile("INVALID.root"), std::runtime_error);
206 
207  }
208 
209  TEST(WeightfileTest, StaticDatabase)
210  {
211 
214  conf.overrideGlobalTags();
215  conf.prependTestingPayloadLocation("localdb/database.txt");
216 
217  MVA::Weightfile weightfile;
218  weightfile.addElement("Test", "a");
219 
220  MVA::Weightfile::saveToDatabase(weightfile, "MVAInterfaceTest");
221 
222  auto loaded = MVA::Weightfile::loadFromDatabase("MVAInterfaceTest");
223 
224  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
225 
226  std::filesystem::remove_all("testPayloads");
227  Database::reset();
228 
229  }
230 
231  TEST(WeightfileTest, StaticDatabaseBadSymbols)
232  {
233 
236  conf.overrideGlobalTags();
237  conf.prependTestingPayloadLocation("localdb/database.txt");
238 
239  MVA::Weightfile weightfile;
240  weightfile.addElement("Test", "a");
241 
242  std::string evilIdentifier = "==> *+:";
243  MVA::Weightfile::saveToDatabase(weightfile, evilIdentifier);
244 
245  auto loaded = MVA::Weightfile::loadFromDatabase(evilIdentifier);
246 
247  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
248 
249  std::filesystem::remove_all("testPayloads");
250  Database::reset();
251 
252  }
253 
254  TEST(WeightfileTest, StaticXMLFile)
255  {
256 
258  MVA::Weightfile weightfile;
259  weightfile.addElement("Test", "a");
260 
261  MVA::Weightfile::saveToXMLFile(weightfile, "MVAInterfaceTest.xml");
262 
263  auto loaded = MVA::Weightfile::loadFromXMLFile("MVAInterfaceTest.xml");
264 
265  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
266 
267  }
268 
269  TEST(WeightfileTest, StaticROOTFile)
270  {
271 
273  MVA::Weightfile weightfile;
274  weightfile.addElement("Test", "a");
275 
276  MVA::Weightfile::saveToROOTFile(weightfile, "MVAInterfaceTest.root");
277 
278  auto loaded = MVA::Weightfile::loadFromROOTFile("MVAInterfaceTest.root");
279 
280  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
281 
282  TFile file("invalid_weightfile.root", "RECREATE");
283  file.Close();
284  EXPECT_THROW(MVA::Weightfile::loadFromROOTFile("invalid_weightfile.root"), std::runtime_error);
285 
286  }
287 
288  TEST(WeightfileTest, StaticStream)
289  {
290 
292  MVA::Weightfile weightfile;
293  weightfile.addElement("Test", "a");
294 
295  std::ofstream ofile("file.txt");
296  MVA::Weightfile::saveToStream(weightfile, ofile);
297  ofile.close();
298 
299  std::ifstream ifile("file.txt");
300  auto loaded = MVA::Weightfile::loadFromStream(ifile);
301  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
302 
303  }
304 
305  TEST(WeightfileTest, GetFileName)
306  {
307 
308  MVA::Weightfile weightfile;
309  std::string filename = weightfile.generateFileName(".xml");
310  unsigned int length = filename.size();
311  EXPECT_TRUE(filename.substr(length - 4, length) == ".xml");
312 
313  {
314  MVA::Weightfile weightfile2;
315  weightfile2.setRemoveTemporaryDirectories(true);
316  filename = weightfile2.generateFileName(".xml");
317  {
318  std::ofstream a(filename);
319  }
320  EXPECT_TRUE(std::filesystem::exists(filename));
321  }
322  EXPECT_FALSE(std::filesystem::exists(filename));
323 
324  {
325  MVA::Weightfile weightfile2;
326  weightfile2.setRemoveTemporaryDirectories(false);
327  filename = weightfile2.generateFileName(".xml");
328  {
329  std::ofstream a(filename);
330  }
331  EXPECT_TRUE(std::filesystem::exists(filename));
332  }
333  EXPECT_TRUE(std::filesystem::exists(filename));
334  std::filesystem::remove_all(std::filesystem::path(filename).parent_path());
335  EXPECT_FALSE(std::filesystem::exists(std::filesystem::path(filename).parent_path()));
336 
337  char* directory_template = strdup((std::filesystem::temp_directory_path() / "Basf2Sub.XXXXXX").c_str());
338  auto tempdir = std::string(mkdtemp(directory_template));
339  setenv("TMPDIR", tempdir.c_str(), 1);
340  {
341  MVA::Weightfile weightfile2;
342  filename = weightfile2.generateFileName(".xml");
343  EXPECT_EQ(filename.substr(0, tempdir.size()), tempdir);
344  }
345  free(directory_template);
346  std::filesystem::remove_all(tempdir);
347  EXPECT_FALSE(std::filesystem::exists(tempdir));
348  setenv("TMPDIR", "/tmp", 1);
349  }
350 
351 }
static Configuration & getInstance()
Get a reference to the instance which will be used when the Database is initialized.
Abstract base class of all Options given to the MVA interface.
Definition: Options.h:34
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
void addStream(const std::string &identifier, std::istream &in)
Add a stream to our weightfile.
Definition: Weightfile.cc:123
void addElement(const std::string &identifier, const T &element)
Add an element to the xml tree.
Definition: Weightfile.h:114
T getElement(const std::string &identifier) const
Returns a stored element from the xml tree.
Definition: Weightfile.h:151
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
Definition: Weightfile.cc:115
std::map< std::string, float > getFeatureImportance() const
Get feature importance.
Definition: Weightfile.cc:83
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
Definition: Weightfile.cc:240
static void save(Weightfile &weightfile, const std::string &filename, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile to a file.
Definition: Weightfile.cc:154
void setRemoveTemporaryDirectories(bool remove_temporary_directories)
Set the deletion behaviour of the weightfile object for temporary directories For debugging it can be...
Definition: Weightfile.h:282
static void saveToXMLFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a XML file.
Definition: Weightfile.cc:175
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:251
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition: Weightfile.cc:62
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
Definition: Weightfile.cc:217
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:67
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
Definition: Weightfile.cc:195
static Weightfile loadFromDatabase(const std::string &identifier, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from the basf2 condition database.
Definition: Weightfile.cc:281
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
Definition: Weightfile.cc:185
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:206
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:95
void addFeatureImportance(const std::map< std::string, float > &importance)
Add variable importance.
Definition: Weightfile.cc:72
static void saveToROOTFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a ROOT file.
Definition: Weightfile.cc:165
float getSignalFraction() const
Loads the signal fraction frm the xml tree.
Definition: Weightfile.cc:100
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
Definition: Weightfile.cc:105
std::string getStream(const std::string &identifier) const
Returns the content of a stored stream as string.
Definition: Weightfile.cc:144
static void saveToDatabase(Weightfile &weightfile, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile in the basf2 condition database.
Definition: Weightfile.cc:258
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Definition: Weightfile.cc:138
changes working directory into a newly created directory, and removes it (and contents) on destructio...
Definition: TestHelpers.h:66
TEST(TestgetDetectorRegion, TestgetDetectorRegion)
Test Constructors.
static void reset(bool keepConfig=false)
Reset the database instance.
Definition: Database.cc:50
Abstract base class for different kinds of events.