Belle II Software  release-08-01-10
Weightfile.h
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #pragma once
10 #ifndef INCLUDE_GUARD_BELLE2_MVA_WEIGHTFILE_HEADER
11 #define INCLUDE_GUARD_BELLE2_MVA_WEIGHTFILE_HEADER
12 
13 #include <mva/interface/Options.h>
14 
15 #include <framework/database/IntervalOfValidity.h>
16 #include <framework/dataobjects/EventMetaData.h>
17 
18 #include <boost/property_tree/ptree.hpp>
19 
20 #include <vector>
21 #include <string>
22 #include <fstream>
23 
24 namespace Belle2 {
30  namespace MVA {
31 
32  std::string makeSaveForDatabase(std::string str);
33 
38  class Weightfile {
39 
40  public:
44  Weightfile() {};
45 
49  ~Weightfile();
50 
55  void addFeatureImportance(const std::map<std::string, float>& importance);
56 
60  std::map<std::string, float> getFeatureImportance() const;
61 
66  void addOptions(const Options& options);
67 
72  void getOptions(Options& options) const;
73 
78  void addSignalFraction(float signal_fraction);
79 
84  float getSignalFraction() const;
85 
92  std::string generateFileName(const std::string& suffix = "");
93 
99  void addFile(const std::string& identifier, const std::string& custom_weightfile);
100 
106  void addStream(const std::string& identifier, std::istream& in);
107 
113  template<class T>
114  void addElement(const std::string& identifier, const T& element)
115  {
116  m_pt.put(identifier, element);
117  }
118 
124  template<class T>
125  void addVector(const std::string& identifier, const std::vector<T>& vector)
126  {
127  m_pt.put(identifier + "_size", vector.size());
128  for (unsigned int i = 0; i < vector.size(); ++i) {
129  m_pt.put(identifier + std::to_string(i), vector[i]);
130  }
131  }
132 
138  void getFile(const std::string& identifier, const std::string& custom_weightfile);
139 
144  std::string getStream(const std::string& identifier) const;
145 
150  template<class T>
151  T getElement(const std::string& identifier) const
152  {
153  return m_pt.get<T>(identifier);
154  }
155 
160  bool containsElement(const std::string& identifier) const
161  {
162  return m_pt.count(identifier) > 0;
163  }
164 
170  template<class T>
171  T getElement(const std::string& identifier, const T& default_value) const
172  {
173  return m_pt.get<T>(identifier, default_value);
174  }
175 
180  template<class T>
181  std::vector<T> getVector(const std::string& identifier) const
182  {
183  std::vector<T> vector;
184  vector.resize(m_pt.get<size_t>(identifier + "_size"));
185  for (unsigned int i = 0; i < vector.size(); ++i) {
186  vector[i] = m_pt.get<T>(identifier + std::to_string(i));
187  }
188  return vector;
189  }
190 
197  static void save(Weightfile& weightfile, const std::string& filename,
198  const Belle2::IntervalOfValidity& iov = Belle2::IntervalOfValidity(0, 0, -1, -1));
199 
205  static void saveToROOTFile(Weightfile& weightfile, const std::string& filename);
206 
212  static void saveToXMLFile(Weightfile& weightfile, const std::string& filename);
213 
219  static void saveToStream(Weightfile& weightfile, std::ostream& stream);
220 
226  static Weightfile load(const std::string& filename, const Belle2::EventMetaData& emd = Belle2::EventMetaData(0, 0, 0));
227 
232  static Weightfile loadFromFile(const std::string& filename);
233 
238  static Weightfile loadFromROOTFile(const std::string& filename);
239 
244  static Weightfile loadFromXMLFile(const std::string& filename);
245 
250  static Weightfile loadFromStream(std::istream& stream);
251 
258  static void saveToDatabase(Weightfile& weightfile, const std::string& identifier,
259  const Belle2::IntervalOfValidity& iov = Belle2::IntervalOfValidity(0, 0, -1, -1));
260 
267  static void saveArrayToDatabase(const std::vector<Weightfile>& weightfiles, const std::string& identifier,
268  const Belle2::IntervalOfValidity& iov = Belle2::IntervalOfValidity(0, 0, -1, -1));
269 
275  static Weightfile loadFromDatabase(const std::string& identifier, const Belle2::EventMetaData& emd = Belle2::EventMetaData(0, 0,
276  0));
277 
282  void setRemoveTemporaryDirectories(bool remove_temporary_directories) { m_remove_temporary_directories = remove_temporary_directories; }
283 
287  const boost::property_tree::ptree& getXMLTree() const { return m_pt; };
288 
289  private:
290  boost::property_tree::ptree m_pt;
291  std::vector<std::string> m_filenames;
293  };
294 
295  }
297 }
298 #endif
Store event, run, and experiment numbers.
Definition: EventMetaData.h:33
A class that describes the interval of experiments/runs for which an object in the database is valid.
Abstract base class of all Options given to the MVA interface.
Definition: Options.h:34
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
void addStream(const std::string &identifier, std::istream &in)
Add a stream to our weightfile.
Definition: Weightfile.cc:123
void addElement(const std::string &identifier, const T &element)
Add an element to the xml tree.
Definition: Weightfile.h:114
Weightfile()
Construct an empty weightfile.
Definition: Weightfile.h:44
T getElement(const std::string &identifier) const
Returns a stored element from the xml tree.
Definition: Weightfile.h:151
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
Definition: Weightfile.cc:115
std::map< std::string, float > getFeatureImportance() const
Get feature importance.
Definition: Weightfile.cc:83
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
Definition: Weightfile.cc:240
static void save(Weightfile &weightfile, const std::string &filename, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile to a file.
Definition: Weightfile.cc:154
~Weightfile()
Destructor (removes temporary files associated with this weightfiles)
Definition: Weightfile.cc:52
void setRemoveTemporaryDirectories(bool remove_temporary_directories)
Set the deletion behaviour of the weightfile object for temporary directories For debugging it can be...
Definition: Weightfile.h:282
bool m_remove_temporary_directories
remove all temporary directories in the destructor of this class
Definition: Weightfile.h:292
static void saveToXMLFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a XML file.
Definition: Weightfile.cc:175
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
Definition: Weightfile.cc:251
bool containsElement(const std::string &identifier) const
Returns true if given element is stored in the property tree.
Definition: Weightfile.h:160
const boost::property_tree::ptree & getXMLTree() const
Get xml tree.
Definition: Weightfile.h:287
boost::property_tree::ptree m_pt
xml tree containing all the saved information of this weightfile
Definition: Weightfile.h:287
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition: Weightfile.cc:62
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
Definition: Weightfile.cc:217
std::vector< T > getVector(const std::string &identifier) const
Returns a stored vector from the xml tree.
Definition: Weightfile.h:181
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:67
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
Definition: Weightfile.cc:195
static Weightfile loadFromDatabase(const std::string &identifier, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from the basf2 condition database.
Definition: Weightfile.cc:281
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
Definition: Weightfile.cc:185
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
Definition: Weightfile.cc:206
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:95
std::vector< std::string > m_filenames
generated temporary filenames, which will be removed in the destructor of this class
Definition: Weightfile.h:291
void addFeatureImportance(const std::map< std::string, float > &importance)
Add variable importance.
Definition: Weightfile.cc:72
static void saveToROOTFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a ROOT file.
Definition: Weightfile.cc:165
void addVector(const std::string &identifier, const std::vector< T > &vector)
Add a vector to the xml tree.
Definition: Weightfile.h:125
T getElement(const std::string &identifier, const T &default_value) const
Returns a stored element from the xml tree.
Definition: Weightfile.h:171
float getSignalFraction() const
Loads the signal fraction frm the xml tree.
Definition: Weightfile.cc:100
static void saveArrayToDatabase(const std::vector< Weightfile > &weightfiles, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves an array of Weightfile objects in the basf2 condition database.
Definition: Weightfile.cc:267
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
Definition: Weightfile.cc:105
std::string getStream(const std::string &identifier) const
Returns the content of a stored stream as string.
Definition: Weightfile.cc:144
static void saveToDatabase(Weightfile &weightfile, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile in the basf2 condition database.
Definition: Weightfile.cc:258
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Definition: Weightfile.cc:138
Abstract base class for different kinds of events.