Belle II Software development
Reweighter.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/methods/Reweighter.h>
10#include <mva/interface/Interface.h>
11#include <framework/logging/Logger.h>
12
13namespace Belle2 {
18 namespace MVA {
19
20 void ReweighterOptions::load(const boost::property_tree::ptree& pt)
21 {
22 int version = pt.get<int>("Reweighter_version");
23 if (version != 1) {
24 B2ERROR("Unknown weightfile version " << std::to_string(version));
25 throw std::runtime_error("Unknown weightfile version " + std::to_string(version));
26 }
27
28 m_weightfile = pt.get<std::string>(std::string("Reweighter_weightfile"));
29 m_variable = pt.get<std::string>(std::string("Reweighter_variable"));
30
31 }
32
33 void ReweighterOptions::save(boost::property_tree::ptree& pt) const
34 {
35 pt.put("Reweighter_version", 1);
36 pt.put(std::string("Reweighter_weightfile"), m_weightfile);
37 pt.put(std::string("Reweighter_variable"), m_variable);
38 }
39
40 po::options_description ReweighterOptions::getDescription()
41 {
42 po::options_description description("Reweighter options");
43 description.add_options()
44 ("reweighter_weightfile", po::value<std::string>(&m_weightfile),
45 "Weightfile of the expert used to reweight")
46 ("reweighter_variable", po::value<std::string>(&m_variable),
47 "Variable which decides if the reweighter is applied or not");
48 return description;
49 }
50
51
53 const ReweighterOptions& specific_options) : Teacher(general_options),
54 m_specific_options(specific_options) { }
55
57 {
58
59 Weightfile weightfile;
60
61 GeneralOptions general_options;
62 auto expert_weightfile = Weightfile::load(m_specific_options.m_weightfile);
63 expert_weightfile.getOptions(general_options);
64
65 // Override our variables with the one from the expert
66 GeneralOptions mod_general_options = m_general_options;
67 mod_general_options.m_variables = general_options.m_variables;
68 mod_general_options.m_spectators = general_options.m_spectators;
69 mod_general_options.m_target_variable = general_options.m_target_variable;
70 mod_general_options.m_weight_variable = general_options.m_weight_variable;
71
72 // Add reweighting variable if it is not already present somewhere
74 if (std::find(mod_general_options.m_variables.begin(), mod_general_options.m_variables.end(),
75 m_specific_options.m_variable) == mod_general_options.m_variables.end() and
76 std::find(mod_general_options.m_spectators.begin(), mod_general_options.m_spectators.end(),
77 m_specific_options.m_variable) == mod_general_options.m_spectators.end() and
78 mod_general_options.m_target_variable != m_specific_options.m_variable and
79 mod_general_options.m_weight_variable != m_specific_options.m_variable) {
80 mod_general_options.m_spectators.push_back(m_specific_options.m_variable);
81 }
82 }
83
85 auto supported_interfaces = AbstractInterface::getSupportedInterfaces();
86 if (supported_interfaces.find(general_options.m_method) == supported_interfaces.end()) {
87 B2ERROR("Couldn't find method named " + general_options.m_method);
88 throw std::runtime_error("Couldn't find method named " + general_options.m_method);
89 }
90 auto expert = supported_interfaces[general_options.m_method]->getExpert();
91 expert->load(expert_weightfile);
92
93 auto prediction = expert->apply(training_data);
94
95 double data_fraction = expert_weightfile.getSignalFraction();
96 double data_over_mc_fraction = data_fraction / (1 - data_fraction);
97
98 double sum_reweights = 0;
99 unsigned long int count_reweights = 0;
100
101 auto isSignal = training_data.getSignals();
102
103 if (m_specific_options.m_variable != "") {
104 auto variable = training_data.getSpectator(training_data.getSpectatorIndex(m_specific_options.m_variable));
105 for (unsigned int iEvent = 0; iEvent < training_data.getNumberOfEvents(); ++iEvent) {
106 // We calculate the norm only on MC events (that is background), because
107 // this is were we apply the weights in the end
108 if (isSignal[iEvent]) {
109 continue;
110 }
111
112 if (variable[iEvent] == 1.0) {
113 if (prediction[iEvent] > 0.995)
114 prediction[iEvent] = 0.995;
115 if (prediction[iEvent] < 0.005)
116 prediction[iEvent] = 0.005;
117
118 prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
119 sum_reweights += prediction[iEvent];
120 count_reweights++;
121 }
122 }
123 } else {
124 for (unsigned int iEvent = 0; iEvent < training_data.getNumberOfEvents(); ++iEvent) {
125 // We calculate the norm only on MC events (that is background), because
126 // this is were we apply the weights in the end
127 if (isSignal[iEvent]) {
128 continue;
129 }
130
131 if (prediction[iEvent] > 0.995)
132 prediction[iEvent] = 0.995;
133 if (prediction[iEvent] < 0.005)
134 prediction[iEvent] = 0.005;
135
136 prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
137 sum_reweights += prediction[iEvent];
138 count_reweights++;
139 }
140 }
141
142 double norm = sum_reweights / count_reweights / data_over_mc_fraction;
143
144 weightfile.addOptions(mod_general_options);
145 weightfile.addOptions(m_specific_options);
146 weightfile.addFile("Reweighter_Weightfile", m_specific_options.m_weightfile);
147 weightfile.addSignalFraction(data_fraction);
148 weightfile.addElement("Reweighter_norm", norm);
149
150 return weightfile;
151
152 }
153
155 {
156
157 weightfile.getOptions(m_specific_options);
158
159 std::string sub_weightfile_name = weightfile.generateFileName(".xml");
160 weightfile.getFile("Reweighter_Weightfile", sub_weightfile_name);
161 auto sub_weightfile = Weightfile::load(sub_weightfile_name);
162 sub_weightfile.getOptions(m_expert_options);
163
165 auto supported_interfaces = AbstractInterface::getSupportedInterfaces();
166 if (supported_interfaces.find(m_expert_options.m_method) == supported_interfaces.end()) {
167 B2ERROR("Couldn't find method named " + m_expert_options.m_method);
168 throw std::runtime_error("Couldn't find method named " + m_expert_options.m_method);
169 }
170 m_expert = supported_interfaces[m_expert_options.m_method]->getExpert();
171 m_expert->load(sub_weightfile);
172
173 m_norm = weightfile.getElement<float>("Reweighter_norm");
174 }
175
176 std::vector<float> ReweighterExpert::apply(Dataset& test_data) const
177 {
178 auto prediction = m_expert->apply(test_data);
179
180 if (m_specific_options.m_variable != "") {
181 auto variable = test_data.getSpectator(test_data.getSpectatorIndex(m_specific_options.m_variable));
182
183 for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
184 if (variable[iEvent] != 1.0) {
185 prediction[iEvent] = 1.0;
186 } else {
187 if (prediction[iEvent] > 0.995)
188 prediction[iEvent] = 0.995;
189 if (prediction[iEvent] < 0.005)
190 prediction[iEvent] = 0.005;
191
192 prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
193 prediction[iEvent] /= m_norm;
194 }
195 }
196 } else {
197 for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
198 if (prediction[iEvent] > 0.995)
199 prediction[iEvent] = 0.995;
200 if (prediction[iEvent] < 0.005)
201 prediction[iEvent] = 0.005;
202
203 prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
204 prediction[iEvent] /= m_norm;
205 }
206 }
207
208 return prediction;
209
210 }
211 }
213}
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
General options which are shared by all MVA trainings.
Definition: Options.h:62
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:86
std::string m_weight_variable
Weight variable (branch name) defining the weights.
Definition: Options.h:91
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
Definition: Options.h:87
std::string m_method
Name of the MVA method to use.
Definition: Options.h:82
std::string m_target_variable
Target variable (branch name) defining the target.
Definition: Options.h:90
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
Definition: Reweighter.cc:176
std::unique_ptr< Expert > m_expert
Experts used to reweight.
Definition: Reweighter.h:102
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: Reweighter.cc:154
GeneralOptions m_expert_options
Method general options of the expert.
Definition: Reweighter.h:101
ReweighterOptions m_specific_options
Method specific options.
Definition: Reweighter.h:100
double m_norm
Norm for the weights.
Definition: Reweighter.h:103
Options for the Reweighter MVA method.
Definition: Reweighter.h:28
std::string m_weightfile
Weightfile of the reweighting expert.
Definition: Reweighter.h:53
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: Reweighter.cc:40
std::string m_variable
Variable which decides if the reweighter is applied or not.
Definition: Reweighter.h:54
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: Reweighter.cc:20
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: Reweighter.cc:33
ReweighterOptions m_specific_options
Method specific options.
Definition: Reweighter.h:78
ReweighterTeacher(const GeneralOptions &general_options, const ReweighterOptions &specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: Reweighter.cc:52
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: Reweighter.cc:56
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
Definition: Teacher.h:29
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition: Teacher.h:49
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
void addElement(const std::string &identifier, const T &element)
Add an element to the xml tree.
Definition: Weightfile.h:114
T getElement(const std::string &identifier) const
Returns a stored element from the xml tree.
Definition: Weightfile.h:151
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
Definition: Weightfile.cc:115
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition: Weightfile.cc:62
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:67
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
Definition: Weightfile.cc:195
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:95
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
Definition: Weightfile.cc:105
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Definition: Weightfile.cc:138
Abstract base class for different kinds of events.