Belle II Software light-2406-ragdoll
TMVA.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/methods/TMVA.h>
10#include <framework/logging/Logger.h>
11#include <framework/utilities/MakeROOTCompatible.h>
12#include <framework/utilities/ScopeGuard.h>
13
14#include <TPluginManager.h>
15
16#include <boost/algorithm/string.hpp>
17#include <filesystem>
18#include <memory>
19
20namespace Belle2 {
25 namespace MVA {
26
27 void TMVAOptions::load(const boost::property_tree::ptree& pt)
28 {
29 int version = pt.get<int>("TMVA_version");
30 if (version != 1) {
31 B2ERROR("Unknown weightfile version " << std::to_string(version));
32 throw std::runtime_error("Unknown weightfile version " + std::to_string(version));
33 }
34 m_method = pt.get<std::string>("TMVA_method");
35 m_type = pt.get<std::string>("TMVA_type");
36 m_config = pt.get<std::string>("TMVA_config");
37 m_factoryOption = pt.get<std::string>("TMVA_factoryOption");
38 m_prepareOption = pt.get<std::string>("TMVA_prepareOption");
39 m_workingDirectory = pt.get<std::string>("TMVA_workingDirectory");
40 m_prefix = pt.get<std::string>("TMVA_prefix");
41 }
42
43 void TMVAOptions::save(boost::property_tree::ptree& pt) const
44 {
45 pt.put("TMVA_version", 1);
46 pt.put("TMVA_method", m_method);
47 pt.put("TMVA_type", m_type);
48 pt.put("TMVA_config", m_config);
49 pt.put("TMVA_factoryOption", m_factoryOption);
50 pt.put("TMVA_prepareOption", m_prepareOption);
51 pt.put("TMVA_workingDirectory", m_workingDirectory);
52 pt.put("TMVA_prefix", m_prefix);
53 }
54
55 po::options_description TMVAOptions::getDescription()
56 {
57 po::options_description description("TMVA options");
58 description.add_options()
59 ("tmva_method", po::value<std::string>(&m_method), "TMVA Method Name")
60 ("tmva_type", po::value<std::string>(&m_type), "TMVA Method Type (e.g. Plugin, BDT, ...)")
61 ("tmva_config", po::value<std::string>(&m_config), "TMVA Configuration string for the method")
62 ("tmva_working_directory", po::value<std::string>(&m_workingDirectory), "TMVA working directory which stores e.g. TMVA.root")
63 ("tmva_factory", po::value<std::string>(&m_factoryOption), "TMVA Factory options passed to TMVAFactory constructor")
64 ("tmva_prepare", po::value<std::string>(&m_prepareOption),
65 "TMVA Preprare options passed to prepareTrainingAndTestTree function");
66 return description;
67 }
68
69 void TMVAOptionsClassification::load(const boost::property_tree::ptree& pt)
70 {
72 transform2probability = pt.get<bool>("TMVA_transform2probability");
73 }
74
75 void TMVAOptionsClassification::save(boost::property_tree::ptree& pt) const
76 {
78 pt.put("TMVA_transform2probability", transform2probability);
79 }
80
82 {
83 po::options_description description = TMVAOptions::getDescription();
84 description.add_options()
85 ("tmva_transform2probability", po::value<bool>(&transform2probability), "TMVA Transform output of classifier to a probability");
86 return description;
87 }
88
89 void TMVAOptionsMulticlass::load(const boost::property_tree::ptree& pt)
90 {
92
93 unsigned int numberOfClasses = pt.get<unsigned int>("TMVA_number_classes", 1);
94 m_classes.resize(numberOfClasses);
95 for (unsigned int i = 0; i < numberOfClasses; ++i) {
96 m_classes[i] = pt.get<std::string>(std::string("TMVA_classes") + std::to_string(i));
97 }
98 }
99
100 void TMVAOptionsMulticlass::save(boost::property_tree::ptree& pt) const
101 {
103
104 pt.put("TMVA_number_classes", m_classes.size());
105 for (unsigned int i = 0; i < m_classes.size(); ++i) {
106 pt.put(std::string("TMVA_classes") + std::to_string(i), m_classes[i]);
107 }
108 }
109
111 {
112 po::options_description description = TMVAOptions::getDescription();
113 description.add_options()
114 ("tmva_classes", po::value<std::vector<std::string>>(&m_classes)->required()->multitoken(),
115 "class name identifiers for multi-class mode");
116 return description;
117 }
118
119 TMVATeacher::TMVATeacher(const GeneralOptions& general_options, const TMVAOptions& _specific_options) : Teacher(general_options),
120 specific_options(_specific_options) { }
121
122 Weightfile TMVATeacher::trainFactory(TMVA::Factory& factory, TMVA::DataLoader& data_loader, const std::string& jobName) const
123 {
124 data_loader.PrepareTrainingAndTestTree("", specific_options.m_prepareOption);
125
126 if (specific_options.m_type == "Plugins") {
127 auto base = std::string("TMVA@@MethodBase");
128 auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
129 auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
130 auto className = std::string("TMVA::Method") + specific_options.m_method;
131 auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
132 auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
133 auto pluginName = std::string("TMVA") + specific_options.m_method;
134
135 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
136 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
137 }
138
139 if (!factory.BookMethod(&data_loader, specific_options.m_type, specific_options.m_method, specific_options.m_config)) {
140 B2ERROR("TMVA Method with name " + specific_options.m_method + " cannot be booked.");
141 }
142
143 Weightfile weightfile;
144 std::string logfilename = weightfile.generateFileName(".log");
145
146 // Pipe stdout into a logfile to get TMVA output, which contains valuable information
147 // which cannot be retrieved otherwise!
148 // Hence we do some black magic here
149 // TODO Using ROOT_VERSION 6.08 this should be possible without this workaround
150 auto logfile = open(logfilename.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0666);
151 auto saved_stdout = dup(STDOUT_FILENO);
152 dup2(logfile, 1);
153
154 factory.TrainAllMethods();
155 factory.TestAllMethods();
156 factory.EvaluateAllMethods();
157
158 // Reset original output
159 dup2(saved_stdout, STDOUT_FILENO);
160 close(saved_stdout);
161 close(logfile);
162
163
164 weightfile.addOptions(m_general_options);
165 weightfile.addFile("TMVA_Weightfile", std::string("TMVA/weights/") + jobName + "_" + specific_options.m_method + ".weights.xml");
166 weightfile.addFile("TMVA_Logfile", logfilename);
167
168 // We have to parse the TMVA output to get the feature importances, there is no other way currently
169 std::string begin = "Ranking input variables (method specific)";
170 std::string end = "-----------------------------------";
171 std::string line;
172 std::ifstream file(logfilename, std::ios::in);
173 std::map<std::string, float> feature_importances;
174 int state = 0;
175 while (std::getline(file, line)) {
176 if (state == 0 && line.find(begin) != std::string::npos) {
177 state = 1;
178 continue;
179 }
180 if (state >= 1 and state <= 4) {
181 state++;
182 continue;
183 }
184 if (state == 5) {
185 if (line.find(end) != std::string::npos)
186 break;
187 std::vector<std::string> strs;
188 boost::split(strs, line, boost::is_any_of(":"));
189 std::string variable = strs[2];
190 boost::trim(variable);
192 float importance = std::stof(strs[3]);
193 feature_importances[variable] = importance;
194 }
195 }
196 weightfile.addFeatureImportance(feature_importances);
197
198 return weightfile;
199
200 }
201
202
204 const TMVAOptionsClassification& _specific_options) : TMVATeacher(general_options, _specific_options),
205 specific_options(_specific_options) { }
206
208 {
209
210 unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
211 unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
212 unsigned int numberOfEvents = training_data.getNumberOfEvents();
213
214 std::string directory = specific_options.m_workingDirectory;
216 char* directory_template = strdup((std::filesystem::temp_directory_path() / "Basf2TMVA.XXXXXX").c_str());
217 directory = mkdtemp(directory_template);
218 free(directory_template);
219 }
220
221 // cppcheck-suppress unreadVariable
222 auto guard = ScopeGuard::guardWorkingDirectory(directory);
223
224 std::string jobName = specific_options.m_prefix;
225 if (jobName.empty())
226 jobName = "TMVA";
227 TFile classFile((jobName + ".root").c_str(), "RECREATE");
228 classFile.cd();
229
230 TMVA::Tools::Instance();
231 TMVA::DataLoader data_loader(jobName);
232 TMVA::Factory factory(jobName, &classFile, specific_options.m_factoryOption);
233
234
235 // Add variables to the factory
236 for (auto& var : m_general_options.m_variables) {
237 data_loader.AddVariable(Belle2::MakeROOTCompatible::makeROOTCompatible(var));
238 }
239
240 // Add variables to the factory
241 for (auto& var : m_general_options.m_spectators) {
242 data_loader.AddSpectator(Belle2::MakeROOTCompatible::makeROOTCompatible(var));
243 }
244
246
247 auto* signal_tree = new TTree("signal_tree", "signal_tree");
248 auto* background_tree = new TTree("background_tree", "background_tree");
249
250 for (unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
252 &training_data.m_input[iFeature]);
253 background_tree->Branch(Belle2::MakeROOTCompatible::makeROOTCompatible(m_general_options.m_variables[iFeature]).c_str(),
254 &training_data.m_input[iFeature]);
255 }
256
257 for (unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
259 &training_data.m_spectators[iSpectator]);
260 background_tree->Branch(Belle2::MakeROOTCompatible::makeROOTCompatible(m_general_options.m_spectators[iSpectator]).c_str(),
261 &training_data.m_spectators[iSpectator]);
262 }
263
264 signal_tree->Branch("__weight__", &training_data.m_weight);
265 background_tree->Branch("__weight__", &training_data.m_weight);
266
267 for (unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
268 training_data.loadEvent(iEvent);
269 if (training_data.m_isSignal) {
270 signal_tree->Fill();
271 } else {
272 background_tree->Fill();
273 }
274 }
275
276 data_loader.AddSignalTree(signal_tree);
277 data_loader.AddBackgroundTree(background_tree);
278 auto weightfile = trainFactory(factory, data_loader, jobName);
279
280 weightfile.addOptions(specific_options);
281 weightfile.addSignalFraction(training_data.getSignalFraction());
282
283 delete signal_tree;
284 delete background_tree;
285
287 std::filesystem::remove_all(directory);
288 }
289
290 return weightfile;
291
292 }
293
295 const TMVAOptionsMulticlass& _specific_options) : TMVATeacher(general_options, _specific_options),
296 specific_options(_specific_options) { }
297
298 // Implement me!
300 {
301 B2ERROR("Training TMVAMulticlass classifiers within the MVA package has not been implemented yet.");
302 (void) training_data;
303 return Weightfile();
304 }
305
307 const TMVAOptionsRegression& _specific_options) : TMVATeacher(general_options, _specific_options),
308 specific_options(_specific_options) { }
309
311 {
312
313 unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
314 unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
315 unsigned int numberOfEvents = training_data.getNumberOfEvents();
316
317 std::string directory = specific_options.m_workingDirectory;
319 char* directory_template = strdup((std::filesystem::temp_directory_path() / "Basf2TMVA.XXXXXX").c_str());
320 directory = mkdtemp(directory_template);
321 free(directory_template);
322 }
323
324 // cppcheck-suppress unreadVariable
325 auto guard = ScopeGuard::guardWorkingDirectory(directory);
326
327 std::string jobName = specific_options.m_prefix;
328 if (jobName.empty())
329 jobName = "TMVA";
330 TFile classFile((jobName + ".root").c_str(), "RECREATE");
331 classFile.cd();
332
333 TMVA::Tools::Instance();
334 TMVA::DataLoader data_loader(jobName);
335 TMVA::Factory factory(jobName, &classFile, specific_options.m_factoryOption);
336
337 // Add variables to the factory
338 for (auto& var : m_general_options.m_variables) {
339 data_loader.AddVariable(Belle2::MakeROOTCompatible::makeROOTCompatible(var));
340 }
341
342 // Add variables to the factory
343 for (auto& var : m_general_options.m_spectators) {
344 data_loader.AddSpectator(Belle2::MakeROOTCompatible::makeROOTCompatible(var));
345 }
346
348
349 auto* regression_tree = new TTree("regression_tree", "regression_tree");
350
351 for (unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
352 regression_tree->Branch(Belle2::MakeROOTCompatible::makeROOTCompatible(m_general_options.m_variables[iFeature]).c_str(),
353 &training_data.m_input[iFeature]);
354 }
355 for (unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
356 regression_tree->Branch(Belle2::MakeROOTCompatible::makeROOTCompatible(m_general_options.m_spectators[iSpectator]).c_str(),
357 &training_data.m_spectators[iSpectator]);
358 }
360 &training_data.m_target);
361
362 regression_tree->Branch("__weight__", &training_data.m_weight);
363
364 for (unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
365 training_data.loadEvent(iEvent);
366 regression_tree->Fill();
367 }
368
369 data_loader.AddRegressionTree(regression_tree);
370 data_loader.SetWeightExpression(Belle2::MakeROOTCompatible::makeROOTCompatible(m_general_options.m_weight_variable), "Regression");
371
372 auto weightfile = trainFactory(factory, data_loader, jobName);
373 weightfile.addOptions(specific_options);
374
375 delete regression_tree;
376
378 std::filesystem::remove_all(directory);
379 }
380
381 return weightfile;
382
383 }
384
385 void TMVAExpert::load(Weightfile& weightfile)
386 {
387
388 // Initialize TMVA and ROOT stuff
389 TMVA::Tools::Instance();
390
391 m_expert = std::make_unique<TMVA::Reader>("!Color:Silent");
392
393 GeneralOptions general_options;
394 weightfile.getOptions(general_options);
395 m_input_cache.resize(general_options.m_variables.size(), 0);
396 for (unsigned int i = 0; i < general_options.m_variables.size(); ++i) {
397 m_expert->AddVariable(Belle2::MakeROOTCompatible::makeROOTCompatible(general_options.m_variables[i]), &m_input_cache[i]);
398 }
399 m_spectators_cache.resize(general_options.m_spectators.size(), 0);
400 for (unsigned int i = 0; i < general_options.m_spectators.size(); ++i) {
401 m_expert->AddSpectator(Belle2::MakeROOTCompatible::makeROOTCompatible(general_options.m_spectators[i]), &m_spectators_cache[i]);
402 }
403
404 if (weightfile.containsElement("TMVA_Logfile")) {
405 std::string custom_weightfile = weightfile.generateFileName("logfile");
406 weightfile.getFile("TMVA_Logfile", custom_weightfile);
407 }
408
409 }
410
412 {
413
414 weightfile.getOptions(specific_options);
417 }
418
419 // TMVA parses the method type for plugins out of the weightfile name, so we must ensure that it has the expected format
420 std::string custom_weightfile = weightfile.generateFileName(std::string("_") + specific_options.m_method + ".weights.xml");
421 weightfile.getFile("TMVA_Weightfile", custom_weightfile);
422
423 TMVAExpert::load(weightfile);
424
425 if (specific_options.m_type == "Plugins") {
426 auto base = std::string("TMVA@@MethodBase");
427 auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
428 auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
429 auto className = std::string("TMVA::Method") + specific_options.m_method;
430 auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
431 auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
432 auto pluginName = std::string("TMVA") + specific_options.m_method;
433
434 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
435 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
436 B2INFO("Registered new TMVA Plugin named " << pluginName);
437 }
438
439 if (!m_expert->BookMVA(specific_options.m_method, custom_weightfile)) {
440 B2FATAL("Could not set up expert! Please see preceding error message from TMVA!");
441 }
442
443 }
444
446 {
447
448 weightfile.getOptions(specific_options);
449
450 // TMVA parses the method type for plugins out of the weightfile name, so we must ensure that it has the expected format
451 std::string custom_weightfile = weightfile.generateFileName(std::string("_") + specific_options.m_method + ".weights.xml");
452 weightfile.getFile("TMVA_Weightfile", custom_weightfile);
453
454 TMVAExpert::load(weightfile);
455
456 if (specific_options.m_type == "Plugins") {
457 auto base = std::string("TMVA@@MethodBase");
458 auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
459 auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
460 auto className = std::string("TMVA::Method") + specific_options.m_method;
461 auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
462 auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
463 auto pluginName = std::string("TMVA") + specific_options.m_method;
464
465 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
466 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
467 B2INFO("Registered new TMVA Plugin named " << pluginName);
468 }
469
470 if (!m_expert->BookMVA(specific_options.m_method, custom_weightfile)) {
471 B2FATAL("Could not set up expert! Please see preceding error message from TMVA!");
472 }
473
474 }
475
477 {
478
479 weightfile.getOptions(specific_options);
480
481 // TMVA parses the method type for plugins out of the weightfile name, so we must ensure that it has the expected format
482 std::string custom_weightfile = weightfile.generateFileName(std::string("_") + specific_options.m_method + ".weights.xml");
483 weightfile.getFile("TMVA_Weightfile", custom_weightfile);
484
485 TMVAExpert::load(weightfile);
486
487 if (specific_options.m_type == "Plugins") {
488 auto base = std::string("TMVA@@MethodBase");
489 auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
490 auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
491 auto className = std::string("TMVA::Method") + specific_options.m_method;
492 auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
493 auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
494 auto pluginName = std::string("TMVA") + specific_options.m_method;
495
496 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
497 gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
498 B2INFO("Registered new TMVA Plugin named " << pluginName);
499 }
500
501 if (!m_expert->BookMVA(specific_options.m_method, custom_weightfile)) {
502 B2FATAL("Could not set up expert! Please see preceding error message from TMVA!");
503 }
504
505 }
506
507 std::vector<float> TMVAExpertClassification::apply(Dataset& test_data) const
508 {
509
510 std::vector<float> probabilities(test_data.getNumberOfEvents());
511 for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
512 test_data.loadEvent(iEvent);
513 for (unsigned int i = 0; i < m_input_cache.size(); ++i)
514 m_input_cache[i] = test_data.m_input[i];
515 for (unsigned int i = 0; i < m_spectators_cache.size(); ++i)
516 m_spectators_cache[i] = test_data.m_spectators[i];
518 probabilities[iEvent] = m_expert->GetProba(specific_options.m_method, expert_signalFraction);
519 else
520 probabilities[iEvent] = m_expert->EvaluateMVA(specific_options.m_method);
521 }
522 return probabilities;
523
524 }
525
526 std::vector<std::vector<float>> TMVAExpertMulticlass::applyMulticlass(Dataset& test_data) const
527 {
528
529 std::vector<std::vector<float>> probabilities(test_data.getNumberOfEvents());
530
531 for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
532 test_data.loadEvent(iEvent);
533 for (unsigned int i = 0; i < m_input_cache.size(); ++i)
534 m_input_cache[i] = test_data.m_input[i];
535 for (unsigned int i = 0; i < m_spectators_cache.size(); ++i)
536 m_spectators_cache[i] = test_data.m_spectators[i];
537 probabilities[iEvent] = m_expert->EvaluateMulticlass(specific_options.m_method);
538 }
539 return probabilities;
540 }
541
542 std::vector<float> TMVAExpertRegression::apply(Dataset& test_data) const
543 {
544
545 std::vector<float> prediction(test_data.getNumberOfEvents());
546 for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
547 test_data.loadEvent(iEvent);
548 for (unsigned int i = 0; i < m_input_cache.size(); ++i)
549 m_input_cache[i] = test_data.m_input[i];
550 prediction[iEvent] = m_expert->EvaluateMVA(specific_options.m_method);
551 }
552 return prediction;
553
554 }
555
556 }
558}
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
General options which are shared by all MVA trainings.
Definition: Options.h:62
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:86
std::string m_weight_variable
Weight variable (branch name) defining the weights.
Definition: Options.h:91
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
Definition: Options.h:87
std::string m_target_variable
Target variable (branch name) defining the target.
Definition: Options.h:90
TMVAOptionsClassification specific_options
Method specific options.
Definition: TMVA.h:320
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.cc:507
float expert_signalFraction
Signal fraction used to calculate the probability.
Definition: TMVA.h:321
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:411
TMVAOptionsMulticlass specific_options
Method specific options.
Definition: TMVA.h:355
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:445
virtual std::vector< std::vector< float > > applyMulticlass(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.cc:526
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.cc:542
TMVAOptionsRegression specific_options
Method specific options.
Definition: TMVA.h:378
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:476
std::vector< float > m_input_cache
Input Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply call.
Definition: TMVA.h:296
std::unique_ptr< TMVA::Reader > m_expert
TMVA::Reader pointer.
Definition: TMVA.h:294
std::vector< float > m_spectators_cache
Spectators Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply ...
Definition: TMVA.h:298
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:385
Options for the TMVA Classification MVA method.
Definition: TMVA.h:80
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:81
bool transform2probability
Transform output of method to a probability.
Definition: TMVA.h:115
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:69
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:75
Options for the TMVA Multiclass MVA method.
Definition: TMVA.h:122
std::vector< std::string > m_classes
Class name identifiers.
Definition: TMVA.h:158
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:110
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:89
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:100
Options for the TMVA Regression MVA method.
Definition: TMVA.h:166
Options for the TMVA MVA method.
Definition: TMVA.h:34
std::string m_prepareOption
Prepare options passed to prepareTrainingAndTestTree method.
Definition: TMVA.h:72
std::string m_prefix
Prefix used for all files generated by TMVA.
Definition: TMVA.h:74
std::string m_config
TMVA config string for the chosen method.
Definition: TMVA.h:66
std::string m_method
tmva method name
Definition: TMVA.h:60
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:55
std::string m_factoryOption
Factory options passed to tmva factory.
Definition: TMVA.h:71
std::string m_type
tmva method type
Definition: TMVA.h:61
std::string m_workingDirectory
Working directory of TMVA, if empty a temporary directory is used.
Definition: TMVA.h:73
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:27
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:43
TMVATeacherClassification(const GeneralOptions &general_options, const TMVAOptionsClassification &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:203
TMVAOptionsClassification specific_options
Method specific options.
Definition: TMVA.h:231
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: TMVA.cc:207
TMVATeacherMulticlass(const GeneralOptions &general_options, const TMVAOptionsMulticlass &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:294
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: TMVA.cc:299
TMVATeacherRegression(const GeneralOptions &general_options, const TMVAOptionsRegression &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:306
TMVAOptionsRegression specific_options
Method specific options.
Definition: TMVA.h:277
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: TMVA.cc:310
Teacher for the TMVA MVA method.
Definition: TMVA.h:188
TMVATeacher(const GeneralOptions &general_options, const TMVAOptions &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:119
Weightfile trainFactory(TMVA::Factory &factory, TMVA::DataLoader &data_loader, const std::string &jobName) const
Train a mva method using the given data loader returning a Weightfile.
Definition: TMVA.cc:122
TMVAOptions specific_options
Method specific options.
Definition: TMVA.h:207
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
Definition: Teacher.h:29
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition: Teacher.h:49
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
Definition: Weightfile.cc:115
bool containsElement(const std::string &identifier) const
Returns true if given element is stored in the property tree.
Definition: Weightfile.h:160
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition: Weightfile.cc:62
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:67
void addFeatureImportance(const std::map< std::string, float > &importance)
Add variable importance.
Definition: Weightfile.cc:72
float getSignalFraction() const
Loads the signal fraction frm the xml tree.
Definition: Weightfile.cc:100
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
Definition: Weightfile.cc:105
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Definition: Weightfile.cc:138
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
static std::string invertMakeROOTCompatible(std::string str)
Invert makeROOTCompatible operation.
static ScopeGuard guardWorkingDirectory()
Create a ScopeGuard of the current working directory.
Definition: ScopeGuard.h:296
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:24