Belle II Software  release-08-01-10
TMVA.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <mva/methods/TMVA.h>
10 #include <framework/logging/Logger.h>
11 #include <framework/utilities/MakeROOTCompatible.h>
12 #include <framework/utilities/ScopeGuard.h>
13 
14 #include <TPluginManager.h>
15 
16 #include <boost/algorithm/string.hpp>
17 #include <filesystem>
18 #include <memory>
19 
20 namespace Belle2 {
25  namespace MVA {
26 
27  void TMVAOptions::load(const boost::property_tree::ptree& pt)
28  {
29  int version = pt.get<int>("TMVA_version");
30  if (version != 1) {
31  B2ERROR("Unknown weightfile version " << std::to_string(version));
32  throw std::runtime_error("Unknown weightfile version " + std::to_string(version));
33  }
34  m_method = pt.get<std::string>("TMVA_method");
35  m_type = pt.get<std::string>("TMVA_type");
36  m_config = pt.get<std::string>("TMVA_config");
37  m_factoryOption = pt.get<std::string>("TMVA_factoryOption");
38  m_prepareOption = pt.get<std::string>("TMVA_prepareOption");
39  m_workingDirectory = pt.get<std::string>("TMVA_workingDirectory");
40  m_prefix = pt.get<std::string>("TMVA_prefix");
41  }
42 
43  void TMVAOptions::save(boost::property_tree::ptree& pt) const
44  {
45  pt.put("TMVA_version", 1);
46  pt.put("TMVA_method", m_method);
47  pt.put("TMVA_type", m_type);
48  pt.put("TMVA_config", m_config);
49  pt.put("TMVA_factoryOption", m_factoryOption);
50  pt.put("TMVA_prepareOption", m_prepareOption);
51  pt.put("TMVA_workingDirectory", m_workingDirectory);
52  pt.put("TMVA_prefix", m_prefix);
53  }
54 
55  po::options_description TMVAOptions::getDescription()
56  {
57  po::options_description description("TMVA options");
58  description.add_options()
59  ("tmva_method", po::value<std::string>(&m_method), "TMVA Method Name")
60  ("tmva_type", po::value<std::string>(&m_type), "TMVA Method Type (e.g. Plugin, BDT, ...)")
61  ("tmva_config", po::value<std::string>(&m_config), "TMVA Configuration string for the method")
62  ("tmva_working_directory", po::value<std::string>(&m_workingDirectory), "TMVA working directory which stores e.g. TMVA.root")
63  ("tmva_factory", po::value<std::string>(&m_factoryOption), "TMVA Factory options passed to TMVAFactory constructor")
64  ("tmva_prepare", po::value<std::string>(&m_prepareOption),
65  "TMVA Preprare options passed to prepareTrainingAndTestTree function");
66  return description;
67  }
68 
69  void TMVAOptionsClassification::load(const boost::property_tree::ptree& pt)
70  {
72  transform2probability = pt.get<bool>("TMVA_transform2probability");
73  }
74 
75  void TMVAOptionsClassification::save(boost::property_tree::ptree& pt) const
76  {
78  pt.put("TMVA_transform2probability", transform2probability);
79  }
80 
82  {
83  po::options_description description = TMVAOptions::getDescription();
84  description.add_options()
85  ("tmva_transform2probability", po::value<bool>(&transform2probability), "TMVA Transform output of classifier to a probability");
86  return description;
87  }
88 
89  void TMVAOptionsMulticlass::load(const boost::property_tree::ptree& pt)
90  {
92 
93  unsigned int numberOfClasses = pt.get<unsigned int>("TMVA_number_classes", 1);
94  m_classes.resize(numberOfClasses);
95  for (unsigned int i = 0; i < numberOfClasses; ++i) {
96  m_classes[i] = pt.get<std::string>(std::string("TMVA_classes") + std::to_string(i));
97  }
98  }
99 
100  void TMVAOptionsMulticlass::save(boost::property_tree::ptree& pt) const
101  {
102  TMVAOptions::save(pt);
103 
104  pt.put("TMVA_number_classes", m_classes.size());
105  for (unsigned int i = 0; i < m_classes.size(); ++i) {
106  pt.put(std::string("TMVA_classes") + std::to_string(i), m_classes[i]);
107  }
108  }
109 
110  po::options_description TMVAOptionsMulticlass::getDescription()
111  {
112  po::options_description description = TMVAOptions::getDescription();
113  description.add_options()
114  ("tmva_classes", po::value<std::vector<std::string>>(&m_classes)->required()->multitoken(),
115  "class name identifiers for multi-class mode");
116  return description;
117  }
118 
119  TMVATeacher::TMVATeacher(const GeneralOptions& general_options, const TMVAOptions& _specific_options) : Teacher(general_options),
120  specific_options(_specific_options) { }
121 
122  Weightfile TMVATeacher::trainFactory(TMVA::Factory& factory, TMVA::DataLoader& data_loader, const std::string& jobName) const
123  {
124  data_loader.PrepareTrainingAndTestTree("", specific_options.m_prepareOption);
125 
126  if (specific_options.m_type == "Plugins") {
127  auto base = std::string("TMVA@@MethodBase");
128  auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
129  auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
130  auto className = std::string("TMVA::Method") + specific_options.m_method;
131  auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
132  auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
133  auto pluginName = std::string("TMVA") + specific_options.m_method;
134 
135  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
136  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
137  }
138 
139  if (!factory.BookMethod(&data_loader, specific_options.m_type, specific_options.m_method, specific_options.m_config)) {
140  B2ERROR("TMVA Method with name " + specific_options.m_method + " cannot be booked.");
141  }
142 
143  Weightfile weightfile;
144  std::string logfilename = weightfile.generateFileName(".log");
145 
146  // Pipe stdout into a logfile to get TMVA output, which contains valuable information
147  // which cannot be retrieved otherwise!
148  // Hence we do some black magic here
149  // TODO Using ROOT_VERSION 6.08 this should be possible without this workaround
150  auto logfile = open(logfilename.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0666);
151  auto saved_stdout = dup(STDOUT_FILENO);
152  dup2(logfile, 1);
153 
154  factory.TrainAllMethods();
155  factory.TestAllMethods();
156  factory.EvaluateAllMethods();
157 
158  // Reset original output
159  dup2(saved_stdout, STDOUT_FILENO);
160  close(saved_stdout);
161  close(logfile);
162 
163 
164  weightfile.addOptions(m_general_options);
165  weightfile.addFile("TMVA_Weightfile", std::string("TMVA/weights/") + jobName + "_" + specific_options.m_method + ".weights.xml");
166  weightfile.addFile("TMVA_Logfile", logfilename);
167 
168  // We have to parse the TMVA output to get the feature importances, there is no other way currently
169  std::string begin = "Ranking input variables (method specific)";
170  std::string end = "-----------------------------------";
171  std::string line;
172  std::ifstream file(logfilename, std::ios::in);
173  std::map<std::string, float> feature_importances;
174  int state = 0;
175  while (std::getline(file, line)) {
176  if (state == 0 && line.find(begin) != std::string::npos) {
177  state = 1;
178  continue;
179  }
180  if (state >= 1 and state <= 4) {
181  state++;
182  continue;
183  }
184  if (state == 5) {
185  if (line.find(end) != std::string::npos)
186  break;
187  std::vector<std::string> strs;
188  boost::split(strs, line, boost::is_any_of(":"));
189  std::string variable = strs[2];
190  boost::trim(variable);
192  float importance = std::stof(strs[3]);
193  feature_importances[variable] = importance;
194  }
195  }
196  weightfile.addFeatureImportance(feature_importances);
197 
198  return weightfile;
199 
200  }
201 
202 
204  const TMVAOptionsClassification& _specific_options) : TMVATeacher(general_options, _specific_options),
205  specific_options(_specific_options) { }
206 
208  {
209 
210  unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
211  unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
212  unsigned int numberOfEvents = training_data.getNumberOfEvents();
213 
214  std::string directory = specific_options.m_workingDirectory;
215  if (specific_options.m_workingDirectory.empty()) {
216  char* directory_template = strdup((std::filesystem::temp_directory_path() / "Basf2TMVA.XXXXXX").c_str());
217  directory = mkdtemp(directory_template);
218  free(directory_template);
219  }
220 
221  // cppcheck-suppress unreadVariable
222  auto guard = ScopeGuard::guardWorkingDirectory(directory);
223 
224  std::string jobName = specific_options.m_prefix;
225  if (jobName.empty())
226  jobName = "TMVA";
227  TFile classFile((jobName + ".root").c_str(), "RECREATE");
228  classFile.cd();
229 
230  TMVA::Tools::Instance();
231  TMVA::DataLoader data_loader(jobName);
232  TMVA::Factory factory(jobName, &classFile, specific_options.m_factoryOption);
233 
234 
235  // Add variables to the factory
236  for (auto& var : m_general_options.m_variables) {
237  data_loader.AddVariable(Belle2::MakeROOTCompatible::makeROOTCompatible(var));
238  }
239 
240  // Add variables to the factory
241  for (auto& var : m_general_options.m_spectators) {
242  data_loader.AddSpectator(Belle2::MakeROOTCompatible::makeROOTCompatible(var));
243  }
244 
246 
247  auto* signal_tree = new TTree("signal_tree", "signal_tree");
248  auto* background_tree = new TTree("background_tree", "background_tree");
249 
250  for (unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
252  &training_data.m_input[iFeature]);
253  background_tree->Branch(Belle2::MakeROOTCompatible::makeROOTCompatible(m_general_options.m_variables[iFeature]).c_str(),
254  &training_data.m_input[iFeature]);
255  }
256 
257  for (unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
258  signal_tree->Branch(Belle2::MakeROOTCompatible::makeROOTCompatible(m_general_options.m_spectators[iSpectator]).c_str(),
259  &training_data.m_spectators[iSpectator]);
260  background_tree->Branch(Belle2::MakeROOTCompatible::makeROOTCompatible(m_general_options.m_spectators[iSpectator]).c_str(),
261  &training_data.m_spectators[iSpectator]);
262  }
263 
264  signal_tree->Branch("__weight__", &training_data.m_weight);
265  background_tree->Branch("__weight__", &training_data.m_weight);
266 
267  for (unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
268  training_data.loadEvent(iEvent);
269  if (training_data.m_isSignal) {
270  signal_tree->Fill();
271  } else {
272  background_tree->Fill();
273  }
274  }
275 
276  data_loader.AddSignalTree(signal_tree);
277  data_loader.AddBackgroundTree(background_tree);
278  auto weightfile = trainFactory(factory, data_loader, jobName);
279 
280  weightfile.addOptions(specific_options);
281  weightfile.addSignalFraction(training_data.getSignalFraction());
282 
283  delete signal_tree;
284  delete background_tree;
285 
286  if (specific_options.m_workingDirectory.empty()) {
287  std::filesystem::remove_all(directory);
288  }
289 
290  return weightfile;
291 
292  }
293 
295  const TMVAOptionsMulticlass& _specific_options) : TMVATeacher(general_options, _specific_options),
296  specific_options(_specific_options) { }
297 
298  // Implement me!
300  {
301  B2ERROR("Training TMVAMulticlass classifiers within the MVA package has not been implemented yet.");
302  (void) training_data;
303  return Weightfile();
304  }
305 
307  const TMVAOptionsRegression& _specific_options) : TMVATeacher(general_options, _specific_options),
308  specific_options(_specific_options) { }
309 
311  {
312 
313  unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
314  unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
315  unsigned int numberOfEvents = training_data.getNumberOfEvents();
316 
317  std::string directory = specific_options.m_workingDirectory;
318  if (specific_options.m_workingDirectory.empty()) {
319  char* directory_template = strdup((std::filesystem::temp_directory_path() / "Basf2TMVA.XXXXXX").c_str());
320  directory = mkdtemp(directory_template);
321  free(directory_template);
322  }
323 
324  // cppcheck-suppress unreadVariable
325  auto guard = ScopeGuard::guardWorkingDirectory(directory);
326 
327  std::string jobName = specific_options.m_prefix;
328  if (jobName.empty())
329  jobName = "TMVA";
330  TFile classFile((jobName + ".root").c_str(), "RECREATE");
331  classFile.cd();
332 
333  TMVA::Tools::Instance();
334  TMVA::DataLoader data_loader(jobName);
335  TMVA::Factory factory(jobName, &classFile, specific_options.m_factoryOption);
336 
337  // Add variables to the factory
338  for (auto& var : m_general_options.m_variables) {
339  data_loader.AddVariable(Belle2::MakeROOTCompatible::makeROOTCompatible(var));
340  }
341 
342  // Add variables to the factory
343  for (auto& var : m_general_options.m_spectators) {
344  data_loader.AddSpectator(Belle2::MakeROOTCompatible::makeROOTCompatible(var));
345  }
346 
348 
349  auto* regression_tree = new TTree("regression_tree", "regression_tree");
350 
351  for (unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
352  regression_tree->Branch(Belle2::MakeROOTCompatible::makeROOTCompatible(m_general_options.m_variables[iFeature]).c_str(),
353  &training_data.m_input[iFeature]);
354  }
355  for (unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
356  regression_tree->Branch(Belle2::MakeROOTCompatible::makeROOTCompatible(m_general_options.m_spectators[iSpectator]).c_str(),
357  &training_data.m_spectators[iSpectator]);
358  }
360  &training_data.m_target);
361 
362  regression_tree->Branch("__weight__", &training_data.m_weight);
363 
364  for (unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
365  training_data.loadEvent(iEvent);
366  regression_tree->Fill();
367  }
368 
369  data_loader.AddRegressionTree(regression_tree);
370  data_loader.SetWeightExpression(Belle2::MakeROOTCompatible::makeROOTCompatible(m_general_options.m_weight_variable), "Regression");
371 
372  auto weightfile = trainFactory(factory, data_loader, jobName);
373  weightfile.addOptions(specific_options);
374 
375  delete regression_tree;
376 
377  if (specific_options.m_workingDirectory.empty()) {
378  std::filesystem::remove_all(directory);
379  }
380 
381  return weightfile;
382 
383  }
384 
385  void TMVAExpert::load(Weightfile& weightfile)
386  {
387 
388  // Initialize TMVA and ROOT stuff
389  TMVA::Tools::Instance();
390 
391  m_expert = std::make_unique<TMVA::Reader>("!Color:Silent");
392 
393  GeneralOptions general_options;
394  weightfile.getOptions(general_options);
395  m_input_cache.resize(general_options.m_variables.size(), 0);
396  for (unsigned int i = 0; i < general_options.m_variables.size(); ++i) {
397  m_expert->AddVariable(Belle2::MakeROOTCompatible::makeROOTCompatible(general_options.m_variables[i]), &m_input_cache[i]);
398  }
399  m_spectators_cache.resize(general_options.m_spectators.size(), 0);
400  for (unsigned int i = 0; i < general_options.m_spectators.size(); ++i) {
401  m_expert->AddSpectator(Belle2::MakeROOTCompatible::makeROOTCompatible(general_options.m_spectators[i]), &m_spectators_cache[i]);
402  }
403 
404  if (weightfile.containsElement("TMVA_Logfile")) {
405  std::string custom_weightfile = weightfile.generateFileName("logfile");
406  weightfile.getFile("TMVA_Logfile", custom_weightfile);
407  }
408 
409  }
410 
412  {
413 
414  weightfile.getOptions(specific_options);
417  }
418 
419  // TMVA parses the method type for plugins out of the weightfile name, so we must ensure that it has the expected format
420  std::string custom_weightfile = weightfile.generateFileName(std::string("_") + specific_options.m_method + ".weights.xml");
421  weightfile.getFile("TMVA_Weightfile", custom_weightfile);
422 
423  TMVAExpert::load(weightfile);
424 
425  if (specific_options.m_type == "Plugins") {
426  auto base = std::string("TMVA@@MethodBase");
427  auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
428  auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
429  auto className = std::string("TMVA::Method") + specific_options.m_method;
430  auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
431  auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
432  auto pluginName = std::string("TMVA") + specific_options.m_method;
433 
434  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
435  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
436  B2INFO("Registered new TMVA Plugin named " << pluginName);
437  }
438 
439  if (!m_expert->BookMVA(specific_options.m_method, custom_weightfile)) {
440  B2FATAL("Could not set up expert! Please see preceding error message from TMVA!");
441  }
442 
443  }
444 
446  {
447 
448  weightfile.getOptions(specific_options);
449 
450  // TMVA parses the method type for plugins out of the weightfile name, so we must ensure that it has the expected format
451  std::string custom_weightfile = weightfile.generateFileName(std::string("_") + specific_options.m_method + ".weights.xml");
452  weightfile.getFile("TMVA_Weightfile", custom_weightfile);
453 
454  TMVAExpert::load(weightfile);
455 
456  if (specific_options.m_type == "Plugins") {
457  auto base = std::string("TMVA@@MethodBase");
458  auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
459  auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
460  auto className = std::string("TMVA::Method") + specific_options.m_method;
461  auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
462  auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
463  auto pluginName = std::string("TMVA") + specific_options.m_method;
464 
465  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
466  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
467  B2INFO("Registered new TMVA Plugin named " << pluginName);
468  }
469 
470  if (!m_expert->BookMVA(specific_options.m_method, custom_weightfile)) {
471  B2FATAL("Could not set up expert! Please see preceding error message from TMVA!");
472  }
473 
474  }
475 
477  {
478 
479  weightfile.getOptions(specific_options);
480 
481  // TMVA parses the method type for plugins out of the weightfile name, so we must ensure that it has the expected format
482  std::string custom_weightfile = weightfile.generateFileName(std::string("_") + specific_options.m_method + ".weights.xml");
483  weightfile.getFile("TMVA_Weightfile", custom_weightfile);
484 
485  TMVAExpert::load(weightfile);
486 
487  if (specific_options.m_type == "Plugins") {
488  auto base = std::string("TMVA@@MethodBase");
489  auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
490  auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
491  auto className = std::string("TMVA::Method") + specific_options.m_method;
492  auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
493  auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
494  auto pluginName = std::string("TMVA") + specific_options.m_method;
495 
496  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
497  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
498  B2INFO("Registered new TMVA Plugin named " << pluginName);
499  }
500 
501  if (!m_expert->BookMVA(specific_options.m_method, custom_weightfile)) {
502  B2FATAL("Could not set up expert! Please see preceding error message from TMVA!");
503  }
504 
505  }
506 
507  std::vector<float> TMVAExpertClassification::apply(Dataset& test_data) const
508  {
509 
510  std::vector<float> probabilities(test_data.getNumberOfEvents());
511  for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
512  test_data.loadEvent(iEvent);
513  for (unsigned int i = 0; i < m_input_cache.size(); ++i)
514  m_input_cache[i] = test_data.m_input[i];
515  for (unsigned int i = 0; i < m_spectators_cache.size(); ++i)
516  m_spectators_cache[i] = test_data.m_spectators[i];
518  probabilities[iEvent] = m_expert->GetProba(specific_options.m_method, expert_signalFraction);
519  else
520  probabilities[iEvent] = m_expert->EvaluateMVA(specific_options.m_method);
521  }
522  return probabilities;
523 
524  }
525 
526  std::vector<std::vector<float>> TMVAExpertMulticlass::applyMulticlass(Dataset& test_data) const
527  {
528 
529  std::vector<std::vector<float>> probabilities(test_data.getNumberOfEvents());
530 
531  for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
532  test_data.loadEvent(iEvent);
533  for (unsigned int i = 0; i < m_input_cache.size(); ++i)
534  m_input_cache[i] = test_data.m_input[i];
535  for (unsigned int i = 0; i < m_spectators_cache.size(); ++i)
536  m_spectators_cache[i] = test_data.m_spectators[i];
537  probabilities[iEvent] = m_expert->EvaluateMulticlass(specific_options.m_method);
538  }
539  return probabilities;
540  }
541 
542  std::vector<float> TMVAExpertRegression::apply(Dataset& test_data) const
543  {
544 
545  std::vector<float> prediction(test_data.getNumberOfEvents());
546  for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
547  test_data.loadEvent(iEvent);
548  for (unsigned int i = 0; i < m_input_cache.size(); ++i)
549  m_input_cache[i] = test_data.m_input[i];
550  prediction[iEvent] = m_expert->EvaluateMVA(specific_options.m_method);
551  }
552  return prediction;
553 
554  }
555 
556  }
558 }
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
General options which are shared by all MVA trainings.
Definition: Options.h:62
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:86
std::string m_weight_variable
Weight variable (branch name) defining the weights.
Definition: Options.h:91
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
Definition: Options.h:87
std::string m_target_variable
Target variable (branch name) defining the target.
Definition: Options.h:90
TMVAOptionsClassification specific_options
Method specific options.
Definition: TMVA.h:320
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.cc:507
float expert_signalFraction
Signal fraction used to calculate the probability.
Definition: TMVA.h:321
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:411
TMVAOptionsMulticlass specific_options
Method specific options.
Definition: TMVA.h:355
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:445
virtual std::vector< std::vector< float > > applyMulticlass(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.cc:526
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.cc:542
TMVAOptionsRegression specific_options
Method specific options.
Definition: TMVA.h:378
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:476
std::vector< float > m_input_cache
Input Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply call.
Definition: TMVA.h:296
std::unique_ptr< TMVA::Reader > m_expert
TMVA::Reader pointer.
Definition: TMVA.h:294
std::vector< float > m_spectators_cache
Spectators Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply ...
Definition: TMVA.h:298
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:385
Options for the TMVA Classification MVA method.
Definition: TMVA.h:80
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:81
bool transform2probability
Transform output of method to a probability.
Definition: TMVA.h:115
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:69
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:75
Options for the TMVA Multiclass MVA method.
Definition: TMVA.h:122
std::vector< std::string > m_classes
Class name identifiers.
Definition: TMVA.h:158
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:110
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:89
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:100
Options for the TMVA Regression MVA method.
Definition: TMVA.h:166
Options for the TMVA MVA method.
Definition: TMVA.h:34
std::string m_prepareOption
Prepare options passed to prepareTrainingAndTestTree method.
Definition: TMVA.h:72
std::string m_prefix
Prefix used for all files generated by TMVA.
Definition: TMVA.h:74
std::string m_config
TMVA config string for the chosen method.
Definition: TMVA.h:66
std::string m_method
tmva method name
Definition: TMVA.h:60
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:55
std::string m_factoryOption
Factory options passed to tmva factory.
Definition: TMVA.h:71
std::string m_type
tmva method type
Definition: TMVA.h:61
std::string m_workingDirectory
Working directory of TMVA, if empty a temporary directory is used.
Definition: TMVA.h:73
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:27
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:43
TMVATeacherClassification(const GeneralOptions &general_options, const TMVAOptionsClassification &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:203
TMVAOptionsClassification specific_options
Method specific options.
Definition: TMVA.h:231
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: TMVA.cc:207
TMVATeacherMulticlass(const GeneralOptions &general_options, const TMVAOptionsMulticlass &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:294
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: TMVA.cc:299
TMVATeacherRegression(const GeneralOptions &general_options, const TMVAOptionsRegression &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:306
TMVAOptionsRegression specific_options
Method specific options.
Definition: TMVA.h:277
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: TMVA.cc:310
Teacher for the TMVA MVA method.
Definition: TMVA.h:188
TMVATeacher(const GeneralOptions &general_options, const TMVAOptions &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:119
Weightfile trainFactory(TMVA::Factory &factory, TMVA::DataLoader &data_loader, const std::string &jobName) const
Train a mva method using the given data loader returning a Weightfile.
Definition: TMVA.cc:122
TMVAOptions specific_options
Method specific options.
Definition: TMVA.h:207
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
Definition: Teacher.h:29
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition: Teacher.h:49
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
Definition: Weightfile.cc:115
bool containsElement(const std::string &identifier) const
Returns true if given element is stored in the property tree.
Definition: Weightfile.h:160
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition: Weightfile.cc:62
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:67
void addFeatureImportance(const std::map< std::string, float > &importance)
Add variable importance.
Definition: Weightfile.cc:72
float getSignalFraction() const
Loads the signal fraction frm the xml tree.
Definition: Weightfile.cc:100
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
Definition: Weightfile.cc:105
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Definition: Weightfile.cc:138
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
static std::string invertMakeROOTCompatible(std::string str)
Invert makeROOTCompatible operation.
static ScopeGuard guardWorkingDirectory()
Create a ScopeGuard of the current working directory.
Definition: ScopeGuard.h:296
Abstract base class for different kinds of events.