9#include <mva/methods/Reweighter.h> 
   10#include <mva/interface/Interface.h> 
   11#include <framework/logging/Logger.h> 
   22      int version = pt.get<
int>(
"Reweighter_version");
 
   24        B2ERROR(
"Unknown weightfile version " << std::to_string(version));
 
   25        throw std::runtime_error(
"Unknown weightfile version " + std::to_string(version));
 
   28      m_weightfile = pt.get<std::string>(std::string(
"Reweighter_weightfile"));
 
   29      m_variable = pt.get<std::string>(std::string(
"Reweighter_variable"));
 
 
   35      pt.put(
"Reweighter_version", 1);
 
   36      pt.put(std::string(
"Reweighter_weightfile"), 
m_weightfile);
 
   37      pt.put(std::string(
"Reweighter_variable"), 
m_variable);
 
 
   42      po::options_description description(
"Reweighter options");
 
   43      description.add_options()
 
   44      (
"reweighter_weightfile", po::value<std::string>(&
m_weightfile),
 
   45       "Weightfile of the expert used to reweight")
 
   46      (
"reweighter_variable", po::value<std::string>(&
m_variable),
 
   47       "Variable which decides if the reweighter is applied or not");
 
 
   63      expert_weightfile.getOptions(general_options);
 
   67      mod_general_options.
m_variables = general_options.m_variables;
 
   68      mod_general_options.
m_spectators = general_options.m_spectators;
 
   86      if (supported_interfaces.find(general_options.m_method) == supported_interfaces.end()) {
 
   87        B2ERROR(
"Couldn't find method named " + general_options.m_method);
 
   88        throw std::runtime_error(
"Couldn't find method named " + general_options.m_method);
 
   90      auto expert = supported_interfaces[general_options.m_method]->getExpert();
 
   91      expert->load(expert_weightfile);
 
   93      auto prediction = expert->apply(training_data);
 
   95      double data_fraction = expert_weightfile.getSignalFraction();
 
   96      double data_over_mc_fraction = data_fraction / (1 - data_fraction);
 
   98      double sum_reweights = 0;
 
   99      unsigned long int count_reweights = 0;
 
  101      auto isSignal = training_data.getSignals();
 
  104        auto variable = training_data.getSpectator(training_data.getSpectatorIndex(
m_specific_options.m_variable));
 
  105        for (
unsigned int iEvent = 0; iEvent < training_data.getNumberOfEvents(); ++iEvent) {
 
  108          if (isSignal[iEvent]) {
 
  112          if (variable[iEvent] == 1.0) {
 
  113            if (prediction[iEvent] > 0.995)
 
  114              prediction[iEvent] = 0.995;
 
  115            if (prediction[iEvent] < 0.005)
 
  116              prediction[iEvent] = 0.005;
 
  118            prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
 
  119            sum_reweights += prediction[iEvent];
 
  124        for (
unsigned int iEvent = 0; iEvent < training_data.getNumberOfEvents(); ++iEvent) {
 
  127          if (isSignal[iEvent]) {
 
  131          if (prediction[iEvent] > 0.995)
 
  132            prediction[iEvent] = 0.995;
 
  133          if (prediction[iEvent] < 0.005)
 
  134            prediction[iEvent] = 0.005;
 
  136          prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
 
  137          sum_reweights += prediction[iEvent];
 
  142      double norm = sum_reweights / count_reweights / data_over_mc_fraction;
 
  148      weightfile.
addElement(
"Reweighter_norm", norm);
 
 
  160      weightfile.
getFile(
"Reweighter_Weightfile", sub_weightfile_name);
 
  166      if (supported_interfaces.find(
m_expert_options.m_method) == supported_interfaces.end()) {
 
  168        throw std::runtime_error(
"Couldn't find method named " + 
m_expert_options.m_method);
 
 
  178      auto prediction = 
m_expert->apply(test_data);
 
  181        auto variable = test_data.getSpectator(test_data.getSpectatorIndex(
m_specific_options.m_variable));
 
  183        for (
unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
 
  184          if (variable[iEvent] != 1.0) {
 
  185            prediction[iEvent] = 1.0;
 
  187            if (prediction[iEvent] > 0.995)
 
  188              prediction[iEvent] = 0.995;
 
  189            if (prediction[iEvent] < 0.005)
 
  190              prediction[iEvent] = 0.005;
 
  192            prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
 
  193            prediction[iEvent] /= 
m_norm;
 
  197        for (
unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
 
  198          if (prediction[iEvent] > 0.995)
 
  199            prediction[iEvent] = 0.995;
 
  200          if (prediction[iEvent] < 0.005)
 
  201            prediction[iEvent] = 0.005;
 
  203          prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
 
  204          prediction[iEvent] /= 
m_norm;
 
 
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
General options which are shared by all MVA trainings.
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
std::string m_weight_variable
Weight variable (branch name) defining the weights.
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
std::string m_target_variable
Target variable (branch name) defining the target.
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
std::unique_ptr< Expert > m_expert
Experts used to reweight.
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
GeneralOptions m_expert_options
Method general options of the expert.
ReweighterOptions m_specific_options
Method specific options.
double m_norm
Norm for the weights.
Options for the Reweighter MVA method.
std::string m_weightfile
Weightfile of the reweighting expert.
virtual po::options_description getDescription() override
Returns a program options description for all available options.
std::string m_variable
Variable which decides if the reweighter is applied or not.
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
ReweighterOptions m_specific_options
Method specific options.
ReweighterTeacher(const GeneralOptions &general_options, const ReweighterOptions &specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Teacher(const GeneralOptions &general_options)
Constructs a new teacher using the GeneralOptions for this training.
The Weightfile class serializes all information about a training into an xml tree.
void addElement(const std::string &identifier, const T &element)
Add an element to the xml tree.
T getElement(const std::string &identifier) const
Returns a stored element from the xml tree.
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
void addOptions(const Options &options)
Add an Option object to the xml tree.
void getOptions(Options &options) const
Fills an Option object from the xml tree.
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Abstract base class for different kinds of events.