11 #include <mva/methods/Reweighter.h>
12 #include <mva/interface/Interface.h>
13 #include <framework/logging/Logger.h>
24 int version = pt.get<
int>(
"Reweighter_version");
26 B2ERROR(
"Unkown weightfile version " << std::to_string(version));
27 throw std::runtime_error(
"Unkown weightfile version " + std::to_string(version));
30 m_weightfile = pt.get<std::string>(std::string(
"Reweighter_weightfile"));
31 m_variable = pt.get<std::string>(std::string(
"Reweighter_variable"));
37 pt.put(
"Reweighter_version", 1);
38 pt.put(std::string(
"Reweighter_weightfile"),
m_weightfile);
39 pt.put(std::string(
"Reweighter_variable"),
m_variable);
44 po::options_description description(
"Reweighter options");
45 description.add_options()
46 (
"reweighter_weightfile", po::value<std::string>(&
m_weightfile),
47 "Weightfile of the expert used to reweight")
48 (
"reweighter_variable", po::value<std::string>(&
m_variable),
49 "Variable which decides if the reweighter is applied or not");
56 m_specific_options(specific_options) { }
61 Weightfile weightfile;
65 expert_weightfile.getOptions(general_options);
69 mod_general_options.
m_variables = general_options.m_variables;
70 mod_general_options.
m_spectators = general_options.m_spectators;
88 if (supported_interfaces.find(general_options.m_method) == supported_interfaces.end()) {
89 B2ERROR(
"Couldn't find method named " + general_options.m_method);
90 throw std::runtime_error(
"Couldn't find method named " + general_options.m_method);
92 auto expert = supported_interfaces[general_options.m_method]->getExpert();
93 expert->load(expert_weightfile);
95 auto prediction = expert->apply(training_data);
97 double data_fraction = expert_weightfile.getSignalFraction();
98 double data_over_mc_fraction = data_fraction / (1 - data_fraction);
100 double sum_reweights = 0;
101 unsigned long int count_reweights = 0;
103 auto isSignal = training_data.getSignals();
107 for (
unsigned int iEvent = 0; iEvent < training_data.getNumberOfEvents(); ++iEvent) {
110 if (isSignal[iEvent]) {
114 if (variable[iEvent] == 1.0) {
115 if (prediction[iEvent] > 0.995)
116 prediction[iEvent] = 0.995;
117 if (prediction[iEvent] < 0.005)
118 prediction[iEvent] = 0.005;
120 prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
121 sum_reweights += prediction[iEvent];
126 for (
unsigned int iEvent = 0; iEvent < training_data.getNumberOfEvents(); ++iEvent) {
129 if (isSignal[iEvent]) {
133 if (prediction[iEvent] > 0.995)
134 prediction[iEvent] = 0.995;
135 if (prediction[iEvent] < 0.005)
136 prediction[iEvent] = 0.005;
138 prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
139 sum_reweights += prediction[iEvent];
144 double norm = sum_reweights / count_reweights / data_over_mc_fraction;
146 weightfile.addOptions(mod_general_options);
149 weightfile.addSignalFraction(data_fraction);
150 weightfile.addElement(
"Reweighter_norm", norm);
161 std::string sub_weightfile_name = weightfile.generateFileName(
".xml");
162 weightfile.getFile(
"Reweighter_Weightfile", sub_weightfile_name);
175 m_norm = weightfile.getElement<
float>(
"Reweighter_norm");
180 auto prediction =
m_expert->apply(test_data);
185 for (
unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
186 if (variable[iEvent] != 1.0) {
187 prediction[iEvent] = 1.0;
189 if (prediction[iEvent] > 0.995)
190 prediction[iEvent] = 0.995;
191 if (prediction[iEvent] < 0.005)
192 prediction[iEvent] = 0.005;
194 prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
195 prediction[iEvent] /=
m_norm;
199 for (
unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
200 if (prediction[iEvent] > 0.995)
201 prediction[iEvent] = 0.995;
202 if (prediction[iEvent] < 0.005)
203 prediction[iEvent] = 0.005;
205 prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
206 prediction[iEvent] /=
m_norm;