 |
Belle II Software
release-05-01-25
|
11 #include <mva/methods/Combination.h>
12 #include <mva/interface/Interface.h>
14 #include <framework/logging/Logger.h>
25 int version = pt.get<
int>(
"Combination_version");
27 B2ERROR(
"Unkown weightfile version " << std::to_string(version));
28 throw std::runtime_error(
"Unkown weightfile version " + std::to_string(version));
31 auto numberOfWeightfiles = pt.get<
unsigned int>(
"Combination_number_of_weightfiles");
33 for (
unsigned int i = 0; i < numberOfWeightfiles; ++i) {
34 m_weightfiles[i] = pt.get<std::string>(std::string(
"Combination_weightfile") + std::to_string(i));
41 pt.put(
"Combination_version", 1);
42 pt.put(
"Combination_number_of_weightfiles",
m_weightfiles.size());
44 pt.put(std::string(
"Combination_weightfile") + std::to_string(i),
m_weightfiles[i]);
50 po::options_description description(
"PDF options");
51 description.add_options()
52 (
"weightfiles", po::value<std::vector<std::string>>(&
m_weightfiles)->multitoken(),
53 "Weightfiles of other experts we want to combine together");
60 m_specific_options(specific_options) { }
88 std::string sub_weightfile_name = weightfile.generateFileName(
".xml");
89 weightfile.getFile(
"Combination_Weightfile" + std::to_string(i), sub_weightfile_name);
92 sub_weightfile.getOptions(general_options);
96 if (supported_interfaces.find(general_options.m_method) == supported_interfaces.end()) {
97 B2ERROR(
"Couldn't find method named " + general_options.m_method);
98 throw std::runtime_error(
"Couldn't find method named " + general_options.m_method);
100 auto expert = supported_interfaces[general_options.m_method]->getExpert();
101 expert->load(sub_weightfile);
102 m_experts.emplace_back(std::move(expert));
109 std::vector<float> probabilities(test_data.getNumberOfEvents(), 0);
110 std::vector<std::vector<float>> expert_probabilities;
112 for (
unsigned int iExpert = 0; iExpert <
m_experts.size(); ++iExpert) {
116 expert_probabilities.push_back(
m_experts[iExpert]->
apply(sub_dataset));
119 for (
unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
122 for (
unsigned int iExpert = 0; iExpert <
m_experts.size(); ++iExpert) {
123 a *= expert_probabilities[iExpert][iEvent];
124 b *= (1.0 - expert_probabilities[iExpert][iEvent]);
126 probabilities[iEvent] = a / (a + b);
128 return probabilities;
virtual po::options_description getDescription() override
Returns a program options description for all available options.
CombinationTeacher(const GeneralOptions &general_options, const CombinationOptions &specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
CombinationOptions m_specific_options
Method specific options.
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
The Weightfile class serializes all information about a training into an xml tree.
std::vector< std::unique_ptr< Expert > > m_experts
Experts of the methods to combine.
std::vector< std::vector< std::string > > m_expert_variables
Results of the experts to combine.
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
CombinationOptions m_specific_options
Method specific options.
GeneralOptions m_general_options
GeneralOptions containing all shared options.
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or fomr the database.
Abstract base class for different kinds of events.
Wraps another Dataset and provides a view to a subset of its features and events.
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
void addOptions(const Options &options)
Add an Option object to the xml tree.
Options for the Combination MVA method.
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
General options which are shared by all MVA trainings.
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
std::vector< std::string > m_weightfiles
Weightfiles of all methods we want to combine.
GeneralOptions m_general_options
General options loaded from the weightfile.
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.