9#include <mva/methods/Combination.h>
10#include <mva/interface/Interface.h>
12#include <framework/logging/Logger.h>
23 int version = pt.get<
int>(
"Combination_version");
25 B2ERROR(
"Unknown weightfile version " << std::to_string(version));
26 throw std::runtime_error(
"Unknown weightfile version " + std::to_string(version));
29 auto numberOfWeightfiles = pt.get<
unsigned int>(
"Combination_number_of_weightfiles");
31 for (
unsigned int i = 0; i < numberOfWeightfiles; ++i) {
32 m_weightfiles[i] = pt.get<std::string>(std::string(
"Combination_weightfile") + std::to_string(i));
39 pt.put(
"Combination_version", 1);
40 pt.put(
"Combination_number_of_weightfiles",
m_weightfiles.size());
42 pt.put(std::string(
"Combination_weightfile") + std::to_string(i),
m_weightfiles[i]);
48 po::options_description description(
"PDF options");
49 description.add_options()
50 (
"weightfiles", po::value<std::vector<std::string>>(&
m_weightfiles)->multitoken(),
51 "Weightfiles of other experts we want to combine together");
58 m_specific_options(specific_options) { }
87 weightfile.
getFile(
"Combination_Weightfile" + std::to_string(i), sub_weightfile_name);
90 sub_weightfile.getOptions(general_options);
94 if (supported_interfaces.find(general_options.m_method) == supported_interfaces.end()) {
95 B2ERROR(
"Couldn't find method named " + general_options.m_method);
96 throw std::runtime_error(
"Couldn't find method named " + general_options.m_method);
98 auto expert = supported_interfaces[general_options.m_method]->getExpert();
99 expert->load(sub_weightfile);
100 m_experts.emplace_back(std::move(expert));
107 std::vector<float> probabilities(test_data.getNumberOfEvents(), 0);
108 std::vector<std::vector<float>> expert_probabilities;
110 for (
unsigned int iExpert = 0; iExpert <
m_experts.size(); ++iExpert) {
113 SubDataset sub_dataset(sub_general_options, {}, test_data);
114 expert_probabilities.push_back(
m_experts[iExpert]->
apply(sub_dataset));
117 for (
unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
120 for (
unsigned int iExpert = 0; iExpert <
m_experts.size(); ++iExpert) {
121 a *= expert_probabilities[iExpert][iEvent];
122 b *= (1.0 - expert_probabilities[iExpert][iEvent]);
124 probabilities[iEvent] = a / (a + b);
126 return probabilities;
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
std::vector< std::unique_ptr< Expert > > m_experts
Experts of the methods to combine.
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
CombinationOptions m_specific_options
Method specific options.
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
std::vector< std::vector< std::string > > m_expert_variables
Results of the experts to combine.
Options for the Combination MVA method.
virtual po::options_description getDescription() override
Returns a program options description for all available options.
std::vector< std::string > m_weightfiles
Weightfiles of all methods we want to combine.
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
CombinationTeacher(const GeneralOptions &general_options, const CombinationOptions &specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
CombinationOptions m_specific_options
Method specific options.
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
GeneralOptions m_general_options
General options loaded from the weightfile.
General options which are shared by all MVA trainings.
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Wraps another Dataset and provides a view to a subset of its features and events.
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
GeneralOptions m_general_options
GeneralOptions containing all shared options.
The Weightfile class serializes all information about a training into an xml tree.
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
void addOptions(const Options &options)
Add an Option object to the xml tree.
void getOptions(Options &options) const
Fills an Option object from the xml tree.
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Abstract base class for different kinds of events.