Belle II Software development
Weightfile.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/interface/Weightfile.h>
10
11#include <mva/dataobjects/DatabaseRepresentationOfWeightfile.h>
12#include <framework/database/Database.h>
13#include <framework/database/DBImportArray.h>
14
15#include <boost/archive/iterators/base64_from_binary.hpp>
16#include <boost/archive/iterators/binary_from_base64.hpp>
17#include <boost/archive/iterators/transform_width.hpp>
18
19#include <boost/property_tree/xml_parser.hpp>
20#include <boost/algorithm/string.hpp>
21#include <boost/algorithm/string/replace.hpp>
22#include <boost/regex.hpp>
23
24#include <TFile.h>
25
26#include <sstream>
27#include <filesystem>
28
29namespace fs = std::filesystem;
30
31namespace Belle2 {
36 namespace MVA {
37
38 std::string makeSaveForDatabase(std::string str)
39 {
40 std::map<std::string, std::string> replace {
41 {" ", "__sp"},
42 };
43 for (auto& pair : replace) {
44 boost::replace_all(str, pair.first, pair.second);
45 }
46 //const static boost::regex blackList("[^a-zA-Z0-9_]");
47 //return boost::regex_replace(str, blackList, "");
48 return str;
49 }
50
52 {
53 for (auto& filename : m_filenames) {
54 if (fs::exists(filename)) {
56 fs::remove_all(filename);
57 }
58 }
59 }
60
61 void Weightfile::addOptions(const Options& options)
62 {
63 options.save(m_pt);
64 }
65
66 void Weightfile::getOptions(Options& options) const
67 {
68 options.load(m_pt);
69 }
70
71 void Weightfile::addFeatureImportance(const std::map<std::string, float>& importance)
72 {
73 m_pt.put("number_of_importance_vars", importance.size());
74 unsigned int i = 0;
75 for (auto& pair : importance) {
76 m_pt.put(std::string("importance_key") + std::to_string(i), pair.first);
77 m_pt.put(std::string("importance_value") + std::to_string(i), pair.second);
78 ++i;
79 }
80 }
81
82 std::map<std::string, float> Weightfile::getFeatureImportance() const
83 {
84 std::map<std::string, float> importance;
85 unsigned int numberOfImportanceVars = m_pt.get<unsigned int>("number_of_importance_vars", 0);
86 for (unsigned int i = 0; i < numberOfImportanceVars; ++i) {
87 auto key = m_pt.get<std::string>(std::string("importance_key") + std::to_string(i));
88 auto value = m_pt.get<float>(std::string("importance_value") + std::to_string(i));
89 importance[key] = value;
90 }
91 return importance;
92 }
93
94 void Weightfile::addSignalFraction(float signal_fraction)
95 {
96 m_pt.put("signal_fraction", signal_fraction);
97 }
98
100 {
101 return m_pt.get<float>("signal_fraction");
102 }
103
104 std::string Weightfile::generateFileName(const std::string& suffix)
105 {
106 char* directory_template = strdup((fs::temp_directory_path() / "Basf2MVA.XXXXXX").c_str());
107 auto directory = mkdtemp(directory_template);
108 std::string tmpfile = std::string(directory) + std::string("/weightfile") + suffix;
109 m_filenames.emplace_back(directory);
110 free(directory_template);
111 return tmpfile;
112 }
113
114 void Weightfile::addFile(const std::string& identifier, const std::string& custom_weightfile)
115 {
116 // TODO Test if file is valid
117 std::ifstream in(custom_weightfile, std::ios::in | std::ios::binary);
118 addStream(identifier, in);
119 in.close();
120 }
121
122 void Weightfile::addStream(const std::string& identifier, std::istream& in)
123 {
124 using base64_t = boost::archive::iterators::base64_from_binary <
125 boost::archive::iterators::transform_width<std::string::const_iterator, 6, 8 >>;
126
127 std::string contents;
128 in.seekg(0, std::ios::end);
129 contents.resize(in.tellg());
130 in.seekg(0, std::ios::beg);
131 in.read(&contents[0], contents.size());
132 std::string enc(base64_t(contents.begin()), base64_t(contents.end()));
133
134 m_pt.put(identifier, enc);
135 }
136
137 void Weightfile::getFile(const std::string& identifier, const std::string& custom_weightfile)
138 {
139 std::ofstream out(custom_weightfile, std::ios::out | std::ios::binary);
140 out << getStream(identifier);
141 }
142
143 std::string Weightfile::getStream(const std::string& identifier) const
144 {
145 using binary_t = boost::archive::iterators::transform_width <
146 boost::archive::iterators::binary_from_base64<std::string::const_iterator>, 8, 6 >;
147
148 auto contents = m_pt.get<std::string>(identifier);
149 std::string dec(binary_t(contents.begin()), binary_t(contents.end()));
150 return dec;
151 }
152
153 void Weightfile::save(Weightfile& weightfile, const std::string& filename, const Belle2::IntervalOfValidity& iov)
154 {
155 if (filename.ends_with(".root")) {
156 return saveToROOTFile(weightfile, filename);
157 } else if (filename.ends_with(".xml")) {
158 return saveToXMLFile(weightfile, filename);
159 } else {
160 return saveToDatabase(weightfile, filename, iov);
161 }
162 }
163
164 void Weightfile::saveToROOTFile(Weightfile& weightfile, const std::string& filename)
165 {
166 std::stringstream ss;
167 Weightfile::saveToStream(weightfile, ss);
168 DatabaseRepresentationOfWeightfile database_representation_of_weightfile;
169 database_representation_of_weightfile.m_data = ss.str();
170 TFile file(filename.c_str(), "RECREATE");
171 file.WriteObject(&database_representation_of_weightfile, "Weightfile");
172 }
173
174 void Weightfile::saveToXMLFile(Weightfile& weightfile, const std::string& filename)
175 {
176#if BOOST_VERSION < 105600
177 boost::property_tree::xml_writer_settings<char> settings('\t', 1);
178#else
179 boost::property_tree::xml_writer_settings<std::string> settings('\t', 1);
180#endif
181 boost::property_tree::xml_parser::write_xml(filename, weightfile.m_pt, std::locale(), settings);
182 }
183
184 void Weightfile::saveToStream(Weightfile& weightfile, std::ostream& stream)
185 {
186#if BOOST_VERSION < 105600
187 boost::property_tree::xml_writer_settings<char> settings('\t', 1);
188#else
189 boost::property_tree::xml_writer_settings<std::string> settings('\t', 1);
190#endif
191 boost::property_tree::xml_parser::write_xml(stream, weightfile.m_pt, settings);
192 }
193
194 Weightfile Weightfile::load(const std::string& filename, const Belle2::EventMetaData& emd)
195 {
196 if (filename.ends_with(".root")) {
197 return loadFromROOTFile(filename);
198 } else if (filename.ends_with(".xml")) {
199 return loadFromXMLFile(filename);
200 } else {
201 return loadFromDatabase(filename, emd);
202 }
203 }
204
205 Weightfile Weightfile::loadFromFile(const std::string& filename)
206 {
207 if (filename.ends_with(".root")) {
208 return loadFromROOTFile(filename);
209 } else if (filename.ends_with(".xml")) {
210 return loadFromXMLFile(filename);
211 } else {
212 throw std::runtime_error("Cannot load file " + filename + " because file extension is not supported");
213 }
214 }
215
216 Weightfile Weightfile::loadFromROOTFile(const std::string& filename)
217 {
218
219 if (not fs::exists(filename)) {
220 throw std::runtime_error("Given filename does not exist: " + filename);
221 }
222
223 TFile file(filename.c_str(), "READ");
224 if (file.IsZombie() or not file.IsOpen()) {
225 throw std::runtime_error("Error during open of ROOT file named " + filename);
226 }
227
228 DatabaseRepresentationOfWeightfile* database_representation_of_weightfile = nullptr;
229 file.GetObject("Weightfile", database_representation_of_weightfile);
230
231 if (database_representation_of_weightfile == nullptr) {
232 throw std::runtime_error("The provided ROOT file " + filename + " does not contain a valid MVA weightfile.");
233 }
234 std::stringstream ss(database_representation_of_weightfile->m_data);
235 delete database_representation_of_weightfile;
237 }
238
239 Weightfile Weightfile::loadFromXMLFile(const std::string& filename)
240 {
241 if (not fs::exists(filename)) {
242 throw std::runtime_error("Given filename does not exist: " + filename);
243 }
244
245 Weightfile weightfile;
246 boost::property_tree::xml_parser::read_xml(filename, weightfile.m_pt);
247 return weightfile;
248 }
249
251 {
252 Weightfile weightfile;
253 boost::property_tree::xml_parser::read_xml(stream, weightfile.m_pt);
254 return weightfile;
255 }
256
257 void Weightfile::saveToDatabase(Weightfile& weightfile, const std::string& identifier, const Belle2::IntervalOfValidity& iov)
258 {
259 std::stringstream ss;
260 Weightfile::saveToStream(weightfile, ss);
261 DatabaseRepresentationOfWeightfile database_representation_of_weightfile;
262 database_representation_of_weightfile.m_data = ss.str();
263 Belle2::Database::Instance().storeData(makeSaveForDatabase(identifier), &database_representation_of_weightfile, iov);
264 }
265
266 void Weightfile::saveArrayToDatabase(const std::vector<Weightfile>& weightfiles, const std::string& identifier,
268 {
269 DBImportArray<DatabaseRepresentationOfWeightfile> dbArray(makeSaveForDatabase(identifier));
270
271 for (auto weightfile : weightfiles) {
272 std::stringstream ss;
273 Weightfile::saveToStream(weightfile, ss);
274 dbArray.appendNew(ss.str());
275 }
276
277 dbArray.import(iov);
278 }
279
280 Weightfile Weightfile::loadFromDatabase(const std::string& identifier, const Belle2::EventMetaData& emd)
281 {
282 auto pair = Belle2::Database::Instance().getData(emd, makeSaveForDatabase(identifier));
283
284 if (pair.first == 0) {
285 throw std::runtime_error("Given identifier cannot be loaded from the database: " + identifier);
286 }
287
288 DatabaseRepresentationOfWeightfile database_representation_of_weightfile = *static_cast<DatabaseRepresentationOfWeightfile*>
289 (pair.first);
290 std::stringstream ss(database_representation_of_weightfile.m_data);
292 }
293
294 }
296}
Class for importing array of objects to the database.
T * appendNew()
Construct a new T object at the end of the array.
bool import(const IntervalOfValidity &iov)
Import the object to database.
Database representation of a Weightfile object.
Store event, run, and experiment numbers.
A class that describes the interval of experiments/runs for which an object in the database is valid.
Abstract base class of all Options given to the MVA interface.
Definition Options.h:34
void addStream(const std::string &identifier, std::istream &in)
Add a stream to our weightfile.
Weightfile()
Construct an empty weightfile.
Definition Weightfile.h:44
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
std::map< std::string, float > getFeatureImportance() const
Get feature importance.
Definition Weightfile.cc:82
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
static void save(Weightfile &weightfile, const std::string &filename, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile to a file.
~Weightfile()
Destructor (removes temporary files associated with this weightfiles)
Definition Weightfile.cc:51
bool m_remove_temporary_directories
remove all temporary directories in the destructor of this class
Definition Weightfile.h:292
static void saveToXMLFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a XML file.
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
boost::property_tree::ptree m_pt
xml tree containing all the saved information of this weightfile
Definition Weightfile.h:290
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition Weightfile.cc:61
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition Weightfile.cc:66
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
static Weightfile loadFromDatabase(const std::string &identifier, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from the basf2 condition database.
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition Weightfile.cc:94
std::vector< std::string > m_filenames
generated temporary filenames, which will be removed in the destructor of this class
Definition Weightfile.h:291
void addFeatureImportance(const std::map< std::string, float > &importance)
Add variable importance.
Definition Weightfile.cc:71
static void saveToROOTFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a ROOT file.
float getSignalFraction() const
Loads the signal fraction frm the xml tree.
Definition Weightfile.cc:99
static void saveArrayToDatabase(const std::vector< Weightfile > &weightfiles, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves an array of Weightfile objects in the basf2 condition database.
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
std::string getStream(const std::string &identifier) const
Returns the content of a stored stream as string.
static void saveToDatabase(Weightfile &weightfile, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile in the basf2 condition database.
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
STL class.
std::pair< TObject *, IntervalOfValidity > getData(const EventMetaData &event, const std::string &name)
Request an object from the database.
Definition Database.cc:71
static Database & Instance()
Instance of a singleton Database.
Definition Database.cc:41
bool storeData(const std::string &name, TObject *object, const IntervalOfValidity &iov)
Store an object in the database.
Definition Database.cc:140
Abstract base class for different kinds of events.