Belle II Software  release-05-01-25
Combination.cc
1 /**************************************************************************
2  * BASF2 (Belle Analysis Framework 2) *
3  * Copyright(C) 2016 - Belle II Collaboration *
4  * *
5  * Author: The Belle II Collaboration *
6  * Contributors: Thomas Keck *
7  * *
8  * This software is provided "as is" without any warranty. *
9  **************************************************************************/
10 
11 #include <mva/methods/Combination.h>
12 #include <mva/interface/Interface.h>
13 
14 #include <framework/logging/Logger.h>
15 
16 namespace Belle2 {
21  namespace MVA {
22 
23  void CombinationOptions::load(const boost::property_tree::ptree& pt)
24  {
25  int version = pt.get<int>("Combination_version");
26  if (version != 1) {
27  B2ERROR("Unkown weightfile version " << std::to_string(version));
28  throw std::runtime_error("Unkown weightfile version " + std::to_string(version));
29  }
30 
31  auto numberOfWeightfiles = pt.get<unsigned int>("Combination_number_of_weightfiles");
32  m_weightfiles.resize(numberOfWeightfiles);
33  for (unsigned int i = 0; i < numberOfWeightfiles; ++i) {
34  m_weightfiles[i] = pt.get<std::string>(std::string("Combination_weightfile") + std::to_string(i));
35  }
36 
37  }
38 
39  void CombinationOptions::save(boost::property_tree::ptree& pt) const
40  {
41  pt.put("Combination_version", 1);
42  pt.put("Combination_number_of_weightfiles", m_weightfiles.size());
43  for (unsigned int i = 0; i < m_weightfiles.size(); ++i) {
44  pt.put(std::string("Combination_weightfile") + std::to_string(i), m_weightfiles[i]);
45  }
46  }
47 
48  po::options_description CombinationOptions::getDescription()
49  {
50  po::options_description description("PDF options");
51  description.add_options()
52  ("weightfiles", po::value<std::vector<std::string>>(&m_weightfiles)->multitoken(),
53  "Weightfiles of other experts we want to combine together");
54  return description;
55  }
56 
57 
59  const CombinationOptions& specific_options) : Teacher(general_options),
60  m_specific_options(specific_options) { }
61 
62  Weightfile CombinationTeacher::train(Dataset& training_data) const
63  {
64 
65  Weightfile weightfile;
67  weightfile.addOptions(m_specific_options);
68 
69  for (unsigned int i = 0; i < m_specific_options.m_weightfiles.size(); ++i) {
70  weightfile.addFile("Combination_Weightfile" + std::to_string(i), m_specific_options.m_weightfiles[i]);
71  }
72 
73  weightfile.addSignalFraction(training_data.getSignalFraction());
74 
75  return weightfile;
76 
77  }
78 
79  void CombinationExpert::load(Weightfile& weightfile)
80  {
81 
82  weightfile.getOptions(m_specific_options);
83 
84  m_experts.clear();
85  m_expert_variables.clear();
86 
87  for (unsigned int i = 0; i < m_specific_options.m_weightfiles.size(); ++i) {
88  std::string sub_weightfile_name = weightfile.generateFileName(".xml");
89  weightfile.getFile("Combination_Weightfile" + std::to_string(i), sub_weightfile_name);
90  auto sub_weightfile = Weightfile::load(sub_weightfile_name);
91  GeneralOptions general_options;
92  sub_weightfile.getOptions(general_options);
93 
95  auto supported_interfaces = AbstractInterface::getSupportedInterfaces();
96  if (supported_interfaces.find(general_options.m_method) == supported_interfaces.end()) {
97  B2ERROR("Couldn't find method named " + general_options.m_method);
98  throw std::runtime_error("Couldn't find method named " + general_options.m_method);
99  }
100  auto expert = supported_interfaces[general_options.m_method]->getExpert();
101  expert->load(sub_weightfile);
102  m_experts.emplace_back(std::move(expert));
103  m_expert_variables.push_back(general_options.m_variables);
104  }
105  }
106 
107  std::vector<float> CombinationExpert::apply(Dataset& test_data) const
108  {
109  std::vector<float> probabilities(test_data.getNumberOfEvents(), 0);
110  std::vector<std::vector<float>> expert_probabilities;
111 
112  for (unsigned int iExpert = 0; iExpert < m_experts.size(); ++iExpert) {
113  GeneralOptions sub_general_options = m_general_options;
114  sub_general_options.m_variables = m_expert_variables[iExpert];
115  SubDataset sub_dataset(sub_general_options, {}, test_data);
116  expert_probabilities.push_back(m_experts[iExpert]->apply(sub_dataset));
117  }
118 
119  for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
120  double a = 1.0;
121  double b = 1.0;
122  for (unsigned int iExpert = 0; iExpert < m_experts.size(); ++iExpert) {
123  a *= expert_probabilities[iExpert][iEvent];
124  b *= (1.0 - expert_probabilities[iExpert][iEvent]);
125  }
126  probabilities[iEvent] = a / (a + b);
127  }
128  return probabilities;
129  }
130 
131  }
133 }
Belle2::MVA::CombinationOptions::getDescription
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: Combination.cc:56
Belle2::MVA::CombinationTeacher::CombinationTeacher
CombinationTeacher(const GeneralOptions &general_options, const CombinationOptions &specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: Combination.cc:66
Belle2::MVA::CombinationTeacher::m_specific_options
CombinationOptions m_specific_options
Method specific options.
Definition: Combination.h:79
Belle2::MVA::AbstractInterface::getSupportedInterfaces
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:55
Belle2::MVA::Dataset
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:34
Belle2::MVA::Weightfile::addSignalFraction
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:104
Belle2::MVA::Weightfile::addFile
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
Definition: Weightfile.cc:124
Belle2::MVA::Weightfile
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:40
Belle2::MVA::CombinationExpert::m_experts
std::vector< std::unique_ptr< Expert > > m_experts
Experts of the methods to combine.
Definition: Combination.h:102
Belle2::MVA::CombinationExpert::m_expert_variables
std::vector< std::vector< std::string > > m_expert_variables
Results of the experts to combine.
Definition: Combination.h:103
Belle2::MVA::CombinationOptions::save
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: Combination.cc:47
Belle2::MVA::CombinationTeacher::train
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: Combination.cc:70
Belle2::MVA::CombinationExpert::m_specific_options
CombinationOptions m_specific_options
Method specific options.
Definition: Combination.h:101
Belle2::MVA::Teacher::m_general_options
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition: Teacher.h:51
Belle2::MVA::CombinationExpert::load
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: Combination.cc:87
Belle2::MVA::Weightfile::load
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or fomr the database.
Definition: Weightfile.cc:204
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::MVA::SubDataset
Wraps another Dataset and provides a view to a subset of its features and events.
Definition: Dataset.h:234
Belle2::MVA::Teacher
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
Definition: Teacher.h:31
Belle2::MVA::Weightfile::addOptions
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition: Weightfile.cc:71
Belle2::MVA::CombinationOptions
Options for the Combination MVA method.
Definition: Combination.h:30
Belle2::MVA::CombinationExpert::apply
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
Definition: Combination.cc:115
Belle2::MVA::GeneralOptions::m_variables
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:88
Belle2::MVA::GeneralOptions
General options which are shared by all MVA trainings.
Definition: Options.h:64
Belle2::MVA::AbstractInterface::initSupportedInterfaces
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:55
Belle2::MVA::CombinationOptions::m_weightfiles
std::vector< std::string > m_weightfiles
Weightfiles of all methods we want to combine.
Definition: Combination.h:55
Belle2::MVA::Expert::m_general_options
GeneralOptions m_general_options
General options loaded from the weightfile.
Definition: Expert.h:74
Belle2::MVA::CombinationOptions::load
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: Combination.cc:31