9 #include <mva/utility/Utility.h>
10 #include <mva/utility/DataDriven.h>
11 #include <mva/methods/PDF.h>
12 #include <mva/methods/Reweighter.h>
13 #include <mva/methods/Trivial.h>
14 #include <mva/methods/Combination.h>
16 #include <framework/logging/Logger.h>
18 #include <framework/utilities/MakeROOTCompatible.h>
20 #include <boost/algorithm/string/predicate.hpp>
21 #include <boost/property_tree/xml_parser.hpp>
29 using namespace Belle2::MVA;
31 void Utility::download(
const std::string& identifier,
const std::string& filename,
int experiment,
int run,
int event)
35 if (boost::ends_with(filename,
".root")) {
37 }
else if (boost::ends_with(filename,
".xml")) {
40 std::cerr <<
"Unknown file extension, fallback to xml" << std::endl;
45 void Utility::upload(
const std::string& filename,
const std::string& identifier,
int exp1,
int run1,
int exp2,
int run2)
49 if (boost::ends_with(filename,
".root")) {
51 }
else if (boost::ends_with(filename,
".xml")) {
54 std::cerr <<
"Unknown file extension, fallback to xml" << std::endl;
60 void Utility::upload_array(
const std::vector<std::string>& filenames,
const std::string& identifier,
int exp1,
int run1,
int exp2,
65 std::vector<Belle2::MVA::Weightfile> weightfiles;
66 for (
const auto& filename : filenames) {
69 if (boost::ends_with(filename,
".root")) {
71 }
else if (boost::ends_with(filename,
".xml")) {
74 std::cerr <<
"Unknown file extension, fallback to xml" << std::endl;
77 weightfiles.push_back(weightfile);
88 weightfile.setRemoveTemporaryDirectories(
false);
89 weightfile.setTemporaryDirectory(directory);
91 weightfile.getOptions(general_options);
92 auto expertLocal = supported_interfaces[general_options.m_method]->getExpert();
93 expertLocal->load(weightfile);
104 weightfile.getOptions(general_options);
106 auto specific_options = supported_interfaces[general_options.m_method]->getOptions();
107 specific_options->load(weightfile.getXMLTree());
109 boost::property_tree::ptree temp_tree;
110 general_options.save(temp_tree);
111 specific_options->save(temp_tree);
112 std::ostringstream oss;
114 #if BOOST_VERSION < 105600
115 boost::property_tree::xml_writer_settings<char> settings(
'\t', 1);
117 boost::property_tree::xml_writer_settings<std::string> settings(
'\t', 1);
119 boost::property_tree::xml_parser::write_xml(oss, temp_tree, settings);;
137 void Utility::expert(
const std::vector<std::string>& filenames,
const std::vector<std::string>& datafiles,
138 const std::string& treename,
139 const std::string& outputfile,
int experiment,
int run,
int event,
bool copy_target)
142 std::vector<Weightfile> weightfiles;
143 std::vector<TBranch*> branches;
145 TFile file(outputfile.c_str(),
"RECREATE");
147 TTree tree(
"variables",
"variables");
150 for (
auto& filename : filenames) {
153 weightfiles.push_back(weightfile);
156 auto branch = tree.Branch(branchname.c_str(), &result, (branchname +
"/F").c_str());
157 branches.push_back(branch);
164 for (
auto& weightfile : weightfiles) {
166 weightfile.getOptions(general_options);
167 general_options.m_treename = treename;
170 general_options.m_max_events = 0;
172 auto expertLocal = supported_interfaces[general_options.m_method]->getExpert();
173 expertLocal->load(weightfile);
175 if (not copy_target) {
176 general_options.m_target_variable = std::string();
179 general_options.m_datafiles = datafiles;
180 auto& branch = branches[i];
182 std::chrono::high_resolution_clock::time_point start = std::chrono::high_resolution_clock::now();
183 auto results = expertLocal->apply(data);
184 std::chrono::high_resolution_clock::time_point stop = std::chrono::high_resolution_clock::now();
185 std::chrono::duration<double, std::milli> training_time = stop - start;
186 B2INFO(
"Elapsed application time in ms " << training_time.count() <<
" for " << general_options.m_identifier);
187 for (
auto& r : results) {
193 if (not general_options.m_target_variable.empty()) {
195 general_options.m_target_variable);
197 auto target_branch = tree.Branch(branchname.c_str(), &target, (branchname +
"/F").c_str());
198 auto targets = data.getTargets();
199 for (
auto& t : targets) {
201 target_branch->Fill();
209 file.Write(
"variables");
214 const std::string& custom_weightfile,
const std::string& output_identifier)
216 std::ifstream ifile(custom_weightfile);
218 B2FATAL(
"Input weight file: " << custom_weightfile <<
" does not exist!");
224 weightfile.
addFile(general_options.m_identifier +
"_Weightfile", custom_weightfile);
225 std::string output_weightfile(custom_weightfile);
226 if (!output_identifier.empty()) {
227 std::regex to_replace(
"(\\.\\S+$)");
228 std::string replacement =
"_" + output_identifier +
"$0";
229 output_weightfile = std::regex_replace(output_weightfile, to_replace, replacement);
237 unsigned int number_of_enabled_meta_trainings = 0;
238 if (meta_options.m_use_splot)
239 number_of_enabled_meta_trainings++;
240 if (meta_options.m_use_sideband_subtraction)
241 number_of_enabled_meta_trainings++;
242 if (meta_options.m_use_reweighting)
243 number_of_enabled_meta_trainings++;
245 if (number_of_enabled_meta_trainings > 1) {
246 B2ERROR(
"You enabled more than one meta training option. You can only use one (sPlot, SidebandSubstraction or Reweighting)");
250 if (meta_options.m_use_splot) {
251 teacher_splot(general_options, specific_options, meta_options);
252 }
else if (meta_options.m_use_sideband_subtraction) {
254 }
else if (meta_options.m_use_reweighting) {
267 if (general_options.m_method.empty()) {
268 general_options.m_method = specific_options.getMethod();
270 if (general_options.m_method != specific_options.getMethod()) {
271 B2ERROR(
"The method specified in the general options is in conflict with the provided specific option:" << general_options.m_method
272 <<
" " << specific_options.getMethod());
277 if (supported_interfaces.find(general_options.m_method) != supported_interfaces.end()) {
278 auto teacherLocal = supported_interfaces[general_options.m_method]->getTeacher(general_options, specific_options);
279 std::chrono::high_resolution_clock::time_point start = std::chrono::high_resolution_clock::now();
280 auto weightfile = teacherLocal->train(data);
281 std::chrono::high_resolution_clock::time_point stop = std::chrono::high_resolution_clock::now();
282 std::chrono::duration<double, std::milli> training_time = stop - start;
283 B2INFO(
"Elapsed training time in ms " << training_time.count() <<
" for " << general_options.m_identifier);
285 auto expertLocal = supported_interfaces[general_options.m_method]->getExpert();
286 expertLocal->load(weightfile);
289 B2ERROR(
"Interface doesn't support chosen method" << general_options.m_method);
290 throw std::runtime_error(
"Interface doesn't support chosen method" + general_options.m_method);
301 if (meta_options.m_splot_combined)
302 data_general_options.
m_identifier = general_options.m_identifier +
"_splot.xml";
309 discriminant_general_options.
m_variables = {meta_options.m_splot_variable};
310 ROOTDataset discriminant_dataset(discriminant_general_options);
312 discriminant_general_options.
m_target_variable = general_options.m_target_variable;
315 mc_general_options.
m_datafiles = meta_options.m_splot_mc_files;
316 mc_general_options.
m_variables = {meta_options.m_splot_variable};
322 auto data_feature = discriminant_dataset.
getFeature(0);
323 auto data_weights = discriminant_dataset.
getWeights();
329 std::vector<double> data(100, 0);
330 double total_data = 0.0;
331 for (
unsigned int iEvent = 0; iEvent < data_dataset.
getNumberOfEvents(); ++iEvent) {
332 data[binning.
getBin(data_feature[iEvent])] += data_weights[iEvent];
333 total_data += data_weights[iEvent];
340 float best_yield = 0.0;
341 double best_chi2 = 1000000000.0;
342 bool empty_bin =
false;
343 for (
double yield = 0; yield < total_data; yield += 1) {
345 for (
unsigned int iBin = 0; iBin < 100; ++iBin) {
346 double deviation = (data[iBin] - (yield * binning.
m_signal_pdf[iBin] + (total_data - yield) * binning.
m_bckgrd_pdf[iBin]) *
349 chi2 += deviation * deviation / data[iBin];
353 if (chi2 < best_chi2) {
360 B2WARNING(
"Encountered empty bin in data histogram during fit of the components for sPlot");
363 B2INFO(
"sPlot best yield " << best_yield);
371 if (meta_options.m_splot_boosted) {
373 boost_general_options.
m_identifier = general_options.m_identifier +
"_boost.xml";
374 SPlotDataset splot_dataset(boost_general_options, data_dataset, getBoostWeights(discriminant_dataset, binning), signalFraction);
375 auto boost_expert =
teacher_dataset(boost_general_options, specific_options, splot_dataset);
377 SPlotDataset aplot_dataset(data_general_options, data_dataset, getAPlotWeights(discriminant_dataset, binning,
378 boost_expert->apply(data_dataset)), signalFraction);
379 auto splot_expert =
teacher_dataset(data_general_options, specific_options, aplot_dataset);
380 if (not meta_options.m_splot_combined)
383 SPlotDataset splot_dataset(data_general_options, data_dataset, getSPlotWeights(discriminant_dataset, binning), signalFraction);
384 auto splot_expert =
teacher_dataset(data_general_options, specific_options, splot_dataset);
385 if (not meta_options.m_splot_combined)
389 mc_general_options.
m_identifier = general_options.m_identifier +
"_pdf.xml";
390 mc_general_options.
m_method =
"PDF";
393 auto pdf_expert =
teacher_dataset(mc_general_options, pdf_options, mc_dataset);
396 combination_general_options.
m_method =
"Combination";
397 combination_general_options.
m_variables.push_back(meta_options.m_splot_variable);
400 auto combination_expert =
teacher_dataset(combination_general_options, combination_options, data_dataset);
402 return combination_expert;
409 if (std::find(general_options.m_variables.begin(), general_options.m_variables.end(),
410 meta_options.m_reweighting_variable) != general_options.m_variables.end()) {
411 B2ERROR(
"You cannot use the reweighting variable as a feature in your training");
417 data_general_options.
m_datafiles = meta_options.m_reweighting_data_files;
421 mc_general_options.
m_datafiles = meta_options.m_reweighting_mc_files;
424 CombinedDataset boost_dataset(general_options, data_dataset, mc_dataset);
427 boost_general_options.
m_identifier = general_options.m_identifier +
"_boost.xml";
429 auto boost_expert =
teacher_dataset(boost_general_options, specific_options, boost_dataset);
432 reweighter_general_options.
m_identifier = meta_options.m_reweighting_identifier;
433 reweighter_general_options.
m_method =
"Reweighter";
436 reweighter_specific_options.
m_variable = meta_options.m_reweighting_variable;
438 if (meta_options.m_reweighting_variable !=
"") {
440 meta_options.m_reweighting_variable) == reweighter_general_options.
m_spectators.end() and
442 meta_options.m_reweighting_variable) == reweighter_general_options.
m_variables.end() and
443 reweighter_general_options.
m_target_variable != meta_options.m_reweighting_variable and
444 reweighter_general_options.
m_weight_variable != meta_options.m_reweighting_variable) {
445 reweighter_general_options.
m_spectators.push_back(meta_options.m_reweighting_variable);
450 auto reweight_expert =
teacher_dataset(reweighter_general_options, reweighter_specific_options, dataset);
451 auto weights = reweight_expert->apply(dataset);
453 auto expertLocal =
teacher_dataset(general_options, specific_options, reweighted_dataset);
463 if (std::find(general_options.m_variables.begin(), general_options.m_variables.end(),
464 meta_options.m_sideband_variable) != general_options.m_variables.end()) {
465 B2ERROR(
"You cannot use the sideband variable as a feature in your training");
471 meta_options.m_sideband_variable) == data_general_options.
m_spectators.end()) {
472 data_general_options.
m_spectators.push_back(meta_options.m_sideband_variable);
477 mc_general_options.
m_datafiles = meta_options.m_sideband_mc_files;
479 meta_options.m_sideband_variable) == mc_general_options.
m_spectators.end()) {
480 mc_general_options.
m_spectators.push_back(meta_options.m_sideband_variable);
485 SidebandDataset sideband_dataset(sideband_general_options, data_dataset, mc_dataset, meta_options.m_sideband_variable);
486 auto expertLocal =
teacher_dataset(general_options, specific_options, sideband_dataset);
A class that describes the interval of experiments/runs for which an object in the database is valid.
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Binning of a data distribution Provides PDF and CDF values of the distribution per bin.
std::vector< float > m_bckgrd_pdf
Background pdf of data distribution per bin.
std::vector< float > m_signal_pdf
Signal pdf of data distribution per bin.
std::vector< float > m_boundaries
Boundaries of data distribution, including minimum and maximum value as first and last boundary.
double m_bckgrd_yield
Background yield in data distribution.
double m_signal_yield
Signal yield in data distribution.
static Binning CreateEqualFrequency(const std::vector< float > &data, const std::vector< float > &weights, const std::vector< bool > &isSignal, unsigned int nBins)
Create an equal frequency (aka equal-statistics) binning.
unsigned int getBin(float datapoint) const
Gets the bin corresponding to the given datapoint.
Options for the Combination MVA method.
std::vector< std::string > m_weightfiles
Weightfiles of all methods we want to combine.
Wraps two other Datasets, one containing signal, the other background events Used by the reweighting ...
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
virtual std::vector< bool > getSignals()
Returns all is Signals.
General options which are shared by all MVA trainings.
std::vector< std::string > m_datafiles
Name of the datafiles containing the training data.
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
std::string m_weight_variable
Weight variable (branch name) defining the weights.
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
std::string m_method
Name of the MVA method to use.
std::string m_target_variable
Target variable (branch name) defining the target.
std::string m_identifier
Identifier containing the finished training.
Options for the PDF MVA method.
Proivdes a dataset from a ROOT file This is the usually used dataset providing training data to the m...
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in this dataset.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float>
virtual std::vector< float > getWeights() override
Returns all values of of the weights in a std::vector<float>
Options for the Reweighter MVA method.
std::string m_weightfile
Weightfile of the reweighting expert.
std::string m_variable
Variable which decides if the reweighter is applied or not.
Dataset for Reweighting Wraps a dataset and provides each data-point with a new weight.
Dataset for sPlot Wraps a dataset and provides each data-point twice, once as signal and once as back...
Dataset for Sideband Subtraction Wraps a dataset and provides each data-point with a new weight.
Specific Options, all method Options have to inherit from this class.
static void expert(const std::vector< std::string > &filenames, const std::vector< std::string > &datafiles, const std::string &treename, const std::string &outputfile, int experiment=0, int run=0, int event=0, bool copy_target=true)
Convenience function applies experts on given data.
static void upload_array(const std::vector< std::string > &filenames, const std::string &identifier, int exp1=0, int run1=0, int exp2=-1, int run2=-1)
Convenience function which uploads an array of weightfiles to the database.
static void upload(const std::string &filename, const std::string &identifier, int exp1=0, int run1=0, int exp2=-1, int run2=-1)
Convenience function which uploads a given weightfile to the database.
static void download(const std::string &identifier, const std::string &filename, int experiment=0, int run=0, int event=0)
Convenience function which downloads a given weightfile from the database.
static std::unique_ptr< Belle2::MVA::Expert > teacher_sideband_subtraction(const GeneralOptions &general_options, const SpecificOptions &specific_options, const MetaOptions &meta_options)
Performs a sideband subtraction training, convenience function.
static std::unique_ptr< Belle2::MVA::Expert > teacher_reweighting(const GeneralOptions &general_options, const SpecificOptions &specific_options, const MetaOptions &meta_options)
Performs a MC vs data pre-training and afterwards reweighted training, convenience function.
static void teacher(const GeneralOptions &general_options, const SpecificOptions &specific_options, const MetaOptions &meta_options=MetaOptions())
Convenience function which performs a training with the given options.
static void extract(const std::string &filename, const std::string &directory)
Convenience function which extracts the expertise in a given weightfile into a temporary directory.
static std::unique_ptr< Belle2::MVA::Expert > teacher_dataset(GeneralOptions general_options, const SpecificOptions &specific_options, Dataset &data)
Convenience function which performs a training on a dataset.
static std::string info(const std::string &filename)
Print information about the classifier stored in the given weightfile.
static void save_custom_weightfile(const GeneralOptions &general_options, const SpecificOptions &specific_options, const std::string &custom_weightfile, const std::string &output_identifier="")
Convenience function which saves a pre-existing weightfile in a mva package-compliant format.
static std::unique_ptr< Belle2::MVA::Expert > teacher_splot(const GeneralOptions &general_options, const SpecificOptions &specific_options, const MetaOptions &meta_options)
Performs an splot training, convenience function.
static bool available(const std::string &filename, int experiment=0, int run=0, int event=0)
Convenience function which checks if an experise is available.
The Weightfile class serializes all information about a training into an xml tree.
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
static void save(Weightfile &weightfile, const std::string &filename, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile to a file.
static void saveToXMLFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a XML file.
void addOptions(const Options &options)
Add an Option object to the xml tree.
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
static Weightfile loadFromDatabase(const std::string &identifier, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from the basf2 condition database.
static void saveToROOTFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a ROOT file.
static void saveArrayToDatabase(const std::vector< Weightfile > &weightfiles, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves an array of Weightfile objects in the basf2 condition database.
static void saveToDatabase(Weightfile &weightfile, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile in the basf2 condition database.
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.