Belle II Software  release-05-01-25
Reweighter.cc
1 /**************************************************************************
2  * BASF2 (Belle Analysis Framework 2) *
3  * Copyright(C) 2017 - Belle II Collaboration *
4  * *
5  * Author: The Belle II Collaboration *
6  * Contributors: Thomas Keck *
7  * *
8  * This software is provided "as is" without any warranty. *
9  **************************************************************************/
10 
11 #include <mva/methods/Reweighter.h>
12 #include <mva/interface/Interface.h>
13 #include <framework/logging/Logger.h>
14 
15 namespace Belle2 {
20  namespace MVA {
21 
22  void ReweighterOptions::load(const boost::property_tree::ptree& pt)
23  {
24  int version = pt.get<int>("Reweighter_version");
25  if (version != 1) {
26  B2ERROR("Unkown weightfile version " << std::to_string(version));
27  throw std::runtime_error("Unkown weightfile version " + std::to_string(version));
28  }
29 
30  m_weightfile = pt.get<std::string>(std::string("Reweighter_weightfile"));
31  m_variable = pt.get<std::string>(std::string("Reweighter_variable"));
32 
33  }
34 
35  void ReweighterOptions::save(boost::property_tree::ptree& pt) const
36  {
37  pt.put("Reweighter_version", 1);
38  pt.put(std::string("Reweighter_weightfile"), m_weightfile);
39  pt.put(std::string("Reweighter_variable"), m_variable);
40  }
41 
42  po::options_description ReweighterOptions::getDescription()
43  {
44  po::options_description description("Reweighter options");
45  description.add_options()
46  ("reweighter_weightfile", po::value<std::string>(&m_weightfile),
47  "Weightfile of the expert used to reweight")
48  ("reweighter_variable", po::value<std::string>(&m_variable),
49  "Variable which decides if the reweighter is applied or not");
50  return description;
51  }
52 
53 
55  const ReweighterOptions& specific_options) : Teacher(general_options),
56  m_specific_options(specific_options) { }
57 
58  Weightfile ReweighterTeacher::train(Dataset& training_data) const
59  {
60 
61  Weightfile weightfile;
62 
63  GeneralOptions general_options;
64  auto expert_weightfile = Weightfile::load(m_specific_options.m_weightfile);
65  expert_weightfile.getOptions(general_options);
66 
67  // Override our variables with the one from the expert
68  GeneralOptions mod_general_options = m_general_options;
69  mod_general_options.m_variables = general_options.m_variables;
70  mod_general_options.m_spectators = general_options.m_spectators;
71  mod_general_options.m_target_variable = general_options.m_target_variable;
72  mod_general_options.m_weight_variable = general_options.m_weight_variable;
73 
74  // Add reweighting variable if it is not already present somewhere
75  if (m_specific_options.m_variable != "") {
76  if (std::find(mod_general_options.m_variables.begin(), mod_general_options.m_variables.end(),
77  m_specific_options.m_variable) == mod_general_options.m_variables.end() and
78  std::find(mod_general_options.m_spectators.begin(), mod_general_options.m_spectators.end(),
79  m_specific_options.m_variable) == mod_general_options.m_spectators.end() and
80  mod_general_options.m_target_variable != m_specific_options.m_variable and
81  mod_general_options.m_weight_variable != m_specific_options.m_variable) {
82  mod_general_options.m_spectators.push_back(m_specific_options.m_variable);
83  }
84  }
85 
87  auto supported_interfaces = AbstractInterface::getSupportedInterfaces();
88  if (supported_interfaces.find(general_options.m_method) == supported_interfaces.end()) {
89  B2ERROR("Couldn't find method named " + general_options.m_method);
90  throw std::runtime_error("Couldn't find method named " + general_options.m_method);
91  }
92  auto expert = supported_interfaces[general_options.m_method]->getExpert();
93  expert->load(expert_weightfile);
94 
95  auto prediction = expert->apply(training_data);
96 
97  double data_fraction = expert_weightfile.getSignalFraction();
98  double data_over_mc_fraction = data_fraction / (1 - data_fraction);
99 
100  double sum_reweights = 0;
101  unsigned long int count_reweights = 0;
102 
103  auto isSignal = training_data.getSignals();
104 
105  if (m_specific_options.m_variable != "") {
106  auto variable = training_data.getSpectator(training_data.getSpectatorIndex(m_specific_options.m_variable));
107  for (unsigned int iEvent = 0; iEvent < training_data.getNumberOfEvents(); ++iEvent) {
108  // We calculate the norm only on MC events (that is background), because
109  // this is were we apply the weights in the end
110  if (isSignal[iEvent]) {
111  continue;
112  }
113 
114  if (variable[iEvent] == 1.0) {
115  if (prediction[iEvent] > 0.995)
116  prediction[iEvent] = 0.995;
117  if (prediction[iEvent] < 0.005)
118  prediction[iEvent] = 0.005;
119 
120  prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
121  sum_reweights += prediction[iEvent];
122  count_reweights++;
123  }
124  }
125  } else {
126  for (unsigned int iEvent = 0; iEvent < training_data.getNumberOfEvents(); ++iEvent) {
127  // We calculate the norm only on MC events (that is background), because
128  // this is were we apply the weights in the end
129  if (isSignal[iEvent]) {
130  continue;
131  }
132 
133  if (prediction[iEvent] > 0.995)
134  prediction[iEvent] = 0.995;
135  if (prediction[iEvent] < 0.005)
136  prediction[iEvent] = 0.005;
137 
138  prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
139  sum_reweights += prediction[iEvent];
140  count_reweights++;
141  }
142  }
143 
144  double norm = sum_reweights / count_reweights / data_over_mc_fraction;
145 
146  weightfile.addOptions(mod_general_options);
147  weightfile.addOptions(m_specific_options);
148  weightfile.addFile("Reweighter_Weightfile", m_specific_options.m_weightfile);
149  weightfile.addSignalFraction(data_fraction);
150  weightfile.addElement("Reweighter_norm", norm);
151 
152  return weightfile;
153 
154  }
155 
156  void ReweighterExpert::load(Weightfile& weightfile)
157  {
158 
159  weightfile.getOptions(m_specific_options);
160 
161  std::string sub_weightfile_name = weightfile.generateFileName(".xml");
162  weightfile.getFile("Reweighter_Weightfile", sub_weightfile_name);
163  auto sub_weightfile = Weightfile::load(sub_weightfile_name);
164  sub_weightfile.getOptions(m_expert_options);
165 
167  auto supported_interfaces = AbstractInterface::getSupportedInterfaces();
168  if (supported_interfaces.find(m_expert_options.m_method) == supported_interfaces.end()) {
169  B2ERROR("Couldn't find method named " + m_expert_options.m_method);
170  throw std::runtime_error("Couldn't find method named " + m_expert_options.m_method);
171  }
172  m_expert = supported_interfaces[m_expert_options.m_method]->getExpert();
173  m_expert->load(sub_weightfile);
174 
175  m_norm = weightfile.getElement<float>("Reweighter_norm");
176  }
177 
178  std::vector<float> ReweighterExpert::apply(Dataset& test_data) const
179  {
180  auto prediction = m_expert->apply(test_data);
181 
182  if (m_specific_options.m_variable != "") {
183  auto variable = test_data.getSpectator(test_data.getSpectatorIndex(m_specific_options.m_variable));
184 
185  for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
186  if (variable[iEvent] != 1.0) {
187  prediction[iEvent] = 1.0;
188  } else {
189  if (prediction[iEvent] > 0.995)
190  prediction[iEvent] = 0.995;
191  if (prediction[iEvent] < 0.005)
192  prediction[iEvent] = 0.005;
193 
194  prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
195  prediction[iEvent] /= m_norm;
196  }
197  }
198  } else {
199  for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
200  if (prediction[iEvent] > 0.995)
201  prediction[iEvent] = 0.995;
202  if (prediction[iEvent] < 0.005)
203  prediction[iEvent] = 0.005;
204 
205  prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
206  prediction[iEvent] /= m_norm;
207  }
208  }
209 
210  return prediction;
211 
212  }
213  }
215 }
Belle2::MVA::ReweighterOptions
Options for the Reweighter MVA method.
Definition: Reweighter.h:30
Belle2::MVA::AbstractInterface::getSupportedInterfaces
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:55
Belle2::MVA::ReweighterOptions::load
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: Reweighter.cc:30
Belle2::MVA::ReweighterOptions::m_variable
std::string m_variable
Variable which decides if the reweighter is applied or not.
Definition: Reweighter.h:56
Belle2::MVA::GeneralOptions::m_weight_variable
std::string m_weight_variable
Weight variable (branch name) defining the weights.
Definition: Options.h:92
Belle2::MVA::ReweighterTeacher::train
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: Reweighter.cc:66
Belle2::MVA::ReweighterOptions::m_weightfile
std::string m_weightfile
Weightfile of the reweighting expert.
Definition: Reweighter.h:55
Belle2::MVA::GeneralOptions::m_spectators
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
Definition: Options.h:89
Belle2::MVA::GeneralOptions::m_method
std::string m_method
Name of the MVA method to use.
Definition: Options.h:84
Belle2::MVA::Teacher::m_general_options
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition: Teacher.h:51
Belle2::MVA::Weightfile::load
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or fomr the database.
Definition: Weightfile.cc:204
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::MVA::Teacher
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
Definition: Teacher.h:31
Belle2::MVA::ReweighterOptions::save
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: Reweighter.cc:43
Belle2::MVA::GeneralOptions::m_target_variable
std::string m_target_variable
Target variable (branch name) defining the target.
Definition: Options.h:91
Belle2::MVA::ReweighterExpert::load
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: Reweighter.cc:164
Belle2::MVA::ReweighterExpert::m_norm
double m_norm
Norm for the weights.
Definition: Reweighter.h:105
Belle2::MVA::GeneralOptions::m_variables
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:88
Belle2::MVA::GeneralOptions
General options which are shared by all MVA trainings.
Definition: Options.h:64
Belle2::MVA::ReweighterExpert::m_expert_options
GeneralOptions m_expert_options
Method general options of the expert.
Definition: Reweighter.h:103
Belle2::MVA::AbstractInterface::initSupportedInterfaces
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:55
Belle2::MVA::ReweighterExpert::m_specific_options
ReweighterOptions m_specific_options
Method specific options.
Definition: Reweighter.h:102
Belle2::MVA::ReweighterExpert::apply
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
Definition: Reweighter.cc:186
Belle2::MVA::ReweighterExpert::m_expert
std::unique_ptr< Expert > m_expert
Experts used to reweight.
Definition: Reweighter.h:104
Belle2::MVA::ReweighterOptions::getDescription
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: Reweighter.cc:50
Belle2::MVA::ReweighterTeacher::m_specific_options
ReweighterOptions m_specific_options
Method specific options.
Definition: Reweighter.h:80
Belle2::MVA::ReweighterTeacher::ReweighterTeacher
ReweighterTeacher(const GeneralOptions &general_options, const ReweighterOptions &specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: Reweighter.cc:62