9 #include <mva/methods/TMVA.h> 
   10 #include <framework/logging/Logger.h> 
   11 #include <framework/utilities/MakeROOTCompatible.h> 
   12 #include <framework/utilities/ScopeGuard.h> 
   14 #include <TPluginManager.h> 
   16 #include <boost/algorithm/string.hpp> 
   17 #include <boost/filesystem/operations.hpp> 
   29       int version = pt.get<
int>(
"TMVA_version");
 
   31         B2ERROR(
"Unknown weightfile version " << std::to_string(version));
 
   32         throw std::runtime_error(
"Unknown weightfile version " + std::to_string(version));
 
   34       m_method = pt.get<std::string>(
"TMVA_method");
 
   35       m_type = pt.get<std::string>(
"TMVA_type");
 
   36       m_config = pt.get<std::string>(
"TMVA_config");
 
   40       m_prefix = pt.get<std::string>(
"TMVA_prefix");
 
   45       pt.put(
"TMVA_version", 1);
 
   47       pt.put(
"TMVA_type", 
m_type);
 
   57       po::options_description description(
"TMVA options");
 
   58       description.add_options()
 
   59       (
"tmva_method", po::value<std::string>(&
m_method), 
"TMVA Method Name")
 
   60       (
"tmva_type", po::value<std::string>(&
m_type), 
"TMVA Method Type (e.g. Plugin, BDT, ...)")
 
   61       (
"tmva_config", po::value<std::string>(&
m_config), 
"TMVA Configuration string for the method")
 
   62       (
"tmva_working_directory", po::value<std::string>(&
m_workingDirectory), 
"TMVA working directory which stores e.g. TMVA.root")
 
   63       (
"tmva_factory", po::value<std::string>(&
m_factoryOption), 
"TMVA Factory options passed to TMVAFactory constructor")
 
   65        "TMVA Preprare options passed to prepareTrainingAndTestTree function");
 
   84       description.add_options()
 
   85       (
"tmva_transform2probability", po::value<bool>(&
transform2probability), 
"TMVA Transform output of classifier to a probability");
 
   93       unsigned int numberOfClasses = pt.get<
unsigned int>(
"TMVA_number_classes", 1);
 
   95       for (
unsigned int i = 0; i < numberOfClasses; ++i) {
 
   96         m_classes[i] = pt.get<std::string>(std::string(
"TMVA_classes") + std::to_string(i));
 
  104       pt.put(
"TMVA_number_classes", 
m_classes.size());
 
  105       for (
unsigned int i = 0; i < 
m_classes.size(); ++i) {
 
  106         pt.put(std::string(
"TMVA_classes") + std::to_string(i), 
m_classes[i]);
 
  113       description.add_options()
 
  114       (
"tmva_classes", po::value<std::vector<std::string>>(&
m_classes)->required()->multitoken(),
 
  115        "class name identifiers for multi-class mode");
 
  120       specific_options(_specific_options) { }
 
  127         auto base = std::string(
"TMVA@@MethodBase");
 
  132         auto ctor2 = std::string(
"Method") + 
specific_options.
m_method + std::string(
"(TString&,TString&,TMVA::DataSetInfo&,TString&)");
 
  135         gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
 
  136         gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
 
  150       auto logfile = open(logfilename.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0666);
 
  151       auto saved_stdout = dup(STDOUT_FILENO);
 
  154       factory.TrainAllMethods();
 
  155       factory.TestAllMethods();
 
  156       factory.EvaluateAllMethods();
 
  159       dup2(saved_stdout, STDOUT_FILENO);
 
  166       weightfile.
addFile(
"TMVA_Logfile", logfilename);
 
  169       std::string begin = 
"Ranking input variables (method specific)";
 
  170       std::string end = 
"-----------------------------------";
 
  172       std::ifstream file(logfilename, std::ios::in);
 
  173       std::map<std::string, float> feature_importances;
 
  175       while (std::getline(file, line)) {
 
  176         if (state == 0 && line.find(begin) != std::string::npos) {
 
  180         if (state >= 1 and state <= 4) {
 
  185           if (line.find(end) != std::string::npos)
 
  187           std::vector<std::string> strs;
 
  188           boost::split(strs, line, boost::is_any_of(
":"));
 
  189           std::string variable = strs[2];
 
  190           boost::trim(variable);
 
  192           float importance = std::stof(strs[3]);
 
  193           feature_importances[variable] = importance;
 
  205       specific_options(_specific_options) { }
 
  210       unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
 
  211       unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
 
  212       unsigned int numberOfEvents = training_data.getNumberOfEvents();
 
  216         char* directory_template = strdup(
"/tmp/Basf2TMVA.XXXXXX");
 
  217         directory = mkdtemp(directory_template);
 
  218         free(directory_template);
 
  227       TFile classFile((jobName + 
".root").c_str(), 
"RECREATE");
 
  230       TMVA::Tools::Instance();
 
  231       TMVA::DataLoader data_loader(jobName);
 
  247       auto* signal_tree = 
new TTree(
"signal_tree", 
"signal_tree");
 
  248       auto* background_tree = 
new TTree(
"background_tree", 
"background_tree");
 
  250       for (
unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
 
  252                             &training_data.m_input[iFeature]);
 
  254                                 &training_data.m_input[iFeature]);
 
  257       for (
unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
 
  259                             &training_data.m_spectators[iSpectator]);
 
  261                                 &training_data.m_spectators[iSpectator]);
 
  264       signal_tree->Branch(
"__weight__", &training_data.m_weight);
 
  265       background_tree->Branch(
"__weight__", &training_data.m_weight);
 
  267       for (
unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
 
  268         training_data.loadEvent(iEvent);
 
  269         if (training_data.m_isSignal) {
 
  272           background_tree->Fill();
 
  276       data_loader.AddSignalTree(signal_tree);
 
  277       data_loader.AddBackgroundTree(background_tree);
 
  278       auto weightfile = 
trainFactory(factory, data_loader, jobName);
 
  281       weightfile.addSignalFraction(training_data.getSignalFraction());
 
  284       delete background_tree;
 
  287         boost::filesystem::remove_all(directory);
 
  296       specific_options(_specific_options) { }
 
  301       B2ERROR(
"Training TMVAMulticlass classifiers within the MVA package has not been implemented yet.");
 
  302       (void) training_data;
 
  308       specific_options(_specific_options) { }
 
  313       unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
 
  314       unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
 
  315       unsigned int numberOfEvents = training_data.getNumberOfEvents();
 
  319         char* directory_template = strdup(
"/tmp/Basf2TMVA.XXXXXX");
 
  320         directory = mkdtemp(directory_template);
 
  321         free(directory_template);
 
  330       TFile classFile((jobName + 
".root").c_str(), 
"RECREATE");
 
  333       TMVA::Tools::Instance();
 
  334       TMVA::DataLoader data_loader(jobName);
 
  349       auto* regression_tree = 
new TTree(
"regression_tree", 
"regression_tree");
 
  351       for (
unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
 
  353                                 &training_data.m_input[iFeature]);
 
  355       for (
unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
 
  357                                 &training_data.m_spectators[iSpectator]);
 
  360                               &training_data.m_target);
 
  362       regression_tree->Branch(
"__weight__", &training_data.m_weight);
 
  364       for (
unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
 
  365         training_data.loadEvent(iEvent);
 
  366         regression_tree->Fill();
 
  369       data_loader.AddRegressionTree(regression_tree);
 
  372       auto weightfile = 
trainFactory(factory, data_loader, jobName);
 
  375       delete regression_tree;
 
  378         boost::filesystem::remove_all(directory);
 
  389       TMVA::Tools::Instance();
 
  391       m_expert = std::make_unique<TMVA::Reader>(
"!Color:Silent");
 
  396       for (
unsigned int i = 0; i < general_options.m_variables.size(); ++i) {
 
  400       for (
unsigned int i = 0; i < general_options.m_spectators.size(); ++i) {
 
  406         weightfile.
getFile(
"TMVA_Logfile", custom_weightfile);
 
  421       weightfile.
getFile(
"TMVA_Weightfile", custom_weightfile);
 
  426         auto base = std::string(
"TMVA@@MethodBase");
 
  431         auto ctor2 = std::string(
"Method") + 
specific_options.
m_method + std::string(
"(TString&,TString&,TMVA::DataSetInfo&,TString&)");
 
  434         gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
 
  435         gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
 
  436         B2INFO(
"Registered new TMVA Plugin named " << pluginName);
 
  440         B2FATAL(
"Could not set up expert! Please see preceding error message from TMVA!");
 
  452       weightfile.
getFile(
"TMVA_Weightfile", custom_weightfile);
 
  457         auto base = std::string(
"TMVA@@MethodBase");
 
  462         auto ctor2 = std::string(
"Method") + 
specific_options.
m_method + std::string(
"(TString&,TString&,TMVA::DataSetInfo&,TString&)");
 
  465         gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
 
  466         gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
 
  467         B2INFO(
"Registered new TMVA Plugin named " << pluginName);
 
  471         B2FATAL(
"Could not set up expert! Please see preceding error message from TMVA!");
 
  483       weightfile.
getFile(
"TMVA_Weightfile", custom_weightfile);
 
  488         auto base = std::string(
"TMVA@@MethodBase");
 
  493         auto ctor2 = std::string(
"Method") + 
specific_options.
m_method + std::string(
"(TString&,TString&,TMVA::DataSetInfo&,TString&)");
 
  496         gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
 
  497         gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
 
  498         B2INFO(
"Registered new TMVA Plugin named " << pluginName);
 
  502         B2FATAL(
"Could not set up expert! Please see preceding error message from TMVA!");
 
  510       std::vector<float> probabilities(test_data.getNumberOfEvents());
 
  511       for (
unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
 
  512         test_data.loadEvent(iEvent);
 
  522       return probabilities;
 
  529       std::vector<std::vector<float>> probabilities(test_data.getNumberOfEvents());
 
  531       for (
unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
 
  532         test_data.loadEvent(iEvent);
 
  539       return probabilities;
 
  545       std::vector<float> prediction(test_data.getNumberOfEvents());
 
  546       for (
unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
 
  547         test_data.loadEvent(iEvent);
 
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
General options which are shared by all MVA trainings.
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
std::string m_weight_variable
Weight variable (branch name) defining the weights.
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
std::string m_target_variable
Target variable (branch name) defining the target.
TMVAOptionsClassification specific_options
Method specific options.
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this m_expert onto a dataset.
float expert_signalFraction
Signal fraction used to calculate the probability.
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
TMVAOptionsMulticlass specific_options
Method specific options.
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
virtual std::vector< std::vector< float > > applyMulticlass(Dataset &test_data) const override
Apply this m_expert onto a dataset.
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this m_expert onto a dataset.
TMVAOptionsRegression specific_options
Method specific options.
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
std::vector< float > m_input_cache
Input Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply call.
std::unique_ptr< TMVA::Reader > m_expert
TMVA::Reader pointer.
std::vector< float > m_spectators_cache
Spectators Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply ...
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Options for the TMVA Classification MVA method.
virtual po::options_description getDescription() override
Returns a program options description for all available options.
bool transform2probability
Transform output of method to a probability.
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Options for the TMVA Multiclass MVA method.
std::vector< std::string > m_classes
Class name identifiers.
virtual po::options_description getDescription() override
Returns a program options description for all available options.
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Options for the TMVA Regression MVA method.
Options for the TMVA MVA method.
std::string m_prepareOption
Prepare options passed to prepareTrainingAndTestTree method.
std::string m_prefix
Prefix used for all files generated by TMVA.
std::string m_config
TMVA config string for the chosen method.
std::string m_method
tmva method name
virtual po::options_description getDescription() override
Returns a program options description for all available options.
std::string m_factoryOption
Factory options passed to tmva factory.
std::string m_type
tmva method type
std::string m_workingDirectory
Working directory of TMVA, if empty a temporary directory is used.
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
TMVATeacherClassification(const GeneralOptions &general_options, const TMVAOptionsClassification &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
TMVAOptionsClassification specific_options
Method specific options.
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
TMVATeacherMulticlass(const GeneralOptions &general_options, const TMVAOptionsMulticlass &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
TMVATeacherRegression(const GeneralOptions &general_options, const TMVAOptionsRegression &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
TMVAOptionsRegression specific_options
Method specific options.
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Teacher for the TMVA MVA method.
TMVATeacher(const GeneralOptions &general_options, const TMVAOptions &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Weightfile trainFactory(TMVA::Factory &factory, TMVA::DataLoader &data_loader, const std::string &jobName) const
Train a mva method using the given data loader returning a Weightfile.
TMVAOptions specific_options
Method specific options.
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
GeneralOptions m_general_options
GeneralOptions containing all shared options.
The Weightfile class serializes all information about a training into an xml tree.
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
bool containsElement(const std::string &identifier) const
Returns true if given element is stored in the property tree.
void addOptions(const Options &options)
Add an Option object to the xml tree.
void getOptions(Options &options) const
Fills an Option object from the xml tree.
void addFeatureImportance(const std::map< std::string, float > &importance)
Add variable importance.
float getSignalFraction() const
Loads the signal fraction frm the xml tree.
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
static std::string invertMakeROOTCompatible(std::string str)
Invert makeROOTCompatible operation.
static ScopeGuard guardWorkingDirectory()
Create a ScopeGuard of the current working directory.
Abstract base class for different kinds of events.