Belle II Software  light-2309-munchkin
Weightfile.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <mva/interface/Weightfile.h>
10 
11 #include <mva/dataobjects/DatabaseRepresentationOfWeightfile.h>
12 #include <framework/database/Database.h>
13 #include <framework/database/DBImportArray.h>
14 
15 #include <boost/archive/iterators/base64_from_binary.hpp>
16 #include <boost/archive/iterators/binary_from_base64.hpp>
17 #include <boost/archive/iterators/transform_width.hpp>
18 
19 #include <boost/property_tree/xml_parser.hpp>
20 #include <boost/algorithm/string/predicate.hpp>
21 #include <boost/algorithm/string.hpp>
22 #include <boost/algorithm/string/replace.hpp>
23 #include <boost/regex.hpp>
24 
25 #include <TFile.h>
26 
27 #include <sstream>
28 #include <filesystem>
29 
30 namespace Belle2 {
35  namespace MVA {
36 
37  std::string makeSaveForDatabase(std::string str)
38  {
39  std::map<std::string, std::string> replace {
40  {" ", "__sp"},
41  };
42  for (auto& pair : replace) {
43  boost::replace_all(str, pair.first, pair.second);
44  }
45  //const static boost::regex blackList("[^a-zA-Z0-9_]");
46  //return boost::regex_replace(str, blackList, "");
47  return str;
48  }
49 
51  {
52  for (auto& filename : m_filenames) {
53  if (std::filesystem::exists(filename)) {
55  std::filesystem::remove_all(filename);
56  }
57  }
58  }
59 
60  void Weightfile::addOptions(const Options& options)
61  {
62  options.save(m_pt);
63  }
64 
65  void Weightfile::getOptions(Options& options) const
66  {
67  options.load(m_pt);
68  }
69 
70  void Weightfile::addFeatureImportance(const std::map<std::string, float>& importance)
71  {
72  m_pt.put("number_of_importance_vars", importance.size());
73  unsigned int i = 0;
74  for (auto& pair : importance) {
75  m_pt.put(std::string("importance_key") + std::to_string(i), pair.first);
76  m_pt.put(std::string("importance_value") + std::to_string(i), pair.second);
77  ++i;
78  }
79  }
80 
81  std::map<std::string, float> Weightfile::getFeatureImportance() const
82  {
83  std::map<std::string, float> importance;
84  unsigned int numberOfImportanceVars = m_pt.get<unsigned int>("number_of_importance_vars", 0);
85  for (unsigned int i = 0; i < numberOfImportanceVars; ++i) {
86  auto key = m_pt.get<std::string>(std::string("importance_key") + std::to_string(i));
87  auto value = m_pt.get<float>(std::string("importance_value") + std::to_string(i));
88  importance[key] = value;
89  }
90  return importance;
91  }
92 
93  void Weightfile::addSignalFraction(float signal_fraction)
94  {
95  m_pt.put("signal_fraction", signal_fraction);
96  }
97 
99  {
100  return m_pt.get<float>("signal_fraction");
101  }
102 
103  std::string Weightfile::generateFileName(const std::string& suffix)
104  {
105  char* directory_template = strdup((m_temporary_directory + "/Basf2MVA.XXXXXX").c_str());
106  auto directory = mkdtemp(directory_template);
107  std::string tmpfile = std::string(directory) + std::string("/weightfile") + suffix;
108  m_filenames.emplace_back(directory);
109  free(directory_template);
110  return tmpfile;
111  }
112 
113  void Weightfile::addFile(const std::string& identifier, const std::string& custom_weightfile)
114  {
115  // TODO Test if file is valid
116  std::ifstream in(custom_weightfile, std::ios::in | std::ios::binary);
117  addStream(identifier, in);
118  in.close();
119  }
120 
121  void Weightfile::addStream(const std::string& identifier, std::istream& in)
122  {
123  using base64_t = boost::archive::iterators::base64_from_binary <
124  boost::archive::iterators::transform_width<std::string::const_iterator, 6, 8 >>;
125 
126  std::string contents;
127  in.seekg(0, std::ios::end);
128  contents.resize(in.tellg());
129  in.seekg(0, std::ios::beg);
130  in.read(&contents[0], contents.size());
131  std::string enc(base64_t(contents.begin()), base64_t(contents.end()));
132 
133  m_pt.put(identifier, enc);
134  }
135 
136  void Weightfile::getFile(const std::string& identifier, const std::string& custom_weightfile)
137  {
138  std::ofstream out(custom_weightfile, std::ios::out | std::ios::binary);
139  out << getStream(identifier);
140  }
141 
142  std::string Weightfile::getStream(const std::string& identifier) const
143  {
144  using binary_t = boost::archive::iterators::transform_width <
145  boost::archive::iterators::binary_from_base64<std::string::const_iterator>, 8, 6 >;
146 
147  auto contents = m_pt.get<std::string>(identifier);
148  std::string dec(binary_t(contents.begin()), binary_t(contents.end()));
149  return dec;
150  }
151 
152  void Weightfile::save(Weightfile& weightfile, const std::string& filename, const Belle2::IntervalOfValidity& iov)
153  {
154  if (boost::ends_with(filename, ".root")) {
155  return saveToROOTFile(weightfile, filename);
156  } else if (boost::ends_with(filename, ".xml")) {
157  return saveToXMLFile(weightfile, filename);
158  } else {
159  return saveToDatabase(weightfile, filename, iov);
160  }
161  }
162 
163  void Weightfile::saveToROOTFile(Weightfile& weightfile, const std::string& filename)
164  {
165  std::stringstream ss;
166  Weightfile::saveToStream(weightfile, ss);
167  DatabaseRepresentationOfWeightfile database_representation_of_weightfile;
168  database_representation_of_weightfile.m_data = ss.str();
169  TFile file(filename.c_str(), "RECREATE");
170  file.WriteObject(&database_representation_of_weightfile, "Weightfile");
171  }
172 
173  void Weightfile::saveToXMLFile(Weightfile& weightfile, const std::string& filename)
174  {
175 #if BOOST_VERSION < 105600
176  boost::property_tree::xml_writer_settings<char> settings('\t', 1);
177 #else
178  boost::property_tree::xml_writer_settings<std::string> settings('\t', 1);
179 #endif
180  boost::property_tree::xml_parser::write_xml(filename, weightfile.m_pt, std::locale(), settings);
181  }
182 
183  void Weightfile::saveToStream(Weightfile& weightfile, std::ostream& stream)
184  {
185 #if BOOST_VERSION < 105600
186  boost::property_tree::xml_writer_settings<char> settings('\t', 1);
187 #else
188  boost::property_tree::xml_writer_settings<std::string> settings('\t', 1);
189 #endif
190  boost::property_tree::xml_parser::write_xml(stream, weightfile.m_pt, settings);
191  }
192 
193  Weightfile Weightfile::load(const std::string& filename, const Belle2::EventMetaData& emd)
194  {
195  if (boost::ends_with(filename, ".root")) {
196  return loadFromROOTFile(filename);
197  } else if (boost::ends_with(filename, ".xml")) {
198  return loadFromXMLFile(filename);
199  } else {
200  return loadFromDatabase(filename, emd);
201  }
202  }
203 
204  Weightfile Weightfile::loadFromFile(const std::string& filename)
205  {
206  if (boost::ends_with(filename, ".root")) {
207  return loadFromROOTFile(filename);
208  } else if (boost::ends_with(filename, ".xml")) {
209  return loadFromXMLFile(filename);
210  } else {
211  throw std::runtime_error("Cannot load file " + filename + " because file extension is not supported");
212  }
213  }
214 
215  Weightfile Weightfile::loadFromROOTFile(const std::string& filename)
216  {
217 
218  if (not std::filesystem::exists(filename)) {
219  throw std::runtime_error("Given filename does not exist: " + filename);
220  }
221 
222  TFile file(filename.c_str(), "READ");
223  if (file.IsZombie() or not file.IsOpen()) {
224  throw std::runtime_error("Error during open of ROOT file named " + filename);
225  }
226 
227  DatabaseRepresentationOfWeightfile* database_representation_of_weightfile = nullptr;
228  file.GetObject("Weightfile", database_representation_of_weightfile);
229 
230  if (database_representation_of_weightfile == nullptr) {
231  throw std::runtime_error("The provided ROOT file " + filename + " does not contain a valid MVA weightfile.");
232  }
233  std::stringstream ss(database_representation_of_weightfile->m_data);
234  delete database_representation_of_weightfile;
235  return Weightfile::loadFromStream(ss);
236  }
237 
238  Weightfile Weightfile::loadFromXMLFile(const std::string& filename)
239  {
240  if (not std::filesystem::exists(filename)) {
241  throw std::runtime_error("Given filename does not exist: " + filename);
242  }
243 
244  Weightfile weightfile;
245  boost::property_tree::xml_parser::read_xml(filename, weightfile.m_pt);
246  return weightfile;
247  }
248 
250  {
251  Weightfile weightfile;
252  boost::property_tree::xml_parser::read_xml(stream, weightfile.m_pt);
253  return weightfile;
254  }
255 
256  void Weightfile::saveToDatabase(Weightfile& weightfile, const std::string& identifier, const Belle2::IntervalOfValidity& iov)
257  {
258  std::stringstream ss;
259  Weightfile::saveToStream(weightfile, ss);
260  DatabaseRepresentationOfWeightfile database_representation_of_weightfile;
261  database_representation_of_weightfile.m_data = ss.str();
262  Belle2::Database::Instance().storeData(makeSaveForDatabase(identifier), &database_representation_of_weightfile, iov);
263  }
264 
265  void Weightfile::saveArrayToDatabase(const std::vector<Weightfile>& weightfiles, const std::string& identifier,
266  const Belle2::IntervalOfValidity& iov)
267  {
268  DBImportArray<DatabaseRepresentationOfWeightfile> dbArray(makeSaveForDatabase(identifier));
269 
270  for (auto weightfile : weightfiles) {
271  std::stringstream ss;
272  Weightfile::saveToStream(weightfile, ss);
273  dbArray.appendNew(ss.str());
274  }
275 
276  dbArray.import(iov);
277  }
278 
279  Weightfile Weightfile::loadFromDatabase(const std::string& identifier, const Belle2::EventMetaData& emd)
280  {
281  auto pair = Belle2::Database::Instance().getData(emd, makeSaveForDatabase(identifier));
282 
283  if (pair.first == 0) {
284  throw std::runtime_error("Given identifier cannot be loaded from the database: " + identifier);
285  }
286 
287  DatabaseRepresentationOfWeightfile database_representation_of_weightfile = *static_cast<DatabaseRepresentationOfWeightfile*>
288  (pair.first);
289  std::stringstream ss(database_representation_of_weightfile.m_data);
290  return Weightfile::loadFromStream(ss);
291  }
292 
293  }
295 }
Class for importing array of objects to the database.
Definition: DBImportArray.h:25
T * appendNew()
Construct a new T object at the end of the array.
Definition: DBImportArray.h:67
bool import(const IntervalOfValidity &iov)
Import the object to database.
Definition: DBImportBase.cc:36
Database representation of a Weightfile object.
Store event, run, and experiment numbers.
Definition: EventMetaData.h:33
A class that describes the interval of experiments/runs for which an object in the database is valid.
Abstract base class of all Options given to the MVA interface.
Definition: Options.h:34
virtual void load(const boost::property_tree::ptree &pt)=0
Load mechanism (used by Weightfile) to load Options from a xml tree.
virtual void save(boost::property_tree::ptree &pt) const =0
Save mechanism (used by Weightfile) to store Options in a xml tree.
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
std::string m_temporary_directory
temporary directory which is used to store temporary directories
Definition: Weightfile.h:298
void addStream(const std::string &identifier, std::istream &in)
Add a stream to our weightfile.
Definition: Weightfile.cc:121
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
Definition: Weightfile.cc:113
std::map< std::string, float > getFeatureImportance() const
Get feature importance.
Definition: Weightfile.cc:81
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
Definition: Weightfile.cc:238
static void save(Weightfile &weightfile, const std::string &filename, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile to a file.
Definition: Weightfile.cc:152
~Weightfile()
Destructor (removes temporary files associated with this weightfiles)
Definition: Weightfile.cc:50
bool m_remove_temporary_directories
remove all temporary directories in the destructor of this class
Definition: Weightfile.h:297
static void saveToXMLFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a XML file.
Definition: Weightfile.cc:173
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:249
boost::property_tree::ptree m_pt
xml tree containing all the saved information of this weightfile
Definition: Weightfile.h:292
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition: Weightfile.cc:60
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
Definition: Weightfile.cc:215
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:65
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
Definition: Weightfile.cc:193
static Weightfile loadFromDatabase(const std::string &identifier, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from the basf2 condition database.
Definition: Weightfile.cc:279
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
Definition: Weightfile.cc:183
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:204
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:93
std::vector< std::string > m_filenames
generated temporary filenames, which will be removed in the destructor of this class
Definition: Weightfile.h:296
void addFeatureImportance(const std::map< std::string, float > &importance)
Add variable importance.
Definition: Weightfile.cc:70
static void saveToROOTFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a ROOT file.
Definition: Weightfile.cc:163
float getSignalFraction() const
Loads the signal fraction frm the xml tree.
Definition: Weightfile.cc:98
static void saveArrayToDatabase(const std::vector< Weightfile > &weightfiles, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves an array of Weightfile objects in the basf2 condition database.
Definition: Weightfile.cc:265
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
Definition: Weightfile.cc:103
std::string getStream(const std::string &identifier) const
Returns the content of a stored stream as string.
Definition: Weightfile.cc:142
static void saveToDatabase(Weightfile &weightfile, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile in the basf2 condition database.
Definition: Weightfile.cc:256
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Definition: Weightfile.cc:136
std::pair< TObject *, IntervalOfValidity > getData(const EventMetaData &event, const std::string &name)
Request an object from the database.
Definition: Database.cc:72
static Database & Instance()
Instance of a singleton Database.
Definition: Database.cc:42
bool storeData(const std::string &name, TObject *object, const IntervalOfValidity &iov)
Store an object in the database.
Definition: Database.cc:141
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:24