9#include <mva/utility/Utility.h>
10#include <mva/utility/DataDriven.h>
11#include <mva/methods/PDF.h>
12#include <mva/methods/Reweighter.h>
13#include <mva/methods/Trivial.h>
14#include <mva/methods/Combination.h>
16#include <framework/logging/Logger.h>
18#include <framework/utilities/MakeROOTCompatible.h>
20#include <boost/algorithm/string/predicate.hpp>
21#include <boost/property_tree/xml_parser.hpp>
30using namespace Belle2::MVA;
32void Utility::download(
const std::string& identifier,
const std::string& filename,
int experiment,
int run,
int event)
36 if (boost::ends_with(filename,
".root")) {
38 }
else if (boost::ends_with(filename,
".xml")) {
41 std::cerr <<
"Unknown file extension, fallback to xml" << std::endl;
46void Utility::upload(
const std::string& filename,
const std::string& identifier,
int exp1,
int run1,
int exp2,
int run2)
50 if (boost::ends_with(filename,
".root")) {
52 }
else if (boost::ends_with(filename,
".xml")) {
55 std::cerr <<
"Unknown file extension, fallback to xml" << std::endl;
61void Utility::upload_array(
const std::vector<std::string>& filenames,
const std::string& identifier,
int exp1,
int run1,
int exp2,
66 std::vector<Belle2::MVA::Weightfile> weightfiles;
67 for (
const auto& filename : filenames) {
70 if (boost::ends_with(filename,
".root")) {
72 }
else if (boost::ends_with(filename,
".xml")) {
75 std::cerr <<
"Unknown file extension, fallback to xml" << std::endl;
78 weightfiles.push_back(weightfile);
89 weightfile.setRemoveTemporaryDirectories(
false);
90 setenv(
"TMPDIR", directory.c_str(), 1);
92 weightfile.getOptions(general_options);
93 auto expertLocal = supported_interfaces[general_options.m_method]->getExpert();
94 expertLocal->load(weightfile);
105 weightfile.getOptions(general_options);
107 auto specific_options = supported_interfaces[general_options.m_method]->getOptions();
108 specific_options->load(weightfile.getXMLTree());
110 boost::property_tree::ptree temp_tree;
111 general_options.save(temp_tree);
112 specific_options->save(temp_tree);
113 std::ostringstream oss;
115#if BOOST_VERSION < 105600
116 boost::property_tree::xml_writer_settings<char> settings(
'\t', 1);
118 boost::property_tree::xml_writer_settings<std::string> settings(
'\t', 1);
120 boost::property_tree::xml_parser::write_xml(oss, temp_tree, settings);;
138void Utility::expert(
const std::vector<std::string>& filenames,
const std::vector<std::string>& datafiles,
139 const std::string& treename,
140 const std::string& outputfile,
int experiment,
int run,
int event,
bool copy_target)
143 TFile file(outputfile.c_str(),
"RECREATE");
145 TTree tree(
"variables",
"variables");
150 for (
auto& filename : filenames) {
156 weightfile.getOptions(general_options);
158 general_options.m_treename = treename;
161 general_options.m_max_events = 0;
163 auto expertLocal = supported_interfaces[general_options.m_method]->getExpert();
164 expertLocal->load(weightfile);
166 bool isMulticlass = general_options.m_nClasses > 2;
169 if (not copy_target) {
170 general_options.m_target_variable = std::string();
173 general_options.m_datafiles = datafiles;
176 std::vector<TBranch*> branches;
178 if (not isMulticlass) {
181 branches.push_back(tree.Branch(branchname.c_str(), &result, (branchname +
"/F").c_str()));
182 std::chrono::high_resolution_clock::time_point start = std::chrono::high_resolution_clock::now();
184 auto results = expertLocal->apply(data);
185 std::chrono::high_resolution_clock::time_point stop = std::chrono::high_resolution_clock::now();
186 std::chrono::duration<double, std::milli> application_time = stop - start;
187 B2INFO(
"Elapsed application time in ms " << application_time.count() <<
" for " << general_options.m_identifier);
188 for (
auto& r : results) {
195 for (
unsigned int iClass = 0; iClass < general_options.m_nClasses; ++iClass) {
197 branches.push_back(tree.Branch(branchname.c_str(), &result, (branchname +
"/F").c_str()));
199 std::chrono::high_resolution_clock::time_point start = std::chrono::high_resolution_clock::now();
200 auto results = expertLocal->applyMulticlass(data);
201 std::chrono::high_resolution_clock::time_point stop = std::chrono::high_resolution_clock::now();
202 std::chrono::duration<double, std::milli> application_time = stop - start;
203 B2INFO(
"Elapsed application time in ms " << application_time.count() <<
" for " << general_options.m_identifier);
204 for (
auto& r : results) {
205 for (
unsigned int iClass = 0; iClass < general_options.m_nClasses; ++iClass) {
207 branches[iClass]->Fill();
214 if (not general_options.m_target_variable.empty()) {
216 general_options.m_target_variable);
218 auto target_branch = tree.Branch(branchname.c_str(), &target, (branchname +
"/F").c_str());
219 auto targets = data.getTargets();
220 for (
auto& t : targets) {
222 target_branch->Fill();
228 file.Write(
"variables");
233 const std::string& custom_weightfile,
const std::string& output_identifier)
235 std::ifstream ifile(custom_weightfile);
237 B2FATAL(
"Input weight file: " << custom_weightfile <<
" does not exist!");
243 weightfile.
addFile(general_options.m_identifier +
"_Weightfile", custom_weightfile);
244 std::string output_weightfile(custom_weightfile);
245 if (!output_identifier.empty()) {
246 std::regex to_replace(
"(\\.\\S+$)");
247 std::string replacement =
"_" + output_identifier +
"$0";
248 output_weightfile = std::regex_replace(output_weightfile, to_replace, replacement);
256 unsigned int number_of_enabled_meta_trainings = 0;
257 if (meta_options.m_use_splot)
258 number_of_enabled_meta_trainings++;
259 if (meta_options.m_use_sideband_subtraction)
260 number_of_enabled_meta_trainings++;
261 if (meta_options.m_use_reweighting)
262 number_of_enabled_meta_trainings++;
264 if (number_of_enabled_meta_trainings > 1) {
265 B2ERROR(
"You enabled more than one meta training option. You can only use one (sPlot, SidebandSubstraction or Reweighting)");
269 if (meta_options.m_use_splot) {
270 teacher_splot(general_options, specific_options, meta_options);
271 }
else if (meta_options.m_use_sideband_subtraction) {
273 }
else if (meta_options.m_use_reweighting) {
286 if (general_options.m_method.empty()) {
287 general_options.m_method = specific_options.getMethod();
289 if (general_options.m_method != specific_options.getMethod()) {
290 B2ERROR(
"The method specified in the general options is in conflict with the provided specific option:" << general_options.m_method
291 <<
" " << specific_options.getMethod());
296 if (supported_interfaces.find(general_options.m_method) != supported_interfaces.end()) {
297 auto teacherLocal = supported_interfaces[general_options.m_method]->getTeacher(general_options, specific_options);
298 std::chrono::high_resolution_clock::time_point start = std::chrono::high_resolution_clock::now();
299 auto weightfile = teacherLocal->train(data);
300 std::chrono::high_resolution_clock::time_point stop = std::chrono::high_resolution_clock::now();
301 std::chrono::duration<double, std::milli> training_time = stop - start;
302 B2INFO(
"Elapsed training time in ms " << training_time.count() <<
" for " << general_options.m_identifier);
304 auto expertLocal = supported_interfaces[general_options.m_method]->getExpert();
305 expertLocal->load(weightfile);
308 B2ERROR(
"Interface doesn't support chosen method" << general_options.m_method);
309 throw std::runtime_error(
"Interface doesn't support chosen method" + general_options.m_method);
320 if (meta_options.m_splot_combined)
321 data_general_options.
m_identifier = general_options.m_identifier +
"_splot.xml";
328 discriminant_general_options.
m_variables = {meta_options.m_splot_variable};
329 ROOTDataset discriminant_dataset(discriminant_general_options);
331 discriminant_general_options.
m_target_variable = general_options.m_target_variable;
334 mc_general_options.
m_datafiles = meta_options.m_splot_mc_files;
335 mc_general_options.
m_variables = {meta_options.m_splot_variable};
341 auto data_feature = discriminant_dataset.
getFeature(0);
342 auto data_weights = discriminant_dataset.
getWeights();
344 Binning binning = Binning::CreateEqualFrequency(mc_feature, mc_weights, mc_signals, 100);
348 std::vector<double> data(100, 0);
349 double total_data = 0.0;
350 for (
unsigned int iEvent = 0; iEvent < data_dataset.
getNumberOfEvents(); ++iEvent) {
351 data[binning.
getBin(data_feature[iEvent])] += data_weights[iEvent];
352 total_data += data_weights[iEvent];
359 float best_yield = 0.0;
360 double best_chi2 = 1000000000.0;
361 bool empty_bin =
false;
362 for (
double yield = 0; yield < total_data; yield += 1) {
364 for (
unsigned int iBin = 0; iBin < 100; ++iBin) {
365 double deviation = (data[iBin] - (yield * binning.
m_signal_pdf[iBin] + (total_data - yield) * binning.
m_bckgrd_pdf[iBin]) *
368 chi2 += deviation * deviation / data[iBin];
372 if (chi2 < best_chi2) {
379 B2WARNING(
"Encountered empty bin in data histogram during fit of the components for sPlot");
382 B2INFO(
"sPlot best yield " << best_yield);
390 if (meta_options.m_splot_boosted) {
392 boost_general_options.
m_identifier = general_options.m_identifier +
"_boost.xml";
393 SPlotDataset splot_dataset(boost_general_options, data_dataset, getBoostWeights(discriminant_dataset, binning), signalFraction);
394 auto boost_expert =
teacher_dataset(boost_general_options, specific_options, splot_dataset);
396 SPlotDataset aplot_dataset(data_general_options, data_dataset, getAPlotWeights(discriminant_dataset, binning,
397 boost_expert->apply(data_dataset)), signalFraction);
398 auto splot_expert =
teacher_dataset(data_general_options, specific_options, aplot_dataset);
399 if (not meta_options.m_splot_combined)
402 SPlotDataset splot_dataset(data_general_options, data_dataset, getSPlotWeights(discriminant_dataset, binning), signalFraction);
403 auto splot_expert =
teacher_dataset(data_general_options, specific_options, splot_dataset);
404 if (not meta_options.m_splot_combined)
408 mc_general_options.
m_identifier = general_options.m_identifier +
"_pdf.xml";
409 mc_general_options.
m_method =
"PDF";
412 auto pdf_expert =
teacher_dataset(mc_general_options, pdf_options, mc_dataset);
415 combination_general_options.
m_method =
"Combination";
416 combination_general_options.
m_variables.push_back(meta_options.m_splot_variable);
419 auto combination_expert =
teacher_dataset(combination_general_options, combination_options, data_dataset);
421 return combination_expert;
428 if (std::find(general_options.m_variables.begin(), general_options.m_variables.end(),
429 meta_options.m_reweighting_variable) != general_options.m_variables.end()) {
430 B2ERROR(
"You cannot use the reweighting variable as a feature in your training");
436 data_general_options.
m_datafiles = meta_options.m_reweighting_data_files;
440 mc_general_options.
m_datafiles = meta_options.m_reweighting_mc_files;
443 CombinedDataset boost_dataset(general_options, data_dataset, mc_dataset);
446 boost_general_options.
m_identifier = general_options.m_identifier +
"_boost.xml";
448 auto boost_expert =
teacher_dataset(boost_general_options, specific_options, boost_dataset);
451 reweighter_general_options.
m_identifier = meta_options.m_reweighting_identifier;
452 reweighter_general_options.
m_method =
"Reweighter";
455 reweighter_specific_options.
m_variable = meta_options.m_reweighting_variable;
457 if (meta_options.m_reweighting_variable !=
"") {
459 meta_options.m_reweighting_variable) == reweighter_general_options.
m_spectators.end() and
461 meta_options.m_reweighting_variable) == reweighter_general_options.
m_variables.end() and
462 reweighter_general_options.
m_target_variable != meta_options.m_reweighting_variable and
463 reweighter_general_options.
m_weight_variable != meta_options.m_reweighting_variable) {
464 reweighter_general_options.
m_spectators.push_back(meta_options.m_reweighting_variable);
469 auto reweight_expert =
teacher_dataset(reweighter_general_options, reweighter_specific_options, dataset);
470 auto weights = reweight_expert->apply(dataset);
472 auto expertLocal =
teacher_dataset(general_options, specific_options, reweighted_dataset);
482 if (std::find(general_options.m_variables.begin(), general_options.m_variables.end(),
483 meta_options.m_sideband_variable) != general_options.m_variables.end()) {
484 B2ERROR(
"You cannot use the sideband variable as a feature in your training");
490 meta_options.m_sideband_variable) == data_general_options.
m_spectators.end()) {
491 data_general_options.
m_spectators.push_back(meta_options.m_sideband_variable);
496 mc_general_options.
m_datafiles = meta_options.m_sideband_mc_files;
498 meta_options.m_sideband_variable) == mc_general_options.
m_spectators.end()) {
499 mc_general_options.
m_spectators.push_back(meta_options.m_sideband_variable);
504 SidebandDataset sideband_dataset(sideband_general_options, data_dataset, mc_dataset, meta_options.m_sideband_variable);
505 auto expertLocal =
teacher_dataset(general_options, specific_options, sideband_dataset);
A class that describes the interval of experiments/runs for which an object in the database is valid.
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Binning of a data distribution Provides PDF and CDF values of the distribution per bin.
std::vector< float > m_bckgrd_pdf
Background pdf of data distribution per bin.
std::vector< float > m_signal_pdf
Signal pdf of data distribution per bin.
std::vector< float > m_boundaries
Boundaries of data distribution, including minimum and maximum value as first and last boundary.
double m_bckgrd_yield
Background yield in data distribution.
double m_signal_yield
Signal yield in data distribution.
unsigned int getBin(float datapoint) const
Gets the bin corresponding to the given datapoint.
Options for the Combination MVA method.
std::vector< std::string > m_weightfiles
Weightfiles of all methods we want to combine.
Wraps two other Datasets, one containing signal, the other background events Used by the reweighting ...
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
virtual std::vector< bool > getSignals()
Returns all is Signals.
General options which are shared by all MVA trainings.
std::vector< std::string > m_datafiles
Name of the datafiles containing the training data.
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
std::string m_weight_variable
Weight variable (branch name) defining the weights.
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
std::string m_method
Name of the MVA method to use.
std::string m_target_variable
Target variable (branch name) defining the target.
std::string m_identifier
Identifier containing the finished training.
Options for the PDF MVA method.
Proivdes a dataset from a ROOT file This is the usually used dataset providing training data to the m...
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in this dataset.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float>
virtual std::vector< float > getWeights() override
Returns all values of of the weights in a std::vector<float>
Options for the Reweighter MVA method.
std::string m_weightfile
Weightfile of the reweighting expert.
std::string m_variable
Variable which decides if the reweighter is applied or not.
Dataset for Reweighting Wraps a dataset and provides each data-point with a new weight.
Dataset for sPlot Wraps a dataset and provides each data-point twice, once as signal and once as back...
Dataset for Sideband Subtraction Wraps a dataset and provides each data-point with a new weight.
Specific Options, all method Options have to inherit from this class.
static void expert(const std::vector< std::string > &filenames, const std::vector< std::string > &datafiles, const std::string &treename, const std::string &outputfile, int experiment=0, int run=0, int event=0, bool copy_target=true)
Convenience function applies experts on given data.
static void upload_array(const std::vector< std::string > &filenames, const std::string &identifier, int exp1=0, int run1=0, int exp2=-1, int run2=-1)
Convenience function which uploads an array of weightfiles to the database.
static void upload(const std::string &filename, const std::string &identifier, int exp1=0, int run1=0, int exp2=-1, int run2=-1)
Convenience function which uploads a given weightfile to the database.
static void download(const std::string &identifier, const std::string &filename, int experiment=0, int run=0, int event=0)
Convenience function which downloads a given weightfile from the database.
static std::unique_ptr< Belle2::MVA::Expert > teacher_sideband_subtraction(const GeneralOptions &general_options, const SpecificOptions &specific_options, const MetaOptions &meta_options)
Performs a sideband subtraction training, convenience function.
static std::unique_ptr< Belle2::MVA::Expert > teacher_reweighting(const GeneralOptions &general_options, const SpecificOptions &specific_options, const MetaOptions &meta_options)
Performs a MC vs data pre-training and afterwards reweighted training, convenience function.
static void teacher(const GeneralOptions &general_options, const SpecificOptions &specific_options, const MetaOptions &meta_options=MetaOptions())
Convenience function which performs a training with the given options.
static void extract(const std::string &filename, const std::string &directory)
Convenience function which extracts the expertise in a given weightfile into a temporary directory.
static std::unique_ptr< Belle2::MVA::Expert > teacher_dataset(GeneralOptions general_options, const SpecificOptions &specific_options, Dataset &data)
Convenience function which performs a training on a dataset.
static std::string info(const std::string &filename)
Print information about the classifier stored in the given weightfile.
static void save_custom_weightfile(const GeneralOptions &general_options, const SpecificOptions &specific_options, const std::string &custom_weightfile, const std::string &output_identifier="")
Convenience function which saves a pre-existing weightfile in a mva package-compliant format.
static std::unique_ptr< Belle2::MVA::Expert > teacher_splot(const GeneralOptions &general_options, const SpecificOptions &specific_options, const MetaOptions &meta_options)
Performs an splot training, convenience function.
static bool available(const std::string &filename, int experiment=0, int run=0, int event=0)
Convenience function which checks if an experise is available.
The Weightfile class serializes all information about a training into an xml tree.
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
static void save(Weightfile &weightfile, const std::string &filename, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile to a file.
static void saveToXMLFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a XML file.
void addOptions(const Options &options)
Add an Option object to the xml tree.
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
static Weightfile loadFromDatabase(const std::string &identifier, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from the basf2 condition database.
static void saveToROOTFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a ROOT file.
static void saveArrayToDatabase(const std::vector< Weightfile > &weightfiles, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves an array of Weightfile objects in the basf2 condition database.
static void saveToDatabase(Weightfile &weightfile, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile in the basf2 condition database.
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.