Belle II Software  release-08-01-10
Reweighter.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <mva/methods/Reweighter.h>
10 #include <mva/interface/Interface.h>
11 #include <framework/logging/Logger.h>
12 
13 namespace Belle2 {
18  namespace MVA {
19 
20  void ReweighterOptions::load(const boost::property_tree::ptree& pt)
21  {
22  int version = pt.get<int>("Reweighter_version");
23  if (version != 1) {
24  B2ERROR("Unknown weightfile version " << std::to_string(version));
25  throw std::runtime_error("Unknown weightfile version " + std::to_string(version));
26  }
27 
28  m_weightfile = pt.get<std::string>(std::string("Reweighter_weightfile"));
29  m_variable = pt.get<std::string>(std::string("Reweighter_variable"));
30 
31  }
32 
33  void ReweighterOptions::save(boost::property_tree::ptree& pt) const
34  {
35  pt.put("Reweighter_version", 1);
36  pt.put(std::string("Reweighter_weightfile"), m_weightfile);
37  pt.put(std::string("Reweighter_variable"), m_variable);
38  }
39 
40  po::options_description ReweighterOptions::getDescription()
41  {
42  po::options_description description("Reweighter options");
43  description.add_options()
44  ("reweighter_weightfile", po::value<std::string>(&m_weightfile),
45  "Weightfile of the expert used to reweight")
46  ("reweighter_variable", po::value<std::string>(&m_variable),
47  "Variable which decides if the reweighter is applied or not");
48  return description;
49  }
50 
51 
53  const ReweighterOptions& specific_options) : Teacher(general_options),
54  m_specific_options(specific_options) { }
55 
57  {
58 
59  Weightfile weightfile;
60 
61  GeneralOptions general_options;
62  auto expert_weightfile = Weightfile::load(m_specific_options.m_weightfile);
63  expert_weightfile.getOptions(general_options);
64 
65  // Override our variables with the one from the expert
66  GeneralOptions mod_general_options = m_general_options;
67  mod_general_options.m_variables = general_options.m_variables;
68  mod_general_options.m_spectators = general_options.m_spectators;
69  mod_general_options.m_target_variable = general_options.m_target_variable;
70  mod_general_options.m_weight_variable = general_options.m_weight_variable;
71 
72  // Add reweighting variable if it is not already present somewhere
73  if (m_specific_options.m_variable != "") {
74  if (std::find(mod_general_options.m_variables.begin(), mod_general_options.m_variables.end(),
75  m_specific_options.m_variable) == mod_general_options.m_variables.end() and
76  std::find(mod_general_options.m_spectators.begin(), mod_general_options.m_spectators.end(),
77  m_specific_options.m_variable) == mod_general_options.m_spectators.end() and
78  mod_general_options.m_target_variable != m_specific_options.m_variable and
79  mod_general_options.m_weight_variable != m_specific_options.m_variable) {
80  mod_general_options.m_spectators.push_back(m_specific_options.m_variable);
81  }
82  }
83 
85  auto supported_interfaces = AbstractInterface::getSupportedInterfaces();
86  if (supported_interfaces.find(general_options.m_method) == supported_interfaces.end()) {
87  B2ERROR("Couldn't find method named " + general_options.m_method);
88  throw std::runtime_error("Couldn't find method named " + general_options.m_method);
89  }
90  auto expert = supported_interfaces[general_options.m_method]->getExpert();
91  expert->load(expert_weightfile);
92 
93  auto prediction = expert->apply(training_data);
94 
95  double data_fraction = expert_weightfile.getSignalFraction();
96  double data_over_mc_fraction = data_fraction / (1 - data_fraction);
97 
98  double sum_reweights = 0;
99  unsigned long int count_reweights = 0;
100 
101  auto isSignal = training_data.getSignals();
102 
103  if (m_specific_options.m_variable != "") {
104  auto variable = training_data.getSpectator(training_data.getSpectatorIndex(m_specific_options.m_variable));
105  for (unsigned int iEvent = 0; iEvent < training_data.getNumberOfEvents(); ++iEvent) {
106  // We calculate the norm only on MC events (that is background), because
107  // this is were we apply the weights in the end
108  if (isSignal[iEvent]) {
109  continue;
110  }
111 
112  if (variable[iEvent] == 1.0) {
113  if (prediction[iEvent] > 0.995)
114  prediction[iEvent] = 0.995;
115  if (prediction[iEvent] < 0.005)
116  prediction[iEvent] = 0.005;
117 
118  prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
119  sum_reweights += prediction[iEvent];
120  count_reweights++;
121  }
122  }
123  } else {
124  for (unsigned int iEvent = 0; iEvent < training_data.getNumberOfEvents(); ++iEvent) {
125  // We calculate the norm only on MC events (that is background), because
126  // this is were we apply the weights in the end
127  if (isSignal[iEvent]) {
128  continue;
129  }
130 
131  if (prediction[iEvent] > 0.995)
132  prediction[iEvent] = 0.995;
133  if (prediction[iEvent] < 0.005)
134  prediction[iEvent] = 0.005;
135 
136  prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
137  sum_reweights += prediction[iEvent];
138  count_reweights++;
139  }
140  }
141 
142  double norm = sum_reweights / count_reweights / data_over_mc_fraction;
143 
144  weightfile.addOptions(mod_general_options);
145  weightfile.addOptions(m_specific_options);
146  weightfile.addFile("Reweighter_Weightfile", m_specific_options.m_weightfile);
147  weightfile.addSignalFraction(data_fraction);
148  weightfile.addElement("Reweighter_norm", norm);
149 
150  return weightfile;
151 
152  }
153 
155  {
156 
157  weightfile.getOptions(m_specific_options);
158 
159  std::string sub_weightfile_name = weightfile.generateFileName(".xml");
160  weightfile.getFile("Reweighter_Weightfile", sub_weightfile_name);
161  auto sub_weightfile = Weightfile::load(sub_weightfile_name);
162  sub_weightfile.getOptions(m_expert_options);
163 
165  auto supported_interfaces = AbstractInterface::getSupportedInterfaces();
166  if (supported_interfaces.find(m_expert_options.m_method) == supported_interfaces.end()) {
167  B2ERROR("Couldn't find method named " + m_expert_options.m_method);
168  throw std::runtime_error("Couldn't find method named " + m_expert_options.m_method);
169  }
170  m_expert = supported_interfaces[m_expert_options.m_method]->getExpert();
171  m_expert->load(sub_weightfile);
172 
173  m_norm = weightfile.getElement<float>("Reweighter_norm");
174  }
175 
176  std::vector<float> ReweighterExpert::apply(Dataset& test_data) const
177  {
178  auto prediction = m_expert->apply(test_data);
179 
180  if (m_specific_options.m_variable != "") {
181  auto variable = test_data.getSpectator(test_data.getSpectatorIndex(m_specific_options.m_variable));
182 
183  for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
184  if (variable[iEvent] != 1.0) {
185  prediction[iEvent] = 1.0;
186  } else {
187  if (prediction[iEvent] > 0.995)
188  prediction[iEvent] = 0.995;
189  if (prediction[iEvent] < 0.005)
190  prediction[iEvent] = 0.005;
191 
192  prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
193  prediction[iEvent] /= m_norm;
194  }
195  }
196  } else {
197  for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
198  if (prediction[iEvent] > 0.995)
199  prediction[iEvent] = 0.995;
200  if (prediction[iEvent] < 0.005)
201  prediction[iEvent] = 0.005;
202 
203  prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
204  prediction[iEvent] /= m_norm;
205  }
206  }
207 
208  return prediction;
209 
210  }
211  }
213 }
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
General options which are shared by all MVA trainings.
Definition: Options.h:62
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:86
std::string m_weight_variable
Weight variable (branch name) defining the weights.
Definition: Options.h:91
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
Definition: Options.h:87
std::string m_method
Name of the MVA method to use.
Definition: Options.h:82
std::string m_target_variable
Target variable (branch name) defining the target.
Definition: Options.h:90
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
Definition: Reweighter.cc:176
std::unique_ptr< Expert > m_expert
Experts used to reweight.
Definition: Reweighter.h:102
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: Reweighter.cc:154
GeneralOptions m_expert_options
Method general options of the expert.
Definition: Reweighter.h:101
ReweighterOptions m_specific_options
Method specific options.
Definition: Reweighter.h:100
double m_norm
Norm for the weights.
Definition: Reweighter.h:103
Options for the Reweighter MVA method.
Definition: Reweighter.h:28
std::string m_weightfile
Weightfile of the reweighting expert.
Definition: Reweighter.h:53
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: Reweighter.cc:40
std::string m_variable
Variable which decides if the reweighter is applied or not.
Definition: Reweighter.h:54
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: Reweighter.cc:20
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: Reweighter.cc:33
ReweighterOptions m_specific_options
Method specific options.
Definition: Reweighter.h:78
ReweighterTeacher(const GeneralOptions &general_options, const ReweighterOptions &specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: Reweighter.cc:52
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: Reweighter.cc:56
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
Definition: Teacher.h:29
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition: Teacher.h:49
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
void addElement(const std::string &identifier, const T &element)
Add an element to the xml tree.
Definition: Weightfile.h:114
T getElement(const std::string &identifier) const
Returns a stored element from the xml tree.
Definition: Weightfile.h:151
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
Definition: Weightfile.cc:115
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition: Weightfile.cc:62
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:67
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
Definition: Weightfile.cc:195
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:95
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
Definition: Weightfile.cc:105
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Definition: Weightfile.cc:138
Abstract base class for different kinds of events.