9#include <mva/methods/TMVA.h>
10#include <framework/logging/Logger.h>
11#include <framework/utilities/MakeROOTCompatible.h>
12#include <framework/utilities/ScopeGuard.h>
14#include <TPluginManager.h>
16#include <boost/algorithm/string.hpp>
29 int version = pt.get<
int>(
"TMVA_version");
31 B2ERROR(
"Unknown weightfile version " << std::to_string(version));
32 throw std::runtime_error(
"Unknown weightfile version " + std::to_string(version));
34 m_method = pt.get<std::string>(
"TMVA_method");
35 m_type = pt.get<std::string>(
"TMVA_type");
36 m_config = pt.get<std::string>(
"TMVA_config");
40 m_prefix = pt.get<std::string>(
"TMVA_prefix");
45 pt.put(
"TMVA_version", 1);
47 pt.put(
"TMVA_type",
m_type);
57 po::options_description description(
"TMVA options");
58 description.add_options()
59 (
"tmva_method", po::value<std::string>(&
m_method),
"TMVA Method Name")
60 (
"tmva_type", po::value<std::string>(&
m_type),
"TMVA Method Type (e.g. Plugin, BDT, ...)")
61 (
"tmva_config", po::value<std::string>(&
m_config),
"TMVA Configuration string for the method")
62 (
"tmva_working_directory", po::value<std::string>(&
m_workingDirectory),
"TMVA working directory which stores e.g. TMVA.root")
63 (
"tmva_factory", po::value<std::string>(&
m_factoryOption),
"TMVA Factory options passed to TMVAFactory constructor")
65 "TMVA Preprare options passed to prepareTrainingAndTestTree function");
84 description.add_options()
85 (
"tmva_transform2probability", po::value<bool>(&
transform2probability),
"TMVA Transform output of classifier to a probability");
93 unsigned int numberOfClasses = pt.get<
unsigned int>(
"TMVA_number_classes", 1);
95 for (
unsigned int i = 0; i < numberOfClasses; ++i) {
96 m_classes[i] = pt.get<std::string>(std::string(
"TMVA_classes") + std::to_string(i));
104 pt.put(
"TMVA_number_classes",
m_classes.size());
105 for (
unsigned int i = 0; i <
m_classes.size(); ++i) {
106 pt.put(std::string(
"TMVA_classes") + std::to_string(i),
m_classes[i]);
113 description.add_options()
114 (
"tmva_classes", po::value<std::vector<std::string>>(&
m_classes)->required()->multitoken(),
115 "class name identifiers for multi-class mode");
120 specific_options(_specific_options) { }
127 auto base = std::string(
"TMVA@@MethodBase");
132 auto ctor2 = std::string(
"Method") +
specific_options.
m_method + std::string(
"(TString&,TString&,TMVA::DataSetInfo&,TString&)");
135 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
136 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
150 auto logfile = open(logfilename.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0666);
151 auto saved_stdout = dup(STDOUT_FILENO);
154 factory.TrainAllMethods();
155 factory.TestAllMethods();
156 factory.EvaluateAllMethods();
159 dup2(saved_stdout, STDOUT_FILENO);
166 weightfile.
addFile(
"TMVA_Logfile", logfilename);
169 std::string begin =
"Ranking input variables (method specific)";
170 std::string end =
"-----------------------------------";
172 std::ifstream file(logfilename, std::ios::in);
173 std::map<std::string, float> feature_importances;
175 while (std::getline(file, line)) {
176 if (state == 0 && line.find(begin) != std::string::npos) {
180 if (state >= 1 and state <= 4) {
185 if (line.find(end) != std::string::npos)
187 std::vector<std::string> strs;
188 boost::split(strs, line, boost::is_any_of(
":"));
189 std::string variable = strs[2];
190 boost::trim(variable);
192 float importance = std::stof(strs[3]);
193 feature_importances[variable] = importance;
205 specific_options(_specific_options) { }
210 unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
211 unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
212 unsigned int numberOfEvents = training_data.getNumberOfEvents();
216 char* directory_template = strdup((std::filesystem::temp_directory_path() /
"Basf2TMVA.XXXXXX").c_str());
217 directory = mkdtemp(directory_template);
218 free(directory_template);
227 TFile classFile((jobName +
".root").c_str(),
"RECREATE");
230 TMVA::Tools::Instance();
231 TMVA::DataLoader data_loader(jobName);
247 auto* signal_tree =
new TTree(
"signal_tree",
"signal_tree");
248 auto* background_tree =
new TTree(
"background_tree",
"background_tree");
250 for (
unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
252 &training_data.m_input[iFeature]);
254 &training_data.m_input[iFeature]);
257 for (
unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
259 &training_data.m_spectators[iSpectator]);
261 &training_data.m_spectators[iSpectator]);
264 signal_tree->Branch(
"__weight__", &training_data.m_weight);
265 background_tree->Branch(
"__weight__", &training_data.m_weight);
267 for (
unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
268 training_data.loadEvent(iEvent);
269 if (training_data.m_isSignal) {
272 background_tree->Fill();
276 data_loader.AddSignalTree(signal_tree);
277 data_loader.AddBackgroundTree(background_tree);
278 auto weightfile =
trainFactory(factory, data_loader, jobName);
281 weightfile.addSignalFraction(training_data.getSignalFraction());
284 delete background_tree;
287 std::filesystem::remove_all(directory);
296 specific_options(_specific_options) { }
301 B2ERROR(
"Training TMVAMulticlass classifiers within the MVA package has not been implemented yet.");
302 (void) training_data;
308 specific_options(_specific_options) { }
313 unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
314 unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
315 unsigned int numberOfEvents = training_data.getNumberOfEvents();
319 char* directory_template = strdup((std::filesystem::temp_directory_path() /
"Basf2TMVA.XXXXXX").c_str());
320 directory = mkdtemp(directory_template);
321 free(directory_template);
330 TFile classFile((jobName +
".root").c_str(),
"RECREATE");
333 TMVA::Tools::Instance();
334 TMVA::DataLoader data_loader(jobName);
349 auto* regression_tree =
new TTree(
"regression_tree",
"regression_tree");
351 for (
unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
353 &training_data.m_input[iFeature]);
355 for (
unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
357 &training_data.m_spectators[iSpectator]);
360 &training_data.m_target);
362 regression_tree->Branch(
"__weight__", &training_data.m_weight);
364 for (
unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
365 training_data.loadEvent(iEvent);
366 regression_tree->Fill();
369 data_loader.AddRegressionTree(regression_tree);
372 auto weightfile =
trainFactory(factory, data_loader, jobName);
375 delete regression_tree;
378 std::filesystem::remove_all(directory);
389 TMVA::Tools::Instance();
391 m_expert = std::make_unique<TMVA::Reader>(
"!Color:Silent");
396 for (
unsigned int i = 0; i < general_options.m_variables.size(); ++i) {
400 for (
unsigned int i = 0; i < general_options.m_spectators.size(); ++i) {
406 weightfile.
getFile(
"TMVA_Logfile", custom_weightfile);
421 weightfile.
getFile(
"TMVA_Weightfile", custom_weightfile);
426 auto base = std::string(
"TMVA@@MethodBase");
431 auto ctor2 = std::string(
"Method") +
specific_options.
m_method + std::string(
"(TString&,TString&,TMVA::DataSetInfo&,TString&)");
434 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
435 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
436 B2INFO(
"Registered new TMVA Plugin named " << pluginName);
440 B2FATAL(
"Could not set up expert! Please see preceding error message from TMVA!");
452 weightfile.
getFile(
"TMVA_Weightfile", custom_weightfile);
457 auto base = std::string(
"TMVA@@MethodBase");
462 auto ctor2 = std::string(
"Method") +
specific_options.
m_method + std::string(
"(TString&,TString&,TMVA::DataSetInfo&,TString&)");
465 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
466 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
467 B2INFO(
"Registered new TMVA Plugin named " << pluginName);
471 B2FATAL(
"Could not set up expert! Please see preceding error message from TMVA!");
483 weightfile.
getFile(
"TMVA_Weightfile", custom_weightfile);
488 auto base = std::string(
"TMVA@@MethodBase");
493 auto ctor2 = std::string(
"Method") +
specific_options.
m_method + std::string(
"(TString&,TString&,TMVA::DataSetInfo&,TString&)");
496 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
497 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
498 B2INFO(
"Registered new TMVA Plugin named " << pluginName);
502 B2FATAL(
"Could not set up expert! Please see preceding error message from TMVA!");
510 std::vector<float> probabilities(test_data.getNumberOfEvents());
511 for (
unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
512 test_data.loadEvent(iEvent);
522 return probabilities;
529 std::vector<std::vector<float>> probabilities(test_data.getNumberOfEvents());
531 for (
unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
532 test_data.loadEvent(iEvent);
539 return probabilities;
545 std::vector<float> prediction(test_data.getNumberOfEvents());
546 for (
unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
547 test_data.loadEvent(iEvent);
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
General options which are shared by all MVA trainings.
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
std::string m_weight_variable
Weight variable (branch name) defining the weights.
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
std::string m_target_variable
Target variable (branch name) defining the target.
TMVAOptionsClassification specific_options
Method specific options.
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this m_expert onto a dataset.
float expert_signalFraction
Signal fraction used to calculate the probability.
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
TMVAOptionsMulticlass specific_options
Method specific options.
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
virtual std::vector< std::vector< float > > applyMulticlass(Dataset &test_data) const override
Apply this m_expert onto a dataset.
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this m_expert onto a dataset.
TMVAOptionsRegression specific_options
Method specific options.
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
std::vector< float > m_input_cache
Input Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply call.
std::unique_ptr< TMVA::Reader > m_expert
TMVA::Reader pointer.
std::vector< float > m_spectators_cache
Spectators Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply ...
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Options for the TMVA Classification MVA method.
virtual po::options_description getDescription() override
Returns a program options description for all available options.
bool transform2probability
Transform output of method to a probability.
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Options for the TMVA Multiclass MVA method.
std::vector< std::string > m_classes
Class name identifiers.
virtual po::options_description getDescription() override
Returns a program options description for all available options.
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Options for the TMVA Regression MVA method.
Options for the TMVA MVA method.
std::string m_prepareOption
Prepare options passed to prepareTrainingAndTestTree method.
std::string m_prefix
Prefix used for all files generated by TMVA.
std::string m_config
TMVA config string for the chosen method.
std::string m_method
tmva method name
virtual po::options_description getDescription() override
Returns a program options description for all available options.
std::string m_factoryOption
Factory options passed to tmva factory.
std::string m_type
tmva method type
std::string m_workingDirectory
Working directory of TMVA, if empty a temporary directory is used.
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
TMVATeacherClassification(const GeneralOptions &general_options, const TMVAOptionsClassification &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
TMVAOptionsClassification specific_options
Method specific options.
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
TMVATeacherMulticlass(const GeneralOptions &general_options, const TMVAOptionsMulticlass &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
TMVATeacherRegression(const GeneralOptions &general_options, const TMVAOptionsRegression &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
TMVAOptionsRegression specific_options
Method specific options.
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Teacher for the TMVA MVA method.
TMVATeacher(const GeneralOptions &general_options, const TMVAOptions &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Weightfile trainFactory(TMVA::Factory &factory, TMVA::DataLoader &data_loader, const std::string &jobName) const
Train a mva method using the given data loader returning a Weightfile.
TMVAOptions specific_options
Method specific options.
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
GeneralOptions m_general_options
GeneralOptions containing all shared options.
The Weightfile class serializes all information about a training into an xml tree.
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
bool containsElement(const std::string &identifier) const
Returns true if given element is stored in the property tree.
void addOptions(const Options &options)
Add an Option object to the xml tree.
void getOptions(Options &options) const
Fills an Option object from the xml tree.
void addFeatureImportance(const std::map< std::string, float > &importance)
Add variable importance.
float getSignalFraction() const
Loads the signal fraction frm the xml tree.
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
static std::string invertMakeROOTCompatible(std::string str)
Invert makeROOTCompatible operation.
static ScopeGuard guardWorkingDirectory()
Create a ScopeGuard of the current working directory.
Abstract base class for different kinds of events.