Belle II Software  release-05-01-25
FANN.cc
1 /**************************************************************************
2  * BASF2 (Belle Analysis Framework 2) *
3  * Copyright(C) 2016 - Belle II Collaboration *
4  * *
5  * Author: The Belle II Collaboration *
6  * Contributors: Thomas Keck and Fernando Abudinen *
7  * *
8  * This software is provided "as is" without any warranty. *
9  **************************************************************************/
10 
11 #include <mva/methods/FANN.h>
12 
13 #include <framework/logging/Logger.h>
14 
15 #ifdef HAS_OPENMP
16 #include <parallel_fann.hpp>
17 #else
18 #include <fann.h>
19 #endif
20 
21 namespace Belle2 {
26  namespace MVA {
27 
28  FANNTeacher::FANNTeacher(const GeneralOptions& general_options, const FANNOptions& specific_options) : Teacher(general_options),
29  m_specific_options(specific_options) { }
30 
31 
32  Weightfile FANNTeacher::train(Dataset& training_data) const
33  {
34 
35  unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
36  unsigned int numberOfEvents = training_data.getNumberOfEvents();
37 
38  std::vector<unsigned int> hiddenLayers = m_specific_options.getHiddenLayerNeurons(numberOfFeatures);
39  unsigned int number_of_layers = 2;
40  for (unsigned int hiddenLayer : hiddenLayers) {
41  if (hiddenLayer > 0) {
42  number_of_layers++;
43  }
44  }
45 
46  auto layers = std::unique_ptr<unsigned int[]>(new unsigned int[number_of_layers]);
47  layers[0] = numberOfFeatures;
48  for (unsigned int i = 0; i < hiddenLayers.size(); ++i) {
49  if (hiddenLayers[i] > 0) {
50  layers[i + 1] = hiddenLayers[i];
51  }
52  }
53  layers[number_of_layers - 1] = 1;
54 
55  struct fann* ann = fann_create_standard_array(number_of_layers, layers.get());
56 
57  std::map<std::string, enum fann_activationfunc_enum> activationFunctions;
58  unsigned int i = 0;
59  for (auto& name : FANN_ACTIVATIONFUNC_NAMES) {
60  activationFunctions[name] = fann_activationfunc_enum(i);
61  i++;
62  }
63 
64 #ifdef HAS_OPENMP
65  typedef float (*FnPtr)(struct fann * ann, struct fann_train_data * data, const unsigned int threadnumb);
66  std::map<std::string, FnPtr> trainingMethods;
67  trainingMethods["FANN_TRAIN_RPROP"] = parallel_fann::train_epoch_irpropm_parallel;
68  trainingMethods["FANN_TRAIN_BATCH"] = parallel_fann::train_epoch_batch_parallel;
69  trainingMethods["FANN_TRAIN_QUICKPROP"] = parallel_fann::train_epoch_quickprop_parallel;
70  trainingMethods["FANN_TRAIN_SARPROP"] = parallel_fann::train_epoch_sarprop_parallel;
71  trainingMethods["FANN_TRAIN_INCREMENTAL"] = nullptr;
72 #else
73  std::map<std::string, enum fann_train_enum> trainingMethods;
74  i = 0;
75  for (auto& name : FANN_TRAIN_NAMES) {
76  trainingMethods[name] = fann_train_enum(i);
77  i++;
78  }
79 #endif
80 
81  std::map<std::string, enum fann_errorfunc_enum> errorFunctions;
82  i = 0;
83  for (auto& name : FANN_ERRORFUNC_NAMES) {
84  errorFunctions[name] = fann_errorfunc_enum(i);
85  i++;
86  }
87 
88  if (activationFunctions.find(m_specific_options.m_hidden_activiation_function) == activationFunctions.end()) {
89  B2ERROR("Coulnd't find activation function named " << m_specific_options.m_hidden_activiation_function);
90  throw std::runtime_error("Coulnd't find activation function named " + m_specific_options.m_hidden_activiation_function);
91  }
92 
93  if (activationFunctions.find(m_specific_options.m_output_activiation_function) == activationFunctions.end()) {
94  B2ERROR("Coulnd't find activation function named " << m_specific_options.m_output_activiation_function);
95  throw std::runtime_error("Coulnd't find activation function named " + m_specific_options.m_output_activiation_function);
96  }
97 
98  if (errorFunctions.find(m_specific_options.m_error_function) == errorFunctions.end()) {
99  B2ERROR("Coulnd't find training method function named " << m_specific_options.m_error_function);
100  throw std::runtime_error("Coulnd't find training method function named " + m_specific_options.m_error_function);
101  }
102 
103  if (trainingMethods.find(m_specific_options.m_training_method) == trainingMethods.end()) {
104  B2ERROR("Coulnd't find training method function named " << m_specific_options.m_training_method);
105  throw std::runtime_error("Coulnd't find training method function named " + m_specific_options.m_training_method);
106  }
107 
109  B2ERROR("m_max_epochs should be larger than 0 " << m_specific_options.m_max_epochs);
110  throw std::runtime_error("m_max_epochs should be larger than 0. The given value is " + std::to_string(
112  }
113 
115  B2ERROR("m_random_seeds should be larger than 0 " << m_specific_options.m_random_seeds);
116  throw std::runtime_error("m_random_seeds should be larger than 0. The given value is " + std::to_string(
118  }
119 
121  B2ERROR("m_test_rate should be larger than 0 " << m_specific_options.m_test_rate);
122  throw std::runtime_error("m_test_rate should be larger than 0. The given value is " + std::to_string(
124  }
125 
127  B2ERROR("m_number_of_threads should be larger than 0. The given value is " << m_specific_options.m_number_of_threads);
128  throw std::runtime_error("m_number_of_threads should be larger than 0. The given value is " +
129  std::to_string(m_specific_options.m_number_of_threads));
130  }
131 
132  // set network parameters
133  fann_set_activation_function_hidden(ann, activationFunctions[m_specific_options.m_hidden_activiation_function]);
134  fann_set_activation_function_output(ann, activationFunctions[m_specific_options.m_output_activiation_function]);
135  fann_set_train_error_function(ann, errorFunctions[m_specific_options.m_error_function]);
136 
137 
138  double nTestingAndValidationEvents = numberOfEvents * m_specific_options.m_validation_fraction;
139  unsigned int nTestingEvents = int(nTestingAndValidationEvents * 0.5); // Number of events in the test sample.
140  unsigned int nValidationEvents = int(nTestingAndValidationEvents * 0.5);
141  unsigned int nTrainingEvents = numberOfEvents - nValidationEvents - nTestingEvents;
142 
143  if (nTestingAndValidationEvents < 1) {
144  B2ERROR("m_validation_fraction should be a number between 0 and 1 (0 < x < 1). The given value is " <<
146  ". The total number of events is " << numberOfEvents << ". numberOfEvents * m_validation_fraction has to be larger than one");
147  throw std::runtime_error("m_validation_fraction should be a number between 0 and 1 (0 < x < 1). numberOfEvents * m_validation_fraction has to be larger than one");
148  }
149 
150  if (nTrainingEvents < 1) {
151  B2ERROR("m_validation_fraction should be a number between 0 and 1 (0 < x < 1). The given value is " <<
153  ". The total number of events is " << numberOfEvents << ". numberOfEvents * (1 - m_validation_fraction) has to be larger than one");
154  throw std::runtime_error("m_validation_fraction should be a number between 0 and 1 (0 < x < 1). numberOfEvents * (1 - m_validation_fraction) has to be larger than one");
155  }
156 
157  // training set
158  struct fann_train_data* train_data =
159  fann_create_train(nTrainingEvents, numberOfFeatures, 1);
160  for (unsigned iEvent = 0; iEvent < nTrainingEvents; ++iEvent) {
161  training_data.loadEvent(iEvent);
162  for (unsigned iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
163  train_data->input[iEvent][iFeature] = training_data.m_input[iFeature];
164  }
165  train_data->output[iEvent][0] = training_data.m_target;
166  }
167  // validation set
168  struct fann_train_data* valid_data =
169  fann_create_train(nValidationEvents, numberOfFeatures, 1);
170  for (unsigned iEvent = nTrainingEvents; iEvent < nTrainingEvents + nValidationEvents; ++iEvent) {
171  training_data.loadEvent(iEvent);
172  for (unsigned iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
173  valid_data->input[iEvent - nTrainingEvents][iFeature] = training_data.m_input[iFeature];
174  }
175  valid_data->output[iEvent - nTrainingEvents][0] = training_data.m_target;
176  }
177 
178  // testing set
179  struct fann_train_data* test_data =
180  fann_create_train(nTestingEvents, numberOfFeatures, 1);
181  for (unsigned iEvent = nTrainingEvents + nValidationEvents; iEvent < numberOfEvents; ++iEvent) {
182  training_data.loadEvent(iEvent);
183  for (unsigned iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
184  test_data->input[iEvent - nTrainingEvents - nValidationEvents][iFeature] = training_data.m_input[iFeature];
185  }
186  test_data->output[iEvent - nTrainingEvents - nValidationEvents][0] = training_data.m_target;
187  }
188 
189  struct fann_train_data* data = fann_create_train(numberOfEvents, numberOfFeatures, 1);
190  for (unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
191  training_data.loadEvent(iEvent);
192  for (unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
193  data->input[iEvent][iFeature] = training_data.m_input[iFeature];
194  }
195  data->output[iEvent][0] = training_data.m_target;
196  }
197 
199  fann_set_input_scaling_params(ann, data, -1.0, 1.0);
200  }
201 
203  fann_set_output_scaling_params(ann, data, -1.0, 1.0);
204  }
205 
207  fann_scale_train(ann, data);
208  fann_scale_train(ann, train_data);
209  fann_scale_train(ann, valid_data);
210  fann_scale_train(ann, test_data);
211  }
212 
213  struct fann* bestANN = nullptr;
214  double bestRMS = 999.;
215  std::vector<double> bestTrainLog = {};
216  std::vector<double> bestValidLog = {};
217 
218  // repeat training several times with different random start weights
219  for (unsigned int iRun = 0; iRun < m_specific_options.m_random_seeds; ++iRun) {
220  double bestValid = 999.;
221  std::vector<double> trainLog = {};
222  std::vector<double> validLog = {};
223  trainLog.assign(m_specific_options.m_max_epochs, 0.);
224  validLog.assign(m_specific_options.m_max_epochs, 0.);
225  int breakEpoch = 0;
226  struct fann* iRunANN = nullptr;
227  fann_randomize_weights(ann, -0.1, 0.1);
228  for (unsigned int iEpoch = 1; iEpoch <= m_specific_options.m_max_epochs; ++iEpoch) {
229  double mse;
230 #ifdef HAS_OPENMP
231  if (m_specific_options.m_training_method != "FANN_TRAIN_INCREMENTAL") {
232  mse = trainingMethods[m_specific_options.m_training_method](ann, train_data, m_specific_options.m_number_of_threads);
233  } else {mse = parallel_fann::train_epoch_incremental_mod(ann, train_data);}
234 #else
235  fann_set_training_algorithm(ann, trainingMethods[m_specific_options.m_training_method]);
236  mse = fann_train_epoch(ann, train_data);
237 #endif
238  trainLog[iEpoch - 1] = mse;
239  // evaluate validation set
240  fann_reset_MSE(ann);
241 
242 #ifdef HAS_OPENMP
243  double valid_mse = parallel_fann::test_data_parallel(ann, valid_data, m_specific_options.m_number_of_threads);
244 #else
245  double valid_mse = fann_test_data(ann, valid_data);
246 #endif
247 
248  validLog[iEpoch - 1] = valid_mse;
249  // keep weights for lowest validation error
250  if (valid_mse < bestValid) {
251  bestValid = valid_mse;
252  iRunANN = fann_copy(ann);
253  }
254  // break when validation error increases
255  if (iEpoch > m_specific_options.m_test_rate && valid_mse > validLog[iEpoch - m_specific_options.m_test_rate]) {
257  B2INFO("Training stopped in iEpoch " << iEpoch);
258  B2INFO("Train error: " << mse << ", valid error: " << valid_mse <<
259  ", best valid: " << bestValid);
260  }
261  breakEpoch = iEpoch;
262  break;
263  }
264  // print current status
265  if (iEpoch == 1 || (iEpoch < 100 && iEpoch % 10 == 0) || iEpoch % 100 == 0) {
266  if (m_specific_options.m_verbose_mode) B2INFO("Epoch " << iEpoch << ": Train error = " << mse <<
267  ", valid error = " << valid_mse << ", best valid = " << bestValid);
268  }
269  }
270 
271  // test trained network
272 
273 #ifdef HAS_OPENMP
274  double test_mse = parallel_fann::test_data_parallel(iRunANN, test_data, m_specific_options.m_number_of_threads);
275 #else
276  double test_mse = fann_test_data(iRunANN, test_data);
277 #endif
278 
279  double RMS = sqrt(test_mse);
280 
281  if (RMS < bestRMS) {
282  bestRMS = RMS;
283  bestANN = fann_copy(iRunANN);
284  fann_destroy(iRunANN);
285  bestTrainLog.assign(trainLog.begin(), trainLog.begin() + breakEpoch);
286  bestValidLog.assign(validLog.begin(), validLog.begin() + breakEpoch);
287  }
288  if (m_specific_options.m_verbose_mode) B2INFO("RMS on test samples: " << RMS << " (best: " << bestRMS << ")");
289  }
290 
291  fann_destroy_train(data);
292  fann_destroy_train(train_data);
293  fann_destroy_train(valid_data);
294  fann_destroy_train(test_data);
295  fann_destroy(ann);
296 
297  Weightfile weightfile;
298  std::string custom_weightfile = weightfile.generateFileName();
299 
300  fann_save(bestANN, custom_weightfile.c_str());
301  fann_destroy(bestANN);
302 
303  weightfile.addOptions(m_general_options);
304  weightfile.addOptions(m_specific_options);
305  weightfile.addFile("FANN_Weightfile", custom_weightfile);
306  weightfile.addVector("FANN_bestTrainLog", bestTrainLog);
307  weightfile.addVector("FANN_bestValidLog", bestValidLog);
308  weightfile.addSignalFraction(training_data.getSignalFraction());
309 
310  return weightfile;
311 
312  }
313 
315  {
316  if (m_ann) {
317  fann_destroy(m_ann);
318  }
319  }
320 
321  void FANNExpert::load(Weightfile& weightfile)
322  {
323 
324  std::string custom_weightfile = weightfile.generateFileName();
325  weightfile.getFile("FANN_Weightfile", custom_weightfile);
326 
327  if (m_ann) {
328  fann_destroy(m_ann);
329  }
330  m_ann = fann_create_from_file(custom_weightfile.c_str());
331 
332  weightfile.getOptions(m_specific_options);
333  }
334 
335  std::vector<float> FANNExpert::apply(Dataset& test_data) const
336  {
337 
338  std::vector<fann_type> input(test_data.getNumberOfFeatures());
339  std::vector<float> probabilities(test_data.getNumberOfEvents());
340  for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
341  test_data.loadEvent(iEvent);
342  for (unsigned int iFeature = 0; iFeature < test_data.getNumberOfFeatures(); ++iFeature) {
343  input[iFeature] = test_data.m_input[iFeature];
344  }
345  if (m_specific_options.m_scale_features) fann_scale_input(m_ann, input.data());
346  probabilities[iEvent] = fann_run(m_ann, input.data())[0];
347  }
348  if (m_specific_options.m_scale_target) fann_descale_output(m_ann, probabilities.data());
349  return probabilities;
350  }
351 
352  }
354 }
Belle2::MVA::FANNOptions::m_random_seeds
unsigned int m_random_seeds
Number of times the training is repeated with a new weight random seed.
Definition: FANN.h:72
Belle2::MVA::FANNTeacher::FANNTeacher
FANNTeacher(const GeneralOptions &general_options, const FANNOptions &specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: FANN.cc:36
Belle2::MVA::FANNExpert::~FANNExpert
virtual ~FANNExpert()
Destructor of FANN Expert.
Definition: FANN.cc:322
Belle2::MVA::FANNOptions::m_training_method
std::string m_training_method
Training method for back propagation.
Definition: FANN.h:69
Belle2::MVA::FANNExpert::load
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: FANN.cc:329
Belle2::MVA::FANNOptions::m_hidden_activiation_function
std::string m_hidden_activiation_function
Activation function in hidden layer.
Definition: FANN.h:66
Belle2::MVA::FANNOptions::m_validation_fraction
double m_validation_fraction
Fraction of training sample used for validation in order to avoid overtraining.
Definition: FANN.h:71
Belle2::MVA::FANNOptions::m_test_rate
unsigned int m_test_rate
Error on validation is compared with the one before.
Definition: FANN.h:74
Belle2::MVA::FANNOptions::m_verbose_mode
bool m_verbose_mode
Sets to report training status or not.
Definition: FANN.h:63
Belle2::MVA::Teacher::m_general_options
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition: Teacher.h:51
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::MVA::FANNOptions::m_number_of_threads
unsigned int m_number_of_threads
Number of threads for parallel training.
Definition: FANN.h:76
Belle2::MVA::FANNOptions::m_output_activiation_function
std::string m_output_activiation_function
Activation function in output layer.
Definition: FANN.h:67
Belle2::MVA::FANNOptions::m_scale_features
bool m_scale_features
Scale features before training.
Definition: FANN.h:79
Belle2::MVA::FANNOptions::m_scale_target
bool m_scale_target
Scale target before training.
Definition: FANN.h:80
Belle2::MVA::FANNExpert::apply
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
Definition: FANN.cc:343
Belle2::MVA::FANNOptions::m_max_epochs
unsigned int m_max_epochs
Maximum number of epochs.
Definition: FANN.h:62
Belle2::MVA::FANNTeacher::train
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: FANN.cc:40
Belle2::MVA::FANNOptions::m_error_function
std::string m_error_function
Loss function.
Definition: FANN.h:68
Belle2::MVA::FANNExpert::m_specific_options
FANNOptions m_specific_options
Method specific options.
Definition: FANN.h:133
Belle2::MVA::FANNExpert::m_ann
struct fann * m_ann
Pointer to FANN expert.
Definition: FANN.h:134
Belle2::MVA::FANNTeacher::m_specific_options
FANNOptions m_specific_options
Method specific options.
Definition: FANN.h:104
Belle2::MVA::FANNOptions::getHiddenLayerNeurons
std::vector< unsigned int > getHiddenLayerNeurons(unsigned int nf) const
Returns the internal vector parameter with the number of hidden neurons per layer.
Definition: FANNOptions.cc:103