Belle II Software  release-08-02-04
FANN.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <mva/methods/FANN.h>
10 
11 #include <framework/logging/Logger.h>
12 
13 #ifdef HAS_OPENMP
14 #include <parallel_fann.hpp>
15 #else
16 #include <fann.h>
17 #endif
18 
19 namespace Belle2 {
24  namespace MVA {
25 
26  FANNTeacher::FANNTeacher(const GeneralOptions& general_options, const FANNOptions& specific_options) : Teacher(general_options),
27  m_specific_options(specific_options) { }
28 
29 
30  Weightfile FANNTeacher::train(Dataset& training_data) const
31  {
32 
33  unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
34  unsigned int numberOfEvents = training_data.getNumberOfEvents();
35 
36  std::vector<unsigned int> hiddenLayers = m_specific_options.getHiddenLayerNeurons(numberOfFeatures);
37  unsigned int number_of_layers = 2;
38  for (unsigned int hiddenLayer : hiddenLayers) {
39  if (hiddenLayer > 0) {
40  number_of_layers++;
41  }
42  }
43 
44  auto layers = std::unique_ptr<unsigned int[]>(new unsigned int[number_of_layers]);
45  layers[0] = numberOfFeatures;
46  for (unsigned int i = 0; i < hiddenLayers.size(); ++i) {
47  if (hiddenLayers[i] > 0) {
48  layers[i + 1] = hiddenLayers[i];
49  }
50  }
51  layers[number_of_layers - 1] = 1;
52 
53  struct fann* ann = fann_create_standard_array(number_of_layers, layers.get());
54 
55  std::map<std::string, enum fann_activationfunc_enum> activationFunctions;
56  unsigned int i = 0;
57  for (auto& name : FANN_ACTIVATIONFUNC_NAMES) {
58  activationFunctions[name] = fann_activationfunc_enum(i);
59  i++;
60  }
61 
62 #ifdef HAS_OPENMP
63  typedef float (*FnPtr)(struct fann * ann, struct fann_train_data * data, const unsigned int threadnumb);
64  std::map<std::string, FnPtr> trainingMethods;
65  trainingMethods["FANN_TRAIN_RPROP"] = parallel_fann::train_epoch_irpropm_parallel;
66  trainingMethods["FANN_TRAIN_BATCH"] = parallel_fann::train_epoch_batch_parallel;
67  trainingMethods["FANN_TRAIN_QUICKPROP"] = parallel_fann::train_epoch_quickprop_parallel;
68  trainingMethods["FANN_TRAIN_SARPROP"] = parallel_fann::train_epoch_sarprop_parallel;
69  trainingMethods["FANN_TRAIN_INCREMENTAL"] = nullptr;
70 #else
71  std::map<std::string, enum fann_train_enum> trainingMethods;
72  i = 0;
73  for (auto& name : FANN_TRAIN_NAMES) {
74  trainingMethods[name] = fann_train_enum(i);
75  i++;
76  }
77 #endif
78 
79  std::map<std::string, enum fann_errorfunc_enum> errorFunctions;
80  i = 0;
81  for (auto& name : FANN_ERRORFUNC_NAMES) {
82  errorFunctions[name] = fann_errorfunc_enum(i);
83  i++;
84  }
85 
86  if (activationFunctions.find(m_specific_options.m_hidden_activiation_function) == activationFunctions.end()) {
87  B2ERROR("Coulnd't find activation function named " << m_specific_options.m_hidden_activiation_function);
88  throw std::runtime_error("Coulnd't find activation function named " + m_specific_options.m_hidden_activiation_function);
89  }
90 
91  if (activationFunctions.find(m_specific_options.m_output_activiation_function) == activationFunctions.end()) {
92  B2ERROR("Coulnd't find activation function named " << m_specific_options.m_output_activiation_function);
93  throw std::runtime_error("Coulnd't find activation function named " + m_specific_options.m_output_activiation_function);
94  }
95 
96  if (errorFunctions.find(m_specific_options.m_error_function) == errorFunctions.end()) {
97  B2ERROR("Coulnd't find training method function named " << m_specific_options.m_error_function);
98  throw std::runtime_error("Coulnd't find training method function named " + m_specific_options.m_error_function);
99  }
100 
101  if (trainingMethods.find(m_specific_options.m_training_method) == trainingMethods.end()) {
102  B2ERROR("Coulnd't find training method function named " << m_specific_options.m_training_method);
103  throw std::runtime_error("Coulnd't find training method function named " + m_specific_options.m_training_method);
104  }
105 
107  B2ERROR("m_max_epochs should be larger than 0 " << m_specific_options.m_max_epochs);
108  throw std::runtime_error("m_max_epochs should be larger than 0. The given value is " + std::to_string(
110  }
111 
113  B2ERROR("m_random_seeds should be larger than 0 " << m_specific_options.m_random_seeds);
114  throw std::runtime_error("m_random_seeds should be larger than 0. The given value is " + std::to_string(
116  }
117 
119  B2ERROR("m_test_rate should be larger than 0 " << m_specific_options.m_test_rate);
120  throw std::runtime_error("m_test_rate should be larger than 0. The given value is " + std::to_string(
122  }
123 
125  B2ERROR("m_number_of_threads should be larger than 0. The given value is " << m_specific_options.m_number_of_threads);
126  throw std::runtime_error("m_number_of_threads should be larger than 0. The given value is " +
127  std::to_string(m_specific_options.m_number_of_threads));
128  }
129 
130  // set network parameters
131  fann_set_activation_function_hidden(ann, activationFunctions[m_specific_options.m_hidden_activiation_function]);
132  fann_set_activation_function_output(ann, activationFunctions[m_specific_options.m_output_activiation_function]);
133  fann_set_train_error_function(ann, errorFunctions[m_specific_options.m_error_function]);
134 
135 
136  double nTestingAndValidationEvents = numberOfEvents * m_specific_options.m_validation_fraction;
137  unsigned int nTestingEvents = int(nTestingAndValidationEvents * 0.5); // Number of events in the test sample.
138  unsigned int nValidationEvents = int(nTestingAndValidationEvents * 0.5);
139  unsigned int nTrainingEvents = numberOfEvents - nValidationEvents - nTestingEvents;
140 
141  if (nTestingAndValidationEvents < 1) {
142  B2ERROR("m_validation_fraction should be a number between 0 and 1 (0 < x < 1). The given value is " <<
144  ". The total number of events is " << numberOfEvents << ". numberOfEvents * m_validation_fraction has to be larger than one");
145  throw std::runtime_error("m_validation_fraction should be a number between 0 and 1 (0 < x < 1). numberOfEvents * m_validation_fraction has to be larger than one");
146  }
147 
148  if (nTrainingEvents < 1) {
149  B2ERROR("m_validation_fraction should be a number between 0 and 1 (0 < x < 1). The given value is " <<
151  ". The total number of events is " << numberOfEvents << ". numberOfEvents * (1 - m_validation_fraction) has to be larger than one");
152  throw std::runtime_error("m_validation_fraction should be a number between 0 and 1 (0 < x < 1). numberOfEvents * (1 - m_validation_fraction) has to be larger than one");
153  }
154 
155  // training set
156  struct fann_train_data* train_data =
157  fann_create_train(nTrainingEvents, numberOfFeatures, 1);
158  for (unsigned iEvent = 0; iEvent < nTrainingEvents; ++iEvent) {
159  training_data.loadEvent(iEvent);
160  for (unsigned iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
161  train_data->input[iEvent][iFeature] = training_data.m_input[iFeature];
162  }
163  train_data->output[iEvent][0] = training_data.m_target;
164  }
165  // validation set
166  struct fann_train_data* valid_data =
167  fann_create_train(nValidationEvents, numberOfFeatures, 1);
168  for (unsigned iEvent = nTrainingEvents; iEvent < nTrainingEvents + nValidationEvents; ++iEvent) {
169  training_data.loadEvent(iEvent);
170  for (unsigned iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
171  valid_data->input[iEvent - nTrainingEvents][iFeature] = training_data.m_input[iFeature];
172  }
173  valid_data->output[iEvent - nTrainingEvents][0] = training_data.m_target;
174  }
175 
176  // testing set
177  struct fann_train_data* test_data =
178  fann_create_train(nTestingEvents, numberOfFeatures, 1);
179  for (unsigned iEvent = nTrainingEvents + nValidationEvents; iEvent < numberOfEvents; ++iEvent) {
180  training_data.loadEvent(iEvent);
181  for (unsigned iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
182  test_data->input[iEvent - nTrainingEvents - nValidationEvents][iFeature] = training_data.m_input[iFeature];
183  }
184  test_data->output[iEvent - nTrainingEvents - nValidationEvents][0] = training_data.m_target;
185  }
186 
187  struct fann_train_data* data = fann_create_train(numberOfEvents, numberOfFeatures, 1);
188  for (unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
189  training_data.loadEvent(iEvent);
190  for (unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
191  data->input[iEvent][iFeature] = training_data.m_input[iFeature];
192  }
193  data->output[iEvent][0] = training_data.m_target;
194  }
195 
197  fann_set_input_scaling_params(ann, data, -1.0, 1.0);
198  }
199 
201  fann_set_output_scaling_params(ann, data, -1.0, 1.0);
202  }
203 
205  fann_scale_train(ann, data);
206  fann_scale_train(ann, train_data);
207  fann_scale_train(ann, valid_data);
208  fann_scale_train(ann, test_data);
209  }
210 
211  struct fann* bestANN = nullptr;
212  double bestRMS = 999.;
213  std::vector<double> bestTrainLog = {};
214  std::vector<double> bestValidLog = {};
215 
216  // repeat training several times with different random start weights
217  for (unsigned int iRun = 0; iRun < m_specific_options.m_random_seeds; ++iRun) {
218  double bestValid = 999.;
219  std::vector<double> trainLog = {};
220  std::vector<double> validLog = {};
221  trainLog.assign(m_specific_options.m_max_epochs, 0.);
222  validLog.assign(m_specific_options.m_max_epochs, 0.);
223  int breakEpoch = 0;
224  struct fann* iRunANN = nullptr;
225  fann_randomize_weights(ann, -0.1, 0.1);
226  for (unsigned int iEpoch = 1; iEpoch <= m_specific_options.m_max_epochs; ++iEpoch) {
227  double mse;
228 #ifdef HAS_OPENMP
229  if (m_specific_options.m_training_method != "FANN_TRAIN_INCREMENTAL") {
230  mse = trainingMethods[m_specific_options.m_training_method](ann, train_data, m_specific_options.m_number_of_threads);
231  } else {mse = parallel_fann::train_epoch_incremental_mod(ann, train_data);}
232 #else
233  fann_set_training_algorithm(ann, trainingMethods[m_specific_options.m_training_method]);
234  mse = fann_train_epoch(ann, train_data);
235 #endif
236  trainLog[iEpoch - 1] = mse;
237  // evaluate validation set
238  fann_reset_MSE(ann);
239 
240 #ifdef HAS_OPENMP
241  double valid_mse = parallel_fann::test_data_parallel(ann, valid_data, m_specific_options.m_number_of_threads);
242 #else
243  double valid_mse = fann_test_data(ann, valid_data);
244 #endif
245 
246  validLog[iEpoch - 1] = valid_mse;
247  // keep weights for lowest validation error
248  if (valid_mse < bestValid) {
249  bestValid = valid_mse;
250  iRunANN = fann_copy(ann);
251  }
252  // break when validation error increases
253  if (iEpoch > m_specific_options.m_test_rate && valid_mse > validLog[iEpoch - m_specific_options.m_test_rate]) {
255  B2INFO("Training stopped in iEpoch " << iEpoch);
256  B2INFO("Train error: " << mse << ", valid error: " << valid_mse <<
257  ", best valid: " << bestValid);
258  }
259  breakEpoch = iEpoch;
260  break;
261  }
262  // print current status
263  if (iEpoch == 1 || (iEpoch < 100 && iEpoch % 10 == 0) || iEpoch % 100 == 0) {
264  if (m_specific_options.m_verbose_mode) B2INFO("Epoch " << iEpoch << ": Train error = " << mse <<
265  ", valid error = " << valid_mse << ", best valid = " << bestValid);
266  }
267  }
268 
269  // test trained network
270 
271 #ifdef HAS_OPENMP
272  double test_mse = parallel_fann::test_data_parallel(iRunANN, test_data, m_specific_options.m_number_of_threads);
273 #else
274  double test_mse = fann_test_data(iRunANN, test_data);
275 #endif
276 
277  double RMS = sqrt(test_mse);
278 
279  if (RMS < bestRMS) {
280  bestRMS = RMS;
281  bestANN = fann_copy(iRunANN);
282  fann_destroy(iRunANN);
283  bestTrainLog.assign(trainLog.begin(), trainLog.begin() + breakEpoch);
284  bestValidLog.assign(validLog.begin(), validLog.begin() + breakEpoch);
285  }
286  if (m_specific_options.m_verbose_mode) B2INFO("RMS on test samples: " << RMS << " (best: " << bestRMS << ")");
287  }
288 
289  fann_destroy_train(data);
290  fann_destroy_train(train_data);
291  fann_destroy_train(valid_data);
292  fann_destroy_train(test_data);
293  fann_destroy(ann);
294 
295  Weightfile weightfile;
296  std::string custom_weightfile = weightfile.generateFileName();
297 
298  fann_save(bestANN, custom_weightfile.c_str());
299  fann_destroy(bestANN);
300 
301  weightfile.addOptions(m_general_options);
302  weightfile.addOptions(m_specific_options);
303  weightfile.addFile("FANN_Weightfile", custom_weightfile);
304  weightfile.addVector("FANN_bestTrainLog", bestTrainLog);
305  weightfile.addVector("FANN_bestValidLog", bestValidLog);
306  weightfile.addSignalFraction(training_data.getSignalFraction());
307 
308  return weightfile;
309 
310  }
311 
313  {
314  if (m_ann) {
315  fann_destroy(m_ann);
316  }
317  }
318 
319  void FANNExpert::load(Weightfile& weightfile)
320  {
321 
322  std::string custom_weightfile = weightfile.generateFileName();
323  weightfile.getFile("FANN_Weightfile", custom_weightfile);
324 
325  if (m_ann) {
326  fann_destroy(m_ann);
327  }
328  m_ann = fann_create_from_file(custom_weightfile.c_str());
329 
330  weightfile.getOptions(m_specific_options);
331  }
332 
333  std::vector<float> FANNExpert::apply(Dataset& test_data) const
334  {
335 
336  std::vector<fann_type> input(test_data.getNumberOfFeatures());
337  std::vector<float> probabilities(test_data.getNumberOfEvents());
338  for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
339  test_data.loadEvent(iEvent);
340  for (unsigned int iFeature = 0; iFeature < test_data.getNumberOfFeatures(); ++iFeature) {
341  input[iFeature] = test_data.m_input[iFeature];
342  }
343  if (m_specific_options.m_scale_features) fann_scale_input(m_ann, input.data());
344  probabilities[iEvent] = fann_run(m_ann, input.data())[0];
345  }
346  if (m_specific_options.m_scale_target) fann_descale_output(m_ann, probabilities.data());
347  return probabilities;
348  }
349 
350  }
352 }
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
virtual ~FANNExpert()
Destructor of FANN Expert.
Definition: FANN.cc:312
struct fann * m_ann
Pointer to FANN expert.
Definition: FANN.h:132
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
Definition: FANN.cc:333
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: FANN.cc:319
FANNOptions m_specific_options
Method specific options.
Definition: FANN.h:131
Options for the FANN MVA method.
Definition: FANN.h:29
double m_validation_fraction
Fraction of training sample used for validation in order to avoid overtraining.
Definition: FANN.h:69
bool m_scale_features
Scale features before training.
Definition: FANN.h:77
bool m_verbose_mode
Sets to report training status or not.
Definition: FANN.h:61
unsigned int m_random_seeds
Number of times the training is repeated with a new weight random seed.
Definition: FANN.h:70
std::string m_error_function
Loss function.
Definition: FANN.h:66
unsigned int m_number_of_threads
Number of threads for parallel training.
Definition: FANN.h:74
unsigned int m_test_rate
Error on validation is compared with the one before.
Definition: FANN.h:72
std::string m_hidden_activiation_function
Activation function in hidden layer.
Definition: FANN.h:64
bool m_scale_target
Scale target before training.
Definition: FANN.h:78
std::vector< unsigned int > getHiddenLayerNeurons(unsigned int nf) const
Returns the internal vector parameter with the number of hidden neurons per layer.
Definition: FANNOptions.cc:93
std::string m_training_method
Training method for back propagation.
Definition: FANN.h:67
std::string m_output_activiation_function
Activation function in output layer.
Definition: FANN.h:65
unsigned int m_max_epochs
Maximum number of epochs.
Definition: FANN.h:60
FANNTeacher(const GeneralOptions &general_options, const FANNOptions &specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: FANN.cc:26
FANNOptions m_specific_options
Method specific options.
Definition: FANN.h:102
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: FANN.cc:30
General options which are shared by all MVA trainings.
Definition: Options.h:62
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
Definition: Teacher.h:29
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition: Teacher.h:49
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
Definition: Weightfile.cc:115
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition: Weightfile.cc:62
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:67
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:95
void addVector(const std::string &identifier, const std::vector< T > &vector)
Add a vector to the xml tree.
Definition: Weightfile.h:125
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
Definition: Weightfile.cc:105
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Definition: Weightfile.cc:138
double sqrt(double a)
sqrt for double
Definition: beamHelpers.h:28
Abstract base class for different kinds of events.