9 #include <mva/interface/Weightfile.h> 
   11 #include <mva/dataobjects/DatabaseRepresentationOfWeightfile.h> 
   12 #include <framework/database/Database.h> 
   13 #include <framework/database/DBImportArray.h> 
   15 #include <boost/archive/iterators/base64_from_binary.hpp> 
   16 #include <boost/archive/iterators/binary_from_base64.hpp> 
   17 #include <boost/archive/iterators/transform_width.hpp> 
   19 #include <boost/property_tree/xml_parser.hpp> 
   20 #include <boost/filesystem/operations.hpp> 
   21 #include <boost/algorithm/string/predicate.hpp> 
   22 #include <boost/algorithm/string.hpp> 
   23 #include <boost/algorithm/string/replace.hpp> 
   24 #include <boost/regex.hpp> 
   38     std::string makeSaveForDatabase(std::string str)
 
   40       std::map<std::string, std::string> replace {
 
   43       for (
auto& pair : replace) {
 
   44         boost::replace_all(str, pair.first, pair.second);
 
   54         if (boost::filesystem::exists(filename)) {
 
   56             boost::filesystem::remove_all(filename);
 
   73       m_pt.put(
"number_of_importance_vars", importance.size());
 
   75       for (
auto& pair : importance) {
 
   76         m_pt.put(std::string(
"importance_key") + std::to_string(i), pair.first);
 
   77         m_pt.put(std::string(
"importance_value") + std::to_string(i), pair.second);
 
   84       std::map<std::string, float> importance;
 
   85       unsigned int numberOfImportanceVars = 
m_pt.get<
unsigned int>(
"number_of_importance_vars", 0);
 
   86       for (
unsigned int i = 0; i < numberOfImportanceVars; ++i) {
 
   87         auto key = 
m_pt.get<std::string>(std::string(
"importance_key") + std::to_string(i));
 
   88         auto value = 
m_pt.get<
float>(std::string(
"importance_value") + std::to_string(i));
 
   89         importance[key] = value;
 
   96       m_pt.put(
"signal_fraction", signal_fraction);
 
  101       return m_pt.get<
float>(
"signal_fraction");
 
  107       auto directory = mkdtemp(directory_template);
 
  108       std::string tmpfile = std::string(directory) + std::string(
"/weightfile") + suffix;
 
  110       free(directory_template);
 
  117       std::ifstream in(custom_weightfile, std::ios::in | std::ios::binary);
 
  124       using base64_t =  boost::archive::iterators::base64_from_binary <
 
  125                         boost::archive::iterators::transform_width<std::string::const_iterator, 6, 8 >>;
 
  127       std::string contents;
 
  128       in.seekg(0, std::ios::end);
 
  129       contents.resize(in.tellg());
 
  130       in.seekg(0, std::ios::beg);
 
  131       in.read(&contents[0], contents.size());
 
  132       std::string enc(base64_t(contents.begin()), base64_t(contents.end()));
 
  134       m_pt.put(identifier, enc);
 
  139       std::ofstream out(custom_weightfile, std::ios::out | std::ios::binary);
 
  145       using binary_t = boost::archive::iterators::transform_width <
 
  146                        boost::archive::iterators::binary_from_base64<std::string::const_iterator>, 8, 6 >;
 
  148       auto contents = 
m_pt.get<std::string>(identifier);
 
  149       std::string dec(binary_t(contents.begin()), binary_t(contents.end()));
 
  155       if (boost::ends_with(filename, 
".root")) {
 
  157       } 
else if (boost::ends_with(filename, 
".xml")) {
 
  166       std::stringstream ss;
 
  169       database_representation_of_weightfile.
m_data = ss.str();
 
  170       TFile file(filename.c_str(), 
"RECREATE");
 
  171       file.WriteObject(&database_representation_of_weightfile, 
"Weightfile");
 
  176 #if BOOST_VERSION < 105600 
  177       boost::property_tree::xml_writer_settings<char> settings(
'\t', 1);
 
  179       boost::property_tree::xml_writer_settings<std::string> settings(
'\t', 1);
 
  181       boost::property_tree::xml_parser::write_xml(filename, weightfile.
m_pt, std::locale(), settings);
 
  186 #if BOOST_VERSION < 105600 
  187       boost::property_tree::xml_writer_settings<char> settings(
'\t', 1);
 
  189       boost::property_tree::xml_writer_settings<std::string> settings(
'\t', 1);
 
  191       boost::property_tree::xml_parser::write_xml(stream, weightfile.
m_pt, settings);
 
  196       if (boost::ends_with(filename, 
".root")) {
 
  198       } 
else if (boost::ends_with(filename, 
".xml")) {
 
  207       if (boost::ends_with(filename, 
".root")) {
 
  209       } 
else if (boost::ends_with(filename, 
".xml")) {
 
  212         throw std::runtime_error(
"Cannot load file " + filename + 
" because file extension is not supported");
 
  219       if (not boost::filesystem::exists(filename)) {
 
  220         throw std::runtime_error(
"Given filename does not exist: " + filename);
 
  223       TFile file(filename.c_str(), 
"READ");
 
  224       if (file.IsZombie() or not file.IsOpen()) {
 
  225         throw std::runtime_error(
"Error during open of ROOT file named " + filename);
 
  229       file.GetObject(
"Weightfile", database_representation_of_weightfile);
 
  231       if (database_representation_of_weightfile == 
nullptr) {
 
  232         throw std::runtime_error(
"The provided ROOT file " + filename + 
" does not contain a valid MVA weightfile.");
 
  234       std::stringstream ss(database_representation_of_weightfile->
m_data);
 
  235       delete database_representation_of_weightfile;
 
  241       if (not boost::filesystem::exists(filename)) {
 
  242         throw std::runtime_error(
"Given filename does not exist: " + filename);
 
  246       boost::property_tree::xml_parser::read_xml(filename, weightfile.
m_pt);
 
  253       boost::property_tree::xml_parser::read_xml(stream, weightfile.
m_pt);
 
  259       std::stringstream ss;
 
  262       database_representation_of_weightfile.
m_data = ss.str();
 
  271       for (
auto weightfile : weightfiles) {
 
  272         std::stringstream ss;
 
  284       if (pair.first == 0) {
 
  285         throw std::runtime_error(
"Given identifier cannot be loaded from the database: " + identifier);
 
  290       std::stringstream ss(database_representation_of_weightfile.
m_data);
 
Class for importing array of objects to the database.
T * appendNew()
Construct a new T object at the end of the array.
bool import(const IntervalOfValidity &iov)
Import the object to database.
Database representation of a Weightfile object.
std::string m_data
Serialized weightfile.
A class that describes the interval of experiments/runs for which an object in the database is valid.
Abstract base class of all Options given to the MVA interface.
virtual void load(const boost::property_tree::ptree &pt)=0
Load mechanism (used by Weightfile) to load Options from a xml tree.
virtual void save(boost::property_tree::ptree &pt) const =0
Save mechanism (used by Weightfile) to store Options in a xml tree.
The Weightfile class serializes all information about a training into an xml tree.
std::string m_temporary_directory
temporary directory which is used to store temporary directories
void addStream(const std::string &identifier, std::istream &in)
Add a stream to our weightfile.
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
std::map< std::string, float > getFeatureImportance() const
Get feature importance.
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
static void save(Weightfile &weightfile, const std::string &filename, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile to a file.
~Weightfile()
Destructor (removes temporary files associated with this weightfiles)
bool m_remove_temporary_directories
remove all temporary directories in the destructor of this class
static void saveToXMLFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a XML file.
static Weightfile loadFromStream(std::istream &stream)
Static function which deserializes a Weightfile from a stream.
boost::property_tree::ptree m_pt
xml tree containing all the saved information of this weightfile
void addOptions(const Options &options)
Add an Option object to the xml tree.
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
void getOptions(Options &options) const
Fills an Option object from the xml tree.
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
static Weightfile loadFromDatabase(const std::string &identifier, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from the basf2 condition database.
static void saveToStream(Weightfile &weightfile, std::ostream &stream)
Static function which serializes a Weightfile to a stream.
static Weightfile loadFromFile(const std::string &filename)
Static function which loads a Weightfile from a file.
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
std::vector< std::string > m_filenames
generated temporary filenames, which will be removed in the destructor of this class
void addFeatureImportance(const std::map< std::string, float > &importance)
Add variable importance.
static void saveToROOTFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a ROOT file.
float getSignalFraction() const
Loads the signal fraction frm the xml tree.
static void saveArrayToDatabase(const std::vector< Weightfile > &weightfiles, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves an array of Weightfile objects in the basf2 condition database.
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
std::string getStream(const std::string &identifier) const
Returns the content of a stored stream as string.
static void saveToDatabase(Weightfile &weightfile, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile in the basf2 condition database.
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
std::pair< TObject *, IntervalOfValidity > getData(const EventMetaData &event, const std::string &name)
Request an object from the database.
static Database & Instance()
Instance of a singleton Database.
bool storeData(const std::string &name, TObject *object, const IntervalOfValidity &iov)
Store an object in the database.
Abstract base class for different kinds of events.