9 #include <mva/utility/Utility.h>
10 #include <mva/utility/DataDriven.h>
11 #include <mva/methods/PDF.h>
12 #include <mva/methods/Reweighter.h>
13 #include <mva/methods/Trivial.h>
14 #include <mva/methods/Combination.h>
16 #include <framework/logging/Logger.h>
18 #include <framework/utilities/MakeROOTCompatible.h>
20 #include <boost/algorithm/string/predicate.hpp>
21 #include <boost/property_tree/xml_parser.hpp>
36 void loadRootDictionary() { }
38 void download(
const std::string& identifier,
const std::string& filename,
int experiment,
int run,
int event)
42 if (boost::ends_with(filename,
".root")) {
44 }
else if (boost::ends_with(filename,
".xml")) {
47 std::cerr <<
"Unknown file extension, fallback to xml" << std::endl;
52 void upload(
const std::string& filename,
const std::string& identifier,
int exp1,
int run1,
int exp2,
int run2)
56 if (boost::ends_with(filename,
".root")) {
58 }
else if (boost::ends_with(filename,
".xml")) {
61 std::cerr <<
"Unknown file extension, fallback to xml" << std::endl;
67 void upload_array(
const std::vector<std::string>& filenames,
const std::string& identifier,
int exp1,
int run1,
int exp2,
int run2)
71 std::vector<Belle2::MVA::Weightfile> weightfiles;
72 for (
const auto& filename : filenames) {
75 if (boost::ends_with(filename,
".root")) {
77 }
else if (boost::ends_with(filename,
".xml")) {
80 std::cerr <<
"Unknown file extension, fallback to xml" << std::endl;
83 weightfiles.push_back(weightfile);
88 void extract(
const std::string& filename,
const std::string& directory)
96 GeneralOptions general_options;
98 auto expertLocal = supported_interfaces[general_options.m_method]->getExpert();
99 expertLocal->load(weightfile);
103 std::string info(
const std::string& filename)
109 GeneralOptions general_options;
112 auto specific_options = supported_interfaces[general_options.m_method]->getOptions();
113 specific_options->load(weightfile.
getXMLTree());
115 boost::property_tree::ptree temp_tree;
116 general_options.save(temp_tree);
117 specific_options->save(temp_tree);
118 std::ostringstream oss;
120 #if BOOST_VERSION < 105600
121 boost::property_tree::xml_writer_settings<char> settings(
'\t', 1);
123 boost::property_tree::xml_writer_settings<std::string> settings(
'\t', 1);
125 boost::property_tree::xml_parser::write_xml(oss, temp_tree, settings);;
131 bool available(
const std::string& filename,
int experiment,
int run,
int event)
143 void expert(
const std::vector<std::string>& filenames,
const std::vector<std::string>& datafiles,
const std::string& treename,
144 const std::string& outputfile,
int experiment,
int run,
int event,
bool copy_target)
147 std::vector<Weightfile> weightfiles;
148 std::vector<TBranch*> branches;
150 TFile file(outputfile.c_str(),
"RECREATE");
152 TTree tree(
"variables",
"variables");
155 for (
auto& filename : filenames) {
158 weightfiles.push_back(weightfile);
161 auto branch = tree.Branch(branchname.c_str(), &result, (branchname +
"/F").c_str());
162 branches.push_back(branch);
169 for (
auto& weightfile : weightfiles) {
170 GeneralOptions general_options;
172 general_options.m_treename = treename;
175 general_options.m_max_events = 0;
177 auto expertLocal = supported_interfaces[general_options.m_method]->getExpert();
178 expertLocal->load(weightfile);
180 if (not copy_target) {
181 general_options.m_target_variable = std::string();
184 general_options.m_datafiles = datafiles;
185 auto& branch = branches[i];
186 ROOTDataset data(general_options);
187 std::chrono::high_resolution_clock::time_point start = std::chrono::high_resolution_clock::now();
188 auto results = expertLocal->apply(data);
189 std::chrono::high_resolution_clock::time_point stop = std::chrono::high_resolution_clock::now();
190 std::chrono::duration<double, std::milli> training_time = stop - start;
191 B2INFO(
"Elapsed application time in ms " << training_time.count() <<
" for " << general_options.m_identifier);
192 for (
auto& r : results) {
198 if (not general_options.m_target_variable.empty()) {
199 std::string branchname =
Belle2::makeROOTCompatible(std::string(branch->GetName()) +
"_" + general_options.m_target_variable);
201 auto target_branch = tree.Branch(branchname.c_str(), &target, (branchname +
"/F").c_str());
202 auto targets = data.getTargets();
203 for (
auto& t : targets) {
205 target_branch->Fill();
213 file.Write(
"variables");
217 void save_custom_weightfile(
const GeneralOptions& general_options,
const SpecificOptions& specific_options,
218 const std::string& custom_weightfile,
const std::string& output_identifier)
220 std::ifstream ifile(custom_weightfile);
222 B2FATAL(
"Input weight file: " << custom_weightfile <<
" does not exist!");
225 Weightfile weightfile;
227 weightfile.addOptions(specific_options);
228 weightfile.addFile(general_options.m_identifier +
"_Weightfile", custom_weightfile);
229 std::string output_weightfile(custom_weightfile);
230 if (!output_identifier.empty()) {
231 std::regex to_replace(
"(\\.\\S+$)");
232 std::string replacement =
"_" + output_identifier +
"$0";
233 output_weightfile = std::regex_replace(output_weightfile, to_replace, replacement);
238 void teacher(
const GeneralOptions& general_options,
const SpecificOptions& specific_options,
const MetaOptions& meta_options)
240 unsigned int number_of_enabled_meta_trainings = 0;
241 if (meta_options.m_use_splot)
242 number_of_enabled_meta_trainings++;
243 if (meta_options.m_use_sideband_subtraction)
244 number_of_enabled_meta_trainings++;
245 if (meta_options.m_use_reweighting)
246 number_of_enabled_meta_trainings++;
248 if (number_of_enabled_meta_trainings > 1) {
249 B2ERROR(
"You enabled more than one meta training option. You can only use one (sPlot, SidebandSubstraction or Reweighting)");
253 if (meta_options.m_use_splot) {
254 teacher_splot(general_options, specific_options, meta_options);
255 }
else if (meta_options.m_use_sideband_subtraction) {
256 teacher_sideband_subtraction(general_options, specific_options, meta_options);
257 }
else if (meta_options.m_use_reweighting) {
258 teacher_reweighting(general_options, specific_options, meta_options);
260 ROOTDataset data(general_options);
261 teacher_dataset(general_options, specific_options, data);
266 std::unique_ptr<Belle2::MVA::Expert> teacher_dataset(GeneralOptions general_options,
const SpecificOptions& specific_options,
269 if (general_options.m_method.empty()) {
270 general_options.m_method = specific_options.getMethod();
272 if (general_options.m_method != specific_options.getMethod()) {
273 B2ERROR(
"The method specified in the general options is in conflict with the provided specific option:" << general_options.m_method
274 <<
" " << specific_options.getMethod());
279 if (supported_interfaces.find(general_options.m_method) != supported_interfaces.end()) {
280 auto teacherLocal = supported_interfaces[general_options.m_method]->getTeacher(general_options, specific_options);
281 std::chrono::high_resolution_clock::time_point start = std::chrono::high_resolution_clock::now();
282 auto weightfile = teacherLocal->train(data);
283 std::chrono::high_resolution_clock::time_point stop = std::chrono::high_resolution_clock::now();
284 std::chrono::duration<double, std::milli> training_time = stop - start;
285 B2INFO(
"Elapsed training time in ms " << training_time.count() <<
" for " << general_options.m_identifier);
287 auto expertLocal = supported_interfaces[general_options.m_method]->getExpert();
288 expertLocal->load(weightfile);
291 B2ERROR(
"Interface doesn't support chosen method" << general_options.m_method);
292 throw std::runtime_error(
"Interface doesn't support chosen method" + general_options.m_method);
296 std::unique_ptr<Belle2::MVA::Expert> teacher_splot(
const GeneralOptions& general_options,
const SpecificOptions& specific_options,
297 const MetaOptions& meta_options)
300 GeneralOptions data_general_options = general_options;
301 data_general_options.m_target_variable =
"";
302 if (meta_options.m_splot_combined)
303 data_general_options.m_identifier = general_options.m_identifier +
"_splot.xml";
304 ROOTDataset data_dataset(data_general_options);
306 data_general_options.m_target_variable = general_options.m_target_variable;
308 GeneralOptions discriminant_general_options = general_options;
309 discriminant_general_options.m_target_variable =
"";
310 discriminant_general_options.m_variables = {meta_options.m_splot_variable};
311 ROOTDataset discriminant_dataset(discriminant_general_options);
313 discriminant_general_options.m_target_variable = general_options.m_target_variable;
315 GeneralOptions mc_general_options = general_options;
316 mc_general_options.m_datafiles = meta_options.m_splot_mc_files;
317 mc_general_options.m_variables = {meta_options.m_splot_variable};
318 ROOTDataset mc_dataset(mc_general_options);
320 auto mc_signals = mc_dataset.getSignals();
321 auto mc_weights = mc_dataset.getWeights();
322 auto mc_feature = mc_dataset.getFeature(0);
323 auto data_feature = discriminant_dataset.getFeature(0);
324 auto data_weights = discriminant_dataset.getWeights();
328 float signalFraction = binning.m_signal_yield / (binning.m_signal_yield + binning.m_bckgrd_yield);
330 std::vector<double> data(100, 0);
331 double total_data = 0.0;
332 for (
unsigned int iEvent = 0; iEvent < data_dataset.getNumberOfEvents(); ++iEvent) {
333 data[binning.getBin(data_feature[iEvent])] += data_weights[iEvent];
334 total_data += data_weights[iEvent];
341 float best_yield = 0.0;
342 double best_chi2 = 1000000000.0;
343 bool empty_bin =
false;
344 for (
double yield = 0; yield < total_data; yield += 1) {
346 for (
unsigned int iBin = 0; iBin < 100; ++iBin) {
347 double deviation = (data[iBin] - (yield * binning.m_signal_pdf[iBin] + (total_data - yield) * binning.m_bckgrd_pdf[iBin]) *
348 (binning.m_boundaries[iBin + 1] - binning.m_boundaries[iBin]) / (binning.m_boundaries[100] - binning.m_boundaries[0]));
350 chi2 += deviation * deviation / data[iBin];
354 if (chi2 < best_chi2) {
361 B2WARNING(
"Encountered empty bin in data histogram during fit of the components for sPlot");
364 B2INFO(
"sPlot best yield " << best_yield);
365 B2INFO(
"sPlot Yields On MC " << binning.m_signal_yield <<
" " << binning.m_bckgrd_yield);
367 binning.m_signal_yield = best_yield;
368 binning.m_bckgrd_yield = (total_data - best_yield);
370 B2INFO(
"sPlot Yields Fitted On Data " << binning.m_signal_yield <<
" " << binning.m_bckgrd_yield);
372 if (meta_options.m_splot_boosted) {
373 GeneralOptions boost_general_options = data_general_options;
374 boost_general_options.m_identifier = general_options.m_identifier +
"_boost.xml";
375 SPlotDataset splot_dataset(boost_general_options, data_dataset, getBoostWeights(discriminant_dataset, binning), signalFraction);
376 auto boost_expert = teacher_dataset(boost_general_options, specific_options, splot_dataset);
378 SPlotDataset aplot_dataset(data_general_options, data_dataset, getAPlotWeights(discriminant_dataset, binning,
379 boost_expert->apply(data_dataset)), signalFraction);
380 auto splot_expert = teacher_dataset(data_general_options, specific_options, aplot_dataset);
381 if (not meta_options.m_splot_combined)
384 SPlotDataset splot_dataset(data_general_options, data_dataset, getSPlotWeights(discriminant_dataset, binning), signalFraction);
385 auto splot_expert = teacher_dataset(data_general_options, specific_options, splot_dataset);
386 if (not meta_options.m_splot_combined)
390 mc_general_options.m_identifier = general_options.m_identifier +
"_pdf.xml";
391 mc_general_options.m_method =
"PDF";
392 PDFOptions pdf_options;
394 auto pdf_expert = teacher_dataset(mc_general_options, pdf_options, mc_dataset);
396 GeneralOptions combination_general_options = general_options;
397 combination_general_options.m_method =
"Combination";
398 combination_general_options.m_variables.push_back(meta_options.m_splot_variable);
399 CombinationOptions combination_options;
400 combination_options.m_weightfiles = {data_general_options.m_identifier, mc_general_options.m_identifier};
401 auto combination_expert = teacher_dataset(combination_general_options, combination_options, data_dataset);
403 return combination_expert;
406 std::unique_ptr<Belle2::MVA::Expert> teacher_reweighting(
const GeneralOptions& general_options,
407 const SpecificOptions& specific_options,
408 const MetaOptions& meta_options)
410 if (std::find(general_options.m_variables.begin(), general_options.m_variables.end(),
411 meta_options.m_reweighting_variable) != general_options.m_variables.end()) {
412 B2ERROR(
"You cannot use the reweighting variable as a feature in your training");
416 GeneralOptions data_general_options = general_options;
417 data_general_options.m_target_variable =
"";
418 data_general_options.m_datafiles = meta_options.m_reweighting_data_files;
419 ROOTDataset data_dataset(data_general_options);
421 GeneralOptions mc_general_options = general_options;
422 mc_general_options.m_datafiles = meta_options.m_reweighting_mc_files;
423 ROOTDataset mc_dataset(mc_general_options);
425 CombinedDataset boost_dataset(general_options, data_dataset, mc_dataset);
427 GeneralOptions boost_general_options = general_options;
428 boost_general_options.m_identifier = general_options.m_identifier +
"_boost.xml";
430 auto boost_expert = teacher_dataset(boost_general_options, specific_options, boost_dataset);
432 GeneralOptions reweighter_general_options = general_options;
433 reweighter_general_options.m_identifier = meta_options.m_reweighting_identifier;
434 reweighter_general_options.m_method =
"Reweighter";
435 ReweighterOptions reweighter_specific_options;
436 reweighter_specific_options.m_weightfile = boost_general_options.m_identifier;
437 reweighter_specific_options.m_variable = meta_options.m_reweighting_variable;
439 if (meta_options.m_reweighting_variable !=
"") {
440 if (std::find(reweighter_general_options.m_spectators.begin(), reweighter_general_options.m_spectators.end(),
441 meta_options.m_reweighting_variable) == reweighter_general_options.m_spectators.end() and
442 std::find(reweighter_general_options.m_variables.begin(), reweighter_general_options.m_variables.end(),
443 meta_options.m_reweighting_variable) == reweighter_general_options.m_variables.end() and
444 reweighter_general_options.m_target_variable != meta_options.m_reweighting_variable and
445 reweighter_general_options.m_weight_variable != meta_options.m_reweighting_variable) {
446 reweighter_general_options.m_spectators.push_back(meta_options.m_reweighting_variable);
450 ROOTDataset dataset(reweighter_general_options);
451 auto reweight_expert = teacher_dataset(reweighter_general_options, reweighter_specific_options, dataset);
452 auto weights = reweight_expert->apply(dataset);
453 ReweightingDataset reweighted_dataset(general_options, dataset, weights);
454 auto expertLocal = teacher_dataset(general_options, specific_options, reweighted_dataset);
459 std::unique_ptr<Belle2::MVA::Expert> teacher_sideband_subtraction(
const GeneralOptions& general_options,
460 const SpecificOptions& specific_options,
461 const MetaOptions& meta_options)
464 if (std::find(general_options.m_variables.begin(), general_options.m_variables.end(),
465 meta_options.m_sideband_variable) != general_options.m_variables.end()) {
466 B2ERROR(
"You cannot use the sideband variable as a feature in your training");
470 GeneralOptions data_general_options = general_options;
471 if (std::find(data_general_options.m_spectators.begin(), data_general_options.m_spectators.end(),
472 meta_options.m_sideband_variable) == data_general_options.m_spectators.end()) {
473 data_general_options.m_spectators.push_back(meta_options.m_sideband_variable);
475 ROOTDataset data_dataset(data_general_options);
477 GeneralOptions mc_general_options = general_options;
478 mc_general_options.m_datafiles = meta_options.m_sideband_mc_files;
479 if (std::find(mc_general_options.m_spectators.begin(), mc_general_options.m_spectators.end(),
480 meta_options.m_sideband_variable) == mc_general_options.m_spectators.end()) {
481 mc_general_options.m_spectators.push_back(meta_options.m_sideband_variable);
483 ROOTDataset mc_dataset(mc_general_options);
485 GeneralOptions sideband_general_options = general_options;
486 SidebandDataset sideband_dataset(sideband_general_options, data_dataset, mc_dataset, meta_options.m_sideband_variable);
487 auto expertLocal = teacher_dataset(general_options, specific_options, sideband_dataset);
A class that describes the interval of experiments/runs for which an object in the database is valid.
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
static Binning CreateEqualFrequency(const std::vector< float > &data, const std::vector< float > &weights, const std::vector< bool > &isSignal, unsigned int nBins)
Create an equal frequency (aka equal-statistics) binning.
The Weightfile class serializes all information about a training into an xml tree.
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
static void save(Weightfile &weightfile, const std::string &filename, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile to a file.
void setRemoveTemporaryDirectories(bool remove_temporary_directories)
Set the deletion behaviour of the weightfile object for temporary directories For debugging it can be...
static void saveToXMLFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a XML file.
const boost::property_tree::ptree & getXMLTree() const
Get xml tree.
void setTemporaryDirectory(const std::string &temporary_directory)
set temporary directory which is used to store temporary directories
void addOptions(const Options &options)
Add an Option object to the xml tree.
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
void getOptions(Options &options) const
Fills an Option object from the xml tree.
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
static Weightfile loadFromDatabase(const std::string &identifier, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from the basf2 condition database.
static void saveToROOTFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a ROOT file.
static void saveArrayToDatabase(const std::vector< Weightfile > &weightfiles, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves an array of Weightfile objects in the basf2 condition database.
static void saveToDatabase(Weightfile &weightfile, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile in the basf2 condition database.
std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
Abstract base class for different kinds of events.