Belle II Software  release-08-01-10
basf2_mva_teacher.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <mva/interface/Interface.h>
10 #include <mva/interface/Options.h>
11 #include <mva/utility/Utility.h>
12 
13 #include <iostream>
14 #include <string>
15 #include <cerrno>
16 
17 namespace po = boost::program_options;
18 using namespace Belle2::MVA;
19 
20 int main(int argc, char* argv[])
21 {
22 
24 
25  GeneralOptions general_options;
26  po::options_description general_description(general_options.getDescription());
27 
28  MetaOptions meta_options;
29  po::options_description meta_description(meta_options.getDescription());
30 
31  // Order of additional options
32  // Loop over all classes from multi-class
33  // Loop over all hyperparameters -> needs also apply functionality from expert
34  // Do sPlot boost
35 
36  std::map<std::string, std::unique_ptr<SpecificOptions>> specific_options;
37 
38  for (auto& interface : AbstractInterface::getSupportedInterfaces()) {
39  specific_options.emplace(interface.second->getName(), interface.second->getOptions());
40  }
41 
42  po::variables_map vm;
43 
44  try {
45  po::options_description cmdline_description;
46  cmdline_description.add(general_description);
47  cmdline_description.add(meta_description);
48 
49  po::parsed_options parsed = po::command_line_parser(argc, argv).options(cmdline_description).allow_unregistered().run();
50  po::store(parsed, vm);
51 
52  if (vm.count("help")) {
53  if (vm.count("method")) {
54  std::string method = vm["method"].as<std::string>();
55  if (specific_options.find(method) != specific_options.end()) {
56  std::cout << specific_options[method]->getDescription() << std::endl;
57  } else {
58  std::cerr << "Provided method is unknown" << std::endl;
59  }
60  } else {
61  std::cout << general_description << std::endl;
62  std::cout << meta_description << std::endl;
63  }
64  return 1;
65  }
66  po::notify(vm);
67 
68  if (vm.count("method")) {
69  std::string method = vm["method"].as<std::string>();
70  if (specific_options.find(method) != specific_options.end()) {
71  cmdline_description.add(specific_options[method]->getDescription());
72  po::parsed_options specific_parsed = po::command_line_parser(argc, argv).options(cmdline_description).run();
73  po::store(specific_parsed, vm);
74  po::notify(vm);
75  } else {
76  std::cerr << "Provided method is unknown" << std::endl;
77  return 1;
78  }
79  } else {
80  std::cerr << "You must provide a method" << std::endl;
81  return 1;
82  }
83  } catch (po::error& err) {
84  std::cerr << "Error: " << err.what() << "\n";
85  return 1;
86  }
87 
88  //Check if method is available
89  if (specific_options.find(general_options.m_method) == specific_options.end()) {
90  std::cerr << "Unknown method " << general_options.m_method << std::endl;
91  }
92 
93  Belle2::MVA::Utility::teacher(general_options, *specific_options[general_options.m_method], meta_options);
94 
95  return 0;
96 
97 }
98 
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
General options which are shared by all MVA trainings.
Definition: Options.h:62
Meta Options which modify the underlying training by doing sPlot, Multiclass and HyperparameterSearch...
Definition: Options.h:111
static void teacher(const GeneralOptions &general_options, const SpecificOptions &specific_options, const MetaOptions &meta_options=MetaOptions())
Convenience function which performs a training with the given options.
Definition: Utility.cc:253
int main(int argc, char **argv)
Run all tests.
Definition: test_main.cc:91