9#include <mva/methods/Reweighter.h>
10#include <mva/interface/Interface.h>
11#include <framework/logging/Logger.h>
22 int version = pt.get<
int>(
"Reweighter_version");
24 B2ERROR(
"Unknown weightfile version " << std::to_string(version));
25 throw std::runtime_error(
"Unknown weightfile version " + std::to_string(version));
28 m_weightfile = pt.get<std::string>(std::string(
"Reweighter_weightfile"));
29 m_variable = pt.get<std::string>(std::string(
"Reweighter_variable"));
35 pt.put(
"Reweighter_version", 1);
36 pt.put(std::string(
"Reweighter_weightfile"),
m_weightfile);
37 pt.put(std::string(
"Reweighter_variable"),
m_variable);
42 po::options_description description(
"Reweighter options");
43 description.add_options()
44 (
"reweighter_weightfile", po::value<std::string>(&
m_weightfile),
45 "Weightfile of the expert used to reweight")
46 (
"reweighter_variable", po::value<std::string>(&
m_variable),
47 "Variable which decides if the reweighter is applied or not");
54 m_specific_options(specific_options) { }
63 expert_weightfile.getOptions(general_options);
67 mod_general_options.
m_variables = general_options.m_variables;
68 mod_general_options.
m_spectators = general_options.m_spectators;
86 if (supported_interfaces.find(general_options.m_method) == supported_interfaces.end()) {
87 B2ERROR(
"Couldn't find method named " + general_options.m_method);
88 throw std::runtime_error(
"Couldn't find method named " + general_options.m_method);
90 auto expert = supported_interfaces[general_options.m_method]->getExpert();
91 expert->load(expert_weightfile);
93 auto prediction = expert->apply(training_data);
95 double data_fraction = expert_weightfile.getSignalFraction();
96 double data_over_mc_fraction = data_fraction / (1 - data_fraction);
98 double sum_reweights = 0;
99 unsigned long int count_reweights = 0;
101 auto isSignal = training_data.getSignals();
105 for (
unsigned int iEvent = 0; iEvent < training_data.getNumberOfEvents(); ++iEvent) {
108 if (isSignal[iEvent]) {
112 if (variable[iEvent] == 1.0) {
113 if (prediction[iEvent] > 0.995)
114 prediction[iEvent] = 0.995;
115 if (prediction[iEvent] < 0.005)
116 prediction[iEvent] = 0.005;
118 prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
119 sum_reweights += prediction[iEvent];
124 for (
unsigned int iEvent = 0; iEvent < training_data.getNumberOfEvents(); ++iEvent) {
127 if (isSignal[iEvent]) {
131 if (prediction[iEvent] > 0.995)
132 prediction[iEvent] = 0.995;
133 if (prediction[iEvent] < 0.005)
134 prediction[iEvent] = 0.005;
136 prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
137 sum_reweights += prediction[iEvent];
142 double norm = sum_reweights / count_reweights / data_over_mc_fraction;
148 weightfile.
addElement(
"Reweighter_norm", norm);
160 weightfile.
getFile(
"Reweighter_Weightfile", sub_weightfile_name);
178 auto prediction =
m_expert->apply(test_data);
183 for (
unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
184 if (variable[iEvent] != 1.0) {
185 prediction[iEvent] = 1.0;
187 if (prediction[iEvent] > 0.995)
188 prediction[iEvent] = 0.995;
189 if (prediction[iEvent] < 0.005)
190 prediction[iEvent] = 0.005;
192 prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
193 prediction[iEvent] /=
m_norm;
197 for (
unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
198 if (prediction[iEvent] > 0.995)
199 prediction[iEvent] = 0.995;
200 if (prediction[iEvent] < 0.005)
201 prediction[iEvent] = 0.005;
203 prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
204 prediction[iEvent] /=
m_norm;
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
General options which are shared by all MVA trainings.
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
std::string m_weight_variable
Weight variable (branch name) defining the weights.
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
std::string m_method
Name of the MVA method to use.
std::string m_target_variable
Target variable (branch name) defining the target.
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
std::unique_ptr< Expert > m_expert
Experts used to reweight.
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
GeneralOptions m_expert_options
Method general options of the expert.
ReweighterOptions m_specific_options
Method specific options.
double m_norm
Norm for the weights.
Options for the Reweighter MVA method.
std::string m_weightfile
Weightfile of the reweighting expert.
virtual po::options_description getDescription() override
Returns a program options description for all available options.
std::string m_variable
Variable which decides if the reweighter is applied or not.
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
ReweighterOptions m_specific_options
Method specific options.
ReweighterTeacher(const GeneralOptions &general_options, const ReweighterOptions &specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
GeneralOptions m_general_options
GeneralOptions containing all shared options.
The Weightfile class serializes all information about a training into an xml tree.
void addElement(const std::string &identifier, const T &element)
Add an element to the xml tree.
T getElement(const std::string &identifier) const
Returns a stored element from the xml tree.
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
void addOptions(const Options &options)
Add an Option object to the xml tree.
void getOptions(Options &options) const
Fills an Option object from the xml tree.
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Abstract base class for different kinds of events.