Belle II Software light-2406-ragdoll
Weightfile.h
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#pragma once
10#ifndef INCLUDE_GUARD_BELLE2_MVA_WEIGHTFILE_HEADER
11#define INCLUDE_GUARD_BELLE2_MVA_WEIGHTFILE_HEADER
12
13#include <mva/interface/Options.h>
14
15#include <framework/database/IntervalOfValidity.h>
16#include <framework/dataobjects/EventMetaData.h>
17
18#include <boost/property_tree/ptree.hpp>
19
20#include <vector>
21#include <string>
22#include <fstream>
23
24namespace Belle2 {
30 namespace MVA {
31
32 std::string makeSaveForDatabase(std::string str);
33
38 class Weightfile {
39
40 public:
45
50
55 void addFeatureImportance(const std::map<std::string, float>& importance);
56
60 std::map<std::string, float> getFeatureImportance() const;
61
66 void addOptions(const Options& options);
67
72 void getOptions(Options& options) const;
73
78 void addSignalFraction(float signal_fraction);
79
84 float getSignalFraction() const;
85
92 std::string generateFileName(const std::string& suffix = "");
93
99 void addFile(const std::string& identifier, const std::string& custom_weightfile);
100
106 void addStream(const std::string& identifier, std::istream& in);
107
113 template<class T>
114 void addElement(const std::string& identifier, const T& element)
115 {
116 m_pt.put(identifier, element);
117 }
118
124 template<class T>
125 void addVector(const std::string& identifier, const std::vector<T>& vector)
126 {
127 m_pt.put(identifier + "_size", vector.size());
128 for (unsigned int i = 0; i < vector.size(); ++i) {
129 m_pt.put(identifier + std::to_string(i), vector[i]);
130 }
131 }
132
138 void getFile(const std::string& identifier, const std::string& custom_weightfile);
139
144 std::string getStream(const std::string& identifier) const;
145
150 template<class T>
151 T getElement(const std::string& identifier) const
152 {
153 return m_pt.get<T>(identifier);
154 }
155
160 bool containsElement(const std::string& identifier) const
161 {
162 return m_pt.count(identifier) > 0;
163 }
164
170 template<class T>
171 T getElement(const std::string& identifier, const T& default_value) const
172 {
173 return m_pt.get<T>(identifier, default_value);
174 }
175
180 template<class T>
181 std::vector<T> getVector(const std::string& identifier) const
182 {
183 std::vector<T> vector;
184 vector.resize(m_pt.get<size_t>(identifier + "_size"));
185 for (unsigned int i = 0; i < vector.size(); ++i) {
186 vector[i] = m_pt.get<T>(identifier + std::to_string(i));
187 }
188 return vector;
189 }
190
197 static void save(Weightfile& weightfile, const std::string& filename,
198 const Belle2::IntervalOfValidity& iov = Belle2::IntervalOfValidity(0, 0, -1, -1));
199
205 static void saveToROOTFile(Weightfile& weightfile, const std::string& filename);
206
212 static void saveToXMLFile(Weightfile& weightfile, const std::string& filename);
213
219 static void saveToStream(Weightfile& weightfile, std::ostream& stream);
220
226 static Weightfile load(const std::string& filename, const Belle2::EventMetaData& emd = Belle2::EventMetaData(0, 0, 0));
227
232 static Weightfile loadFromFile(const std::string& filename);
233
238 static Weightfile loadFromROOTFile(const std::string& filename);
239
244 static Weightfile loadFromXMLFile(const std::string& filename);
245
250 static Weightfile loadFromStream(std::istream& stream);
251
258 static void saveToDatabase(Weightfile& weightfile, const std::string& identifier,
259 const Belle2::IntervalOfValidity& iov = Belle2::IntervalOfValidity(0, 0, -1, -1));
260
267 static void saveArrayToDatabase(const std::vector<Weightfile>& weightfiles, const std::string& identifier,
268 const Belle2::IntervalOfValidity& iov = Belle2::IntervalOfValidity(0, 0, -1, -1));
269
275 static Weightfile loadFromDatabase(const std::string& identifier, const Belle2::EventMetaData& emd = Belle2::EventMetaData(0, 0,
276 0));
277
282 void setRemoveTemporaryDirectories(bool remove_temporary_directories) { m_remove_temporary_directories = remove_temporary_directories; }
283
287 const boost::property_tree::ptree& getXMLTree() const { return m_pt; };
288
289 private:
290 boost::property_tree::ptree m_pt;
291 std::vector<std::string> m_filenames;
293 };
294
295 }
297}
298#endif
Store event, run, and experiment numbers.
Definition: EventMetaData.h:33
A class that describes the interval of experiments/runs for which an object in the database is valid.
Abstract base class of all Options given to the MVA interface.
Definition: Options.h:34
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
void addStream(const std::string &identifier, std::istream &in)
Add a stream to our weightfile.
Definition: Weightfile.cc:123
void addElement(const std::string &identifier, const T &element)
Add an element to the xml tree.
Definition: Weightfile.h:114
Weightfile()
Construct an empty weightfile.
Definition: Weightfile.h:44
T getElement(const std::string &identifier) const
Returns a stored element from the xml tree.
Definition: Weightfile.h:151
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
Definition: Weightfile.cc:115
const boost::property_tree::ptree & getXMLTree() const
Get xml tree.
Definition: Weightfile.h:287
std::map< std::string, float > getFeatureImportance() const
Get feature importance.
Definition: Weightfile.cc:83
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
Definition: Weightfile.cc:240
static void save(Weightfile &weightfile, const std::string &filename, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile to a file.
Definition: Weightfile.cc:154
~Weightfile()
Destructor (removes temporary files associated with this weightfiles)
Definition: Weightfile.cc:52
void setRemoveTemporaryDirectories(bool remove_temporary_directories)
Set the deletion behaviour of the weightfile object for temporary directories For debugging it can be...
Definition: Weightfile.h:282
bool m_remove_temporary_directories
remove all temporary directories in the destructor of this class
Definition: Weightfile.h:292
static void saveToXMLFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a XML file.
Definition: Weightfile.cc:175
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:251
bool containsElement(const std::string &identifier) const
Returns true if given element is stored in the property tree.
Definition: Weightfile.h:160
boost::property_tree::ptree m_pt
xml tree containing all the saved information of this weightfile
Definition: Weightfile.h:290
std::vector< T > getVector(const std::string &identifier) const
Returns a stored vector from the xml tree.
Definition: Weightfile.h:181
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition: Weightfile.cc:62
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
Definition: Weightfile.cc:217
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:67
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
Definition: Weightfile.cc:195
static Weightfile loadFromDatabase(const std::string &identifier, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from the basf2 condition database.
Definition: Weightfile.cc:281
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
Definition: Weightfile.cc:185
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:206
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:95
std::vector< std::string > m_filenames
generated temporary filenames, which will be removed in the destructor of this class
Definition: Weightfile.h:291
void addFeatureImportance(const std::map< std::string, float > &importance)
Add variable importance.
Definition: Weightfile.cc:72
static void saveToROOTFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a ROOT file.
Definition: Weightfile.cc:165
void addVector(const std::string &identifier, const std::vector< T > &vector)
Add a vector to the xml tree.
Definition: Weightfile.h:125
T getElement(const std::string &identifier, const T &default_value) const
Returns a stored element from the xml tree.
Definition: Weightfile.h:171
float getSignalFraction() const
Loads the signal fraction frm the xml tree.
Definition: Weightfile.cc:100
static void saveArrayToDatabase(const std::vector< Weightfile > &weightfiles, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves an array of Weightfile objects in the basf2 condition database.
Definition: Weightfile.cc:267
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
Definition: Weightfile.cc:105
std::string getStream(const std::string &identifier) const
Returns the content of a stored stream as string.
Definition: Weightfile.cc:144
static void saveToDatabase(Weightfile &weightfile, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile in the basf2 condition database.
Definition: Weightfile.cc:258
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Definition: Weightfile.cc:138
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:24