Belle II Software light-2406-ragdoll
FANNOptions.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/methods/FANN.h>
10
11#include <framework/logging/Logger.h>
12#include <TFormula.h>
13
14namespace Belle2 {
19 namespace MVA {
20
21 void FANNOptions::load(const boost::property_tree::ptree& pt)
22 {
23
24 int version = pt.get<int>("FANN_version");
25 if (version != 1) {
26 B2ERROR("Unknown weightfile version " << std::to_string(version));
27 throw std::runtime_error("Unknown weightfile version " + std::to_string(version));
28 }
29 m_max_epochs = pt.get<unsigned int>("FANN_max_epochs");
30 m_verbose_mode = pt.get<bool>("FANN_verbose_mode");
31
32 m_hidden_layers_architecture = pt.get<std::string>("FANN_hidden_layers_architecture");
33
34 m_hidden_activiation_function = pt.get<std::string>("FANN_hidden_activation_function");
35 m_output_activiation_function = pt.get<std::string>("FANN_output_activation_function");
36 m_error_function = pt.get<std::string>("FANN_error_function");
37 m_training_method = pt.get<std::string>("FANN_training_method");
38 m_validation_fraction = pt.get<double>("FANN_validation_fraction");
39 m_random_seeds = pt.get<unsigned int>("FANN_random_seeds");
40 m_test_rate = pt.get<unsigned int>("FANN_test_rate");
41 m_number_of_threads = pt.get<unsigned int>("FANN_number_of_threads");
42
43 m_scale_features = pt.get<bool>("FANN_scale_features");
44 m_scale_target = pt.get<bool>("FANN_scale_target");
45
46 }
47
48 void FANNOptions::save(boost::property_tree::ptree& pt) const
49 {
50 pt.put("FANN_version", 1);
51 pt.put("FANN_max_epochs", m_max_epochs);
52 pt.put("FANN_verbose_mode", m_verbose_mode);
53 pt.put("FANN_hidden_layers_architecture", m_hidden_layers_architecture);
54 pt.put("FANN_hidden_activation_function", m_hidden_activiation_function);
55 pt.put("FANN_output_activation_function", m_output_activiation_function);
56 pt.put("FANN_error_function", m_error_function);
57 pt.put("FANN_training_method", m_training_method);
58 pt.put("FANN_validation_fraction", m_validation_fraction);
59 pt.put("FANN_random_seeds", m_random_seeds);
60 pt.put("FANN_test_rate", m_test_rate);
61 pt.put("FANN_number_of_threads", m_number_of_threads);
62
63 pt.put("FANN_scale_features", m_scale_features);
64 pt.put("FANN_scale_target", m_scale_target);
65
66
67 }
68
69 po::options_description FANNOptions::getDescription()
70 {
71 po::options_description description("FANN options");
72 description.add_options()
73 ("max_epochs", po::value<unsigned int>(&m_max_epochs), "Number of iEpochs")
74 ("verbose_mode", po::value<bool>(&m_verbose_mode), "Prints out the training status or not")
75 ("hidden_layers_architecture", po::value<std::string>(&m_hidden_layers_architecture),
76 "Architecture with number of neurons in each hidden layer")
77 ("hidden_activiation_function", po::value<std::string>(&m_hidden_activiation_function),
78 "Name of acitvation function used for hidden layers")
79 ("output_activiation_function", po::value<std::string>(&m_output_activiation_function),
80 "Name of acitvation function used for output layer")
81 ("error_function", po::value<std::string>(&m_error_function), "Name of error function")
82 ("training_method", po::value<std::string>(&m_training_method), "Method used for backpropagation")
83 ("validation_fraction", po::value<double>(&m_validation_fraction), "Fraction of training sample used for validation.")
84 ("random_seeds", po::value<unsigned int>(&m_random_seeds),
85 "Number of times the training is repeated with a new weight random seed.")
86 ("test_rate", po::value<unsigned int>(&m_test_rate), "Rate of iEpochs to check the validation error")
87 ("number_of_threads", po::value<unsigned int>(&m_number_of_threads), "Number of threads for parallel training")
88 ("scale_features", po::value<bool>(&m_scale_features), "Boolean indicating if features should be scaled or not")
89 ("scale_target", po::value<bool>(&m_scale_target), "Boolean indicating if target should be scaled or not");
90 return description;
91 }
92
93 std::vector<unsigned int> FANNOptions::getHiddenLayerNeurons(unsigned int nf) const
94 {
95 std::vector<unsigned int> hiddenLayers;
96 std::stringstream iLayers(m_hidden_layers_architecture);
97 std::string layer;
98 while (std::getline(iLayers, layer, ',')) {
99 for (auto& character : layer) {
100 if (character == 'N') character = 'x';
101 }
102 auto* iLayerSize = new TFormula("iLayerSize", layer.c_str());
103 hiddenLayers.push_back(iLayerSize->Eval(nf));
104 }
105 return hiddenLayers;
106 }
107 }
109}
double m_validation_fraction
Fraction of training sample used for validation in order to avoid overtraining.
Definition: FANN.h:69
std::string m_hidden_layers_architecture
String containing the architecture of hidden neurons.
Definition: FANN.h:62
bool m_scale_features
Scale features before training.
Definition: FANN.h:77
bool m_verbose_mode
Sets to report training status or not.
Definition: FANN.h:61
unsigned int m_random_seeds
Number of times the training is repeated with a new weight random seed.
Definition: FANN.h:70
std::string m_error_function
Loss function.
Definition: FANN.h:66
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: FANNOptions.cc:69
unsigned int m_number_of_threads
Number of threads for parallel training.
Definition: FANN.h:74
unsigned int m_test_rate
Error on validation is compared with the one before.
Definition: FANN.h:72
std::string m_hidden_activiation_function
Activation function in hidden layer.
Definition: FANN.h:64
bool m_scale_target
Scale target before training.
Definition: FANN.h:78
std::vector< unsigned int > getHiddenLayerNeurons(unsigned int nf) const
Returns the internal vector parameter with the number of hidden neurons per layer.
Definition: FANNOptions.cc:93
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: FANNOptions.cc:21
std::string m_training_method
Training method for back propagation.
Definition: FANN.h:67
std::string m_output_activiation_function
Activation function in output layer.
Definition: FANN.h:65
unsigned int m_max_epochs
Maximum number of epochs.
Definition: FANN.h:60
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: FANNOptions.cc:48
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:24