Belle II Software  release-06-02-00
TMVA.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <mva/methods/TMVA.h>
10 #include <framework/logging/Logger.h>
11 #include <framework/utilities/MakeROOTCompatible.h>
12 #include <framework/utilities/ScopeGuard.h>
13 
14 #include <TPluginManager.h>
15 
16 #include <boost/algorithm/string.hpp>
17 #include <boost/filesystem/operations.hpp>
18 #include <memory>
19 
20 namespace Belle2 {
25  namespace MVA {
26 
27  void TMVAOptions::load(const boost::property_tree::ptree& pt)
28  {
29  int version = pt.get<int>("TMVA_version");
30  if (version != 1) {
31  B2ERROR("Unknown weightfile version " << std::to_string(version));
32  throw std::runtime_error("Unknown weightfile version " + std::to_string(version));
33  }
34  m_method = pt.get<std::string>("TMVA_method");
35  m_type = pt.get<std::string>("TMVA_type");
36  m_config = pt.get<std::string>("TMVA_config");
37  m_factoryOption = pt.get<std::string>("TMVA_factoryOption");
38  m_prepareOption = pt.get<std::string>("TMVA_prepareOption");
39  m_workingDirectory = pt.get<std::string>("TMVA_workingDirectory");
40  m_prefix = pt.get<std::string>("TMVA_prefix");
41  }
42 
43  void TMVAOptions::save(boost::property_tree::ptree& pt) const
44  {
45  pt.put("TMVA_version", 1);
46  pt.put("TMVA_method", m_method);
47  pt.put("TMVA_type", m_type);
48  pt.put("TMVA_config", m_config);
49  pt.put("TMVA_factoryOption", m_factoryOption);
50  pt.put("TMVA_prepareOption", m_prepareOption);
51  pt.put("TMVA_workingDirectory", m_workingDirectory);
52  pt.put("TMVA_prefix", m_prefix);
53  }
54 
55  po::options_description TMVAOptions::getDescription()
56  {
57  po::options_description description("TMVA options");
58  description.add_options()
59  ("tmva_method", po::value<std::string>(&m_method), "TMVA Method Name")
60  ("tmva_type", po::value<std::string>(&m_type), "TMVA Method Type (e.g. Plugin, BDT, ...)")
61  ("tmva_config", po::value<std::string>(&m_config), "TMVA Configuration string for the method")
62  ("tmva_working_directory", po::value<std::string>(&m_workingDirectory), "TMVA working directory which stores e.g. TMVA.root")
63  ("tmva_factory", po::value<std::string>(&m_factoryOption), "TMVA Factory options passed to TMVAFactory constructor")
64  ("tmva_prepare", po::value<std::string>(&m_prepareOption),
65  "TMVA Preprare options passed to prepareTrainingAndTestTree function");
66  return description;
67  }
68 
69  void TMVAOptionsClassification::load(const boost::property_tree::ptree& pt)
70  {
72  transform2probability = pt.get<bool>("TMVA_transform2probability");
73  }
74 
75  void TMVAOptionsClassification::save(boost::property_tree::ptree& pt) const
76  {
78  pt.put("TMVA_transform2probability", transform2probability);
79  }
80 
82  {
83  po::options_description description = TMVAOptions::getDescription();
84  description.add_options()
85  ("tmva_transform2probability", po::value<bool>(&transform2probability), "TMVA Transform output of classifier to a probability");
86  return description;
87  }
88 
89  void TMVAOptionsMulticlass::load(const boost::property_tree::ptree& pt)
90  {
92 
93  unsigned int numberOfClasses = pt.get<unsigned int>("TMVA_number_classes", 1);
94  m_classes.resize(numberOfClasses);
95  for (unsigned int i = 0; i < numberOfClasses; ++i) {
96  m_classes[i] = pt.get<std::string>(std::string("TMVA_classes") + std::to_string(i));
97  }
98  }
99 
100  void TMVAOptionsMulticlass::save(boost::property_tree::ptree& pt) const
101  {
102  TMVAOptions::save(pt);
103 
104  pt.put("TMVA_number_classes", m_classes.size());
105  for (unsigned int i = 0; i < m_classes.size(); ++i) {
106  pt.put(std::string("TMVA_classes") + std::to_string(i), m_classes[i]);
107  }
108  }
109 
110  po::options_description TMVAOptionsMulticlass::getDescription()
111  {
112  po::options_description description = TMVAOptions::getDescription();
113  description.add_options()
114  ("tmva_classes", po::value<std::vector<std::string>>(&m_classes)->required()->multitoken(),
115  "class name identifiers for multi-class mode");
116  return description;
117  }
118 
119  TMVATeacher::TMVATeacher(const GeneralOptions& general_options, const TMVAOptions& _specific_options) : Teacher(general_options),
120  specific_options(_specific_options) { }
121 
122  Weightfile TMVATeacher::trainFactory(TMVA::Factory& factory, TMVA::DataLoader& data_loader, const std::string& jobName) const
123  {
124  data_loader.PrepareTrainingAndTestTree("", specific_options.m_prepareOption);
125 
126  if (specific_options.m_type == "Plugins") {
127  auto base = std::string("TMVA@@MethodBase");
128  auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
129  auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
130  auto className = std::string("TMVA::Method") + specific_options.m_method;
131  auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
132  auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
133  auto pluginName = std::string("TMVA") + specific_options.m_method;
134 
135  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
136  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
137  }
138 
139  if (!factory.BookMethod(&data_loader, specific_options.m_type, specific_options.m_method, specific_options.m_config)) {
140  B2ERROR("TMVA Method with name " + specific_options.m_method + " cannot be booked.");
141  }
142 
143  Weightfile weightfile;
144  std::string logfilename = weightfile.generateFileName(".log");
145 
146  // Pipe stdout into a logfile to get TMVA output, which contains valuable information
147  // which cannot be retrieved otherwise!
148  // Hence we do some black magic here
149  // TODO Using ROOT_VERSION 6.08 this should be possible without this workaround
150  auto logfile = open(logfilename.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0666);
151  auto saved_stdout = dup(STDOUT_FILENO);
152  dup2(logfile, 1);
153 
154  factory.TrainAllMethods();
155  factory.TestAllMethods();
156  factory.EvaluateAllMethods();
157 
158  // Reset original output
159  dup2(saved_stdout, STDOUT_FILENO);
160  close(saved_stdout);
161  close(logfile);
162 
163 
164  weightfile.addOptions(m_general_options);
165  weightfile.addFile("TMVA_Weightfile", std::string("TMVA/weights/") + jobName + "_" + specific_options.m_method + ".weights.xml");
166  weightfile.addFile("TMVA_Logfile", logfilename);
167 
168  // We have to parse the TMVA output to get the feature importances, there is no other way currently
169  std::string begin = "Ranking input variables (method specific)";
170  std::string end = "-----------------------------------";
171  std::string line;
172  std::ifstream file(logfilename, std::ios::in);
173  std::map<std::string, float> feature_importances;
174  int state = 0;
175  while (std::getline(file, line)) {
176  if (state == 0 && line.find(begin) != std::string::npos) {
177  state = 1;
178  continue;
179  }
180  if (state >= 1 and state <= 4) {
181  state++;
182  continue;
183  }
184  if (state == 5) {
185  if (line.find(end) != std::string::npos)
186  break;
187  std::vector<std::string> strs;
188  boost::split(strs, line, boost::is_any_of(":"));
189  std::string variable = strs[2];
190  boost::trim(variable);
191  variable = Belle2::invertMakeROOTCompatible(variable);
192  float importance = std::stof(strs[3]);
193  feature_importances[variable] = importance;
194  }
195  }
196  weightfile.addFeatureImportance(feature_importances);
197 
198  return weightfile;
199 
200  }
201 
202 
204  const TMVAOptionsClassification& _specific_options) : TMVATeacher(general_options, _specific_options),
205  specific_options(_specific_options) { }
206 
208  {
209 
210  unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
211  unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
212  unsigned int numberOfEvents = training_data.getNumberOfEvents();
213 
214  std::string directory = specific_options.m_workingDirectory;
215  if (specific_options.m_workingDirectory.empty()) {
216  char* directory_template = strdup("/tmp/Basf2TMVA.XXXXXX");
217  directory = mkdtemp(directory_template);
218  free(directory_template);
219  }
220 
221  auto guard = ScopeGuard::guardWorkingDirectory(directory);
222 
223  std::string jobName = specific_options.m_prefix;
224  if (jobName.empty())
225  jobName = "TMVA";
226  TFile classFile((jobName + ".root").c_str(), "RECREATE");
227  classFile.cd();
228 
229  TMVA::Tools::Instance();
230  TMVA::DataLoader data_loader(jobName);
231  TMVA::Factory factory(jobName, &classFile, specific_options.m_factoryOption);
232 
233 
234  // Add variables to the factory
235  for (auto& var : m_general_options.m_variables) {
236  data_loader.AddVariable(Belle2::makeROOTCompatible(var));
237  }
238 
239  // Add variables to the factory
240  for (auto& var : m_general_options.m_spectators) {
241  data_loader.AddSpectator(Belle2::makeROOTCompatible(var));
242  }
243 
244  data_loader.SetWeightExpression(Belle2::makeROOTCompatible(m_general_options.m_weight_variable));
245 
246  auto* signal_tree = new TTree("signal_tree", "signal_tree");
247  auto* background_tree = new TTree("background_tree", "background_tree");
248 
249  for (unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
250  signal_tree->Branch(Belle2::makeROOTCompatible(m_general_options.m_variables[iFeature]).c_str(),
251  &training_data.m_input[iFeature]);
252  background_tree->Branch(Belle2::makeROOTCompatible(m_general_options.m_variables[iFeature]).c_str(),
253  &training_data.m_input[iFeature]);
254  }
255 
256  for (unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
257  signal_tree->Branch(Belle2::makeROOTCompatible(m_general_options.m_spectators[iSpectator]).c_str(),
258  &training_data.m_spectators[iSpectator]);
259  background_tree->Branch(Belle2::makeROOTCompatible(m_general_options.m_spectators[iSpectator]).c_str(),
260  &training_data.m_spectators[iSpectator]);
261  }
262 
263  signal_tree->Branch("__weight__", &training_data.m_weight);
264  background_tree->Branch("__weight__", &training_data.m_weight);
265 
266  for (unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
267  training_data.loadEvent(iEvent);
268  if (training_data.m_isSignal) {
269  signal_tree->Fill();
270  } else {
271  background_tree->Fill();
272  }
273  }
274 
275  data_loader.AddSignalTree(signal_tree);
276  data_loader.AddBackgroundTree(background_tree);
277  auto weightfile = trainFactory(factory, data_loader, jobName);
278 
279  weightfile.addOptions(specific_options);
280  weightfile.addSignalFraction(training_data.getSignalFraction());
281 
282  delete signal_tree;
283  delete background_tree;
284 
285  if (specific_options.m_workingDirectory.empty()) {
286  boost::filesystem::remove_all(directory);
287  }
288 
289  return weightfile;
290 
291  }
292 
294  const TMVAOptionsMulticlass& _specific_options) : TMVATeacher(general_options, _specific_options),
295  specific_options(_specific_options) { }
296 
297  // Implement me!
299  {
300  (void) training_data;
301  return Weightfile();
302  }
303 
305  const TMVAOptionsRegression& _specific_options) : TMVATeacher(general_options, _specific_options),
306  specific_options(_specific_options) { }
307 
309  {
310 
311  unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
312  unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
313  unsigned int numberOfEvents = training_data.getNumberOfEvents();
314 
315  std::string directory = specific_options.m_workingDirectory;
316  if (specific_options.m_workingDirectory.empty()) {
317  char* directory_template = strdup("/tmp/Basf2TMVA.XXXXXX");
318  directory = mkdtemp(directory_template);
319  free(directory_template);
320  }
321 
322  auto guard = ScopeGuard::guardWorkingDirectory(directory);
323 
324  std::string jobName = specific_options.m_prefix;
325  if (jobName.empty())
326  jobName = "TMVA";
327  TFile classFile((jobName + ".root").c_str(), "RECREATE");
328  classFile.cd();
329 
330  TMVA::Tools::Instance();
331  TMVA::DataLoader data_loader(jobName);
332  TMVA::Factory factory(jobName, &classFile, specific_options.m_factoryOption);
333 
334  // Add variables to the factory
335  for (auto& var : m_general_options.m_variables) {
336  data_loader.AddVariable(Belle2::makeROOTCompatible(var));
337  }
338 
339  // Add variables to the factory
340  for (auto& var : m_general_options.m_spectators) {
341  data_loader.AddSpectator(Belle2::makeROOTCompatible(var));
342  }
343 
345 
346  auto* regression_tree = new TTree("regression_tree", "regression_tree");
347 
348  for (unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
349  regression_tree->Branch(Belle2::makeROOTCompatible(m_general_options.m_variables[iFeature]).c_str(),
350  &training_data.m_input[iFeature]);
351  }
352  for (unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
353  regression_tree->Branch(Belle2::makeROOTCompatible(m_general_options.m_spectators[iSpectator]).c_str(),
354  &training_data.m_spectators[iSpectator]);
355  }
356  regression_tree->Branch(Belle2::makeROOTCompatible(m_general_options.m_target_variable).c_str(),
357  &training_data.m_target);
358 
359  regression_tree->Branch("__weight__", &training_data.m_weight);
360 
361  for (unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
362  training_data.loadEvent(iEvent);
363  regression_tree->Fill();
364  }
365 
366  data_loader.AddRegressionTree(regression_tree);
367  data_loader.SetWeightExpression(Belle2::makeROOTCompatible(m_general_options.m_weight_variable), "Regression");
368 
369  auto weightfile = trainFactory(factory, data_loader, jobName);
370  weightfile.addOptions(specific_options);
371 
372  delete regression_tree;
373 
374  if (specific_options.m_workingDirectory.empty()) {
375  boost::filesystem::remove_all(directory);
376  }
377 
378  return weightfile;
379 
380  }
381 
382  void TMVAExpert::load(Weightfile& weightfile)
383  {
384 
385  // Initialize TMVA and ROOT stuff
386  TMVA::Tools::Instance();
387 
388  m_expert = std::make_unique<TMVA::Reader>("!Color:!Silent");
389 
390  GeneralOptions general_options;
391  weightfile.getOptions(general_options);
392  m_input_cache.resize(general_options.m_variables.size(), 0);
393  for (unsigned int i = 0; i < general_options.m_variables.size(); ++i) {
394  m_expert->AddVariable(Belle2::makeROOTCompatible(general_options.m_variables[i]), &m_input_cache[i]);
395  }
396  m_spectators_cache.resize(general_options.m_spectators.size(), 0);
397  for (unsigned int i = 0; i < general_options.m_spectators.size(); ++i) {
398  m_expert->AddSpectator(Belle2::makeROOTCompatible(general_options.m_spectators[i]), &m_spectators_cache[i]);
399  }
400 
401  if (weightfile.containsElement("TMVA_Logfile")) {
402  std::string custom_weightfile = weightfile.generateFileName("logfile");
403  weightfile.getFile("TMVA_Logfile", custom_weightfile);
404  }
405 
406  }
407 
409  {
410 
411  weightfile.getOptions(specific_options);
414  }
415 
416  // TMVA parses the method type for plugins out of the weightfile name, so we must ensure that it has the expected format
417  std::string custom_weightfile = weightfile.generateFileName(std::string("_") + specific_options.m_method + ".weights.xml");
418  weightfile.getFile("TMVA_Weightfile", custom_weightfile);
419 
420  TMVAExpert::load(weightfile);
421 
422  if (specific_options.m_type == "Plugins") {
423  auto base = std::string("TMVA@@MethodBase");
424  auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
425  auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
426  auto className = std::string("TMVA::Method") + specific_options.m_method;
427  auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
428  auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
429  auto pluginName = std::string("TMVA") + specific_options.m_method;
430 
431  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
432  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
433  B2INFO("Registered new TMVA Plugin named " << pluginName);
434  }
435 
436  if (!m_expert->BookMVA(specific_options.m_method, custom_weightfile)) {
437  B2FATAL("Could not set up expert! Please see preceding error message from TMVA!");
438  }
439 
440  }
441 
443  {
444 
445  weightfile.getOptions(specific_options);
446 
447  // TMVA parses the method type for plugins out of the weightfile name, so we must ensure that it has the expected format
448  std::string custom_weightfile = weightfile.generateFileName(std::string("_") + specific_options.m_method + ".weights.xml");
449  weightfile.getFile("TMVA_Weightfile", custom_weightfile);
450 
451  TMVAExpert::load(weightfile);
452 
453  if (specific_options.m_type == "Plugins") {
454  auto base = std::string("TMVA@@MethodBase");
455  auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
456  auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
457  auto className = std::string("TMVA::Method") + specific_options.m_method;
458  auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
459  auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
460  auto pluginName = std::string("TMVA") + specific_options.m_method;
461 
462  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
463  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
464  B2INFO("Registered new TMVA Plugin named " << pluginName);
465  }
466 
467  if (!m_expert->BookMVA(specific_options.m_method, custom_weightfile)) {
468  B2FATAL("Could not set up expert! Please see preceding error message from TMVA!");
469  }
470 
471  }
472 
474  {
475 
476  weightfile.getOptions(specific_options);
477 
478  // TMVA parses the method type for plugins out of the weightfile name, so we must ensure that it has the expected format
479  std::string custom_weightfile = weightfile.generateFileName(std::string("_") + specific_options.m_method + ".weights.xml");
480  weightfile.getFile("TMVA_Weightfile", custom_weightfile);
481 
482  TMVAExpert::load(weightfile);
483 
484  if (specific_options.m_type == "Plugins") {
485  auto base = std::string("TMVA@@MethodBase");
486  auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
487  auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
488  auto className = std::string("TMVA::Method") + specific_options.m_method;
489  auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
490  auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
491  auto pluginName = std::string("TMVA") + specific_options.m_method;
492 
493  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
494  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
495  B2INFO("Registered new TMVA Plugin named " << pluginName);
496  }
497 
498  if (!m_expert->BookMVA(specific_options.m_method, custom_weightfile)) {
499  B2FATAL("Could not set up expert! Please see preceding error message from TMVA!");
500  }
501 
502  }
503 
504  std::vector<float> TMVAExpertClassification::apply(Dataset& test_data) const
505  {
506 
507  std::vector<float> probabilities(test_data.getNumberOfEvents());
508  for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
509  test_data.loadEvent(iEvent);
510  for (unsigned int i = 0; i < m_input_cache.size(); ++i)
511  m_input_cache[i] = test_data.m_input[i];
512  for (unsigned int i = 0; i < m_spectators_cache.size(); ++i)
513  m_spectators_cache[i] = test_data.m_spectators[i];
515  probabilities[iEvent] = m_expert->GetProba(specific_options.m_method, expert_signalFraction);
516  else
517  probabilities[iEvent] = m_expert->EvaluateMVA(specific_options.m_method);
518  }
519  return probabilities;
520 
521  }
522 
523  std::vector<std::vector<float>> TMVAExpertMulticlass::applyMulticlass(Dataset& test_data) const
524  {
525 
526  std::vector<std::vector<float>> probabilities(test_data.getNumberOfEvents());
527 
528  for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
529  test_data.loadEvent(iEvent);
530  for (unsigned int i = 0; i < m_input_cache.size(); ++i)
531  m_input_cache[i] = test_data.m_input[i];
532  for (unsigned int i = 0; i < m_spectators_cache.size(); ++i)
533  m_spectators_cache[i] = test_data.m_spectators[i];
534  probabilities[iEvent] = m_expert->EvaluateMulticlass(specific_options.m_method);
535  }
536  return probabilities;
537  }
538 
539  std::vector<float> TMVAExpertRegression::apply(Dataset& test_data) const
540  {
541 
542  std::vector<float> prediction(test_data.getNumberOfEvents());
543  for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
544  test_data.loadEvent(iEvent);
545  for (unsigned int i = 0; i < m_input_cache.size(); ++i)
546  m_input_cache[i] = test_data.m_input[i];
547  prediction[iEvent] = m_expert->EvaluateMVA(specific_options.m_method);
548  }
549  return prediction;
550 
551  }
552 
553  }
555 }
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:31
General options which are shared by all MVA trainings.
Definition: Options.h:62
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:86
std::string m_weight_variable
Weight variable (branch name) defining the weights.
Definition: Options.h:90
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
Definition: Options.h:87
std::string m_target_variable
Target variable (branch name) defining the target.
Definition: Options.h:89
TMVAOptionsClassification specific_options
Method specific options.
Definition: TMVA.h:320
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.cc:504
float expert_signalFraction
Signal fraction used to calculate the probability.
Definition: TMVA.h:321
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:408
TMVAOptionsMulticlass specific_options
Method specific options.
Definition: TMVA.h:355
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:442
virtual std::vector< std::vector< float > > applyMulticlass(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.cc:523
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.cc:539
TMVAOptionsRegression specific_options
Method specific options.
Definition: TMVA.h:378
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:473
std::vector< float > m_input_cache
Input Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply call.
Definition: TMVA.h:296
std::unique_ptr< TMVA::Reader > m_expert
TMVA::Reader pointer.
Definition: TMVA.h:294
std::vector< float > m_spectators_cache
Spectators Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply ...
Definition: TMVA.h:298
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:382
Options for the TMVA Classification MVA method.
Definition: TMVA.h:80
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:81
bool transform2probability
Transform output of method to a probability.
Definition: TMVA.h:115
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:69
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:75
Options for the TMVA Multiclass MVA method.
Definition: TMVA.h:122
std::vector< std::string > m_classes
Class name identifiers.
Definition: TMVA.h:158
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:110
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:89
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:100
Options for the TMVA Regression MVA method.
Definition: TMVA.h:166
Options for the TMVA MVA method.
Definition: TMVA.h:34
std::string m_prepareOption
Prepare options passed to prepareTrainingAndTestTree method.
Definition: TMVA.h:72
std::string m_prefix
Prefix used for all files generated by TMVA.
Definition: TMVA.h:74
std::string m_config
TMVA config string for the chosen method.
Definition: TMVA.h:66
std::string m_method
tmva method name
Definition: TMVA.h:60
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:55
std::string m_factoryOption
Factory options passed to tmva factory.
Definition: TMVA.h:71
std::string m_type
tmva method type
Definition: TMVA.h:61
std::string m_workingDirectory
Working directory of TMVA, if empty a temporary directory is used.
Definition: TMVA.h:73
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:27
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:43
TMVATeacherClassification(const GeneralOptions &general_options, const TMVAOptionsClassification &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:203
TMVAOptionsClassification specific_options
Method specific options.
Definition: TMVA.h:231
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: TMVA.cc:207
TMVATeacherMulticlass(const GeneralOptions &general_options, const TMVAOptionsMulticlass &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:293
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: TMVA.cc:298
TMVATeacherRegression(const GeneralOptions &general_options, const TMVAOptionsRegression &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:304
TMVAOptionsRegression specific_options
Method specific options.
Definition: TMVA.h:277
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: TMVA.cc:308
Teacher for the TMVA MVA method.
Definition: TMVA.h:188
TMVATeacher(const GeneralOptions &general_options, const TMVAOptions &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:119
Weightfile trainFactory(TMVA::Factory &factory, TMVA::DataLoader &data_loader, const std::string &jobName) const
Train a mva method using the given data loader returning a Weightfile.
Definition: TMVA.cc:122
TMVAOptions specific_options
Method specific options.
Definition: TMVA.h:207
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
Definition: Teacher.h:29
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition: Teacher.h:49
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
Definition: Weightfile.cc:114
bool containsElement(const std::string &identifier) const
Returns true if given element is stored in the property tree.
Definition: Weightfile.h:160
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition: Weightfile.cc:61
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:66
void addFeatureImportance(const std::map< std::string, float > &importance)
Add variable importance.
Definition: Weightfile.cc:71
float getSignalFraction() const
Loads the signal fraction frm the xml tree.
Definition: Weightfile.cc:99
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
Definition: Weightfile.cc:104
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Definition: Weightfile.cc:137
static ScopeGuard guardWorkingDirectory()
Create a ScopeGuard of the current working directory.
Definition: ScopeGuard.h:296
std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
std::string invertMakeROOTCompatible(std::string str)
Invert makeROOTCompatible operation.
Abstract base class for different kinds of events.