Belle II Software  release-08-01-10
Combination.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <mva/methods/Combination.h>
10 #include <mva/interface/Interface.h>
11 
12 #include <framework/logging/Logger.h>
13 
14 namespace Belle2 {
19  namespace MVA {
20 
21  void CombinationOptions::load(const boost::property_tree::ptree& pt)
22  {
23  int version = pt.get<int>("Combination_version");
24  if (version != 1) {
25  B2ERROR("Unknown weightfile version " << std::to_string(version));
26  throw std::runtime_error("Unknown weightfile version " + std::to_string(version));
27  }
28 
29  auto numberOfWeightfiles = pt.get<unsigned int>("Combination_number_of_weightfiles");
30  m_weightfiles.resize(numberOfWeightfiles);
31  for (unsigned int i = 0; i < numberOfWeightfiles; ++i) {
32  m_weightfiles[i] = pt.get<std::string>(std::string("Combination_weightfile") + std::to_string(i));
33  }
34 
35  }
36 
37  void CombinationOptions::save(boost::property_tree::ptree& pt) const
38  {
39  pt.put("Combination_version", 1);
40  pt.put("Combination_number_of_weightfiles", m_weightfiles.size());
41  for (unsigned int i = 0; i < m_weightfiles.size(); ++i) {
42  pt.put(std::string("Combination_weightfile") + std::to_string(i), m_weightfiles[i]);
43  }
44  }
45 
46  po::options_description CombinationOptions::getDescription()
47  {
48  po::options_description description("PDF options");
49  description.add_options()
50  ("weightfiles", po::value<std::vector<std::string>>(&m_weightfiles)->multitoken(),
51  "Weightfiles of other experts we want to combine together");
52  return description;
53  }
54 
55 
57  const CombinationOptions& specific_options) : Teacher(general_options),
58  m_specific_options(specific_options) { }
59 
61  {
62 
63  Weightfile weightfile;
64  weightfile.addOptions(m_general_options);
65  weightfile.addOptions(m_specific_options);
66 
67  for (unsigned int i = 0; i < m_specific_options.m_weightfiles.size(); ++i) {
68  weightfile.addFile("Combination_Weightfile" + std::to_string(i), m_specific_options.m_weightfiles[i]);
69  }
70 
71  weightfile.addSignalFraction(training_data.getSignalFraction());
72 
73  return weightfile;
74 
75  }
76 
78  {
79 
80  weightfile.getOptions(m_specific_options);
81 
82  m_experts.clear();
83  m_expert_variables.clear();
84 
85  for (unsigned int i = 0; i < m_specific_options.m_weightfiles.size(); ++i) {
86  std::string sub_weightfile_name = weightfile.generateFileName(".xml");
87  weightfile.getFile("Combination_Weightfile" + std::to_string(i), sub_weightfile_name);
88  auto sub_weightfile = Weightfile::load(sub_weightfile_name);
89  GeneralOptions general_options;
90  sub_weightfile.getOptions(general_options);
91 
93  auto supported_interfaces = AbstractInterface::getSupportedInterfaces();
94  if (supported_interfaces.find(general_options.m_method) == supported_interfaces.end()) {
95  B2ERROR("Couldn't find method named " + general_options.m_method);
96  throw std::runtime_error("Couldn't find method named " + general_options.m_method);
97  }
98  auto expert = supported_interfaces[general_options.m_method]->getExpert();
99  expert->load(sub_weightfile);
100  m_experts.emplace_back(std::move(expert));
101  m_expert_variables.push_back(general_options.m_variables);
102  }
103  }
104 
105  std::vector<float> CombinationExpert::apply(Dataset& test_data) const
106  {
107  std::vector<float> probabilities(test_data.getNumberOfEvents(), 0);
108  std::vector<std::vector<float>> expert_probabilities;
109 
110  for (unsigned int iExpert = 0; iExpert < m_experts.size(); ++iExpert) {
111  GeneralOptions sub_general_options = m_general_options;
112  sub_general_options.m_variables = m_expert_variables[iExpert];
113  SubDataset sub_dataset(sub_general_options, {}, test_data);
114  expert_probabilities.push_back(m_experts[iExpert]->apply(sub_dataset));
115  }
116 
117  for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
118  double a = 1.0;
119  double b = 1.0;
120  for (unsigned int iExpert = 0; iExpert < m_experts.size(); ++iExpert) {
121  a *= expert_probabilities[iExpert][iEvent];
122  b *= (1.0 - expert_probabilities[iExpert][iEvent]);
123  }
124  probabilities[iEvent] = a / (a + b);
125  }
126  return probabilities;
127  }
128 
129  }
131 }
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
std::vector< std::unique_ptr< Expert > > m_experts
Experts of the methods to combine.
Definition: Combination.h:100
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
Definition: Combination.cc:105
CombinationOptions m_specific_options
Method specific options.
Definition: Combination.h:99
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: Combination.cc:77
std::vector< std::vector< std::string > > m_expert_variables
Results of the experts to combine.
Definition: Combination.h:101
Options for the Combination MVA method.
Definition: Combination.h:28
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: Combination.cc:46
std::vector< std::string > m_weightfiles
Weightfiles of all methods we want to combine.
Definition: Combination.h:53
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: Combination.cc:21
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: Combination.cc:37
CombinationTeacher(const GeneralOptions &general_options, const CombinationOptions &specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: Combination.cc:56
CombinationOptions m_specific_options
Method specific options.
Definition: Combination.h:77
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: Combination.cc:60
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
GeneralOptions m_general_options
General options loaded from the weightfile.
Definition: Expert.h:70
General options which are shared by all MVA trainings.
Definition: Options.h:62
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:86
Wraps another Dataset and provides a view to a subset of its features and events.
Definition: Dataset.h:234
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
Definition: Teacher.h:29
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition: Teacher.h:49
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
Definition: Weightfile.cc:115
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition: Weightfile.cc:62
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:67
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
Definition: Weightfile.cc:195
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:95
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
Definition: Weightfile.cc:105
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Definition: Weightfile.cc:138
Abstract base class for different kinds of events.