Belle II Software  light-2205-abys
TMVA.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <mva/methods/TMVA.h>
10 #include <framework/logging/Logger.h>
11 #include <framework/utilities/MakeROOTCompatible.h>
12 #include <framework/utilities/ScopeGuard.h>
13 
14 #include <TPluginManager.h>
15 
16 #include <boost/algorithm/string.hpp>
17 #include <boost/filesystem/operations.hpp>
18 #include <memory>
19 
20 namespace Belle2 {
25  namespace MVA {
26 
27  void TMVAOptions::load(const boost::property_tree::ptree& pt)
28  {
29  int version = pt.get<int>("TMVA_version");
30  if (version != 1) {
31  B2ERROR("Unknown weightfile version " << std::to_string(version));
32  throw std::runtime_error("Unknown weightfile version " + std::to_string(version));
33  }
34  m_method = pt.get<std::string>("TMVA_method");
35  m_type = pt.get<std::string>("TMVA_type");
36  m_config = pt.get<std::string>("TMVA_config");
37  m_factoryOption = pt.get<std::string>("TMVA_factoryOption");
38  m_prepareOption = pt.get<std::string>("TMVA_prepareOption");
39  m_workingDirectory = pt.get<std::string>("TMVA_workingDirectory");
40  m_prefix = pt.get<std::string>("TMVA_prefix");
41  }
42 
43  void TMVAOptions::save(boost::property_tree::ptree& pt) const
44  {
45  pt.put("TMVA_version", 1);
46  pt.put("TMVA_method", m_method);
47  pt.put("TMVA_type", m_type);
48  pt.put("TMVA_config", m_config);
49  pt.put("TMVA_factoryOption", m_factoryOption);
50  pt.put("TMVA_prepareOption", m_prepareOption);
51  pt.put("TMVA_workingDirectory", m_workingDirectory);
52  pt.put("TMVA_prefix", m_prefix);
53  }
54 
55  po::options_description TMVAOptions::getDescription()
56  {
57  po::options_description description("TMVA options");
58  description.add_options()
59  ("tmva_method", po::value<std::string>(&m_method), "TMVA Method Name")
60  ("tmva_type", po::value<std::string>(&m_type), "TMVA Method Type (e.g. Plugin, BDT, ...)")
61  ("tmva_config", po::value<std::string>(&m_config), "TMVA Configuration string for the method")
62  ("tmva_working_directory", po::value<std::string>(&m_workingDirectory), "TMVA working directory which stores e.g. TMVA.root")
63  ("tmva_factory", po::value<std::string>(&m_factoryOption), "TMVA Factory options passed to TMVAFactory constructor")
64  ("tmva_prepare", po::value<std::string>(&m_prepareOption),
65  "TMVA Preprare options passed to prepareTrainingAndTestTree function");
66  return description;
67  }
68 
69  void TMVAOptionsClassification::load(const boost::property_tree::ptree& pt)
70  {
72  transform2probability = pt.get<bool>("TMVA_transform2probability");
73  }
74 
75  void TMVAOptionsClassification::save(boost::property_tree::ptree& pt) const
76  {
78  pt.put("TMVA_transform2probability", transform2probability);
79  }
80 
82  {
83  po::options_description description = TMVAOptions::getDescription();
84  description.add_options()
85  ("tmva_transform2probability", po::value<bool>(&transform2probability), "TMVA Transform output of classifier to a probability");
86  return description;
87  }
88 
89  void TMVAOptionsMulticlass::load(const boost::property_tree::ptree& pt)
90  {
92 
93  unsigned int numberOfClasses = pt.get<unsigned int>("TMVA_number_classes", 1);
94  m_classes.resize(numberOfClasses);
95  for (unsigned int i = 0; i < numberOfClasses; ++i) {
96  m_classes[i] = pt.get<std::string>(std::string("TMVA_classes") + std::to_string(i));
97  }
98  }
99 
100  void TMVAOptionsMulticlass::save(boost::property_tree::ptree& pt) const
101  {
102  TMVAOptions::save(pt);
103 
104  pt.put("TMVA_number_classes", m_classes.size());
105  for (unsigned int i = 0; i < m_classes.size(); ++i) {
106  pt.put(std::string("TMVA_classes") + std::to_string(i), m_classes[i]);
107  }
108  }
109 
110  po::options_description TMVAOptionsMulticlass::getDescription()
111  {
112  po::options_description description = TMVAOptions::getDescription();
113  description.add_options()
114  ("tmva_classes", po::value<std::vector<std::string>>(&m_classes)->required()->multitoken(),
115  "class name identifiers for multi-class mode");
116  return description;
117  }
118 
119  TMVATeacher::TMVATeacher(const GeneralOptions& general_options, const TMVAOptions& _specific_options) : Teacher(general_options),
120  specific_options(_specific_options) { }
121 
122  Weightfile TMVATeacher::trainFactory(TMVA::Factory& factory, TMVA::DataLoader& data_loader, const std::string& jobName) const
123  {
124  data_loader.PrepareTrainingAndTestTree("", specific_options.m_prepareOption);
125 
126  if (specific_options.m_type == "Plugins") {
127  auto base = std::string("TMVA@@MethodBase");
128  auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
129  auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
130  auto className = std::string("TMVA::Method") + specific_options.m_method;
131  auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
132  auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
133  auto pluginName = std::string("TMVA") + specific_options.m_method;
134 
135  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
136  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
137  }
138 
139  if (!factory.BookMethod(&data_loader, specific_options.m_type, specific_options.m_method, specific_options.m_config)) {
140  B2ERROR("TMVA Method with name " + specific_options.m_method + " cannot be booked.");
141  }
142 
143  Weightfile weightfile;
144  std::string logfilename = weightfile.generateFileName(".log");
145 
146  // Pipe stdout into a logfile to get TMVA output, which contains valuable information
147  // which cannot be retrieved otherwise!
148  // Hence we do some black magic here
149  // TODO Using ROOT_VERSION 6.08 this should be possible without this workaround
150  auto logfile = open(logfilename.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0666);
151  auto saved_stdout = dup(STDOUT_FILENO);
152  dup2(logfile, 1);
153 
154  factory.TrainAllMethods();
155  factory.TestAllMethods();
156  factory.EvaluateAllMethods();
157 
158  // Reset original output
159  dup2(saved_stdout, STDOUT_FILENO);
160  close(saved_stdout);
161  close(logfile);
162 
163 
164  weightfile.addOptions(m_general_options);
165  weightfile.addFile("TMVA_Weightfile", std::string("TMVA/weights/") + jobName + "_" + specific_options.m_method + ".weights.xml");
166  weightfile.addFile("TMVA_Logfile", logfilename);
167 
168  // We have to parse the TMVA output to get the feature importances, there is no other way currently
169  std::string begin = "Ranking input variables (method specific)";
170  std::string end = "-----------------------------------";
171  std::string line;
172  std::ifstream file(logfilename, std::ios::in);
173  std::map<std::string, float> feature_importances;
174  int state = 0;
175  while (std::getline(file, line)) {
176  if (state == 0 && line.find(begin) != std::string::npos) {
177  state = 1;
178  continue;
179  }
180  if (state >= 1 and state <= 4) {
181  state++;
182  continue;
183  }
184  if (state == 5) {
185  if (line.find(end) != std::string::npos)
186  break;
187  std::vector<std::string> strs;
188  boost::split(strs, line, boost::is_any_of(":"));
189  std::string variable = strs[2];
190  boost::trim(variable);
192  float importance = std::stof(strs[3]);
193  feature_importances[variable] = importance;
194  }
195  }
196  weightfile.addFeatureImportance(feature_importances);
197 
198  return weightfile;
199 
200  }
201 
202 
204  const TMVAOptionsClassification& _specific_options) : TMVATeacher(general_options, _specific_options),
205  specific_options(_specific_options) { }
206 
208  {
209 
210  unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
211  unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
212  unsigned int numberOfEvents = training_data.getNumberOfEvents();
213 
214  std::string directory = specific_options.m_workingDirectory;
215  if (specific_options.m_workingDirectory.empty()) {
216  char* directory_template = strdup("/tmp/Basf2TMVA.XXXXXX");
217  directory = mkdtemp(directory_template);
218  free(directory_template);
219  }
220 
221  // cppcheck-suppress unreadVariable
222  auto guard = ScopeGuard::guardWorkingDirectory(directory);
223 
224  std::string jobName = specific_options.m_prefix;
225  if (jobName.empty())
226  jobName = "TMVA";
227  TFile classFile((jobName + ".root").c_str(), "RECREATE");
228  classFile.cd();
229 
230  TMVA::Tools::Instance();
231  TMVA::DataLoader data_loader(jobName);
232  TMVA::Factory factory(jobName, &classFile, specific_options.m_factoryOption);
233 
234 
235  // Add variables to the factory
236  for (auto& var : m_general_options.m_variables) {
237  data_loader.AddVariable(Belle2::MakeROOTCompatible::makeROOTCompatible(var));
238  }
239 
240  // Add variables to the factory
241  for (auto& var : m_general_options.m_spectators) {
242  data_loader.AddSpectator(Belle2::MakeROOTCompatible::makeROOTCompatible(var));
243  }
244 
246 
247  auto* signal_tree = new TTree("signal_tree", "signal_tree");
248  auto* background_tree = new TTree("background_tree", "background_tree");
249 
250  for (unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
252  &training_data.m_input[iFeature]);
253  background_tree->Branch(Belle2::MakeROOTCompatible::makeROOTCompatible(m_general_options.m_variables[iFeature]).c_str(),
254  &training_data.m_input[iFeature]);
255  }
256 
257  for (unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
258  signal_tree->Branch(Belle2::MakeROOTCompatible::makeROOTCompatible(m_general_options.m_spectators[iSpectator]).c_str(),
259  &training_data.m_spectators[iSpectator]);
260  background_tree->Branch(Belle2::MakeROOTCompatible::makeROOTCompatible(m_general_options.m_spectators[iSpectator]).c_str(),
261  &training_data.m_spectators[iSpectator]);
262  }
263 
264  signal_tree->Branch("__weight__", &training_data.m_weight);
265  background_tree->Branch("__weight__", &training_data.m_weight);
266 
267  for (unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
268  training_data.loadEvent(iEvent);
269  if (training_data.m_isSignal) {
270  signal_tree->Fill();
271  } else {
272  background_tree->Fill();
273  }
274  }
275 
276  data_loader.AddSignalTree(signal_tree);
277  data_loader.AddBackgroundTree(background_tree);
278  auto weightfile = trainFactory(factory, data_loader, jobName);
279 
280  weightfile.addOptions(specific_options);
281  weightfile.addSignalFraction(training_data.getSignalFraction());
282 
283  delete signal_tree;
284  delete background_tree;
285 
286  if (specific_options.m_workingDirectory.empty()) {
287  boost::filesystem::remove_all(directory);
288  }
289 
290  return weightfile;
291 
292  }
293 
295  const TMVAOptionsMulticlass& _specific_options) : TMVATeacher(general_options, _specific_options),
296  specific_options(_specific_options) { }
297 
298  // Implement me!
300  {
301  (void) training_data;
302  return Weightfile();
303  }
304 
306  const TMVAOptionsRegression& _specific_options) : TMVATeacher(general_options, _specific_options),
307  specific_options(_specific_options) { }
308 
310  {
311 
312  unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
313  unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
314  unsigned int numberOfEvents = training_data.getNumberOfEvents();
315 
316  std::string directory = specific_options.m_workingDirectory;
317  if (specific_options.m_workingDirectory.empty()) {
318  char* directory_template = strdup("/tmp/Basf2TMVA.XXXXXX");
319  directory = mkdtemp(directory_template);
320  free(directory_template);
321  }
322 
323  // cppcheck-suppress unreadVariable
324  auto guard = ScopeGuard::guardWorkingDirectory(directory);
325 
326  std::string jobName = specific_options.m_prefix;
327  if (jobName.empty())
328  jobName = "TMVA";
329  TFile classFile((jobName + ".root").c_str(), "RECREATE");
330  classFile.cd();
331 
332  TMVA::Tools::Instance();
333  TMVA::DataLoader data_loader(jobName);
334  TMVA::Factory factory(jobName, &classFile, specific_options.m_factoryOption);
335 
336  // Add variables to the factory
337  for (auto& var : m_general_options.m_variables) {
338  data_loader.AddVariable(Belle2::MakeROOTCompatible::makeROOTCompatible(var));
339  }
340 
341  // Add variables to the factory
342  for (auto& var : m_general_options.m_spectators) {
343  data_loader.AddSpectator(Belle2::MakeROOTCompatible::makeROOTCompatible(var));
344  }
345 
347 
348  auto* regression_tree = new TTree("regression_tree", "regression_tree");
349 
350  for (unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
351  regression_tree->Branch(Belle2::MakeROOTCompatible::makeROOTCompatible(m_general_options.m_variables[iFeature]).c_str(),
352  &training_data.m_input[iFeature]);
353  }
354  for (unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
355  regression_tree->Branch(Belle2::MakeROOTCompatible::makeROOTCompatible(m_general_options.m_spectators[iSpectator]).c_str(),
356  &training_data.m_spectators[iSpectator]);
357  }
359  &training_data.m_target);
360 
361  regression_tree->Branch("__weight__", &training_data.m_weight);
362 
363  for (unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
364  training_data.loadEvent(iEvent);
365  regression_tree->Fill();
366  }
367 
368  data_loader.AddRegressionTree(regression_tree);
369  data_loader.SetWeightExpression(Belle2::MakeROOTCompatible::makeROOTCompatible(m_general_options.m_weight_variable), "Regression");
370 
371  auto weightfile = trainFactory(factory, data_loader, jobName);
372  weightfile.addOptions(specific_options);
373 
374  delete regression_tree;
375 
376  if (specific_options.m_workingDirectory.empty()) {
377  boost::filesystem::remove_all(directory);
378  }
379 
380  return weightfile;
381 
382  }
383 
384  void TMVAExpert::load(Weightfile& weightfile)
385  {
386 
387  // Initialize TMVA and ROOT stuff
388  TMVA::Tools::Instance();
389 
390  m_expert = std::make_unique<TMVA::Reader>("!Color:!Silent");
391 
392  GeneralOptions general_options;
393  weightfile.getOptions(general_options);
394  m_input_cache.resize(general_options.m_variables.size(), 0);
395  for (unsigned int i = 0; i < general_options.m_variables.size(); ++i) {
396  m_expert->AddVariable(Belle2::MakeROOTCompatible::makeROOTCompatible(general_options.m_variables[i]), &m_input_cache[i]);
397  }
398  m_spectators_cache.resize(general_options.m_spectators.size(), 0);
399  for (unsigned int i = 0; i < general_options.m_spectators.size(); ++i) {
400  m_expert->AddSpectator(Belle2::MakeROOTCompatible::makeROOTCompatible(general_options.m_spectators[i]), &m_spectators_cache[i]);
401  }
402 
403  if (weightfile.containsElement("TMVA_Logfile")) {
404  std::string custom_weightfile = weightfile.generateFileName("logfile");
405  weightfile.getFile("TMVA_Logfile", custom_weightfile);
406  }
407 
408  }
409 
411  {
412 
413  weightfile.getOptions(specific_options);
416  }
417 
418  // TMVA parses the method type for plugins out of the weightfile name, so we must ensure that it has the expected format
419  std::string custom_weightfile = weightfile.generateFileName(std::string("_") + specific_options.m_method + ".weights.xml");
420  weightfile.getFile("TMVA_Weightfile", custom_weightfile);
421 
422  TMVAExpert::load(weightfile);
423 
424  if (specific_options.m_type == "Plugins") {
425  auto base = std::string("TMVA@@MethodBase");
426  auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
427  auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
428  auto className = std::string("TMVA::Method") + specific_options.m_method;
429  auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
430  auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
431  auto pluginName = std::string("TMVA") + specific_options.m_method;
432 
433  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
434  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
435  B2INFO("Registered new TMVA Plugin named " << pluginName);
436  }
437 
438  if (!m_expert->BookMVA(specific_options.m_method, custom_weightfile)) {
439  B2FATAL("Could not set up expert! Please see preceding error message from TMVA!");
440  }
441 
442  }
443 
445  {
446 
447  weightfile.getOptions(specific_options);
448 
449  // TMVA parses the method type for plugins out of the weightfile name, so we must ensure that it has the expected format
450  std::string custom_weightfile = weightfile.generateFileName(std::string("_") + specific_options.m_method + ".weights.xml");
451  weightfile.getFile("TMVA_Weightfile", custom_weightfile);
452 
453  TMVAExpert::load(weightfile);
454 
455  if (specific_options.m_type == "Plugins") {
456  auto base = std::string("TMVA@@MethodBase");
457  auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
458  auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
459  auto className = std::string("TMVA::Method") + specific_options.m_method;
460  auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
461  auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
462  auto pluginName = std::string("TMVA") + specific_options.m_method;
463 
464  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
465  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
466  B2INFO("Registered new TMVA Plugin named " << pluginName);
467  }
468 
469  if (!m_expert->BookMVA(specific_options.m_method, custom_weightfile)) {
470  B2FATAL("Could not set up expert! Please see preceding error message from TMVA!");
471  }
472 
473  }
474 
476  {
477 
478  weightfile.getOptions(specific_options);
479 
480  // TMVA parses the method type for plugins out of the weightfile name, so we must ensure that it has the expected format
481  std::string custom_weightfile = weightfile.generateFileName(std::string("_") + specific_options.m_method + ".weights.xml");
482  weightfile.getFile("TMVA_Weightfile", custom_weightfile);
483 
484  TMVAExpert::load(weightfile);
485 
486  if (specific_options.m_type == "Plugins") {
487  auto base = std::string("TMVA@@MethodBase");
488  auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
489  auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
490  auto className = std::string("TMVA::Method") + specific_options.m_method;
491  auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
492  auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
493  auto pluginName = std::string("TMVA") + specific_options.m_method;
494 
495  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
496  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
497  B2INFO("Registered new TMVA Plugin named " << pluginName);
498  }
499 
500  if (!m_expert->BookMVA(specific_options.m_method, custom_weightfile)) {
501  B2FATAL("Could not set up expert! Please see preceding error message from TMVA!");
502  }
503 
504  }
505 
506  std::vector<float> TMVAExpertClassification::apply(Dataset& test_data) const
507  {
508 
509  std::vector<float> probabilities(test_data.getNumberOfEvents());
510  for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
511  test_data.loadEvent(iEvent);
512  for (unsigned int i = 0; i < m_input_cache.size(); ++i)
513  m_input_cache[i] = test_data.m_input[i];
514  for (unsigned int i = 0; i < m_spectators_cache.size(); ++i)
515  m_spectators_cache[i] = test_data.m_spectators[i];
517  probabilities[iEvent] = m_expert->GetProba(specific_options.m_method, expert_signalFraction);
518  else
519  probabilities[iEvent] = m_expert->EvaluateMVA(specific_options.m_method);
520  }
521  return probabilities;
522 
523  }
524 
525  std::vector<std::vector<float>> TMVAExpertMulticlass::applyMulticlass(Dataset& test_data) const
526  {
527 
528  std::vector<std::vector<float>> probabilities(test_data.getNumberOfEvents());
529 
530  for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
531  test_data.loadEvent(iEvent);
532  for (unsigned int i = 0; i < m_input_cache.size(); ++i)
533  m_input_cache[i] = test_data.m_input[i];
534  for (unsigned int i = 0; i < m_spectators_cache.size(); ++i)
535  m_spectators_cache[i] = test_data.m_spectators[i];
536  probabilities[iEvent] = m_expert->EvaluateMulticlass(specific_options.m_method);
537  }
538  return probabilities;
539  }
540 
541  std::vector<float> TMVAExpertRegression::apply(Dataset& test_data) const
542  {
543 
544  std::vector<float> prediction(test_data.getNumberOfEvents());
545  for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
546  test_data.loadEvent(iEvent);
547  for (unsigned int i = 0; i < m_input_cache.size(); ++i)
548  m_input_cache[i] = test_data.m_input[i];
549  prediction[iEvent] = m_expert->EvaluateMVA(specific_options.m_method);
550  }
551  return prediction;
552 
553  }
554 
555  }
557 }
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
General options which are shared by all MVA trainings.
Definition: Options.h:62
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:86
std::string m_weight_variable
Weight variable (branch name) defining the weights.
Definition: Options.h:90
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
Definition: Options.h:87
std::string m_target_variable
Target variable (branch name) defining the target.
Definition: Options.h:89
TMVAOptionsClassification specific_options
Method specific options.
Definition: TMVA.h:320
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.cc:506
float expert_signalFraction
Signal fraction used to calculate the probability.
Definition: TMVA.h:321
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:410
TMVAOptionsMulticlass specific_options
Method specific options.
Definition: TMVA.h:355
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:444
virtual std::vector< std::vector< float > > applyMulticlass(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.cc:525
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.cc:541
TMVAOptionsRegression specific_options
Method specific options.
Definition: TMVA.h:378
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:475
std::vector< float > m_input_cache
Input Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply call.
Definition: TMVA.h:296
std::unique_ptr< TMVA::Reader > m_expert
TMVA::Reader pointer.
Definition: TMVA.h:294
std::vector< float > m_spectators_cache
Spectators Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply ...
Definition: TMVA.h:298
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:384
Options for the TMVA Classification MVA method.
Definition: TMVA.h:80
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:81
bool transform2probability
Transform output of method to a probability.
Definition: TMVA.h:115
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:69
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:75
Options for the TMVA Multiclass MVA method.
Definition: TMVA.h:122
std::vector< std::string > m_classes
Class name identifiers.
Definition: TMVA.h:158
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:110
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:89
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:100
Options for the TMVA Regression MVA method.
Definition: TMVA.h:166
Options for the TMVA MVA method.
Definition: TMVA.h:34
std::string m_prepareOption
Prepare options passed to prepareTrainingAndTestTree method.
Definition: TMVA.h:72
std::string m_prefix
Prefix used for all files generated by TMVA.
Definition: TMVA.h:74
std::string m_config
TMVA config string for the chosen method.
Definition: TMVA.h:66
std::string m_method
tmva method name
Definition: TMVA.h:60
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:55
std::string m_factoryOption
Factory options passed to tmva factory.
Definition: TMVA.h:71
std::string m_type
tmva method type
Definition: TMVA.h:61
std::string m_workingDirectory
Working directory of TMVA, if empty a temporary directory is used.
Definition: TMVA.h:73
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:27
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:43
TMVATeacherClassification(const GeneralOptions &general_options, const TMVAOptionsClassification &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:203
TMVAOptionsClassification specific_options
Method specific options.
Definition: TMVA.h:231
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: TMVA.cc:207
TMVATeacherMulticlass(const GeneralOptions &general_options, const TMVAOptionsMulticlass &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:294
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: TMVA.cc:299
TMVATeacherRegression(const GeneralOptions &general_options, const TMVAOptionsRegression &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:305
TMVAOptionsRegression specific_options
Method specific options.
Definition: TMVA.h:277
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: TMVA.cc:309
Teacher for the TMVA MVA method.
Definition: TMVA.h:188
TMVATeacher(const GeneralOptions &general_options, const TMVAOptions &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:119
Weightfile trainFactory(TMVA::Factory &factory, TMVA::DataLoader &data_loader, const std::string &jobName) const
Train a mva method using the given data loader returning a Weightfile.
Definition: TMVA.cc:122
TMVAOptions specific_options
Method specific options.
Definition: TMVA.h:207
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
Definition: Teacher.h:29
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition: Teacher.h:49
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
Definition: Weightfile.cc:114
bool containsElement(const std::string &identifier) const
Returns true if given element is stored in the property tree.
Definition: Weightfile.h:160
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition: Weightfile.cc:61
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:66
void addFeatureImportance(const std::map< std::string, float > &importance)
Add variable importance.
Definition: Weightfile.cc:71
float getSignalFraction() const
Loads the signal fraction frm the xml tree.
Definition: Weightfile.cc:99
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
Definition: Weightfile.cc:104
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Definition: Weightfile.cc:137
static std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
static std::string invertMakeROOTCompatible(std::string str)
Invert makeROOTCompatible operation.
static ScopeGuard guardWorkingDirectory()
Create a ScopeGuard of the current working directory.
Definition: ScopeGuard.h:296
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:23