11 #include <mva/methods/TMVA.h>
12 #include <framework/logging/Logger.h>
13 #include <framework/utilities/MakeROOTCompatible.h>
14 #include <framework/utilities/ScopeGuard.h>
16 #include <TPluginManager.h>
18 #include <boost/algorithm/string.hpp>
19 #include <boost/filesystem/operations.hpp>
31 int version = pt.get<
int>(
"TMVA_version");
33 B2ERROR(
"Unkown weightfile version " << std::to_string(version));
34 throw std::runtime_error(
"Unkown weightfile version " + std::to_string(version));
36 m_method = pt.get<std::string>(
"TMVA_method");
37 m_type = pt.get<std::string>(
"TMVA_type");
38 m_config = pt.get<std::string>(
"TMVA_config");
42 m_prefix = pt.get<std::string>(
"TMVA_prefix");
47 pt.put(
"TMVA_version", 1);
49 pt.put(
"TMVA_type",
m_type);
59 po::options_description description(
"TMVA options");
60 description.add_options()
61 (
"tmva_method", po::value<std::string>(&
m_method),
"TMVA Method Name")
62 (
"tmva_type", po::value<std::string>(&
m_type),
"TMVA Method Type (e.g. Plugin, BDT, ...)")
63 (
"tmva_config", po::value<std::string>(&
m_config),
"TMVA Configuration string for the method")
64 (
"tmva_working_directory", po::value<std::string>(&
m_workingDirectory),
"TMVA working directory which stores e.g. TMVA.root")
65 (
"tmva_factory", po::value<std::string>(&
m_factoryOption),
"TMVA Factory options passed to TMVAFactory constructor")
67 "TMVA Preprare options passed to prepareTrainingAndTestTree function");
86 description.add_options()
87 (
"tmva_transform2probability", po::value<bool>(&
transform2probability),
"TMVA Transform output of classifier to a probability");
95 unsigned int numberOfClasses = pt.get<
unsigned int>(
"TMVA_number_classes", 1);
97 for (
unsigned int i = 0; i < numberOfClasses; ++i) {
98 m_classes[i] = pt.get<std::string>(std::string(
"TMVA_classes") + std::to_string(i));
106 pt.put(
"TMVA_number_classes",
m_classes.size());
107 for (
unsigned int i = 0; i <
m_classes.size(); ++i) {
108 pt.put(std::string(
"TMVA_classes") + std::to_string(i),
m_classes[i]);
115 description.add_options()
116 (
"tmva_classes", po::value<std::vector<std::string>>(&
m_classes)->required()->multitoken(),
117 "class name identifiers for multi-class mode");
122 specific_options(_specific_options) { }
124 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
130 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
131 data_loader.PrepareTrainingAndTestTree(
"", specific_options.m_prepareOption);
133 factory.PrepareTrainingAndTestTree(
"", specific_options.m_prepareOption);
136 if (specific_options.m_type ==
"Plugins") {
137 auto base = std::string(
"TMVA@@MethodBase");
138 auto regexp1 = std::string(
".*_") + specific_options.m_method + std::string(
".*");
139 auto regexp2 = std::string(
".*") + specific_options.m_method + std::string(
".*");
140 auto className = std::string(
"TMVA::Method") + specific_options.m_method;
141 auto ctor1 = std::string(
"Method") + specific_options.m_method + std::string(
"(TMVA::DataSetInfo&,TString)");
142 auto ctor2 = std::string(
"Method") + specific_options.m_method + std::string(
"(TString&,TString&,TMVA::DataSetInfo&,TString&)");
143 auto pluginName = std::string(
"TMVA") + specific_options.m_method;
145 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
146 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
149 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
150 if (!factory.BookMethod(&data_loader, specific_options.m_type, specific_options.m_method, specific_options.m_config)) {
152 if (!factory.BookMethod(specific_options.m_type, specific_options.m_method, specific_options.m_config)) {
154 B2ERROR(
"TMVA Method with name " + specific_options.m_method +
" cannot be booked.");
157 Weightfile weightfile;
158 std::string logfilename = weightfile.generateFileName(
".log");
164 auto logfile = open(logfilename.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0666);
165 auto saved_stdout = dup(STDOUT_FILENO);
168 factory.TrainAllMethods();
169 factory.TestAllMethods();
170 factory.EvaluateAllMethods();
173 dup2(saved_stdout, STDOUT_FILENO);
178 weightfile.addOptions(m_general_options);
179 weightfile.addFile(
"TMVA_Weightfile", std::string(
"TMVA/weights/") + jobName +
"_" + specific_options.m_method +
".weights.xml");
180 weightfile.addFile(
"TMVA_Logfile", logfilename);
183 std::string begin =
"Ranking input variables (method specific)";
184 std::string end =
"-----------------------------------";
186 std::ifstream file(logfilename, std::ios::in);
187 std::map<std::string, float> feature_importances;
189 while (std::getline(file, line)) {
190 if (state == 0 && line.find(begin) != std::string::npos) {
194 if (state >= 1 and state <= 4) {
199 if (line.find(end) != std::string::npos)
201 std::vector<std::string> strs;
202 boost::split(strs, line, boost::is_any_of(
":"));
203 std::string variable = strs[2];
204 boost::trim(variable);
206 float importance = std::stof(strs[3]);
207 feature_importances[variable] = importance;
210 weightfile.addFeatureImportance(feature_importances);
218 const TMVAOptionsClassification& _specific_options) : TMVATeacher(general_options, _specific_options),
219 specific_options(_specific_options) { }
224 unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
225 unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
226 unsigned int numberOfEvents = training_data.getNumberOfEvents();
230 char* directory_template = strdup(
"/tmp/Basf2TMVA.XXXXXX");
231 directory = mkdtemp(directory_template);
232 free(directory_template);
240 TFile classFile((jobName +
".root").c_str(),
"RECREATE");
243 TMVA::Tools::Instance();
244 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
245 TMVA::DataLoader data_loader(jobName);
252 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
261 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
268 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
274 auto* signal_tree =
new TTree(
"signal_tree",
"signal_tree");
275 auto* background_tree =
new TTree(
"background_tree",
"background_tree");
277 for (
unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
279 &training_data.m_input[iFeature]);
281 &training_data.m_input[iFeature]);
284 for (
unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
286 &training_data.m_spectators[iSpectator]);
288 &training_data.m_spectators[iSpectator]);
291 signal_tree->Branch(
"__weight__", &training_data.m_weight);
292 background_tree->Branch(
"__weight__", &training_data.m_weight);
294 for (
unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
295 training_data.loadEvent(iEvent);
296 if (training_data.m_isSignal) {
299 background_tree->Fill();
303 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
304 data_loader.AddSignalTree(signal_tree);
305 data_loader.AddBackgroundTree(background_tree);
306 auto weightfile =
trainFactory(factory, data_loader, jobName);
308 factory.AddSignalTree(signal_tree);
309 factory.AddBackgroundTree(background_tree);
314 weightfile.addSignalFraction(training_data.getSignalFraction());
317 delete background_tree;
320 boost::filesystem::remove_all(directory);
328 const TMVAOptionsMulticlass& _specific_options) : TMVATeacher(general_options, _specific_options),
329 specific_options(_specific_options) { }
334 (void) training_data;
340 specific_options(_specific_options) { }
345 unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
346 unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
347 unsigned int numberOfEvents = training_data.getNumberOfEvents();
351 char* directory_template = strdup(
"/tmp/Basf2TMVA.XXXXXX");
352 directory = mkdtemp(directory_template);
353 free(directory_template);
361 TFile classFile((jobName +
".root").c_str(),
"RECREATE");
364 TMVA::Tools::Instance();
365 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
366 TMVA::DataLoader data_loader(jobName);
372 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
381 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
388 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
395 auto* regression_tree =
new TTree(
"regression_tree",
"regression_tree");
397 for (
unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
399 &training_data.m_input[iFeature]);
401 for (
unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
403 &training_data.m_spectators[iSpectator]);
406 &training_data.m_target);
408 regression_tree->Branch(
"__weight__", &training_data.m_weight);
410 for (
unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
411 training_data.loadEvent(iEvent);
412 regression_tree->Fill();
415 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
416 data_loader.AddRegressionTree(regression_tree);
419 auto weightfile =
trainFactory(factory, data_loader, jobName);
421 factory.AddRegressionTree(regression_tree);
428 delete regression_tree;
431 boost::filesystem::remove_all(directory);
442 TMVA::Tools::Instance();
444 m_expert = std::make_unique<TMVA::Reader>(
"!Color:!Silent");
447 weightfile.getOptions(general_options);
449 for (
unsigned int i = 0; i < general_options.m_variables.size(); ++i) {
453 for (
unsigned int i = 0; i < general_options.m_spectators.size(); ++i) {
457 if (weightfile.containsElement(
"TMVA_Logfile")) {
458 std::string custom_weightfile = weightfile.generateFileName(
"logfile");
459 weightfile.getFile(
"TMVA_Logfile", custom_weightfile);
473 std::string custom_weightfile = weightfile.generateFileName(std::string(
"_") +
specific_options.
m_method +
".weights.xml");
474 weightfile.getFile(
"TMVA_Weightfile", custom_weightfile);
479 auto base = std::string(
"TMVA@@MethodBase");
484 auto ctor2 = std::string(
"Method") +
specific_options.
m_method + std::string(
"(TString&,TString&,TMVA::DataSetInfo&,TString&)");
487 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
488 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
489 B2INFO(
"Registered new TMVA Plugin named " << pluginName);
493 B2FATAL(
"Could not set up expert! Please see preceding error message from TMVA!");
504 std::string custom_weightfile = weightfile.generateFileName(std::string(
"_") +
specific_options.
m_method +
".weights.xml");
505 weightfile.getFile(
"TMVA_Weightfile", custom_weightfile);
510 auto base = std::string(
"TMVA@@MethodBase");
515 auto ctor2 = std::string(
"Method") +
specific_options.
m_method + std::string(
"(TString&,TString&,TMVA::DataSetInfo&,TString&)");
518 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
519 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
520 B2INFO(
"Registered new TMVA Plugin named " << pluginName);
524 B2FATAL(
"Could not set up expert! Please see preceding error message from TMVA!");
535 std::string custom_weightfile = weightfile.generateFileName(std::string(
"_") +
specific_options.
m_method +
".weights.xml");
536 weightfile.getFile(
"TMVA_Weightfile", custom_weightfile);
541 auto base = std::string(
"TMVA@@MethodBase");
546 auto ctor2 = std::string(
"Method") +
specific_options.
m_method + std::string(
"(TString&,TString&,TMVA::DataSetInfo&,TString&)");
549 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
550 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
551 B2INFO(
"Registered new TMVA Plugin named " << pluginName);
555 B2FATAL(
"Could not set up expert! Please see preceding error message from TMVA!");
563 std::vector<float> probabilities(test_data.getNumberOfEvents());
564 for (
unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
565 test_data.loadEvent(iEvent);
575 return probabilities;
582 std::vector<float> probabilities(test_data.getNumberOfEvents());
583 for (
unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
584 test_data.loadEvent(iEvent);
591 return probabilities;
597 std::vector<float> prediction(test_data.getNumberOfEvents());
598 for (
unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
599 test_data.loadEvent(iEvent);