Belle II Software  release-05-01-25
test_Weightfile.cc
1 /* BASF2 (Belle Analysis Framework 2) *
2  * Copyright(C) 2016 - Belle II Collaboration *
3  * *
4  * Author: The Belle II Collaboration *
5  * Contributors: Thomas Keck *
6  * *
7  * This software is provided "as is" without any warranty. *
8  **************************************************************************/
9 
10 #include <mva/interface/Weightfile.h>
11 #include <mva/interface/Options.h>
12 #include <framework/utilities/TestHelpers.h>
13 
14 #include <framework/database/Configuration.h>
15 #include <framework/database/Database.h>
16 #include <boost/filesystem/operations.hpp>
17 
18 #include <TFile.h>
19 
20 #include <fstream>
21 
22 #include <gtest/gtest.h>
23 
24 using namespace Belle2;
25 
26 namespace {
27 
28  class TestOptions : public MVA::Options {
29 
30  public:
31  TestOptions(const std::string& _x, const std::string& _y) : x(_x), y(_y) { }
32  void load(const boost::property_tree::ptree& pt) override { y = pt.get<std::string>(x); }
33  void save(boost::property_tree::ptree& pt) const override { pt.put(x, y); }
34  po::options_description getDescription() override
35  {
36  po::options_description description("General options");
37  description.add_options()
38  ("help", "print this message");
39  return description;
40  }
41 
42  std::string x;
43  std::string y;
44  };
45 
46  TEST(WeightfileTest, Options)
47  {
48  TestOptions options1("Test1", "a");
49  TestOptions options2("Test2", "b");
50  MVA::Weightfile weightfile;
51  weightfile.addOptions(options1);
52  weightfile.addOptions(options2);
53 
54  EXPECT_EQ(weightfile.getElement<std::string>("Test1"), "a");
55  EXPECT_EQ(weightfile.getElement<std::string>("Test2"), "b");
56 
57  TestOptions options3("Test2", "c");
58  weightfile.getOptions(options3);
59  EXPECT_EQ(options3.y, "b");
60  }
61 
62  TEST(WeightfileTest, FeatureImportance)
63  {
64  std::map<std::string, float> importance;
65  importance["a"] = 1.0;
66  importance["b"] = 2.0;
67  importance["c"] = 3.0;
68  MVA::Weightfile weightfile;
69  weightfile.addFeatureImportance(importance);
70 
71  EXPECT_EQ(weightfile.getElement<unsigned int>("number_of_importance_vars"), 3);
72  EXPECT_EQ(weightfile.getElement<std::string>("importance_key0"), "a");
73  EXPECT_EQ(weightfile.getElement<float>("importance_value0"), 1.0);
74 
75  auto importance2 = weightfile.getFeatureImportance();
76 
77  EXPECT_EQ(importance2.size(), 3);
78  EXPECT_EQ(importance2["a"], 1.0);
79  EXPECT_EQ(importance2["b"], 2.0);
80  EXPECT_EQ(importance2["c"], 3.0);
81 
82  }
83 
84  TEST(WeightfileTest, SignalFraction)
85  {
86 
87  MVA::Weightfile weightfile;
88  weightfile.addSignalFraction(0.7);
89  EXPECT_FLOAT_EQ(weightfile.getSignalFraction(), 0.7);
90 
91  }
92 
93  TEST(WeightfileTest, Element)
94  {
95 
96  MVA::Weightfile weightfile;
97  weightfile.addElement("Test", 1);
98  EXPECT_EQ(weightfile.getElement<int>("Test"), 1);
99 
100  }
101 
102  TEST(WeightfileTest, Stream)
103  {
104 
105  MVA::Weightfile weightfile;
106  std::stringstream sstream("MyStream");
107  weightfile.addStream("Test", sstream);
108  EXPECT_EQ(weightfile.getStream("Test"), "MyStream");
109 
110  }
111 
112  TEST(WeightfileTest, File)
113  {
114 
116  std::ofstream ofile("file.txt");
117  ofile << "MyFile";
118  ofile.close();
119 
120  MVA::Weightfile weightfile;
121  weightfile.addFile("Test", "file.txt");
122 
123  weightfile.getFile("Test", "file2.txt");
124  std::string content;
125  std::ifstream ifile("file2.txt");
126  ifile >> content;
127  ifile.close();
128 
129  EXPECT_EQ(content, "MyFile");
130 
131  }
132 
133  TEST(WeightfileTest, StaticSaveLoadDatabase)
134  {
135 
137 
139  conf.overrideGlobalTags();
140  conf.prependTestingPayloadLocation("localdb/database.txt");
141 
142  MVA::Weightfile weightfile;
143  weightfile.addElement("Test", "a");
144 
145  MVA::Weightfile::save(weightfile, "MVAInterfaceTest");
146  auto loaded = MVA::Weightfile::loadFromDatabase("MVAInterfaceTest");
147  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
148 
149  EXPECT_THROW(MVA::Weightfile::loadFromFile("MVAInterfaceTest"), std::runtime_error);
150 
151  loaded = MVA::Weightfile::load("MVAInterfaceTest");
152  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
153 
154  EXPECT_THROW(MVA::Weightfile::loadFromDatabase("DOES_NOT_EXIST"), std::runtime_error);
155 
156  boost::filesystem::remove_all("testPayloads");
157  Database::reset();
158 
159  }
160 
161  TEST(WeightfileTest, StaticSaveLoadXML)
162  {
163 
165 
166  MVA::Weightfile weightfile;
167  weightfile.addElement("Test", "a");
168 
169  MVA::Weightfile::save(weightfile, "MVAInterfaceTest.xml");
170  auto loaded = MVA::Weightfile::loadFromXMLFile("MVAInterfaceTest.xml");
171  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
172 
173  loaded = MVA::Weightfile::loadFromFile("MVAInterfaceTest.xml");
174  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
175 
176  loaded = MVA::Weightfile::load("MVAInterfaceTest.xml");
177  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
178 
179  EXPECT_THROW(MVA::Weightfile::loadFromXMLFile("DOES_NOT_EXIST.xml"), std::runtime_error);
180 
181  }
182 
183  TEST(WeightfileTest, StaticSaveLoadROOT)
184  {
185 
187 
188  MVA::Weightfile weightfile;
189  weightfile.addElement("Test", "a");
190 
191  MVA::Weightfile::save(weightfile, "MVAInterfaceTest.root");
192  auto loaded = MVA::Weightfile::loadFromROOTFile("MVAInterfaceTest.root");
193  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
194 
195  loaded = MVA::Weightfile::loadFromFile("MVAInterfaceTest.root");
196  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
197 
198  loaded = MVA::Weightfile::load("MVAInterfaceTest.root");
199  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
200 
201  EXPECT_THROW(MVA::Weightfile::loadFromROOTFile("DOES_NOT_EXIST.root"), std::runtime_error);
202 
203  {
204  // cppcheck-suppress unreadVariable
205  std::fstream file("INVALID.root");
206  }
207  EXPECT_THROW(MVA::Weightfile::loadFromROOTFile("INVALID.root"), std::runtime_error);
208 
209  }
210 
211  TEST(WeightfileTest, StaticDatabase)
212  {
213 
216  conf.overrideGlobalTags();
217  conf.prependTestingPayloadLocation("localdb/database.txt");
218 
219  MVA::Weightfile weightfile;
220  weightfile.addElement("Test", "a");
221 
222  MVA::Weightfile::saveToDatabase(weightfile, "MVAInterfaceTest");
223 
224  auto loaded = MVA::Weightfile::loadFromDatabase("MVAInterfaceTest");
225 
226  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
227 
228  boost::filesystem::remove_all("testPayloads");
229  Database::reset();
230 
231  }
232 
233  TEST(WeightfileTest, StaticDatabaseBadSymbols)
234  {
235 
238  conf.overrideGlobalTags();
239  conf.prependTestingPayloadLocation("localdb/database.txt");
240 
241  MVA::Weightfile weightfile;
242  weightfile.addElement("Test", "a");
243 
244  std::string evilIdentifier = "==> *+:";
245  MVA::Weightfile::saveToDatabase(weightfile, evilIdentifier);
246 
247  auto loaded = MVA::Weightfile::loadFromDatabase(evilIdentifier);
248 
249  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
250 
251  boost::filesystem::remove_all("testPayloads");
252  Database::reset();
253 
254  }
255 
256  TEST(WeightfileTest, StaticXMLFile)
257  {
258 
260  MVA::Weightfile weightfile;
261  weightfile.addElement("Test", "a");
262 
263  MVA::Weightfile::saveToXMLFile(weightfile, "MVAInterfaceTest.xml");
264 
265  auto loaded = MVA::Weightfile::loadFromXMLFile("MVAInterfaceTest.xml");
266 
267  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
268 
269  }
270 
271  TEST(WeightfileTest, StaticROOTFile)
272  {
273 
275  MVA::Weightfile weightfile;
276  weightfile.addElement("Test", "a");
277 
278  MVA::Weightfile::saveToROOTFile(weightfile, "MVAInterfaceTest.root");
279 
280  auto loaded = MVA::Weightfile::loadFromROOTFile("MVAInterfaceTest.root");
281 
282  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
283 
284  TFile file("invalid_weightfile.root", "RECREATE");
285  file.Close();
286  EXPECT_THROW(MVA::Weightfile::loadFromROOTFile("invalid_weightfile.root"), std::runtime_error);
287 
288  }
289 
290  TEST(WeightfileTest, StaticStream)
291  {
292 
294  MVA::Weightfile weightfile;
295  weightfile.addElement("Test", "a");
296 
297  std::ofstream ofile("file.txt");
298  MVA::Weightfile::saveToStream(weightfile, ofile);
299  ofile.close();
300 
301  std::ifstream ifile("file.txt");
302  auto loaded = MVA::Weightfile::loadFromStream(ifile);
303  EXPECT_EQ(loaded.getElement<std::string>("Test"), "a");
304 
305  }
306 
307  TEST(WeightfileTest, GetFileName)
308  {
309 
310  MVA::Weightfile weightfile;
311  std::string filename = weightfile.generateFileName(".xml");
312  unsigned int length = filename.size();
313  EXPECT_TRUE(filename.substr(length - 4, length) == ".xml");
314 
315  {
316  MVA::Weightfile weightfile2;
317  weightfile2.setRemoveTemporaryDirectories(true);
318  filename = weightfile2.generateFileName(".xml");
319  {
320  // cppcheck-suppress unreadVariable
321  std::ofstream a(filename);
322  }
323  EXPECT_TRUE(boost::filesystem::exists(filename));
324  }
325  EXPECT_FALSE(boost::filesystem::exists(filename));
326 
327  {
328  MVA::Weightfile weightfile2;
329  weightfile2.setRemoveTemporaryDirectories(false);
330  filename = weightfile2.generateFileName(".xml");
331  {
332  // cppcheck-suppress unreadVariable
333  std::ofstream a(filename);
334  }
335  EXPECT_TRUE(boost::filesystem::exists(filename));
336  }
337  EXPECT_TRUE(boost::filesystem::exists(filename));
338  boost::filesystem::remove_all(boost::filesystem::path(filename).parent_path());
339  EXPECT_FALSE(boost::filesystem::exists(boost::filesystem::path(filename).parent_path()));
340 
341  char* directory_template = strdup("/tmp/Basf2Sub.XXXXXX");
342  auto tempdir = std::string(mkdtemp(directory_template));
343  {
344  MVA::Weightfile weightfile2;
345  weightfile2.setTemporaryDirectory(tempdir);
346  filename = weightfile2.generateFileName(".xml");
347  EXPECT_EQ(filename.substr(0, tempdir.size()), tempdir);
348  }
349  free(directory_template);
350  boost::filesystem::remove_all(tempdir);
351  EXPECT_FALSE(boost::filesystem::exists(tempdir));
352 
353  }
354 
355 }
Belle2::MVA::Weightfile::addSignalFraction
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:104
Belle2::MVA::Weightfile::getOptions
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:76
Belle2::MVA::Weightfile::saveToDatabase
static void saveToDatabase(Weightfile &weightfile, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile in the basf2 condition database.
Definition: Weightfile.cc:267
Belle2::MVA::Weightfile::setTemporaryDirectory
void setTemporaryDirectory(const std::string &temporary_directory)
set temporary directory which is used to store temporary directories
Definition: Weightfile.h:286
Belle2::MVA::Weightfile::addFile
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
Definition: Weightfile.cc:124
Belle2::MVA::Weightfile
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:40
Belle2::TestHelpers::TempDirCreator
changes working directory into a newly created directory, and removes it (and contents) on destructio...
Definition: TestHelpers.h:57
Belle2::MVA::Weightfile::addFeatureImportance
void addFeatureImportance(const std::map< std::string, float > &importance)
Add variable importance.
Definition: Weightfile.cc:81
Belle2::Database::reset
static void reset(bool keepConfig=false)
Reset the database instance.
Definition: Database.cc:62
Belle2::MVA::Weightfile::addStream
void addStream(const std::string &identifier, std::istream &in)
Add a stream to our weightfile.
Definition: Weightfile.cc:132
Belle2::File
Definition: File.h:14
Belle2::MVA::Weightfile::getElement
T getElement(const std::string &identifier) const
Returns a stored element from the xml tree.
Definition: Weightfile.h:153
Belle2::MVA::Weightfile::getStream
std::string getStream(const std::string &identifier) const
Returns the content of a stored stream as string.
Definition: Weightfile.cc:153
Belle2::MVA::Weightfile::loadFromROOTFile
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
Definition: Weightfile.cc:226
Belle2::MVA::Options
Abstract base class of all Options given to the MVA interface.
Definition: Options.h:36
Belle2::MVA::Weightfile::getSignalFraction
float getSignalFraction() const
Loads the signal fraction frm the xml tree.
Definition: Weightfile.cc:109
Belle2::Conditions::Configuration::getInstance
static Configuration & getInstance()
Get a reference to the instance which will be used when the Database is initialized.
Belle2::MVA::Weightfile::load
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or fomr the database.
Definition: Weightfile.cc:204
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::MVA::Weightfile::addOptions
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition: Weightfile.cc:71
Belle2::MVA::Weightfile::loadFromFile
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:215
Belle2::MVA::Weightfile::loadFromXMLFile
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
Definition: Weightfile.cc:249
Belle2::MVA::Weightfile::saveToStream
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
Definition: Weightfile.cc:194
Belle2::MVA::Weightfile::getFile
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Definition: Weightfile.cc:147
Belle2::MVA::Weightfile::saveToXMLFile
static void saveToXMLFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a XML file.
Definition: Weightfile.cc:184
Belle2::TEST
TEST(TestgetDetectorRegion, TestgetDetectorRegion)
Test Constructors.
Definition: utilityFunctions.cc:18
Belle2::MVA::Weightfile::save
static void save(Weightfile &weightfile, const std::string &filename, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile to a file.
Definition: Weightfile.cc:163
Belle2::MVA::Weightfile::addElement
void addElement(const std::string &identifier, const T &element)
Add an element to the xml tree.
Definition: Weightfile.h:116
Belle2::MVA::Weightfile::loadFromDatabase
static Weightfile loadFromDatabase(const std::string &identifier, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from the basf2 condition database.
Definition: Weightfile.cc:290
Belle2::MVA::Weightfile::getFeatureImportance
std::map< std::string, float > getFeatureImportance() const
Get feature importance.
Definition: Weightfile.cc:92
alignment.constraints_generator.filename
filename
File name.
Definition: constraints_generator.py:224
Belle2::MVA::Weightfile::loadFromStream
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:260
Belle2::MVA::Weightfile::generateFileName
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
Definition: Weightfile.cc:114
Belle2::MVA::Weightfile::setRemoveTemporaryDirectories
void setRemoveTemporaryDirectories(bool remove_temporary_directories)
Set the deletion behaviour of the weightfile object for temporary directories For debugging it can be...
Definition: Weightfile.h:281
Belle2::MVA::Weightfile::saveToROOTFile
static void saveToROOTFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a ROOT file.
Definition: Weightfile.cc:174