57 {
58
59 Weightfile weightfile;
60
61 GeneralOptions general_options;
63 expert_weightfile.getOptions(general_options);
64
65
67 mod_general_options.
m_variables = general_options.m_variables;
68 mod_general_options.m_spectators = general_options.m_spectators;
69 mod_general_options.m_target_variable = general_options.m_target_variable;
70 mod_general_options.m_weight_variable = general_options.m_weight_variable;
71
72
74 if (std::find(mod_general_options.m_variables.begin(), mod_general_options.m_variables.end(),
76 std::find(mod_general_options.m_spectators.begin(), mod_general_options.m_spectators.end(),
81 }
82 }
83
86 if (supported_interfaces.find(general_options.m_method) == supported_interfaces.end()) {
87 B2ERROR("Couldn't find method named " + general_options.m_method);
88 throw std::runtime_error("Couldn't find method named " + general_options.m_method);
89 }
90 auto expert = supported_interfaces[general_options.m_method]->getExpert();
91 expert->load(expert_weightfile);
92
93 auto prediction = expert->apply(training_data);
94
95 double data_fraction = expert_weightfile.getSignalFraction();
96 double data_over_mc_fraction = data_fraction / (1 - data_fraction);
97
98 double sum_reweights = 0;
99 unsigned long int count_reweights = 0;
100
101 auto isSignal = training_data.getSignals();
102
105 for (unsigned int iEvent = 0; iEvent < training_data.getNumberOfEvents(); ++iEvent) {
106
107
108 if (isSignal[iEvent]) {
109 continue;
110 }
111
112 if (variable[iEvent] == 1.0) {
113 if (prediction[iEvent] > 0.995)
114 prediction[iEvent] = 0.995;
115 if (prediction[iEvent] < 0.005)
116 prediction[iEvent] = 0.005;
117
118 prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
119 sum_reweights += prediction[iEvent];
120 count_reweights++;
121 }
122 }
123 } else {
124 for (unsigned int iEvent = 0; iEvent < training_data.getNumberOfEvents(); ++iEvent) {
125
126
127 if (isSignal[iEvent]) {
128 continue;
129 }
130
131 if (prediction[iEvent] > 0.995)
132 prediction[iEvent] = 0.995;
133 if (prediction[iEvent] < 0.005)
134 prediction[iEvent] = 0.005;
135
136 prediction[iEvent] = (prediction[iEvent]) / (1 - prediction[iEvent]);
137 sum_reweights += prediction[iEvent];
138 count_reweights++;
139 }
140 }
141
142 double norm = sum_reweights / count_reweights / data_over_mc_fraction;
143
144 weightfile.addOptions(mod_general_options);
147 weightfile.addSignalFraction(data_fraction);
148 weightfile.addElement("Reweighter_norm", norm);
149
150 return weightfile;
151
152 }
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
std::string m_weightfile
Weightfile of the reweighting expert.
std::string m_variable
Variable which decides if the reweighter is applied or not.
GeneralOptions m_general_options
GeneralOptions containing all shared options.
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.