Belle II Software light-2406-ragdoll
Combination.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/methods/Combination.h>
10#include <mva/interface/Interface.h>
11
12#include <framework/logging/Logger.h>
13
14namespace Belle2 {
19 namespace MVA {
20
21 void CombinationOptions::load(const boost::property_tree::ptree& pt)
22 {
23 int version = pt.get<int>("Combination_version");
24 if (version != 1) {
25 B2ERROR("Unknown weightfile version " << std::to_string(version));
26 throw std::runtime_error("Unknown weightfile version " + std::to_string(version));
27 }
28
29 auto numberOfWeightfiles = pt.get<unsigned int>("Combination_number_of_weightfiles");
30 m_weightfiles.resize(numberOfWeightfiles);
31 for (unsigned int i = 0; i < numberOfWeightfiles; ++i) {
32 m_weightfiles[i] = pt.get<std::string>(std::string("Combination_weightfile") + std::to_string(i));
33 }
34
35 }
36
37 void CombinationOptions::save(boost::property_tree::ptree& pt) const
38 {
39 pt.put("Combination_version", 1);
40 pt.put("Combination_number_of_weightfiles", m_weightfiles.size());
41 for (unsigned int i = 0; i < m_weightfiles.size(); ++i) {
42 pt.put(std::string("Combination_weightfile") + std::to_string(i), m_weightfiles[i]);
43 }
44 }
45
46 po::options_description CombinationOptions::getDescription()
47 {
48 po::options_description description("PDF options");
49 description.add_options()
50 ("weightfiles", po::value<std::vector<std::string>>(&m_weightfiles)->multitoken(),
51 "Weightfiles of other experts we want to combine together");
52 return description;
53 }
54
55
57 const CombinationOptions& specific_options) : Teacher(general_options),
58 m_specific_options(specific_options) { }
59
61 {
62
63 Weightfile weightfile;
64 weightfile.addOptions(m_general_options);
66
67 for (unsigned int i = 0; i < m_specific_options.m_weightfiles.size(); ++i) {
68 weightfile.addFile("Combination_Weightfile" + std::to_string(i), m_specific_options.m_weightfiles[i]);
69 }
70
71 weightfile.addSignalFraction(training_data.getSignalFraction());
72
73 return weightfile;
74
75 }
76
78 {
79
81
82 m_experts.clear();
83 m_expert_variables.clear();
84
85 for (unsigned int i = 0; i < m_specific_options.m_weightfiles.size(); ++i) {
86 std::string sub_weightfile_name = weightfile.generateFileName(".xml");
87 weightfile.getFile("Combination_Weightfile" + std::to_string(i), sub_weightfile_name);
88 auto sub_weightfile = Weightfile::load(sub_weightfile_name);
89 GeneralOptions general_options;
90 sub_weightfile.getOptions(general_options);
91
93 auto supported_interfaces = AbstractInterface::getSupportedInterfaces();
94 if (supported_interfaces.find(general_options.m_method) == supported_interfaces.end()) {
95 B2ERROR("Couldn't find method named " + general_options.m_method);
96 throw std::runtime_error("Couldn't find method named " + general_options.m_method);
97 }
98 auto expert = supported_interfaces[general_options.m_method]->getExpert();
99 expert->load(sub_weightfile);
100 m_experts.emplace_back(std::move(expert));
101 m_expert_variables.push_back(general_options.m_variables);
102 }
103 }
104
105 std::vector<float> CombinationExpert::apply(Dataset& test_data) const
106 {
107 std::vector<float> probabilities(test_data.getNumberOfEvents(), 0);
108 std::vector<std::vector<float>> expert_probabilities;
109
110 for (unsigned int iExpert = 0; iExpert < m_experts.size(); ++iExpert) {
111 GeneralOptions sub_general_options = m_general_options;
112 sub_general_options.m_variables = m_expert_variables[iExpert];
113 SubDataset sub_dataset(sub_general_options, {}, test_data);
114 expert_probabilities.push_back(m_experts[iExpert]->apply(sub_dataset));
115 }
116
117 for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
118 double a = 1.0;
119 double b = 1.0;
120 for (unsigned int iExpert = 0; iExpert < m_experts.size(); ++iExpert) {
121 a *= expert_probabilities[iExpert][iEvent];
122 b *= (1.0 - expert_probabilities[iExpert][iEvent]);
123 }
124 probabilities[iEvent] = a / (a + b);
125 }
126 return probabilities;
127 }
128
129 }
131}
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
std::vector< std::unique_ptr< Expert > > m_experts
Experts of the methods to combine.
Definition: Combination.h:100
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
Definition: Combination.cc:105
CombinationOptions m_specific_options
Method specific options.
Definition: Combination.h:99
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: Combination.cc:77
std::vector< std::vector< std::string > > m_expert_variables
Results of the experts to combine.
Definition: Combination.h:101
Options for the Combination MVA method.
Definition: Combination.h:28
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: Combination.cc:46
std::vector< std::string > m_weightfiles
Weightfiles of all methods we want to combine.
Definition: Combination.h:53
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: Combination.cc:21
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: Combination.cc:37
CombinationTeacher(const GeneralOptions &general_options, const CombinationOptions &specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: Combination.cc:56
CombinationOptions m_specific_options
Method specific options.
Definition: Combination.h:77
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: Combination.cc:60
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
GeneralOptions m_general_options
General options loaded from the weightfile.
Definition: Expert.h:70
General options which are shared by all MVA trainings.
Definition: Options.h:62
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:86
Wraps another Dataset and provides a view to a subset of its features and events.
Definition: Dataset.h:234
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
Definition: Teacher.h:29
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition: Teacher.h:49
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
Definition: Weightfile.cc:115
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition: Weightfile.cc:62
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:67
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
Definition: Weightfile.cc:195
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:95
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
Definition: Weightfile.cc:105
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Definition: Weightfile.cc:138
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:24