11 #include <mva/utility/Utility.h>
12 #include <mva/utility/DataDriven.h>
13 #include <mva/methods/PDF.h>
14 #include <mva/methods/Reweighter.h>
15 #include <mva/methods/Trivial.h>
16 #include <mva/methods/Combination.h>
18 #include <framework/logging/Logger.h>
20 #include <framework/utilities/MakeROOTCompatible.h>
22 #include <boost/algorithm/string/predicate.hpp>
23 #include <boost/property_tree/xml_parser.hpp>
38 void loadRootDictionary() { }
40 void download(
const std::string& identifier,
const std::string& filename,
int experiment,
int run,
int event)
44 if (boost::ends_with(filename,
".root")) {
46 }
else if (boost::ends_with(filename,
".xml")) {
49 std::cerr <<
"Unkown file extension, fallback to xml" << std::endl;
54 void upload(
const std::string& filename,
const std::string& identifier,
int exp1,
int run1,
int exp2,
int run2)
58 if (boost::ends_with(filename,
".root")) {
60 }
else if (boost::ends_with(filename,
".xml")) {
63 std::cerr <<
"Unkown file extension, fallback to xml" << std::endl;
69 void upload_array(std::vector<std::string>& filenames,
const std::string& identifier,
int exp1,
int run1,
int exp2,
int run2)
73 std::vector<Belle2::MVA::Weightfile> weightfiles;
74 for (
const auto& filename : filenames) {
77 if (boost::ends_with(filename,
".root")) {
79 }
else if (boost::ends_with(filename,
".xml")) {
82 std::cerr <<
"Unkown file extension, fallback to xml" << std::endl;
85 weightfiles.push_back(weightfile);
90 void extract(
const std::string& filename,
const std::string& directory)
98 GeneralOptions general_options;
100 auto expert = supported_interfaces[general_options.m_method]->getExpert();
101 expert->load(weightfile);
105 std::string info(
const std::string& filename)
111 GeneralOptions general_options;
114 auto specific_options = supported_interfaces[general_options.m_method]->getOptions();
115 specific_options->load(weightfile.
getXMLTree());
117 boost::property_tree::ptree temp_tree;
118 general_options.save(temp_tree);
119 specific_options->save(temp_tree);
120 std::ostringstream oss;
122 #if BOOST_VERSION < 105600
123 boost::property_tree::xml_writer_settings<char> settings(
'\t', 1);
125 boost::property_tree::xml_writer_settings<std::string> settings(
'\t', 1);
127 boost::property_tree::xml_parser::write_xml(oss, temp_tree, settings);;
133 bool available(
const std::string& filename,
int experiment,
int run,
int event)
145 void expert(
const std::vector<std::string>& filenames,
const std::vector<std::string>& datafiles,
const std::string& treename,
146 const std::string& outputfile,
int experiment,
int run,
int event,
bool copy_target)
149 std::vector<Weightfile> weightfiles;
150 std::vector<TBranch*> branches;
152 TFile file(outputfile.c_str(),
"RECREATE");
154 TTree tree(
"variables",
"variables");
157 for (
auto& filename : filenames) {
160 weightfiles.push_back(weightfile);
163 auto branch = tree.Branch(branchname.c_str(), &result, (branchname +
"/F").c_str());
164 branches.push_back(branch);
171 for (
auto& weightfile : weightfiles) {
172 GeneralOptions general_options;
174 general_options.m_treename = treename;
177 general_options.m_max_events = 0;
179 auto expert = supported_interfaces[general_options.m_method]->getExpert();
180 expert->load(weightfile);
182 if (not copy_target) {
183 general_options.m_target_variable = std::string();
186 general_options.m_datafiles = datafiles;
187 auto& branch = branches[i];
188 ROOTDataset data(general_options);
189 std::chrono::high_resolution_clock::time_point start = std::chrono::high_resolution_clock::now();
190 auto results = expert->apply(data);
191 std::chrono::high_resolution_clock::time_point stop = std::chrono::high_resolution_clock::now();
192 std::chrono::duration<double, std::milli> training_time = stop - start;
193 B2INFO(
"Elapsed application time in ms " << training_time.count() <<
" for " << general_options.m_identifier);
194 for (
auto& r : results) {
200 if (not general_options.m_target_variable.empty()) {
201 std::string branchname =
Belle2::makeROOTCompatible(std::string(branch->GetName()) +
"_" + general_options.m_target_variable);
203 auto target_branch = tree.Branch(branchname.c_str(), &target, (branchname +
"/F").c_str());
204 auto targets = data.getTargets();
205 for (
auto& t : targets) {
207 target_branch->Fill();
215 file.Write(
"variables");
219 void save_custom_weightfile(
const GeneralOptions& general_options,
const SpecificOptions& specific_options,
220 const std::string& custom_weightfile,
const std::string& output_identifier)
222 std::ifstream ifile(custom_weightfile);
224 B2FATAL(
"Input weight file: " << custom_weightfile <<
" does not exist!");
227 Weightfile weightfile;
229 weightfile.addOptions(specific_options);
230 weightfile.addFile(general_options.m_identifier +
"_Weightfile", custom_weightfile);
231 std::string output_weightfile(custom_weightfile);
232 if (!output_identifier.empty()) {
233 std::regex to_replace(
"(\\.\\S+$)");
234 std::string replacement =
"_" + output_identifier +
"$0";
235 output_weightfile = std::regex_replace(output_weightfile, to_replace, replacement);
240 void teacher(
const GeneralOptions& general_options,
const SpecificOptions& specific_options,
const MetaOptions& meta_options)
242 unsigned int number_of_enabled_meta_trainings = 0;
243 if (meta_options.m_use_splot)
244 number_of_enabled_meta_trainings++;
245 if (meta_options.m_use_sideband_substraction)
246 number_of_enabled_meta_trainings++;
247 if (meta_options.m_use_reweighting)
248 number_of_enabled_meta_trainings++;
250 if (number_of_enabled_meta_trainings > 1) {
251 B2ERROR(
"You enabled more than one meta training option. You can only use one (sPlot, SidebandSubstraction or Reweighting)");
255 if (meta_options.m_use_splot) {
256 teacher_splot(general_options, specific_options, meta_options);
257 }
else if (meta_options.m_use_sideband_substraction) {
258 teacher_sideband_substraction(general_options, specific_options, meta_options);
259 }
else if (meta_options.m_use_reweighting) {
260 teacher_reweighting(general_options, specific_options, meta_options);
262 ROOTDataset data(general_options);
263 teacher_dataset(general_options, specific_options, data);
268 std::unique_ptr<Belle2::MVA::Expert> teacher_dataset(GeneralOptions general_options,
const SpecificOptions& specific_options,
271 if (general_options.m_method.empty()) {
272 general_options.m_method = specific_options.getMethod();
274 if (general_options.m_method != specific_options.getMethod()) {
275 B2ERROR(
"The method specified in the general options is in conflict with the provided specific option:" << general_options.m_method
276 <<
" " << specific_options.getMethod());
281 if (supported_interfaces.find(general_options.m_method) != supported_interfaces.end()) {
282 auto teacher = supported_interfaces[general_options.m_method]->getTeacher(general_options, specific_options);
283 std::chrono::high_resolution_clock::time_point start = std::chrono::high_resolution_clock::now();
284 auto weightfile = teacher->train(data);
285 std::chrono::high_resolution_clock::time_point stop = std::chrono::high_resolution_clock::now();
286 std::chrono::duration<double, std::milli> training_time = stop - start;
287 B2INFO(
"Elapsed training time in ms " << training_time.count() <<
" for " << general_options.m_identifier);
289 auto expert = supported_interfaces[general_options.m_method]->getExpert();
290 expert->load(weightfile);
293 B2ERROR(
"Interface doesn't support chosen method" << general_options.m_method);
294 throw std::runtime_error(
"Interface doesn't support chosen method" + general_options.m_method);
298 std::unique_ptr<Belle2::MVA::Expert> teacher_splot(
const GeneralOptions& general_options,
const SpecificOptions& specific_options,
299 const MetaOptions& meta_options)
302 GeneralOptions data_general_options = general_options;
303 data_general_options.m_target_variable =
"";
304 if (meta_options.m_splot_combined)
305 data_general_options.m_identifier = general_options.m_identifier +
"_splot.xml";
306 ROOTDataset data_dataset(data_general_options);
308 data_general_options.m_target_variable = general_options.m_target_variable;
310 GeneralOptions discriminant_general_options = general_options;
311 discriminant_general_options.m_target_variable =
"";
312 discriminant_general_options.m_variables = {meta_options.m_splot_variable};
313 ROOTDataset discriminant_dataset(discriminant_general_options);
315 discriminant_general_options.m_target_variable = general_options.m_target_variable;
317 GeneralOptions mc_general_options = general_options;
318 mc_general_options.m_datafiles = meta_options.m_splot_mc_files;
319 mc_general_options.m_variables = {meta_options.m_splot_variable};
320 ROOTDataset mc_dataset(mc_general_options);
322 auto mc_signals = mc_dataset.getSignals();
323 auto mc_weights = mc_dataset.getWeights();
324 auto mc_feature = mc_dataset.getFeature(0);
325 auto data_feature = discriminant_dataset.getFeature(0);
326 auto data_weights = discriminant_dataset.getWeights();
330 float signalFraction = binning.m_signal_yield / (binning.m_signal_yield + binning.m_bckgrd_yield);
332 std::vector<double> data(100, 0);
333 double total_data = 0.0;
334 for (
unsigned int iEvent = 0; iEvent < data_dataset.getNumberOfEvents(); ++iEvent) {
335 data[binning.getBin(data_feature[iEvent])] += data_weights[iEvent];
336 total_data += data_weights[iEvent];
343 float best_yield = 0.0;
344 double best_chi2 = 1000000000.0;
345 bool empty_bin =
false;
346 for (
double yield = 0; yield < total_data; yield += 1) {
348 for (
unsigned int iBin = 0; iBin < 100; ++iBin) {
349 double deviation = (data[iBin] - (yield * binning.m_signal_pdf[iBin] + (total_data - yield) * binning.m_bckgrd_pdf[iBin]) *
350 (binning.m_boundaries[iBin + 1] - binning.m_boundaries[iBin]) / (binning.m_boundaries[100] - binning.m_boundaries[0]));
352 chi2 += deviation * deviation / data[iBin];
356 if (chi2 < best_chi2) {
363 B2WARNING(
"Encountered empty bin in data histogram during fit of the components for sPlot");
366 B2INFO(
"sPlot best yield " << best_yield);
367 B2INFO(
"sPlot Yields On MC " << binning.m_signal_yield <<
" " << binning.m_bckgrd_yield);
369 binning.m_signal_yield = best_yield;
370 binning.m_bckgrd_yield = (total_data - best_yield);
372 B2INFO(
"sPlot Yields Fitted On Data " << binning.m_signal_yield <<
" " << binning.m_bckgrd_yield);
374 if (meta_options.m_splot_boosted) {
375 GeneralOptions boost_general_options = data_general_options;
376 boost_general_options.m_identifier = general_options.m_identifier +
"_boost.xml";
377 SPlotDataset splot_dataset(boost_general_options, data_dataset, getBoostWeights(discriminant_dataset, binning), signalFraction);
378 auto boost_expert = teacher_dataset(boost_general_options, specific_options, splot_dataset);
380 SPlotDataset aplot_dataset(data_general_options, data_dataset, getAPlotWeights(discriminant_dataset, binning,
381 boost_expert->apply(data_dataset)), signalFraction);
382 auto splot_expert = teacher_dataset(data_general_options, specific_options, aplot_dataset);
383 if (not meta_options.m_splot_combined)
386 SPlotDataset splot_dataset(data_general_options, data_dataset, getSPlotWeights(discriminant_dataset, binning), signalFraction);
387 auto splot_expert = teacher_dataset(data_general_options, specific_options, splot_dataset);
388 if (not meta_options.m_splot_combined)
392 mc_general_options.m_identifier = general_options.m_identifier +
"_pdf.xml";
393 mc_general_options.m_method =
"PDF";
394 PDFOptions pdf_options;
395 auto pdf_expert = teacher_dataset(mc_general_options, pdf_options, mc_dataset);
397 GeneralOptions combination_general_options = general_options;
398 combination_general_options.m_method =
"Combination";
399 combination_general_options.m_variables.push_back(meta_options.m_splot_variable);
400 CombinationOptions combination_options;
401 combination_options.m_weightfiles = {data_general_options.m_identifier, mc_general_options.m_identifier};
402 auto combination_expert = teacher_dataset(combination_general_options, combination_options, data_dataset);
404 return combination_expert;
407 std::unique_ptr<Belle2::MVA::Expert> teacher_reweighting(
const GeneralOptions& general_options,
408 const SpecificOptions& specific_options,
409 const MetaOptions& meta_options)
411 if (std::find(general_options.m_variables.begin(), general_options.m_variables.end(),
412 meta_options.m_reweighting_variable) != general_options.m_variables.end()) {
413 B2ERROR(
"You cannot use the reweighting variable as a feature in your training");
417 GeneralOptions data_general_options = general_options;
418 data_general_options.m_target_variable =
"";
419 data_general_options.m_datafiles = meta_options.m_reweighting_data_files;
420 ROOTDataset data_dataset(data_general_options);
422 GeneralOptions mc_general_options = general_options;
423 mc_general_options.m_datafiles = meta_options.m_reweighting_mc_files;
424 ROOTDataset mc_dataset(mc_general_options);
426 CombinedDataset boost_dataset(general_options, data_dataset, mc_dataset);
428 GeneralOptions boost_general_options = general_options;
429 boost_general_options.m_identifier = general_options.m_identifier +
"_boost.xml";
430 auto boost_expert = teacher_dataset(boost_general_options, specific_options, boost_dataset);
432 GeneralOptions reweighter_general_options = general_options;
433 reweighter_general_options.m_identifier = meta_options.m_reweighting_identifier;
434 reweighter_general_options.m_method =
"Reweighter";
435 ReweighterOptions reweighter_specific_options;
436 reweighter_specific_options.m_weightfile = boost_general_options.m_identifier;
437 reweighter_specific_options.m_variable = meta_options.m_reweighting_variable;
439 if (meta_options.m_reweighting_variable !=
"") {
440 if (std::find(reweighter_general_options.m_spectators.begin(), reweighter_general_options.m_spectators.end(),
441 meta_options.m_reweighting_variable) == reweighter_general_options.m_spectators.end() and
442 std::find(reweighter_general_options.m_variables.begin(), reweighter_general_options.m_variables.end(),
443 meta_options.m_reweighting_variable) == reweighter_general_options.m_variables.end() and
444 reweighter_general_options.m_target_variable != meta_options.m_reweighting_variable and
445 reweighter_general_options.m_weight_variable != meta_options.m_reweighting_variable) {
446 reweighter_general_options.m_spectators.push_back(meta_options.m_reweighting_variable);
450 ROOTDataset dataset(reweighter_general_options);
451 auto reweight_expert = teacher_dataset(reweighter_general_options, reweighter_specific_options, dataset);
452 auto weights = reweight_expert->apply(dataset);
453 ReweightingDataset reweighted_dataset(general_options, dataset, weights);
454 auto expert = teacher_dataset(general_options, specific_options, reweighted_dataset);
459 std::unique_ptr<Belle2::MVA::Expert> teacher_sideband_substraction(
const GeneralOptions& general_options,
460 const SpecificOptions& specific_options,
461 const MetaOptions& meta_options)
464 if (std::find(general_options.m_variables.begin(), general_options.m_variables.end(),
465 meta_options.m_sideband_variable) != general_options.m_variables.end()) {
466 B2ERROR(
"You cannot use the sideband variable as a feature in your training");
470 GeneralOptions data_general_options = general_options;
471 if (std::find(data_general_options.m_spectators.begin(), data_general_options.m_spectators.end(),
472 meta_options.m_sideband_variable) == data_general_options.m_spectators.end()) {
473 data_general_options.m_spectators.push_back(meta_options.m_sideband_variable);
475 ROOTDataset data_dataset(data_general_options);
477 GeneralOptions mc_general_options = general_options;
478 mc_general_options.m_datafiles = meta_options.m_sideband_mc_files;
479 if (std::find(mc_general_options.m_spectators.begin(), mc_general_options.m_spectators.end(),
480 meta_options.m_sideband_variable) == mc_general_options.m_spectators.end()) {
481 mc_general_options.m_spectators.push_back(meta_options.m_sideband_variable);
483 ROOTDataset mc_dataset(mc_general_options);
485 GeneralOptions sideband_general_options = general_options;
486 SidebandDataset sideband_dataset(sideband_general_options, data_dataset, mc_dataset, meta_options.m_sideband_variable);
487 auto expert = teacher_dataset(general_options, specific_options, sideband_dataset);