9 #include <mva/utility/Utility.h>
10 #include <mva/utility/DataDriven.h>
11 #include <mva/methods/PDF.h>
12 #include <mva/methods/Reweighter.h>
13 #include <mva/methods/Trivial.h>
14 #include <mva/methods/Combination.h>
16 #include <framework/logging/Logger.h>
18 #include <framework/utilities/MakeROOTCompatible.h>
20 #include <boost/algorithm/string/predicate.hpp>
21 #include <boost/property_tree/xml_parser.hpp>
29 using namespace Belle2::MVA;
31 void Utility::download(
const std::string& identifier,
const std::string& filename,
int experiment,
int run,
int event)
35 if (boost::ends_with(filename,
".root")) {
37 }
else if (boost::ends_with(filename,
".xml")) {
40 std::cerr <<
"Unknown file extension, fallback to xml" << std::endl;
45 void Utility::upload(
const std::string& filename,
const std::string& identifier,
int exp1,
int run1,
int exp2,
int run2)
49 if (boost::ends_with(filename,
".root")) {
51 }
else if (boost::ends_with(filename,
".xml")) {
54 std::cerr <<
"Unknown file extension, fallback to xml" << std::endl;
60 void Utility::upload_array(
const std::vector<std::string>& filenames,
const std::string& identifier,
int exp1,
int run1,
int exp2,
65 std::vector<Belle2::MVA::Weightfile> weightfiles;
66 for (
const auto& filename : filenames) {
69 if (boost::ends_with(filename,
".root")) {
71 }
else if (boost::ends_with(filename,
".xml")) {
74 std::cerr <<
"Unknown file extension, fallback to xml" << std::endl;
77 weightfiles.push_back(weightfile);
88 weightfile.setRemoveTemporaryDirectories(
false);
89 weightfile.setTemporaryDirectory(directory);
91 weightfile.getOptions(general_options);
92 auto expertLocal = supported_interfaces[general_options.m_method]->getExpert();
93 expertLocal->load(weightfile);
104 weightfile.getOptions(general_options);
106 auto specific_options = supported_interfaces[general_options.m_method]->getOptions();
107 specific_options->load(weightfile.getXMLTree());
109 boost::property_tree::ptree temp_tree;
110 general_options.save(temp_tree);
111 specific_options->save(temp_tree);
112 std::ostringstream oss;
114 #if BOOST_VERSION < 105600
115 boost::property_tree::xml_writer_settings<char> settings(
'\t', 1);
117 boost::property_tree::xml_writer_settings<std::string> settings(
'\t', 1);
119 boost::property_tree::xml_parser::write_xml(oss, temp_tree, settings);;
137 void Utility::expert(
const std::vector<std::string>& filenames,
const std::vector<std::string>& datafiles,
138 const std::string& treename,
139 const std::string& outputfile,
int experiment,
int run,
int event,
bool copy_target)
142 TFile file(outputfile.c_str(),
"RECREATE");
144 TTree tree(
"variables",
"variables");
150 for (
auto& filename : filenames) {
156 weightfile.getOptions(general_options);
158 general_options.m_treename = treename;
161 general_options.m_max_events = 0;
163 auto expertLocal = supported_interfaces[general_options.m_method]->getExpert();
164 expertLocal->load(weightfile);
166 bool isMulticlass = general_options.m_nClasses > 2;
169 if (not copy_target) {
170 general_options.m_target_variable = std::string();
173 general_options.m_datafiles = datafiles;
176 std::vector<TBranch*> branches;
178 if (not isMulticlass) {
181 branches.push_back(tree.Branch(branchname.c_str(), &result, (branchname +
"/F").c_str()));
182 std::chrono::high_resolution_clock::time_point start = std::chrono::high_resolution_clock::now();
184 auto results = expertLocal->apply(data);
185 std::chrono::high_resolution_clock::time_point stop = std::chrono::high_resolution_clock::now();
186 std::chrono::duration<double, std::milli> application_time = stop - start;
187 B2INFO(
"Elapsed application time in ms " << application_time.count() <<
" for " << general_options.m_identifier);
188 for (
auto& r : results) {
195 for (
unsigned int iClass = 0; iClass < general_options.m_nClasses; ++iClass) {
197 branches.push_back(tree.Branch(branchname.c_str(), &result, (branchname +
"/F").c_str()));
199 std::chrono::high_resolution_clock::time_point start = std::chrono::high_resolution_clock::now();
200 auto results = expertLocal->applyMulticlass(data);
201 std::chrono::high_resolution_clock::time_point stop = std::chrono::high_resolution_clock::now();
202 std::chrono::duration<double, std::milli> application_time = stop - start;
203 B2INFO(
"Elapsed application time in ms " << application_time.count() <<
" for " << general_options.m_identifier);
204 for (
auto& r : results) {
205 for (
unsigned int iClass = 0; iClass < general_options.m_nClasses; ++iClass) {
207 branches[iClass]->Fill();
214 if (not general_options.m_target_variable.empty()) {
216 general_options.m_target_variable);
218 auto target_branch = tree.Branch(branchname.c_str(), &target, (branchname +
"/F").c_str());
219 auto targets = data.getTargets();
220 for (
auto& t : targets) {
222 target_branch->Fill();
230 file.Write(
"variables");
235 const std::string& custom_weightfile,
const std::string& output_identifier)
237 std::ifstream ifile(custom_weightfile);
239 B2FATAL(
"Input weight file: " << custom_weightfile <<
" does not exist!");
245 weightfile.
addFile(general_options.m_identifier +
"_Weightfile", custom_weightfile);
246 std::string output_weightfile(custom_weightfile);
247 if (!output_identifier.empty()) {
248 std::regex to_replace(
"(\\.\\S+$)");
249 std::string replacement =
"_" + output_identifier +
"$0";
250 output_weightfile = std::regex_replace(output_weightfile, to_replace, replacement);
258 unsigned int number_of_enabled_meta_trainings = 0;
259 if (meta_options.m_use_splot)
260 number_of_enabled_meta_trainings++;
261 if (meta_options.m_use_sideband_subtraction)
262 number_of_enabled_meta_trainings++;
263 if (meta_options.m_use_reweighting)
264 number_of_enabled_meta_trainings++;
266 if (number_of_enabled_meta_trainings > 1) {
267 B2ERROR(
"You enabled more than one meta training option. You can only use one (sPlot, SidebandSubstraction or Reweighting)");
271 if (meta_options.m_use_splot) {
272 teacher_splot(general_options, specific_options, meta_options);
273 }
else if (meta_options.m_use_sideband_subtraction) {
275 }
else if (meta_options.m_use_reweighting) {
288 if (general_options.m_method.empty()) {
289 general_options.m_method = specific_options.getMethod();
291 if (general_options.m_method != specific_options.getMethod()) {
292 B2ERROR(
"The method specified in the general options is in conflict with the provided specific option:" << general_options.m_method
293 <<
" " << specific_options.getMethod());
298 if (supported_interfaces.find(general_options.m_method) != supported_interfaces.end()) {
299 auto teacherLocal = supported_interfaces[general_options.m_method]->getTeacher(general_options, specific_options);
300 std::chrono::high_resolution_clock::time_point start = std::chrono::high_resolution_clock::now();
301 auto weightfile = teacherLocal->train(data);
302 std::chrono::high_resolution_clock::time_point stop = std::chrono::high_resolution_clock::now();
303 std::chrono::duration<double, std::milli> training_time = stop - start;
304 B2INFO(
"Elapsed training time in ms " << training_time.count() <<
" for " << general_options.m_identifier);
306 auto expertLocal = supported_interfaces[general_options.m_method]->getExpert();
307 expertLocal->load(weightfile);
310 B2ERROR(
"Interface doesn't support chosen method" << general_options.m_method);
311 throw std::runtime_error(
"Interface doesn't support chosen method" + general_options.m_method);
322 if (meta_options.m_splot_combined)
323 data_general_options.
m_identifier = general_options.m_identifier +
"_splot.xml";
330 discriminant_general_options.
m_variables = {meta_options.m_splot_variable};
331 ROOTDataset discriminant_dataset(discriminant_general_options);
333 discriminant_general_options.
m_target_variable = general_options.m_target_variable;
336 mc_general_options.
m_datafiles = meta_options.m_splot_mc_files;
337 mc_general_options.
m_variables = {meta_options.m_splot_variable};
343 auto data_feature = discriminant_dataset.
getFeature(0);
344 auto data_weights = discriminant_dataset.
getWeights();
350 std::vector<double> data(100, 0);
351 double total_data = 0.0;
352 for (
unsigned int iEvent = 0; iEvent < data_dataset.
getNumberOfEvents(); ++iEvent) {
353 data[binning.
getBin(data_feature[iEvent])] += data_weights[iEvent];
354 total_data += data_weights[iEvent];
361 float best_yield = 0.0;
362 double best_chi2 = 1000000000.0;
363 bool empty_bin =
false;
364 for (
double yield = 0; yield < total_data; yield += 1) {
366 for (
unsigned int iBin = 0; iBin < 100; ++iBin) {
367 double deviation = (data[iBin] - (yield * binning.
m_signal_pdf[iBin] + (total_data - yield) * binning.
m_bckgrd_pdf[iBin]) *
370 chi2 += deviation * deviation / data[iBin];
374 if (chi2 < best_chi2) {
381 B2WARNING(
"Encountered empty bin in data histogram during fit of the components for sPlot");
384 B2INFO(
"sPlot best yield " << best_yield);
392 if (meta_options.m_splot_boosted) {
394 boost_general_options.
m_identifier = general_options.m_identifier +
"_boost.xml";
395 SPlotDataset splot_dataset(boost_general_options, data_dataset, getBoostWeights(discriminant_dataset, binning), signalFraction);
396 auto boost_expert =
teacher_dataset(boost_general_options, specific_options, splot_dataset);
398 SPlotDataset aplot_dataset(data_general_options, data_dataset, getAPlotWeights(discriminant_dataset, binning,
399 boost_expert->apply(data_dataset)), signalFraction);
400 auto splot_expert =
teacher_dataset(data_general_options, specific_options, aplot_dataset);
401 if (not meta_options.m_splot_combined)
404 SPlotDataset splot_dataset(data_general_options, data_dataset, getSPlotWeights(discriminant_dataset, binning), signalFraction);
405 auto splot_expert =
teacher_dataset(data_general_options, specific_options, splot_dataset);
406 if (not meta_options.m_splot_combined)
410 mc_general_options.
m_identifier = general_options.m_identifier +
"_pdf.xml";
411 mc_general_options.
m_method =
"PDF";
414 auto pdf_expert =
teacher_dataset(mc_general_options, pdf_options, mc_dataset);
417 combination_general_options.
m_method =
"Combination";
418 combination_general_options.
m_variables.push_back(meta_options.m_splot_variable);
421 auto combination_expert =
teacher_dataset(combination_general_options, combination_options, data_dataset);
423 return combination_expert;
430 if (std::find(general_options.m_variables.begin(), general_options.m_variables.end(),
431 meta_options.m_reweighting_variable) != general_options.m_variables.end()) {
432 B2ERROR(
"You cannot use the reweighting variable as a feature in your training");
438 data_general_options.
m_datafiles = meta_options.m_reweighting_data_files;
442 mc_general_options.
m_datafiles = meta_options.m_reweighting_mc_files;
445 CombinedDataset boost_dataset(general_options, data_dataset, mc_dataset);
448 boost_general_options.
m_identifier = general_options.m_identifier +
"_boost.xml";
450 auto boost_expert =
teacher_dataset(boost_general_options, specific_options, boost_dataset);
453 reweighter_general_options.
m_identifier = meta_options.m_reweighting_identifier;
454 reweighter_general_options.
m_method =
"Reweighter";
457 reweighter_specific_options.
m_variable = meta_options.m_reweighting_variable;
459 if (meta_options.m_reweighting_variable !=
"") {
461 meta_options.m_reweighting_variable) == reweighter_general_options.
m_spectators.end() and
463 meta_options.m_reweighting_variable) == reweighter_general_options.
m_variables.end() and
464 reweighter_general_options.
m_target_variable != meta_options.m_reweighting_variable and
465 reweighter_general_options.
m_weight_variable != meta_options.m_reweighting_variable) {
466 reweighter_general_options.
m_spectators.push_back(meta_options.m_reweighting_variable);
471 auto reweight_expert =
teacher_dataset(reweighter_general_options, reweighter_specific_options, dataset);
472 auto weights = reweight_expert->apply(dataset);
474 auto expertLocal =
teacher_dataset(general_options, specific_options, reweighted_dataset);
484 if (std::find(general_options.m_variables.begin(), general_options.m_variables.end(),
485 meta_options.m_sideband_variable) != general_options.m_variables.end()) {
486 B2ERROR(
"You cannot use the sideband variable as a feature in your training");
492 meta_options.m_sideband_variable) == data_general_options.
m_spectators.end()) {
493 data_general_options.
m_spectators.push_back(meta_options.m_sideband_variable);
498 mc_general_options.
m_datafiles = meta_options.m_sideband_mc_files;
500 meta_options.m_sideband_variable) == mc_general_options.
m_spectators.end()) {
501 mc_general_options.
m_spectators.push_back(meta_options.m_sideband_variable);
506 SidebandDataset sideband_dataset(sideband_general_options, data_dataset, mc_dataset, meta_options.m_sideband_variable);
507 auto expertLocal =
teacher_dataset(general_options, specific_options, sideband_dataset);
A class that describes the interval of experiments/runs for which an object in the database is valid.
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Binning of a data distribution Provides PDF and CDF values of the distribution per bin.
std::vector< float > m_bckgrd_pdf
Background pdf of data distribution per bin.
std::vector< float > m_signal_pdf
Signal pdf of data distribution per bin.
std::vector< float > m_boundaries
Boundaries of data distribution, including minimum and maximum value as first and last boundary.
double m_bckgrd_yield
Background yield in data distribution.
double m_signal_yield
Signal yield in data distribution.
static Binning CreateEqualFrequency(const std::vector< float > &data, const std::vector< float > &weights, const std::vector< bool > &isSignal, unsigned int nBins)
Create an equal frequency (aka equal-statistics) binning.
unsigned int getBin(float datapoint) const
Gets the bin corresponding to the given datapoint.
Options for the Combination MVA method.
std::vector< std::string > m_weightfiles
Weightfiles of all methods we want to combine.
Wraps two other Datasets, one containing signal, the other background events Used by the reweighting ...
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
virtual std::vector< bool > getSignals()
Returns all is Signals.
General options which are shared by all MVA trainings.
std::vector< std::string > m_datafiles
Name of the datafiles containing the training data.
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
std::string m_weight_variable
Weight variable (branch name) defining the weights.
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
std::string m_method
Name of the MVA method to use.
std::string m_target_variable
Target variable (branch name) defining the target.
std::string m_identifier
Identifier containing the finished training.
Options for the PDF MVA method.
Proivdes a dataset from a ROOT file This is the usually used dataset providing training data to the m...
virtual unsigned int getNumberOfEvents() const override
Returns the number of events in this dataset.
virtual std::vector< float > getFeature(unsigned int iFeature) override
Returns all values of one feature in a std::vector<float>
virtual std::vector< float > getWeights() override
Returns all values of of the weights in a std::vector<float>
Options for the Reweighter MVA method.
std::string m_weightfile
Weightfile of the reweighting expert.
std::string m_variable
Variable which decides if the reweighter is applied or not.
Dataset for Reweighting Wraps a dataset and provides each data-point with a new weight.
Dataset for sPlot Wraps a dataset and provides each data-point twice, once as signal and once as back...
Dataset for Sideband Subtraction Wraps a dataset and provides each data-point with a new weight.
Specific Options, all method Options have to inherit from this class.
static void expert(const std::vector< std::string > &filenames, const std::vector< std::string > &datafiles, const std::string &treename, const std::string &outputfile, int experiment=0, int run=0, int event=0, bool copy_target=true)
Convenience function applies experts on given data.
static void upload_array(const std::vector< std::string > &filenames, const std::string &identifier, int exp1=0, int run1=0, int exp2=-1, int run2=-1)
Convenience function which uploads an array of weightfiles to the database.
static void upload(const std::string &filename, const std::string &identifier, int exp1=0, int run1=0, int exp2=-1, int run2=-1)
Convenience function which uploads a given weightfile to the database.
static void download(const std::string &identifier, const std::string &filename, int experiment=0, int run=0, int event=0)
Convenience function which downloads a given weightfile from the database.
static std::unique_ptr< Belle2::MVA::Expert > teacher_sideband_subtraction(const GeneralOptions &general_options, const SpecificOptions &specific_options, const MetaOptions &meta_options)
Performs a sideband subtraction training, convenience function.
static std::unique_ptr< Belle2::MVA::Expert > teacher_reweighting(const GeneralOptions &general_options, const SpecificOptions &specific_options, const MetaOptions &meta_options)
Performs a MC vs data pre-training and afterwards reweighted training, convenience function.
static void teacher(const GeneralOptions &general_options, const SpecificOptions &specific_options, const MetaOptions &meta_options=MetaOptions())
Convenience function which performs a training with the given options.
static void extract(const std::string &filename, const std::string &directory)
Convenience function which extracts the expertise in a given weightfile into a temporary directory.
static std::unique_ptr< Belle2::MVA::Expert > teacher_dataset(GeneralOptions general_options, const SpecificOptions &specific_options, Dataset &data)
Convenience function which performs a training on a dataset.
static std::string info(const std::string &filename)
Print information about the classifier stored in the given weightfile.
static void save_custom_weightfile(const GeneralOptions &general_options, const SpecificOptions &specific_options, const std::string &custom_weightfile, const std::string &output_identifier="")
Convenience function which saves a pre-existing weightfile in a mva package-compliant format.
static std::unique_ptr< Belle2::MVA::Expert > teacher_splot(const GeneralOptions &general_options, const SpecificOptions &specific_options, const MetaOptions &meta_options)
Performs an splot training, convenience function.
static bool available(const std::string &filename, int experiment=0, int run=0, int event=0)
Convenience function which checks if an experise is available.
The Weightfile class serializes all information about a training into an xml tree.
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
static Weightfile loadFromXMLFile(const std::string &filename)
Static function which loads a Weightfile from a XML file.
static void save(Weightfile &weightfile, const std::string &filename, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile to a file.
static void saveToXMLFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a XML file.
void addOptions(const Options &options)
Add an Option object to the xml tree.
static Weightfile loadFromROOTFile(const std::string &filename)
Static function which loads a Weightfile from a ROOT file.
static Weightfile load(const std::string &filename, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from a file or from the database.
static Weightfile loadFromDatabase(const std::string &identifier, const Belle2::EventMetaData &emd=Belle2::EventMetaData(0, 0, 0))
Static function which loads a Weightfile from the basf2 condition database.
static void saveToROOTFile(Weightfile &weightfile, const std::string &filename)
Static function which saves a Weightfile to a ROOT file.
static void saveArrayToDatabase(const std::vector< Weightfile > &weightfiles, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves an array of Weightfile objects in the basf2 condition database.
static void saveToDatabase(Weightfile &weightfile, const std::string &identifier, const Belle2::IntervalOfValidity &iov=Belle2::IntervalOfValidity(0, 0, -1, -1))
Static function which saves a Weightfile in the basf2 condition database.
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.