Belle II Software  release-05-01-25
TMVA.cc
1 /**************************************************************************
2  * BASF2 (Belle Analysis Framework 2) *
3  * Copyright(C) 2016 - Belle II Collaboration *
4  * *
5  * Author: The Belle II Collaboration *
6  * Contributors: Thomas Keck *
7  * *
8  * This software is provided "as is" without any warranty. *
9  **************************************************************************/
10 
11 #include <mva/methods/TMVA.h>
12 #include <framework/logging/Logger.h>
13 #include <framework/utilities/MakeROOTCompatible.h>
14 #include <framework/utilities/ScopeGuard.h>
15 
16 #include <TPluginManager.h>
17 
18 #include <boost/algorithm/string.hpp>
19 #include <boost/filesystem/operations.hpp>
20 #include <memory>
21 
22 namespace Belle2 {
27  namespace MVA {
28 
29  void TMVAOptions::load(const boost::property_tree::ptree& pt)
30  {
31  int version = pt.get<int>("TMVA_version");
32  if (version != 1) {
33  B2ERROR("Unkown weightfile version " << std::to_string(version));
34  throw std::runtime_error("Unkown weightfile version " + std::to_string(version));
35  }
36  m_method = pt.get<std::string>("TMVA_method");
37  m_type = pt.get<std::string>("TMVA_type");
38  m_config = pt.get<std::string>("TMVA_config");
39  m_factoryOption = pt.get<std::string>("TMVA_factoryOption");
40  m_prepareOption = pt.get<std::string>("TMVA_prepareOption");
41  m_workingDirectory = pt.get<std::string>("TMVA_workingDirectory");
42  m_prefix = pt.get<std::string>("TMVA_prefix");
43  }
44 
45  void TMVAOptions::save(boost::property_tree::ptree& pt) const
46  {
47  pt.put("TMVA_version", 1);
48  pt.put("TMVA_method", m_method);
49  pt.put("TMVA_type", m_type);
50  pt.put("TMVA_config", m_config);
51  pt.put("TMVA_factoryOption", m_factoryOption);
52  pt.put("TMVA_prepareOption", m_prepareOption);
53  pt.put("TMVA_workingDirectory", m_workingDirectory);
54  pt.put("TMVA_prefix", m_prefix);
55  }
56 
57  po::options_description TMVAOptions::getDescription()
58  {
59  po::options_description description("TMVA options");
60  description.add_options()
61  ("tmva_method", po::value<std::string>(&m_method), "TMVA Method Name")
62  ("tmva_type", po::value<std::string>(&m_type), "TMVA Method Type (e.g. Plugin, BDT, ...)")
63  ("tmva_config", po::value<std::string>(&m_config), "TMVA Configuration string for the method")
64  ("tmva_working_directory", po::value<std::string>(&m_workingDirectory), "TMVA working directory which stores e.g. TMVA.root")
65  ("tmva_factory", po::value<std::string>(&m_factoryOption), "TMVA Factory options passed to TMVAFactory constructor")
66  ("tmva_prepare", po::value<std::string>(&m_prepareOption),
67  "TMVA Preprare options passed to prepareTrainingAndTestTree function");
68  return description;
69  }
70 
71  void TMVAOptionsClassification::load(const boost::property_tree::ptree& pt)
72  {
74  transform2probability = pt.get<bool>("TMVA_transform2probability");
75  }
76 
77  void TMVAOptionsClassification::save(boost::property_tree::ptree& pt) const
78  {
80  pt.put("TMVA_transform2probability", transform2probability);
81  }
82 
83  po::options_description TMVAOptionsClassification::getDescription()
84  {
85  po::options_description description = TMVAOptions::getDescription();
86  description.add_options()
87  ("tmva_transform2probability", po::value<bool>(&transform2probability), "TMVA Transform output of classifier to a probability");
88  return description;
89  }
90 
91  void TMVAOptionsMulticlass::load(const boost::property_tree::ptree& pt)
92  {
94 
95  unsigned int numberOfClasses = pt.get<unsigned int>("TMVA_number_classes", 1);
96  m_classes.resize(numberOfClasses);
97  for (unsigned int i = 0; i < numberOfClasses; ++i) {
98  m_classes[i] = pt.get<std::string>(std::string("TMVA_classes") + std::to_string(i));
99  }
100  }
101 
102  void TMVAOptionsMulticlass::save(boost::property_tree::ptree& pt) const
103  {
104  TMVAOptions::save(pt);
105 
106  pt.put("TMVA_number_classes", m_classes.size());
107  for (unsigned int i = 0; i < m_classes.size(); ++i) {
108  pt.put(std::string("TMVA_classes") + std::to_string(i), m_classes[i]);
109  }
110  }
111 
112  po::options_description TMVAOptionsMulticlass::getDescription()
113  {
114  po::options_description description = TMVAOptions::getDescription();
115  description.add_options()
116  ("tmva_classes", po::value<std::vector<std::string>>(&m_classes)->required()->multitoken(),
117  "class name identifiers for multi-class mode");
118  return description;
119  }
120 
121  TMVATeacher::TMVATeacher(const GeneralOptions& general_options, const TMVAOptions& _specific_options) : Teacher(general_options),
122  specific_options(_specific_options) { }
123 
124 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
125  Weightfile TMVATeacher::trainFactory(TMVA::Factory& factory, TMVA::DataLoader& data_loader, std::string& jobName) const
126 #else
127  Weightfile TMVATeacher::trainFactory(TMVA::Factory& factory, std::string& jobName) const
128 #endif
129  {
130 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
131  data_loader.PrepareTrainingAndTestTree("", specific_options.m_prepareOption);
132 #else
133  factory.PrepareTrainingAndTestTree("", specific_options.m_prepareOption);
134 #endif
135 
136  if (specific_options.m_type == "Plugins") {
137  auto base = std::string("TMVA@@MethodBase");
138  auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
139  auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
140  auto className = std::string("TMVA::Method") + specific_options.m_method;
141  auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
142  auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
143  auto pluginName = std::string("TMVA") + specific_options.m_method;
144 
145  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
146  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
147  }
148 
149 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
150  if (!factory.BookMethod(&data_loader, specific_options.m_type, specific_options.m_method, specific_options.m_config)) {
151 #else
152  if (!factory.BookMethod(specific_options.m_type, specific_options.m_method, specific_options.m_config)) {
153 #endif
154  B2ERROR("TMVA Method with name " + specific_options.m_method + " cannot be booked.");
155  }
156 
157  Weightfile weightfile;
158  std::string logfilename = weightfile.generateFileName(".log");
159 
160  // Pipe stdout into a logfile to get TMVA output, which contains valueable information
161  // which cannot be retreived otherwise!
162  // Hence we do some black magic here
163  // TODO Using ROOT_VERSION 6.08 this should be possible without this workaround
164  auto logfile = open(logfilename.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0666);
165  auto saved_stdout = dup(STDOUT_FILENO);
166  dup2(logfile, 1);
167 
168  factory.TrainAllMethods();
169  factory.TestAllMethods();
170  factory.EvaluateAllMethods();
171 
172  // Reset original output
173  dup2(saved_stdout, STDOUT_FILENO);
174  close(saved_stdout);
175  close(logfile);
176 
177 
178  weightfile.addOptions(m_general_options);
179  weightfile.addFile("TMVA_Weightfile", std::string("TMVA/weights/") + jobName + "_" + specific_options.m_method + ".weights.xml");
180  weightfile.addFile("TMVA_Logfile", logfilename);
181 
182  // We have to parse the TMVA output to get the feature importances, there is no other way currently
183  std::string begin = "Ranking input variables (method specific)";
184  std::string end = "-----------------------------------";
185  std::string line;
186  std::ifstream file(logfilename, std::ios::in);
187  std::map<std::string, float> feature_importances;
188  int state = 0;
189  while (std::getline(file, line)) {
190  if (state == 0 && line.find(begin) != std::string::npos) {
191  state = 1;
192  continue;
193  }
194  if (state >= 1 and state <= 4) {
195  state++;
196  continue;
197  }
198  if (state == 5) {
199  if (line.find(end) != std::string::npos)
200  break;
201  std::vector<std::string> strs;
202  boost::split(strs, line, boost::is_any_of(":"));
203  std::string variable = strs[2];
204  boost::trim(variable);
205  variable = Belle2::invertMakeROOTCompatible(variable);
206  float importance = std::stof(strs[3]);
207  feature_importances[variable] = importance;
208  }
209  }
210  weightfile.addFeatureImportance(feature_importances);
211 
212  return weightfile;
213 
214  }
215 
216 
217  TMVATeacherClassification::TMVATeacherClassification(const GeneralOptions& general_options,
218  const TMVAOptionsClassification& _specific_options) : TMVATeacher(general_options, _specific_options),
219  specific_options(_specific_options) { }
220 
221  Weightfile TMVATeacherClassification::train(Dataset& training_data) const
222  {
223 
224  unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
225  unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
226  unsigned int numberOfEvents = training_data.getNumberOfEvents();
227 
228  std::string directory = specific_options.m_workingDirectory;
230  char* directory_template = strdup("/tmp/Basf2TMVA.XXXXXX");
231  directory = mkdtemp(directory_template);
232  free(directory_template);
233  }
234 
235  auto guard = ScopeGuard::guardWorkingDirectory(directory);
236 
237  std::string jobName = specific_options.m_prefix;
238  if (jobName.empty())
239  jobName = "TMVA";
240  TFile classFile((jobName + ".root").c_str(), "RECREATE");
241  classFile.cd();
242 
243  TMVA::Tools::Instance();
244 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
245  TMVA::DataLoader data_loader(jobName);
246 #endif
247  TMVA::Factory factory(jobName, &classFile, specific_options.m_factoryOption);
248 
249 
250  // Add variables to the factory
251  for (auto& var : m_general_options.m_variables) {
252 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
253  data_loader.AddVariable(Belle2::makeROOTCompatible(var));
254 #else
255  factory.AddVariable(Belle2::makeROOTCompatible(var));
256 #endif
257  }
258 
259  // Add variables to the factory
260  for (auto& var : m_general_options.m_spectators) {
261 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
262  data_loader.AddSpectator(Belle2::makeROOTCompatible(var));
263 #else
264  factory.AddSpectator(Belle2::makeROOTCompatible(var));
265 #endif
266  }
267 
268 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
269  data_loader.SetWeightExpression(Belle2::makeROOTCompatible(m_general_options.m_weight_variable));
270 #else
271  factory.SetWeightExpression(Belle2::makeROOTCompatible(m_general_options.m_weight_variable));
272 #endif
273 
274  auto* signal_tree = new TTree("signal_tree", "signal_tree");
275  auto* background_tree = new TTree("background_tree", "background_tree");
276 
277  for (unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
278  signal_tree->Branch(Belle2::makeROOTCompatible(m_general_options.m_variables[iFeature]).c_str(),
279  &training_data.m_input[iFeature]);
280  background_tree->Branch(Belle2::makeROOTCompatible(m_general_options.m_variables[iFeature]).c_str(),
281  &training_data.m_input[iFeature]);
282  }
283 
284  for (unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
285  signal_tree->Branch(Belle2::makeROOTCompatible(m_general_options.m_spectators[iSpectator]).c_str(),
286  &training_data.m_spectators[iSpectator]);
287  background_tree->Branch(Belle2::makeROOTCompatible(m_general_options.m_spectators[iSpectator]).c_str(),
288  &training_data.m_spectators[iSpectator]);
289  }
290 
291  signal_tree->Branch("__weight__", &training_data.m_weight);
292  background_tree->Branch("__weight__", &training_data.m_weight);
293 
294  for (unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
295  training_data.loadEvent(iEvent);
296  if (training_data.m_isSignal) {
297  signal_tree->Fill();
298  } else {
299  background_tree->Fill();
300  }
301  }
302 
303 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
304  data_loader.AddSignalTree(signal_tree);
305  data_loader.AddBackgroundTree(background_tree);
306  auto weightfile = trainFactory(factory, data_loader, jobName);
307 #else
308  factory.AddSignalTree(signal_tree);
309  factory.AddBackgroundTree(background_tree);
310  auto weightfile = trainFactory(factory, jobName);
311 #endif
312 
313  weightfile.addOptions(specific_options);
314  weightfile.addSignalFraction(training_data.getSignalFraction());
315 
316  delete signal_tree;
317  delete background_tree;
318 
319  if (specific_options.m_workingDirectory.empty()) {
320  boost::filesystem::remove_all(directory);
321  }
322 
323  return weightfile;
324 
325  }
326 
327  TMVATeacherMulticlass::TMVATeacherMulticlass(const GeneralOptions& general_options,
328  const TMVAOptionsMulticlass& _specific_options) : TMVATeacher(general_options, _specific_options),
329  specific_options(_specific_options) { }
330 
331  // Implement me!
332  Weightfile TMVATeacherMulticlass::train(Dataset& training_data) const
333  {
334  (void) training_data;
335  return Weightfile();
336  }
337 
339  const TMVAOptionsRegression& _specific_options) : TMVATeacher(general_options, _specific_options),
340  specific_options(_specific_options) { }
341 
342  Weightfile TMVATeacherRegression::train(Dataset& training_data) const
343  {
344 
345  unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
346  unsigned int numberOfSpectators = training_data.getNumberOfSpectators();
347  unsigned int numberOfEvents = training_data.getNumberOfEvents();
348 
349  std::string directory = specific_options.m_workingDirectory;
351  char* directory_template = strdup("/tmp/Basf2TMVA.XXXXXX");
352  directory = mkdtemp(directory_template);
353  free(directory_template);
354  }
355 
356  auto guard = ScopeGuard::guardWorkingDirectory(directory);
357 
358  std::string jobName = specific_options.m_prefix;
359  if (jobName.empty())
360  jobName = "TMVA";
361  TFile classFile((jobName + ".root").c_str(), "RECREATE");
362  classFile.cd();
363 
364  TMVA::Tools::Instance();
365 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
366  TMVA::DataLoader data_loader(jobName);
367 #endif
368  TMVA::Factory factory(jobName, &classFile, specific_options.m_factoryOption);
369 
370  // Add variables to the factory
371  for (auto& var : m_general_options.m_variables) {
372 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
373  data_loader.AddVariable(Belle2::makeROOTCompatible(var));
374 #else
375  factory.AddVariable(Belle2::makeROOTCompatible(var));
376 #endif
377  }
378 
379  // Add variables to the factory
380  for (auto& var : m_general_options.m_spectators) {
381 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
382  data_loader.AddSpectator(Belle2::makeROOTCompatible(var));
383 #else
384  factory.AddSpectator(Belle2::makeROOTCompatible(var));
385 #endif
386  }
387 
388 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
390 #else
392 #endif
393 
394 
395  auto* regression_tree = new TTree("regression_tree", "regression_tree");
396 
397  for (unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
398  regression_tree->Branch(Belle2::makeROOTCompatible(m_general_options.m_variables[iFeature]).c_str(),
399  &training_data.m_input[iFeature]);
400  }
401  for (unsigned int iSpectator = 0; iSpectator < numberOfSpectators; ++iSpectator) {
402  regression_tree->Branch(Belle2::makeROOTCompatible(m_general_options.m_spectators[iSpectator]).c_str(),
403  &training_data.m_spectators[iSpectator]);
404  }
405  regression_tree->Branch(Belle2::makeROOTCompatible(m_general_options.m_target_variable).c_str(),
406  &training_data.m_target);
407 
408  regression_tree->Branch("__weight__", &training_data.m_weight);
409 
410  for (unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
411  training_data.loadEvent(iEvent);
412  regression_tree->Fill();
413  }
414 
415 #if ROOT_VERSION_CODE >= ROOT_VERSION(6,8,0)
416  data_loader.AddRegressionTree(regression_tree);
417  data_loader.SetWeightExpression(Belle2::makeROOTCompatible(m_general_options.m_weight_variable), "Regression");
418 
419  auto weightfile = trainFactory(factory, data_loader, jobName);
420 #else
421  factory.AddRegressionTree(regression_tree);
422  factory.SetWeightExpression(Belle2::makeROOTCompatible(m_general_options.m_weight_variable), "Regression");
423 
424  auto weightfile = trainFactory(factory, jobName);
425 #endif
426  weightfile.addOptions(specific_options);
427 
428  delete regression_tree;
429 
430  if (specific_options.m_workingDirectory.empty()) {
431  boost::filesystem::remove_all(directory);
432  }
433 
434  return weightfile;
435 
436  }
437 
438  void TMVAExpert::load(Weightfile& weightfile)
439  {
440 
441  // Initialize TMVA and ROOT stuff
442  TMVA::Tools::Instance();
443 
444  m_expert = std::make_unique<TMVA::Reader>("!Color:!Silent");
445 
446  GeneralOptions general_options;
447  weightfile.getOptions(general_options);
448  m_input_cache.resize(general_options.m_variables.size(), 0);
449  for (unsigned int i = 0; i < general_options.m_variables.size(); ++i) {
450  m_expert->AddVariable(Belle2::makeROOTCompatible(general_options.m_variables[i]), &m_input_cache[i]);
451  }
452  m_spectators_cache.resize(general_options.m_spectators.size(), 0);
453  for (unsigned int i = 0; i < general_options.m_spectators.size(); ++i) {
454  m_expert->AddSpectator(Belle2::makeROOTCompatible(general_options.m_spectators[i]), &m_spectators_cache[i]);
455  }
456 
457  if (weightfile.containsElement("TMVA_Logfile")) {
458  std::string custom_weightfile = weightfile.generateFileName("logfile");
459  weightfile.getFile("TMVA_Logfile", custom_weightfile);
460  }
461 
462  }
463 
464  void TMVAExpertClassification::load(Weightfile& weightfile)
465  {
466 
467  weightfile.getOptions(specific_options);
469  expert_signalFraction = weightfile.getSignalFraction();
470  }
471 
472  // TMVA parses the method type for plugins out of the weightfile name, so we must ensure that it has the expected format
473  std::string custom_weightfile = weightfile.generateFileName(std::string("_") + specific_options.m_method + ".weights.xml");
474  weightfile.getFile("TMVA_Weightfile", custom_weightfile);
475 
476  TMVAExpert::load(weightfile);
477 
478  if (specific_options.m_type == "Plugins") {
479  auto base = std::string("TMVA@@MethodBase");
480  auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
481  auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
482  auto className = std::string("TMVA::Method") + specific_options.m_method;
483  auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
484  auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
485  auto pluginName = std::string("TMVA") + specific_options.m_method;
486 
487  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
488  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
489  B2INFO("Registered new TMVA Plugin named " << pluginName);
490  }
491 
492  if (!m_expert->BookMVA(specific_options.m_method, custom_weightfile)) {
493  B2FATAL("Could not set up expert! Please see preceding error message from TMVA!");
494  }
495 
496  }
497 
498  void TMVAExpertMulticlass::load(Weightfile& weightfile)
499  {
500 
501  weightfile.getOptions(specific_options);
502 
503  // TMVA parses the method type for plugins out of the weightfile name, so we must ensure that it has the expected format
504  std::string custom_weightfile = weightfile.generateFileName(std::string("_") + specific_options.m_method + ".weights.xml");
505  weightfile.getFile("TMVA_Weightfile", custom_weightfile);
506 
507  TMVAExpert::load(weightfile);
508 
509  if (specific_options.m_type == "Plugins") {
510  auto base = std::string("TMVA@@MethodBase");
511  auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
512  auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
513  auto className = std::string("TMVA::Method") + specific_options.m_method;
514  auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
515  auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
516  auto pluginName = std::string("TMVA") + specific_options.m_method;
517 
518  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
519  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
520  B2INFO("Registered new TMVA Plugin named " << pluginName);
521  }
522 
523  if (!m_expert->BookMVA(specific_options.m_method, custom_weightfile)) {
524  B2FATAL("Could not set up expert! Please see preceding error message from TMVA!");
525  }
526 
527  }
528 
529  void TMVAExpertRegression::load(Weightfile& weightfile)
530  {
531 
532  weightfile.getOptions(specific_options);
533 
534  // TMVA parses the method type for plugins out of the weightfile name, so we must ensure that it has the expected format
535  std::string custom_weightfile = weightfile.generateFileName(std::string("_") + specific_options.m_method + ".weights.xml");
536  weightfile.getFile("TMVA_Weightfile", custom_weightfile);
537 
538  TMVAExpert::load(weightfile);
539 
540  if (specific_options.m_type == "Plugins") {
541  auto base = std::string("TMVA@@MethodBase");
542  auto regexp1 = std::string(".*_") + specific_options.m_method + std::string(".*");
543  auto regexp2 = std::string(".*") + specific_options.m_method + std::string(".*");
544  auto className = std::string("TMVA::Method") + specific_options.m_method;
545  auto ctor1 = std::string("Method") + specific_options.m_method + std::string("(TMVA::DataSetInfo&,TString)");
546  auto ctor2 = std::string("Method") + specific_options.m_method + std::string("(TString&,TString&,TMVA::DataSetInfo&,TString&)");
547  auto pluginName = std::string("TMVA") + specific_options.m_method;
548 
549  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp1.c_str(), className.c_str(), pluginName.c_str(), ctor1.c_str());
550  gROOT->GetPluginManager()->AddHandler(base.c_str(), regexp2.c_str(), className.c_str(), pluginName.c_str(), ctor2.c_str());
551  B2INFO("Registered new TMVA Plugin named " << pluginName);
552  }
553 
554  if (!m_expert->BookMVA(specific_options.m_method, custom_weightfile)) {
555  B2FATAL("Could not set up expert! Please see preceding error message from TMVA!");
556  }
557 
558  }
559 
560  std::vector<float> TMVAExpertClassification::apply(Dataset& test_data) const
561  {
562 
563  std::vector<float> probabilities(test_data.getNumberOfEvents());
564  for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
565  test_data.loadEvent(iEvent);
566  for (unsigned int i = 0; i < m_input_cache.size(); ++i)
567  m_input_cache[i] = test_data.m_input[i];
568  for (unsigned int i = 0; i < m_spectators_cache.size(); ++i)
569  m_spectators_cache[i] = test_data.m_spectators[i];
571  probabilities[iEvent] = m_expert->GetProba(specific_options.m_method, expert_signalFraction);
572  else
573  probabilities[iEvent] = m_expert->EvaluateMVA(specific_options.m_method);
574  }
575  return probabilities;
576 
577  }
578 
579  std::vector<float> TMVAExpertMulticlass::apply(Dataset& test_data, const unsigned int classID) const
580  {
581 
582  std::vector<float> probabilities(test_data.getNumberOfEvents());
583  for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
584  test_data.loadEvent(iEvent);
585  for (unsigned int i = 0; i < m_input_cache.size(); ++i)
586  m_input_cache[i] = test_data.m_input[i];
587  for (unsigned int i = 0; i < m_spectators_cache.size(); ++i)
588  m_spectators_cache[i] = test_data.m_spectators[i];
589  probabilities[iEvent] = m_expert->EvaluateMulticlass(classID, specific_options.m_method);
590  }
591  return probabilities;
592  }
593 
594  std::vector<float> TMVAExpertRegression::apply(Dataset& test_data) const
595  {
596 
597  std::vector<float> prediction(test_data.getNumberOfEvents());
598  for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
599  test_data.loadEvent(iEvent);
600  for (unsigned int i = 0; i < m_input_cache.size(); ++i)
601  m_input_cache[i] = test_data.m_input[i];
602  prediction[iEvent] = m_expert->EvaluateMVA(specific_options.m_method);
603  }
604  return prediction;
605 
606  }
607 
608  }
610 }
Belle2::MVA::TMVAExpertClassification::apply
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.cc:568
Belle2::MVA::TMVAOptions::m_config
std::string m_config
TMVA config string for the chosen method.
Definition: TMVA.h:70
Belle2::MVA::TMVAOptions::m_prepareOption
std::string m_prepareOption
Prepare options passed to prepareTrainingAndTestTree method.
Definition: TMVA.h:76
Belle2::MVA::TMVAOptionsClassification::load
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:79
Belle2::MVA::TMVATeacherClassification::train
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: TMVA.cc:229
Belle2::ScopeGuard::guardWorkingDirectory
static ScopeGuard guardWorkingDirectory()
Create a ScopeGuard of the current working directory.
Definition: ScopeGuard.h:306
Belle2::MVA::TMVAOptionsClassification::save
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:85
Belle2::MVA::TMVATeacherRegression::TMVATeacherRegression
TMVATeacherRegression(const GeneralOptions &general_options, const TMVAOptionsRegression &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:346
Belle2::MVA::TMVAOptionsMulticlass::getDescription
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:120
Belle2::MVA::TMVAOptionsMulticlass::load
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:99
Belle2::MVA::TMVAOptionsMulticlass::m_classes
std::vector< std::string > m_classes
Class name identifiers.
Definition: TMVA.h:162
Belle2::MVA::Dataset
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:34
Belle2::MVA::Weightfile
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:40
Belle2::invertMakeROOTCompatible
std::string invertMakeROOTCompatible(std::string str)
Invert makeROOTCompatible operation.
Definition: MakeROOTCompatible.cc:89
Belle2::MVA::TMVAOptionsClassification::transform2probability
bool transform2probability
Transform output of method to a probability.
Definition: TMVA.h:119
Belle2::MVA::TMVAOptions::getDescription
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:65
Belle2::MVA::GeneralOptions::m_weight_variable
std::string m_weight_variable
Weight variable (branch name) defining the weights.
Definition: Options.h:92
Belle2::MVA::TMVAExpertClassification::load
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:472
Belle2::MVA::TMVAExpertMulticlass::specific_options
TMVAOptionsMulticlass specific_options
Method specific options.
Definition: TMVA.h:369
Belle2::MVA::TMVAExpert::m_spectators_cache
std::vector< float > m_spectators_cache
Spectators Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply ...
Definition: TMVA.h:311
Belle2::MVA::TMVATeacherRegression::train
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: TMVA.cc:350
Belle2::MVA::TMVATeacher::TMVATeacher
TMVATeacher(const GeneralOptions &general_options, const TMVAOptions &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:129
Belle2::MVA::TMVAOptions::m_prefix
std::string m_prefix
Prefix used for all files generated by TMVA.
Definition: TMVA.h:78
Belle2::MVA::GeneralOptions::m_spectators
std::vector< std::string > m_spectators
Vector of all spectators (branch names) used in the training.
Definition: Options.h:89
Belle2::makeROOTCompatible
std::string makeROOTCompatible(std::string str)
Remove special characters that ROOT dislikes in branch names, e.g.
Definition: MakeROOTCompatible.cc:74
Belle2::MVA::TMVAOptions::save
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:53
Belle2::MVA::Teacher::m_general_options
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition: Teacher.h:51
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::MVA::TMVAExpertMulticlass::apply
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.h:354
Belle2::MVA::Teacher
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
Definition: Teacher.h:31
Belle2::MVA::TMVATeacherRegression::specific_options
TMVAOptionsRegression specific_options
Method specific options.
Definition: TMVA.h:290
Belle2::MVA::GeneralOptions::m_target_variable
std::string m_target_variable
Target variable (branch name) defining the target.
Definition: Options.h:91
Belle2::MVA::TMVAExpertClassification::expert_signalFraction
float expert_signalFraction
Signal fraction used to calculate the probability.
Definition: TMVA.h:334
Belle2::MVA::GeneralOptions::m_variables
std::vector< std::string > m_variables
Vector of all variables (branch names) used in the training.
Definition: Options.h:88
Belle2::MVA::TMVAExpert::load
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:446
Belle2::MVA::GeneralOptions
General options which are shared by all MVA trainings.
Definition: Options.h:64
Belle2::MVA::TMVAExpertRegression::specific_options
TMVAOptionsRegression specific_options
Method specific options.
Definition: TMVA.h:392
Belle2::MVA::TMVAOptions::m_method
std::string m_method
tmva method name
Definition: TMVA.h:64
Belle2::MVA::TMVATeacherClassification::specific_options
TMVAOptionsClassification specific_options
Method specific options.
Definition: TMVA.h:244
Belle2::MVA::TMVAOptions::m_type
std::string m_type
tmva method type
Definition: TMVA.h:65
Belle2::MVA::TMVATeacherMulticlass::TMVATeacherMulticlass
TMVATeacherMulticlass(const GeneralOptions &general_options, const TMVAOptionsMulticlass &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:335
Belle2::MVA::TMVATeacher::trainFactory
Weightfile trainFactory(TMVA::Factory &factory, TMVA::DataLoader &data_loader, std::string &jobName) const
Train a mva method using the given data loader returning a Weightfile.
Definition: TMVA.cc:133
Belle2::MVA::TMVAOptionsClassification::getDescription
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: TMVA.cc:91
Belle2::MVA::TMVATeacherClassification::TMVATeacherClassification
TMVATeacherClassification(const GeneralOptions &general_options, const TMVAOptionsClassification &_specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: TMVA.cc:225
Belle2::MVA::TMVAOptionsMulticlass::save
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: TMVA.cc:110
Belle2::MVA::TMVAExpertClassification::specific_options
TMVAOptionsClassification specific_options
Method specific options.
Definition: TMVA.h:333
Belle2::MVA::TMVAOptions::m_factoryOption
std::string m_factoryOption
Factory options passed to tmva factory.
Definition: TMVA.h:75
Belle2::MVA::TMVAOptionsRegression
Options for the TMVA Regression MVA method.
Definition: TMVA.h:170
Belle2::MVA::TMVAExpert::m_input_cache
std::vector< float > m_input_cache
Input Cache for TMVA::Reader: Otherwise we would have to set the branch addresses in each apply call.
Definition: TMVA.h:309
Belle2::MVA::TMVAExpert::m_expert
std::unique_ptr< TMVA::Reader > m_expert
TMVA::Reader pointer.
Definition: TMVA.h:307
Belle2::MVA::TMVAExpertRegression::load
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:537
Belle2::MVA::TMVAExpertMulticlass::load
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: TMVA.cc:506
Belle2::MVA::TMVAOptions::m_workingDirectory
std::string m_workingDirectory
Working directory of TMVA, if empty a temporary directory is used.
Definition: TMVA.h:77
Belle2::MVA::TMVAExpertRegression::apply
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this m_expert onto a dataset.
Definition: TMVA.cc:602
Belle2::MVA::TMVATeacherMulticlass::train
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: TMVA.cc:340
Belle2::MVA::TMVATeacher
Teacher for the TMVA MVA method.
Definition: TMVA.h:192
Belle2::MVA::TMVAOptions
Options for the TMVA MVA method.
Definition: TMVA.h:38
Belle2::MVA::TMVAOptions::load
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: TMVA.cc:37