Belle II Software development
ReweighterTeacher Class Reference

Teacher for the Reweighter MVA method. More...

#include <Reweighter.h>

Inheritance diagram for ReweighterTeacher:
Teacher

Public Member Functions

 ReweighterTeacher (const GeneralOptions &general_options, const ReweighterOptions &specific_options)
 Constructs a new teacher using the GeneralOptions and specific options of this training.
 
virtual Weightfile train (Dataset &training_data) const override
 Train a mva method using the given dataset returning a Weightfile.
 

Protected Attributes

GeneralOptions m_general_options
 GeneralOptions containing all shared options.
 

Private Attributes

ReweighterOptions m_specific_options
 Method specific options.
 

Detailed Description

Teacher for the Reweighter MVA method.

Definition at line 61 of file Reweighter.h.

Constructor & Destructor Documentation

◆ ReweighterTeacher()

ReweighterTeacher ( const GeneralOptions general_options,
const ReweighterOptions specific_options 
)

Constructs a new teacher using the GeneralOptions and specific options of this training.

Parameters
general_optionsdefining all shared options
specific_optionsdefininf all method specific options

Definition at line 52 of file Reweighter.cc.

53 : Teacher(general_options),
54 m_specific_options(specific_options) { }
ReweighterOptions m_specific_options
Method specific options.
Definition: Reweighter.h:78
Teacher(const GeneralOptions &general_options)
Constructs a new teacher using the GeneralOptions for this training.
Definition: Teacher.cc:18

Member Function Documentation

◆ train()

Weightfile train ( Dataset training_data) const
overridevirtual

Train a mva method using the given dataset returning a Weightfile.

Parameters
training_dataused to train the method

Implements Teacher.

Definition at line 56 of file Reweighter.cc.

57 {
58
59 Weightfile weightfile;
60
61 GeneralOptions general_options;
62 auto expert_weightfile = Weightfile::load(m_specific_options.m_weightfile);
63 expert_weightfile.getOptions(general_options);
64
65 // Override our variables with the one from the expert
66 GeneralOptions mod_general_options = m_general_options;
67 mod_general_options.m_variables = general_options.m_variables;
68 mod_general_options.m_spectators = general_options.m_spectators;
69 mod_general_options.m_target_variable = general_options.m_target_variable;
70 mod_general_options.m_weight_variable = general_options.m_weight_variable;
71
72 // Add reweighting variable if it is not already present somewhere
74 if (std::find(mod_general_options.m_variables.begin(), mod_general_options.m_variables.end(),
75 m_specific_options.m_variable) == mod_general_options.m_variables.end() and
76 std::find(mod_general_options.m_spectators.begin(), mod_general_options.m_spectators.end(),
77 m_specific_options.m_variable) == mod_general_options.m_spectators.end() and
78 mod_general_options.m_target_variable != m_specific_options.m_variable and
79 mod_general_options.m_weight_variable != m_specific_options.m_variable) {
80 mod_general_options.m_spectators.push_back(m_specific_options.m_variable);
81 }
82 }
83
85 auto supported_interfaces = AbstractInterface::getSupportedInterfaces();
86 if (supported_interfaces.find(general_options.m_method) == supported_interfaces.end()) {
87 B2ERROR("Couldn't find method named " + general_options.m_method);
88 throw std::runtime_error("Couldn't find method named " + general_options.m_method);
89 }
90 auto expert = supported_interfaces[general_options.m_method]->getExpert();
91 expert->load(expert_weightfile);
92
93 auto prediction = expert->apply(training_data);
94
95 double data_fraction = expert_weightfile.getSignalFraction();
96 double data_over_mc_fraction = data_fraction / (1 - data_fraction);
97
98 double sum_reweights = 0;
99 unsigned long int count_reweights = 0;
100
101 auto isSignal = training_data.getSignals();
102
103 if (m_specific_options.m_variable != "") {
104 auto variable = training_data.getSpectator(training_data.getSpectatorIndex(m_specific_options.m_variable));
105 for (unsigned int iEvent = 0; iEvent < training_data.getNumberOfEvents(); ++iEvent) {
106 // We calculate the norm only on MC events (that is background), because
107 // this is were we apply the weights in the end
108 if (isSignal[iEvent]) {
109 continue;
110 }
111
112 if (variable[iEvent] == 1.0) {
113 if (prediction[iEvent] > 0.995)
114 prediction[iEvent] = 0.995;
115 if (prediction[iEvent] < 0.005)
116 prediction[iEvent] = 0.005;
117
118 prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
119 sum_reweights += prediction[iEvent];
120 count_reweights++;
121 }
122 }
123 } else {
124 for (unsigned int iEvent = 0; iEvent < training_data.getNumberOfEvents(); ++iEvent) {
125 // We calculate the norm only on MC events (that is background), because
126 // this is were we apply the weights in the end
127 if (isSignal[iEvent]) {
128 continue;
129 }
130
131 if (prediction[iEvent] > 0.995)
132 prediction[iEvent] = 0.995;
133 if (prediction[iEvent] < 0.005)
134 prediction[iEvent] = 0.005;
135
136 prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
137 sum_reweights += prediction[iEvent];
138 count_reweights++;
139 }
140 }
141
142 double norm = sum_reweights / count_reweights / data_over_mc_fraction;
143
144 weightfile.addOptions(mod_general_options);
145 weightfile.addOptions(m_specific_options);
146 weightfile.addFile("Reweighter_Weightfile", m_specific_options.m_weightfile);
147 weightfile.addSignalFraction(data_fraction);
148 weightfile.addElement("Reweighter_norm", norm);
149
150 return weightfile;
151
152 }
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:86
std::string m_weightfile
Weightfile of the reweighting expert.
Definition: Reweighter.h:53
std::string m_variable
Variable which decides if the reweighter is applied or not.
Definition: Reweighter.h:54
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition: Teacher.h:49
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
Definition: Weightfile.cc:195

Member Data Documentation

◆ m_general_options

GeneralOptions m_general_options
protectedinherited

GeneralOptions containing all shared options.

Definition at line 49 of file Teacher.h.

◆ m_specific_options

ReweighterOptions m_specific_options
private

Method specific options.

Definition at line 78 of file Reweighter.h.


The documentation for this class was generated from the following files: