Belle II Software development
basf2_mva_teacher.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/interface/Interface.h>
10#include <mva/interface/Options.h>
11#include <mva/utility/Utility.h>
12
13#include <iostream>
14#include <string>
15#include <cerrno>
16
17namespace po = boost::program_options;
18using namespace Belle2::MVA;
19
20int main(int argc, char* argv[])
21{
22
24
25 GeneralOptions general_options;
26 po::options_description general_description(general_options.getDescription());
27
28 MetaOptions meta_options;
29 po::options_description meta_description(meta_options.getDescription());
30
31 // Order of additional options
32 // Loop over all classes from multi-class
33 // Loop over all hyperparameters -> needs also apply functionality from expert
34 // Do sPlot boost
35
36 std::map<std::string, std::unique_ptr<SpecificOptions>> specific_options;
37
38 for (auto& interface : AbstractInterface::getSupportedInterfaces()) {
39 specific_options.emplace(interface.second->getName(), interface.second->getOptions());
40 }
41
42 po::variables_map vm;
43
44 try {
45 po::options_description cmdline_description;
46 cmdline_description.add(general_description);
47 cmdline_description.add(meta_description);
48
49 po::parsed_options parsed = po::command_line_parser(argc, argv).options(cmdline_description).allow_unregistered().run();
50 po::store(parsed, vm);
51
52 if (vm.count("help")) {
53 if (vm.count("method")) {
54 std::string method = vm["method"].as<std::string>();
55 if (specific_options.find(method) != specific_options.end()) {
56 std::cout << specific_options[method]->getDescription() << std::endl;
57 } else {
58 std::cerr << "Provided method is unknown" << std::endl;
59 }
60 } else {
61 std::cout << general_description << std::endl;
62 std::cout << meta_description << std::endl;
63 }
64 return 1;
65 }
66 po::notify(vm);
67
68 if (vm.count("method")) {
69 std::string method = vm["method"].as<std::string>();
70 if (specific_options.find(method) != specific_options.end()) {
71 cmdline_description.add(specific_options[method]->getDescription());
72 po::parsed_options specific_parsed = po::command_line_parser(argc, argv).options(cmdline_description).run();
73 po::store(specific_parsed, vm);
74 po::notify(vm);
75 } else {
76 std::cerr << "Provided method is unknown" << std::endl;
77 return 1;
78 }
79 } else {
80 std::cerr << "You must provide a method" << std::endl;
81 return 1;
82 }
83 } catch (po::error& err) {
84 std::cerr << "Error: " << err.what() << "\n";
85 return 1;
86 }
87
88 //Check if method is available
89 if (specific_options.find(general_options.m_method) == specific_options.end()) {
90 std::cerr << "Unknown method " << general_options.m_method << std::endl;
91 }
92
93 Belle2::MVA::Utility::teacher(general_options, *specific_options[general_options.m_method], meta_options);
94
95 return 0;
96
97}
98
static void initSupportedInterfaces()
Static function which initliazes all supported interfaces, has to be called once before getSupportedI...
Definition: Interface.cc:45
static std::map< std::string, AbstractInterface * > getSupportedInterfaces()
Returns interfaces supported by the MVA Interface.
Definition: Interface.h:53
General options which are shared by all MVA trainings.
Definition: Options.h:62
Meta Options which modify the underlying training by doing sPlot, Multiclass and HyperparameterSearch...
Definition: Options.h:111
static void teacher(const GeneralOptions &general_options, const SpecificOptions &specific_options, const MetaOptions &meta_options=MetaOptions())
Convenience function which performs a training with the given options.
Definition: Utility.cc:253