Belle II Software  release-08-01-10
FANN.h
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #pragma once
10 #ifndef INCLUDE_GUARD_BELLE2_MVA_FANN_HEADER
11 #define INCLUDE_GUARD_BELLE2_MVA_FANN_HEADER
12 
13 #include <mva/interface/Options.h>
14 #include <mva/interface/Teacher.h>
15 #include <mva/interface/Expert.h>
16 
17 #include <fann.h>
18 
19 namespace Belle2 {
24  namespace MVA {
25 
29  class FANNOptions : public SpecificOptions {
30 
31  public:
36  virtual void load(const boost::property_tree::ptree& pt) override;
37 
42  virtual void save(boost::property_tree::ptree& pt) const override;
43 
47  virtual po::options_description getDescription() override;
48 
52  virtual std::string getMethod() const override { return "FANN"; }
53 
58  std::vector<unsigned int> getHiddenLayerNeurons(unsigned int nf) const;
59 
60  unsigned int m_max_epochs = 10000;
61  bool m_verbose_mode = true;
62  std::string m_hidden_layers_architecture = "3*N";
64  std::string m_hidden_activiation_function = "FANN_SIGMOID_SYMMETRIC";
65  std::string m_output_activiation_function = "FANN_SIGMOID_SYMMETRIC";
66  std::string m_error_function = "FANN_ERRORFUNC_LINEAR";
67  std::string m_training_method = "FANN_TRAIN_RPROP";
69  double m_validation_fraction = 0.5;
70  unsigned int m_random_seeds =
71  3;
72  unsigned int m_test_rate =
73  500;
74  unsigned int m_number_of_threads = 8;
77  bool m_scale_features = true;
78  bool m_scale_target = true;
80  };
81 
85  class FANNTeacher : public Teacher {
86 
87  public:
93  FANNTeacher(const GeneralOptions& general_options, const FANNOptions& specific_options);
94 
99  virtual Weightfile train(Dataset& training_data) const override;
100 
101  private:
103  };
104 
105 
109  class FANNExpert : public MVA::Expert {
110 
111  public:
112 
116  virtual ~FANNExpert();
117 
122  virtual void load(Weightfile& weightfile) override;
123 
128  virtual std::vector<float> apply(Dataset& test_data) const override;
129 
130  private:
132  struct fann* m_ann = nullptr;
133  };
134 
135  }
137 }
138 #endif
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
Abstract base class of all Expert Each MVA library has its own implementation of this class,...
Definition: Expert.h:31
Expert for the FANN MVA method.
Definition: FANN.h:109
virtual ~FANNExpert()
Destructor of FANN Expert.
Definition: FANN.cc:312
struct fann * m_ann
Pointer to FANN expert.
Definition: FANN.h:132
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
Definition: FANN.cc:333
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: FANN.cc:319
FANNOptions m_specific_options
Method specific options.
Definition: FANN.h:131
Options for the FANN MVA method.
Definition: FANN.h:29
double m_validation_fraction
Fraction of training sample used for validation in order to avoid overtraining.
Definition: FANN.h:69
virtual std::string getMethod() const override
Return method name.
Definition: FANN.h:52
std::string m_hidden_layers_architecture
String containing the architecture of hidden neurons.
Definition: FANN.h:62
bool m_scale_features
Scale features before training.
Definition: FANN.h:77
bool m_verbose_mode
Sets to report training status or not.
Definition: FANN.h:61
unsigned int m_random_seeds
Number of times the training is repeated with a new weight random seed.
Definition: FANN.h:70
std::string m_error_function
Loss function.
Definition: FANN.h:66
virtual po::options_description getDescription() override
Returns a program options description for all available options.
Definition: FANNOptions.cc:69
unsigned int m_number_of_threads
Number of threads for parallel training.
Definition: FANN.h:74
unsigned int m_test_rate
Error on validation is compared with the one before.
Definition: FANN.h:72
std::string m_hidden_activiation_function
Activation function in hidden layer.
Definition: FANN.h:64
bool m_scale_target
Scale target before training.
Definition: FANN.h:78
std::vector< unsigned int > getHiddenLayerNeurons(unsigned int nf) const
Returns the internal vector parameter with the number of hidden neurons per layer.
Definition: FANNOptions.cc:93
virtual void load(const boost::property_tree::ptree &pt) override
Load mechanism to load Options from a xml tree.
Definition: FANNOptions.cc:21
std::string m_training_method
Training method for back propagation.
Definition: FANN.h:67
std::string m_output_activiation_function
Activation function in output layer.
Definition: FANN.h:65
unsigned int m_max_epochs
Maximum number of epochs.
Definition: FANN.h:60
virtual void save(boost::property_tree::ptree &pt) const override
Save mechanism to store Options in a xml tree.
Definition: FANNOptions.cc:48
Teacher for the FANN MVA method.
Definition: FANN.h:85
FANNTeacher(const GeneralOptions &general_options, const FANNOptions &specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: FANN.cc:26
FANNOptions m_specific_options
Method specific options.
Definition: FANN.h:102
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: FANN.cc:30
General options which are shared by all MVA trainings.
Definition: Options.h:62
Specific Options, all method Options have to inherit from this class.
Definition: Options.h:98
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
Definition: Teacher.h:29
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
Abstract base class for different kinds of events.