Belle II Software light-2406-ragdoll
FANN.h
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#pragma once
10#ifndef INCLUDE_GUARD_BELLE2_MVA_FANN_HEADER
11#define INCLUDE_GUARD_BELLE2_MVA_FANN_HEADER
12
13#include <mva/interface/Options.h>
14#include <mva/interface/Teacher.h>
15#include <mva/interface/Expert.h>
16
17#include <fann.h>
18
19namespace Belle2 {
24 namespace MVA {
25
30
31 public:
36 virtual void load(const boost::property_tree::ptree& pt) override;
37
42 virtual void save(boost::property_tree::ptree& pt) const override;
43
47 virtual po::options_description getDescription() override;
48
52 virtual std::string getMethod() const override { return "FANN"; }
53
58 std::vector<unsigned int> getHiddenLayerNeurons(unsigned int nf) const;
59
60 unsigned int m_max_epochs = 10000;
61 bool m_verbose_mode = true;
62 std::string m_hidden_layers_architecture = "3*N";
64 std::string m_hidden_activiation_function = "FANN_SIGMOID_SYMMETRIC";
65 std::string m_output_activiation_function = "FANN_SIGMOID_SYMMETRIC";
66 std::string m_error_function = "FANN_ERRORFUNC_LINEAR";
67 std::string m_training_method = "FANN_TRAIN_RPROP";
69 double m_validation_fraction = 0.5;
70 unsigned int m_random_seeds =
71 3;
72 unsigned int m_test_rate =
73 500;
74 unsigned int m_number_of_threads = 8;
77 bool m_scale_features = true;
78 bool m_scale_target = true;
80 };
81
85 class FANNTeacher : public Teacher {
86
87 public:
93 FANNTeacher(const GeneralOptions& general_options, const FANNOptions& specific_options);
94
99 virtual Weightfile train(Dataset& training_data) const override;
100
101 private:
103 };
104
105
109 class FANNExpert : public MVA::Expert {
110
111 public:
112
116 virtual ~FANNExpert();
117
122 virtual void load(Weightfile& weightfile) override;
123
128 virtual std::vector<float> apply(Dataset& test_data) const override;
129
130 private:
132 struct fann* m_ann = nullptr;
133 };
134
135 }
137}
138#endif
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
Abstract base class of all Expert Each MVA library has its own implementation of this class,...
Definition: Expert.h:31
Expert for the FANN MVA method.
Definition: FANN.h:109
virtual ~FANNExpert()
Destructor of FANN Expert.
Definition: FANN.cc:312
struct fann * m_ann
Pointer to FANN expert.
Definition: FANN.h:132
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
Definition: FANN.cc:333
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: FANN.cc:319
FANNOptions m_specific_options
Method specific options.
Definition: FANN.h:131
Options for the FANN MVA method.
Definition: FANN.h:29
double m_validation_fraction
Fraction of training sample used for validation in order to avoid overtraining.
Definition: FANN.h:69
virtual std::string getMethod() const override
Return method name.
Definition: FANN.h:52
std::string m_hidden_layers_architecture
String containing the architecture of hidden neurons.
Definition: FANN.h:62
bool m_scale_features
Scale features before training.
Definition: FANN.h:77
bool m_verbose_mode
Sets to report training status or not.
Definition: FANN.h:61
unsigned int m_random_seeds
Number of times the training is repeated with a new weight random seed.
Definition: FANN.h:70
std::string m_error_function
Loss function.
Definition: FANN.h:66
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: FANNOptions.cc:69
unsigned int m_number_of_threads
Number of threads for parallel training.
Definition: FANN.h:74
unsigned int m_test_rate
Error on validation is compared with the one before.
Definition: FANN.h:72
std::string m_hidden_activiation_function
Activation function in hidden layer.
Definition: FANN.h:64
bool m_scale_target
Scale target before training.
Definition: FANN.h:78
std::vector< unsigned int > getHiddenLayerNeurons(unsigned int nf) const
Returns the internal vector parameter with the number of hidden neurons per layer.
Definition: FANNOptions.cc:93
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: FANNOptions.cc:21
std::string m_training_method
Training method for back propagation.
Definition: FANN.h:67
std::string m_output_activiation_function
Activation function in output layer.
Definition: FANN.h:65
unsigned int m_max_epochs
Maximum number of epochs.
Definition: FANN.h:60
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: FANNOptions.cc:48
Teacher for the FANN MVA method.
Definition: FANN.h:85
FANNOptions m_specific_options
Method specific options.
Definition: FANN.h:102
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: FANN.cc:30
General options which are shared by all MVA trainings.
Definition: Options.h:62
Specific Options, all method Options have to inherit from this class.
Definition: Options.h:98
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
Definition: Teacher.h:29
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
Abstract base class for different kinds of events.
Definition: ClusterUtils.h:24