9 #include <mva/utility/Utility.h>
10 #include <mva/utility/DataDriven.h>
11 #include <mva/methods/PDF.h>
12 #include <mva/methods/Reweighter.h>
13 #include <mva/methods/Trivial.h>
14 #include <mva/methods/Combination.h>
16 #include <framework/logging/Logger.h>
18 #include <framework/utilities/MakeROOTCompatible.h>
20 #include <boost/algorithm/string/predicate.hpp>
21 #include <boost/property_tree/xml_parser.hpp>
29 using namespace Belle2::MVA;
31 void Utility::download(
const std::string& identifier,
const std::string& filename,
int experiment,
int run,
int event)
35 if (boost::ends_with(filename,
".root")) {
37 }
else if (boost::ends_with(filename,
".xml")) {
40 std::cerr <<
"Unknown file extension, fallback to xml" << std::endl;
45 void Utility::upload(
const std::string& filename,
const std::string& identifier,
int exp1,
int run1,
int exp2,
int run2)
49 if (boost::ends_with(filename,
".root")) {
51 }
else if (boost::ends_with(filename,
".xml")) {
54 std::cerr <<
"Unknown file extension, fallback to xml" << std::endl;
60 void Utility::upload_array(
const std::vector<std::string>& filenames,
const std::string& identifier,
int exp1,
int run1,
int exp2,
65 std::vector<Belle2::MVA::Weightfile> weightfiles;
66 for (
const auto& filename : filenames) {
69 if (boost::ends_with(filename,
".root")) {
71 }
else if (boost::ends_with(filename,
".xml")) {
74 std::cerr <<
"Unknown file extension, fallback to xml" << std::endl;
77 weightfiles.push_back(weightfile);
88 weightfile.setRemoveTemporaryDirectories(
false);
89 weightfile.setTemporaryDirectory(directory);
91 weightfile.getOptions(general_options);
92 auto expertLocal = supported_interfaces[general_options.m_method]->getExpert();
93 expertLocal->load(weightfile);
104 weightfile.getOptions(general_options);
106 auto specific_options = supported_interfaces[general_options.m_method]->getOptions();
107 specific_options->load(weightfile.getXMLTree());
109 boost::property_tree::ptree temp_tree;
110 general_options.save(temp_tree);
111 specific_options->save(temp_tree);
112 std::ostringstream oss;
114 #if BOOST_VERSION < 105600
115 boost::property_tree::xml_writer_settings<char> settings(
'\t', 1);
117 boost::property_tree::xml_writer_settings<std::string> settings(
'\t', 1);
119 boost::property_tree::xml_parser::write_xml(oss, temp_tree, settings);;
137 void Utility::expert(
const std::vector<std::string>& filenames,
const std::vector<std::string>& datafiles,
138 const std::string& treename,
139 const std::string& outputfile,
int experiment,
int run,
int event,
bool copy_target)
142 TFile file(outputfile.c_str(),
"RECREATE");
144 TTree tree(
"variables",
"variables");
149 for (
auto& filename : filenames) {
155 weightfile.getOptions(general_options);
157 general_options.m_treename = treename;
160 general_options.m_max_events = 0;
162 auto expertLocal = supported_interfaces[general_options.m_method]->getExpert();
163 expertLocal->load(weightfile);
165 bool isMulticlass = general_options.m_nClasses > 2;
168 if (not copy_target) {
169 general_options.m_target_variable = std::string();
172 general_options.m_datafiles = datafiles;
175 std::vector<TBranch*> branches;
177 if (not isMulticlass) {
180 branches.push_back(tree.Branch(branchname.c_str(), &result, (branchname +
"/F").c_str()));
181 std::chrono::high_resolution_clock::time_point start = std::chrono::high_resolution_clock::now();
183 auto results = expertLocal->apply(data);
184 std::chrono::high_resolution_clock::time_point stop = std::chrono::high_resolution_clock::now();
185 std::chrono::duration<double, std::milli> application_time = stop - start;
186 B2INFO(
"Elapsed application time in ms " << application_time.count() <<
" for " << general_options.m_identifier);
187 for (
auto& r : results) {
194 for (
unsigned int iClass = 0; iClass < general_options.m_nClasses; ++iClass) {
196 branches.push_back(tree.Branch(branchname.c_str(), &result, (branchname +
"/F").c_str()));
198 std::chrono::high_resolution_clock::time_point start = std::chrono::high_resolution_clock::now();
199 auto results = expertLocal->applyMulticlass(data);
200 std::chrono::high_resolution_clock::time_point stop = std::chrono::high_resolution_clock::now();
201 std::chrono::duration<double, std::milli> application_time = stop - start;
202 B2INFO(
"Elapsed application time in ms " << application_time.count() <<
" for " << general_options.m_identifier);
203 for (
auto& r : results) {
204 for (
unsigned int iClass = 0; iClass < general_options.m_nClasses; ++iClass) {
206 branches[iClass]->Fill();
213 if (not general_options.m_target_variable.empty()) {
215 general_options.m_target_variable);
217 auto target_branch = tree.Branch(branchname.c_str(), &target, (branchname +
"/F").c_str());
218 auto targets = data.getTargets();
219 for (
auto& t : targets) {
221 target_branch->Fill();
227 file.Write(
"variables");
232 const std::string& custom_weightfile,
const std::string& output_identifier)
234 std::ifstream ifile(custom_weightfile);
236 B2FATAL(
"Input weight file: " << custom_weightfile <<
" does not exist!");
242 weightfile.
addFile(general_options.m_identifier +
"_Weightfile", custom_weightfile);
243 std::string output_weightfile(custom_weightfile);
244 if (!output_identifier.empty()) {
245 std::regex to_replace(
"(\\.\\S+$)");
246 std::string replacement =
"_" + output_identifier +
"$0";
247 output_weightfile = std::regex_replace(output_weightfile, to_replace, replacement);
255 unsigned int number_of_enabled_meta_trainings = 0;
256 if (meta_options.m_use_splot)
257 number_of_enabled_meta_trainings++;
258 if (meta_options.m_use_sideband_subtraction)
259 number_of_enabled_meta_trainings++;
260 if (meta_options.m_use_reweighting)
261 number_of_enabled_meta_trainings++;
263 if (number_of_enabled_meta_trainings > 1) {
264 B2ERROR(
"You enabled more than one meta training option. You can only use one (sPlot, SidebandSubstraction or Reweighting)");
268 if (meta_options.m_use_splot) {
269 teacher_splot(general_options, specific_options, meta_options);
270 }
else if (meta_options.m_use_sideband_subtraction) {
272 }
else if (meta_options.m_use_reweighting) {
285 if (general_options.m_method.empty()) {
286 general_options.m_method = specific_options.getMethod();
288 if (general_options.m_method != specific_options.getMethod()) {
289 B2ERROR(
"The method specified in the general options is in conflict with the provided specific option:" << general_options.m_method
290 <<
" " << specific_options.getMethod());
295 if (supported_interfaces.find(general_options.m_method) != supported_interfaces.end()) {
296 auto teacherLocal = supported_interfaces[general_options.m_method]->getTeacher(general_options, specific_options);
297 std::chrono::high_resolution_clock::time_point start = std::chrono::high_resolution_clock::now();
298 auto weightfile = teacherLocal->train(data);
299 std::chrono::high_resolution_clock::time_point stop = std::chrono::high_resolution_clock::now();
300 std::chrono::duration<double, std::milli> training_time = stop - start;
301 B2INFO(
"Elapsed training time in ms " << training_time.count() <<
" for " << general_options.m_identifier);
303 auto expertLocal = supported_interfaces[general_options.m_method]->getExpert();
304 expertLocal->load(weightfile);
307 B2ERROR(
"Interface doesn't support chosen method" << general_options.m_method);
308 throw std::runtime_error(
"Interface doesn't support chosen method" + general_options.m_method);
319 if (meta_options.m_splot_combined)
320 data_general_options.
m_identifier = general_options.m_identifier +
"_splot.xml";
327 discriminant_general_options.
m_variables = {meta_options.m_splot_variable};
328 ROOTDataset discriminant_dataset(discriminant_general_options);
330 discriminant_general_options.
m_target_variable = general_options.m_target_variable;
333 mc_general_options.
m_datafiles = meta_options.m_splot_mc_files;
334 mc_general_options.
m_variables = {meta_options.m_splot_variable};
340 auto data_feature = discriminant_dataset.
getFeature(0);
341 auto data_weights = discriminant_dataset.
getWeights();
347 std::vector<double> data(100, 0);
348 double total_data = 0.0;
349 for (
unsigned int iEvent = 0; iEvent < data_dataset.
getNumberOfEvents(); ++iEvent) {
350 data[binning.
getBin(data_feature[iEvent])] += data_weights[iEvent];
351 total_data += data_weights[iEvent];
358 float best_yield = 0.0;
359 double best_chi2 = 1000000000.0;
360 bool empty_bin =
false;
361 for (
double yield = 0; yield < total_data; yield += 1) {
363 for (
unsigned int iBin = 0; iBin < 100; ++iBin) {
364 double deviation = (data[iBin] - (yield * binning.
m_signal_pdf[iBin] + (total_data - yield) * binning.
m_bckgrd_pdf[iBin]) *
367 chi2 += deviation * deviation / data[iBin];
371 if (chi2 < best_chi2) {
378 B2WARNING(
"Encountered empty bin in data histogram during fit of the components for sPlot");
381 B2INFO(
"sPlot best yield " << best_yield);
389 if (meta_options.m_splot_boosted) {
391 boost_general_options.
m_identifier = general_options.m_identifier +
"_boost.xml";
392 SPlotDataset splot_dataset(boost_general_options, data_dataset, getBoostWeights(discriminant_dataset, binning), signalFraction);
393 auto boost_expert =
teacher_dataset(boost_general_options, specific_options, splot_dataset);
395 SPlotDataset aplot_dataset(data_general_options, data_dataset, getAPlotWeights(discriminant_dataset, binning,
396 boost_expert->apply(data_dataset)), signalFraction);
397 auto splot_expert =
teacher_dataset(data_general_options, specific_options, aplot_dataset);
398 if (not meta_options.m_splot_combined)
401 SPlotDataset splot_dataset(data_general_options, data_dataset, getSPlotWeights(discriminant_dataset, binning), signalFraction);
402 auto splot_expert =
teacher_dataset(data_general_options, specific_options, splot_dataset);
403 if (not meta_options.m_splot_combined)
407 mc_general_options.
m_identifier = general_options.m_identifier +
"_pdf.xml";
408 mc_general_options.
m_method =
"PDF";
411 auto pdf_expert =
teacher_dataset(mc_general_options, pdf_options, mc_dataset);
414 combination_general_options.
m_method =
"Combination";
415 combination_general_options.
m_variables.push_back(meta_options.m_splot_variable);
418 auto combination_expert =
teacher_dataset(combination_general_options, combination_options, data_dataset);
420 return combination_expert;
427 if (std::find(general_options.m_variables.begin(), general_options.m_variables.end(),
428 meta_options.m_reweighting_variable) != general_options.m_variables.end()) {
429 B2ERROR(
"You cannot use the reweighting variable as a feature in your training");
435 data_general_options.
m_datafiles = meta_options.m_reweighting_data_files;
439 mc_general_options.
m_datafiles = meta_options.m_reweighting_mc_files;
442 CombinedDataset boost_dataset(general_options, data_dataset, mc_dataset);
445 boost_general_options.
m_identifier = general_options.m_identifier +
"_boost.xml";
447 auto boost_expert =
teacher_dataset(boost_general_options, specific_options, boost_dataset);
450 reweighter_general_options.
m_identifier = meta_options.m_reweighting_identifier;
451 reweighter_general_options.
m_method =
"Reweighter";
454 reweighter_specific_options.
m_variable = meta_options.m_reweighting_variable;
456 if (meta_options.m_reweighting_variable !=
"") {
458 meta_options.m_reweighting_variable) == reweighter_general_options.
m_spectators.end() and
460 meta_options.m_reweighting_variable) == reweighter_general_options.
m_variables.end() and
461 reweighter_general_options.
m_target_variable != meta_options.m_reweighting_variable and
462 reweighter_general_options.
m_weight_variable != meta_options.m_reweighting_variable) {
463 reweighter_general_options.
m_spectators.push_back(meta_options.m_reweighting_variable);
468 auto reweight_expert =
teacher_dataset(reweighter_general_options, reweighter_specific_options, dataset);
469 auto weights = reweight_expert->apply(dataset);
471 auto expertLocal =
teacher_dataset(general_options, specific_options, reweighted_dataset);
481 if (std::find(general_options.m_variables.begin(), general_options.m_variables.end(),
482 meta_options.m_sideband_variable) != general_options.m_variables.end()) {
483 B2ERROR(
"You cannot use the sideband variable as a feature in your training");
489 meta_options.m_sideband_variable) == data_general_options.
m_spectators.end()) {
490 data_general_options.
m_spectators.push_back(meta_options.m_sideband_variable);
495 mc_general_options.
m_datafiles = meta_options.m_sideband_mc_files;
497 meta_options.m_sideband_variable) == mc_general_options.
m_spectators.end()) {
498 mc_general_options.
m_spectators.push_back(meta_options.m_sideband_variable);
503 SidebandDataset sideband_dataset(sideband_general_options, data_dataset, mc_dataset, meta_options.m_sideband_variable);
504 auto expertLocal =
teacher_dataset(general_options, specific_options, sideband_dataset);
A class that describes the interval of experiments/runs for which an object in the database is valid.
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Binning of a data distribution Provides PDF and CDF values of the distribution per bin.
std::vector< float > m_bckgrd_pdf
Background pdf of data distribution per bin.
std::vector< float > m_signal_pdf
Signal pdf of data distribution per bin.
std::vector< float > m_boundaries
Boundaries of data distribution, including minimum and maximum value as first and last boundary.
double m_bckgrd_yield
Background yield in data distribution.
double m_signal_yield
Signal yield in data distribution.
static Binning CreateEqualFrequency(const std::vector< float > &data, const std::vector< float > &weights, const std::vector< bool > &isSignal, unsigned int nBins)
Create an equal frequency (aka equal-statistics) binning.
unsigned int getBin(float datapoint) const
Gets the bin corresponding to the given datapoint.
Options for the Combination MVA method.
std::vector< std::string > m_weightfiles
Weightfiles of all methods we want to combine.
Wraps two other Datasets, one containing signal, the other background events Used by the reweighting ...
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
virtual std::vector< bool > getSignals()
Returns all is Signals.
General options which are shared by all MVA trainings.
std::vector< std::string > m_datafiles
Name of the datafiles containing the training data.
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
std::string m_weight_variable
Weight variable (branch name) defining the weights.
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
std::string m_method
Name of the MVA method to use.
std::string m_target_variable
Target variable (branch name) defining the target.
std::string m_identifier
Identifier containing the finished training.
Options for the PDF MVA method.
Proivdes a dataset from a ROOT file This is the usually used dataset providing training data to the m...
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in this dataset.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float>
virtual std::vector< float > getWeights() override
Returns all values of of the weights in a std::vector<float>
Options for the Reweighter MVA method.
std::string m_weightfile
Weightfile of the reweighting expert.
std::string m_variable
Variable which decides if the reweighter is applied or not.
Dataset for Reweighting Wraps a dataset and provides each data-point with a new weight.
Dataset for sPlot Wraps a dataset and provides each data-point twice, once as signal and once as back...
Dataset for Sideband Subtraction Wraps a dataset and provides each data-point with a new weight.
Specific Options, all method Options have to inherit from this class.
static void expert(const std::vector< std::string > &filenames, const std::vector< std::string > &datafiles, const std::string &treename, const std::string &outputfile, int experiment=0, int run=0, int event=0, bool copy_target=true)
Convenience function applies experts on given data.
static void upload_array(const std::vector< std::string > &filenames, const std::string &identifier, int exp1=0, int run1=0, int exp2=-1, int run2=-1)
Convenience function which uploads an array of weightfiles to the database.
static void upload(const std::string &filename, const std::string &identifier, int exp1=0, int run1=0, int exp2=-1, int run2=-1)
Convenience function which uploads a given weightfile to the database.
static void download(const std::string &identifier, const std::string &filename, int experiment=0, int run=0, int event=0)
Convenience function which downloads a given weightfile from the database.
static std::unique_ptr< Belle2::MVA::Expert > teacher_sideband_subtraction(const GeneralOptions &general_options, const SpecificOptions &specific_options, const MetaOptions &meta_options)
Performs a sideband subtraction training, convenience function.
static std::unique_ptr< Belle2::MVA::Expert > teacher_reweighting(const GeneralOptions &general_options, const SpecificOptions &specific_options, const MetaOptions &meta_options)
Performs a MC vs data pre-training and afterwards reweighted training, convenience function.
static void teacher(const GeneralOptions &general_options, const SpecificOptions &specific_options, const MetaOptions &meta_options=MetaOptions())
Convenience function which performs a training with the given options.
static void extract(const std::string &filename, const std::string &directory)
Convenience function which extracts the expertise in a given weightfile into a temporary directory.
static std::unique_ptr< Belle2::MVA::Expert > teacher_dataset(GeneralOptions general_options, const SpecificOptions &specific_options, Dataset &data)
Convenience function which performs a training on a dataset.
static std::string info(const std::string &filename)
Print information about the classifier stored in the given weightfile.
static void save_custom_weightfile(const GeneralOptions &general_options, const SpecificOptions &specific_options, const std::string &custom_weightfile, const std::string &output_identifier="")
Convenience function which saves a pre-existing weightfile in a mva package-compliant format.
static std::unique_ptr< Belle2::MVA::Expert > teacher_splot(const GeneralOptions &general_options, const SpecificOptions &specific_options, const MetaOptions &meta_options)
Performs an splot training, convenience function.
static bool available(const std::string &filename, int experiment=0, int run=0, int event=0)
Convenience function which checks if an experise is available.
The Weightfile class serializes all information about a training into an xml tree.
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
static void save(Weightfile &weightfile, const std::string &filename, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile to a file.
static void saveToXMLFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a XML file.
void addOptions(const Options &options)
Add an Option object to the xml tree.
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
static Weightfile loadFromDatabase(const std::string &identifier, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from the basf2 condition database.
static void saveToROOTFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a ROOT file.
static void saveArrayToDatabase(const std::vector< Weightfile > &weightfiles, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves an array of Weightfile objects in the basf2 condition database.
static void saveToDatabase(Weightfile &weightfile, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile in the basf2 condition database.
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.