Belle II Software  release-05-02-19
Weightfile.cc
1 /**************************************************************************
2  * BASF2 (Belle Analysis Framework 2) *
3  * Copyright(C) 2016 - Belle II Collaboration *
4  * *
5  * Author: The Belle II Collaboration *
6  * Contributors: Thomas Keck *
7  * *
8  * This software is provided "as is" without any warranty. *
9  **************************************************************************/
10 
11 #include <mva/interface/Weightfile.h>
12 
13 #include <mva/dataobjects/DatabaseRepresentationOfWeightfile.h>
14 #include <framework/database/Database.h>
15 #include <framework/database/DBImportArray.h>
16 
17 #include <boost/archive/iterators/base64_from_binary.hpp>
18 #include <boost/archive/iterators/binary_from_base64.hpp>
19 #include <boost/archive/iterators/transform_width.hpp>
20 
21 #include <boost/property_tree/xml_parser.hpp>
22 #include <boost/filesystem/operations.hpp>
23 #include <boost/algorithm/string/predicate.hpp>
24 #include <boost/algorithm/string.hpp>
25 #include <boost/algorithm/string/replace.hpp>
26 #include <boost/regex.hpp>
27 
28 #include <TFile.h>
29 
30 #include <sstream>
31 
32 
33 namespace Belle2 {
38  namespace MVA {
39 
40  std::string makeSaveForDatabase(std::string str)
41  {
42  std::map<std::string, std::string> replace {
43  {" ", "__sp"},
44  };
45  for (auto& pair : replace) {
46  boost::replace_all(str, pair.first, pair.second);
47  }
48  //const static boost::regex blackList("[^a-zA-Z0-9_]");
49  //return boost::regex_replace(str, blackList, "");
50  return str;
51  }
52 
54  {
55  for (auto& filename : m_filenames) {
56  if (boost::filesystem::exists(filename)) {
58  boost::filesystem::remove_all(filename);
59  }
60  }
61  }
62 
63  void Weightfile::addOptions(const Options& options)
64  {
65  options.save(m_pt);
66  }
67 
68  void Weightfile::getOptions(Options& options) const
69  {
70  options.load(m_pt);
71  }
72 
73  void Weightfile::addFeatureImportance(const std::map<std::string, float>& importance)
74  {
75  m_pt.put("number_of_importance_vars", importance.size());
76  unsigned int i = 0;
77  for (auto& pair : importance) {
78  m_pt.put(std::string("importance_key") + std::to_string(i), pair.first);
79  m_pt.put(std::string("importance_value") + std::to_string(i), pair.second);
80  ++i;
81  }
82  }
83 
84  std::map<std::string, float> Weightfile::getFeatureImportance() const
85  {
86  std::map<std::string, float> importance;
87  unsigned int numberOfImportanceVars = m_pt.get<unsigned int>("number_of_importance_vars", 0);
88  for (unsigned int i = 0; i < numberOfImportanceVars; ++i) {
89  auto key = m_pt.get<std::string>(std::string("importance_key") + std::to_string(i));
90  auto value = m_pt.get<float>(std::string("importance_value") + std::to_string(i));
91  importance[key] = value;
92  }
93  return importance;
94  }
95 
96  void Weightfile::addSignalFraction(float signal_fraction)
97  {
98  m_pt.put("signal_fraction", signal_fraction);
99  }
100 
101  float Weightfile::getSignalFraction() const
102  {
103  return m_pt.get<float>("signal_fraction");
104  }
105 
106  std::string Weightfile::generateFileName(const std::string& suffix)
107  {
108  char* directory_template = strdup((m_temporary_directory + "/Basf2MVA.XXXXXX").c_str());
109  auto directory = mkdtemp(directory_template);
110  std::string tmpfile = std::string(directory) + std::string("/weightfile") + suffix;
111  m_filenames.emplace_back(directory);
112  free(directory_template);
113  return tmpfile;
114  }
115 
116  void Weightfile::addFile(const std::string& identifier, const std::string& custom_weightfile)
117  {
118  // TODO Test if file is valid
119  std::ifstream in(custom_weightfile, std::ios::in | std::ios::binary);
120  addStream(identifier, in);
121  in.close();
122  }
123 
124  void Weightfile::addStream(const std::string& identifier, std::istream& in)
125  {
126  using base64_t = boost::archive::iterators::base64_from_binary <
127  boost::archive::iterators::transform_width<std::string::const_iterator, 6, 8 >>;
128 
129  std::string contents;
130  in.seekg(0, std::ios::end);
131  contents.resize(in.tellg());
132  in.seekg(0, std::ios::beg);
133  in.read(&contents[0], contents.size());
134  std::string enc(base64_t(contents.begin()), base64_t(contents.end()));
135 
136  m_pt.put(identifier, enc);
137  }
138 
139  void Weightfile::getFile(const std::string& identifier, const std::string& custom_weightfile)
140  {
141  std::ofstream out(custom_weightfile, std::ios::out | std::ios::binary);
142  out << getStream(identifier);
143  }
144 
145  std::string Weightfile::getStream(const std::string& identifier) const
146  {
147  using binary_t = boost::archive::iterators::transform_width <
148  boost::archive::iterators::binary_from_base64<std::string::const_iterator>, 8, 6 >;
149 
150  auto contents = m_pt.get<std::string>(identifier);
151  std::string dec(binary_t(contents.begin()), binary_t(contents.end()));
152  return dec;
153  }
154 
155  void Weightfile::save(Weightfile& weightfile, const std::string& filename, const Belle2::IntervalOfValidity& iov)
156  {
157  if (boost::ends_with(filename, ".root")) {
158  return saveToROOTFile(weightfile, filename);
159  } else if (boost::ends_with(filename, ".xml")) {
160  return saveToXMLFile(weightfile, filename);
161  } else {
162  return saveToDatabase(weightfile, filename, iov);
163  }
164  }
165 
166  void Weightfile::saveToROOTFile(Weightfile& weightfile, const std::string& filename)
167  {
168  std::stringstream ss;
169  Weightfile::saveToStream(weightfile, ss);
170  DatabaseRepresentationOfWeightfile database_representation_of_weightfile;
171  database_representation_of_weightfile.m_data = ss.str();
172  TFile file(filename.c_str(), "RECREATE");
173  file.WriteObject(&database_representation_of_weightfile, "Weightfile");
174  }
175 
176  void Weightfile::saveToXMLFile(Weightfile& weightfile, const std::string& filename)
177  {
178 #if BOOST_VERSION < 105600
179  boost::property_tree::xml_writer_settings<char> settings('\t', 1);
180 #else
181  boost::property_tree::xml_writer_settings<std::string> settings('\t', 1);
182 #endif
183  boost::property_tree::xml_parser::write_xml(filename, weightfile.m_pt, std::locale(), settings);
184  }
185 
186  void Weightfile::saveToStream(Weightfile& weightfile, std::ostream& stream)
187  {
188 #if BOOST_VERSION < 105600
189  boost::property_tree::xml_writer_settings<char> settings('\t', 1);
190 #else
191  boost::property_tree::xml_writer_settings<std::string> settings('\t', 1);
192 #endif
193  boost::property_tree::xml_parser::write_xml(stream, weightfile.m_pt, settings);
194  }
195 
196  Weightfile Weightfile::load(const std::string& filename, const Belle2::EventMetaData& emd)
197  {
198  if (boost::ends_with(filename, ".root")) {
199  return loadFromROOTFile(filename);
200  } else if (boost::ends_with(filename, ".xml")) {
201  return loadFromXMLFile(filename);
202  } else {
203  return loadFromDatabase(filename, emd);
204  }
205  }
206 
207  Weightfile Weightfile::loadFromFile(const std::string& filename)
208  {
209  if (boost::ends_with(filename, ".root")) {
210  return loadFromROOTFile(filename);
211  } else if (boost::ends_with(filename, ".xml")) {
212  return loadFromXMLFile(filename);
213  } else {
214  throw std::runtime_error("Cannot load file " + filename + " because file extension is not supported");
215  }
216  }
217 
218  Weightfile Weightfile::loadFromROOTFile(const std::string& filename)
219  {
220 
221  if (not boost::filesystem::exists(filename)) {
222  throw std::runtime_error("Given filename does not exist: " + filename);
223  }
224 
225  TFile file(filename.c_str(), "READ");
226  if (file.IsZombie() or not file.IsOpen()) {
227  throw std::runtime_error("Error during open of ROOT file named " + filename);
228  }
229 
230  DatabaseRepresentationOfWeightfile* database_representation_of_weightfile = nullptr;
231  file.GetObject("Weightfile", database_representation_of_weightfile);
232 
233  if (database_representation_of_weightfile == nullptr) {
234  throw std::runtime_error("The provided ROOT file " + filename + " does not contain a valid MVA weightfile.");
235  }
236  std::stringstream ss(database_representation_of_weightfile->m_data);
237  delete database_representation_of_weightfile;
238  return Weightfile::loadFromStream(ss);
239  }
240 
241  Weightfile Weightfile::loadFromXMLFile(const std::string& filename)
242  {
243  if (not boost::filesystem::exists(filename)) {
244  throw std::runtime_error("Given filename does not exist: " + filename);
245  }
246 
247  Weightfile weightfile;
248  boost::property_tree::xml_parser::read_xml(filename, weightfile.m_pt);
249  return weightfile;
250  }
251 
252  Weightfile Weightfile::loadFromStream(std::istream& stream)
253  {
254  Weightfile weightfile;
255  boost::property_tree::xml_parser::read_xml(stream, weightfile.m_pt);
256  return weightfile;
257  }
258 
259  void Weightfile::saveToDatabase(Weightfile& weightfile, const std::string& identifier, const Belle2::IntervalOfValidity& iov)
260  {
261  std::stringstream ss;
262  Weightfile::saveToStream(weightfile, ss);
263  DatabaseRepresentationOfWeightfile database_representation_of_weightfile;
264  database_representation_of_weightfile.m_data = ss.str();
265  Belle2::Database::Instance().storeData(makeSaveForDatabase(identifier), &database_representation_of_weightfile, iov);
266  }
267 
268  void Weightfile::saveArrayToDatabase(std::vector<Weightfile>& weightfiles, const std::string& identifier,
269  const Belle2::IntervalOfValidity& iov)
270  {
271  DBImportArray<DatabaseRepresentationOfWeightfile> dbArray(makeSaveForDatabase(identifier));
272 
273  for (auto weightfile : weightfiles) {
274  std::stringstream ss;
275  Weightfile::saveToStream(weightfile, ss);
276  dbArray.appendNew(ss.str());
277  }
278 
279  dbArray.import(iov);
280  }
281 
282  Weightfile Weightfile::loadFromDatabase(const std::string& identifier, const Belle2::EventMetaData& emd)
283  {
284  auto pair = Belle2::Database::Instance().getData(emd, makeSaveForDatabase(identifier));
285 
286  if (pair.first == 0) {
287  throw std::runtime_error("Given identifier cannot be loaded from the database: " + identifier);
288  }
289 
290  DatabaseRepresentationOfWeightfile database_representation_of_weightfile = *static_cast<DatabaseRepresentationOfWeightfile*>
291  (pair.first);
292  std::stringstream ss(database_representation_of_weightfile.m_data);
293  return Weightfile::loadFromStream(ss);
294  }
295 
296  }
298 }
Belle2::IntervalOfValidity
A class that describes the interval of experiments/runs for which an object in the database is valid.
Definition: IntervalOfValidity.h:35
Belle2::Database::getData
std::pair< TObject *, IntervalOfValidity > getData(const EventMetaData &event, const std::string &name)
Request an object from the database.
Definition: Database.cc:84
Belle2::DBImportArray::appendNew
T * appendNew()
Construct a new T object at the end of the array.
Definition: DBImportArray.h:77
Belle2::MVA::Weightfile::addSignalFraction
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:104
Belle2::MVA::Weightfile::getOptions
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:76
Belle2::MVA::Weightfile::saveToDatabase
static void saveToDatabase(Weightfile &weightfile, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile in the basf2 condition database.
Definition: Weightfile.cc:267
Belle2::MVA::Weightfile::m_pt
boost::property_tree::ptree m_pt
xml tree containing all the saved information of this weightfile
Definition: Weightfile.h:291
Belle2::MVA::Weightfile::addFile
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
Definition: Weightfile.cc:124
Belle2::MVA::Weightfile
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:40
Belle2::Database::storeData
bool storeData(const std::string &name, TObject *object, const IntervalOfValidity &iov)
Store an object in the database.
Definition: Database.cc:152
Belle2::DBImportBase::import
bool import(const IntervalOfValidity &iov)
Import the object to database.
Definition: DBImportBase.cc:38
Belle2::MVA::Weightfile::m_temporary_directory
std::string m_temporary_directory
temporary directory which is used to store temporary directories
Definition: Weightfile.h:297
Belle2::MVA::Weightfile::addFeatureImportance
void addFeatureImportance(const std::map< std::string, float > &importance)
Add variable importance.
Definition: Weightfile.cc:81
Belle2::MVA::Weightfile::m_remove_temporary_directories
bool m_remove_temporary_directories
remove all temporary directories in the destructor of this class
Definition: Weightfile.h:296
Belle2::MVA::Weightfile::addStream
void addStream(const std::string &identifier, std::istream &in)
Add a stream to our weightfile.
Definition: Weightfile.cc:132
Belle2::DatabaseRepresentationOfWeightfile::m_data
std::string m_data
Serialized weightfile.
Definition: DatabaseRepresentationOfWeightfile.h:42
Belle2::MVA::Weightfile::getStream
std::string getStream(const std::string &identifier) const
Returns the content of a stored stream as string.
Definition: Weightfile.cc:153
Belle2::MVA::Weightfile::loadFromROOTFile
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
Definition: Weightfile.cc:226
Belle2::MVA::Options
Abstract base class of all Options given to the MVA interface.
Definition: Options.h:36
Belle2::DatabaseRepresentationOfWeightfile
Database representation of a Weightfile object.
Definition: DatabaseRepresentationOfWeightfile.h:30
Belle2::MVA::Weightfile::getSignalFraction
float getSignalFraction() const
Loads the signal fraction frm the xml tree.
Definition: Weightfile.cc:109
Belle2::MVA::Weightfile::load
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or fomr the database.
Definition: Weightfile.cc:204
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::MVA::Weightfile::addOptions
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition: Weightfile.cc:71
Belle2::MVA::Weightfile::loadFromFile
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:215
Belle2::MVA::Weightfile::loadFromXMLFile
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
Definition: Weightfile.cc:249
Belle2::MVA::Weightfile::m_filenames
std::vector< std::string > m_filenames
generated temporary filenames, which will be removed in the destructor of this class
Definition: Weightfile.h:295
Belle2::MVA::Weightfile::saveToStream
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
Definition: Weightfile.cc:194
Belle2::MVA::Weightfile::Weightfile
Weightfile()
Construct an empty weightfile.
Definition: Weightfile.h:46
Belle2::MVA::Weightfile::getFile
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Definition: Weightfile.cc:147
Belle2::MVA::Weightfile::saveToXMLFile
static void saveToXMLFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a XML file.
Definition: Weightfile.cc:184
Belle2::MVA::Weightfile::save
static void save(Weightfile &weightfile, const std::string &filename, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile to a file.
Definition: Weightfile.cc:163
Belle2::MVA::Weightfile::loadFromDatabase
static Weightfile loadFromDatabase(const std::string &identifier, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from the basf2 condition database.
Definition: Weightfile.cc:290
Belle2::MVA::Weightfile::getFeatureImportance
std::map< std::string, float > getFeatureImportance() const
Get feature importance.
Definition: Weightfile.cc:92
Belle2::Database::Instance
static Database & Instance()
Instance of a singleton Database.
Definition: Database.cc:54
Belle2::MVA::Weightfile::loadFromStream
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:260
Belle2::EventMetaData
Store event, run, and experiment numbers.
Definition: EventMetaData.h:43
Belle2::MVA::Weightfile::generateFileName
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
Definition: Weightfile.cc:114
Belle2::MVA::Weightfile::saveToROOTFile
static void saveToROOTFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a ROOT file.
Definition: Weightfile.cc:174
Belle2::DBImportArray
Class for importing array of objects to the database.
Definition: DBImportArray.h:35
Belle2::MVA::Weightfile::~Weightfile
~Weightfile()
Destructor (removes temporary files associated with this weightfiles)
Definition: Weightfile.cc:61
Belle2::MVA::Weightfile::saveArrayToDatabase
static void saveArrayToDatabase(std::vector< Weightfile > &weightfiles, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves an array of Weightfile objects in the basf2 condition database.
Definition: Weightfile.cc:276