Belle II Software development
FANN.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <mva/methods/FANN.h>
10
11#include <framework/logging/Logger.h>
12
13#ifdef HAS_OPENMP
14#include <parallel_fann.hpp>
15#else
16#include <fann.h>
17#endif
18
19namespace Belle2 {
24 namespace MVA {
25
26 FANNTeacher::FANNTeacher(const GeneralOptions& general_options, const FANNOptions& specific_options) : Teacher(general_options),
27 m_specific_options(specific_options) { }
28
29
31 {
32
33 unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
34 unsigned int numberOfEvents = training_data.getNumberOfEvents();
35
36 std::vector<unsigned int> hiddenLayers = m_specific_options.getHiddenLayerNeurons(numberOfFeatures);
37 unsigned int number_of_layers = 2;
38 for (unsigned int hiddenLayer : hiddenLayers) {
39 if (hiddenLayer > 0) {
40 number_of_layers++;
41 }
42 }
43
44 auto layers = std::unique_ptr<unsigned int[]>(new unsigned int[number_of_layers]);
45 layers[0] = numberOfFeatures;
46 for (unsigned int i = 0; i < hiddenLayers.size(); ++i) {
47 if (hiddenLayers[i] > 0) {
48 layers[i + 1] = hiddenLayers[i];
49 }
50 }
51 layers[number_of_layers - 1] = 1;
52
53 struct fann* ann = fann_create_standard_array(number_of_layers, layers.get());
54
55 std::map<std::string, enum fann_activationfunc_enum> activationFunctions;
56 unsigned int i = 0;
57 for (auto& name : FANN_ACTIVATIONFUNC_NAMES) {
58 activationFunctions[name] = fann_activationfunc_enum(i);
59 i++;
60 }
61
62#ifdef HAS_OPENMP
63 typedef float (*FnPtr)(struct fann * ann, struct fann_train_data * data, const unsigned int threadnumb);
64 std::map<std::string, FnPtr> trainingMethods;
65 trainingMethods["FANN_TRAIN_RPROP"] = parallel_fann::train_epoch_irpropm_parallel;
66 trainingMethods["FANN_TRAIN_BATCH"] = parallel_fann::train_epoch_batch_parallel;
67 trainingMethods["FANN_TRAIN_QUICKPROP"] = parallel_fann::train_epoch_quickprop_parallel;
68 trainingMethods["FANN_TRAIN_SARPROP"] = parallel_fann::train_epoch_sarprop_parallel;
69 trainingMethods["FANN_TRAIN_INCREMENTAL"] = nullptr;
70#else
71 std::map<std::string, enum fann_train_enum> trainingMethods;
72 i = 0;
73 for (auto& name : FANN_TRAIN_NAMES) {
74 trainingMethods[name] = fann_train_enum(i);
75 i++;
76 }
77#endif
78
79 std::map<std::string, enum fann_errorfunc_enum> errorFunctions;
80 i = 0;
81 for (auto& name : FANN_ERRORFUNC_NAMES) {
82 errorFunctions[name] = fann_errorfunc_enum(i);
83 i++;
84 }
85
86 if (activationFunctions.find(m_specific_options.m_hidden_activiation_function) == activationFunctions.end()) {
87 B2ERROR("Coulnd't find activation function named " << m_specific_options.m_hidden_activiation_function);
88 throw std::runtime_error("Coulnd't find activation function named " + m_specific_options.m_hidden_activiation_function);
89 }
90
91 if (activationFunctions.find(m_specific_options.m_output_activiation_function) == activationFunctions.end()) {
92 B2ERROR("Coulnd't find activation function named " << m_specific_options.m_output_activiation_function);
93 throw std::runtime_error("Coulnd't find activation function named " + m_specific_options.m_output_activiation_function);
94 }
95
96 if (errorFunctions.find(m_specific_options.m_error_function) == errorFunctions.end()) {
97 B2ERROR("Coulnd't find training method function named " << m_specific_options.m_error_function);
98 throw std::runtime_error("Coulnd't find training method function named " + m_specific_options.m_error_function);
99 }
100
101 if (trainingMethods.find(m_specific_options.m_training_method) == trainingMethods.end()) {
102 B2ERROR("Coulnd't find training method function named " << m_specific_options.m_training_method);
103 throw std::runtime_error("Coulnd't find training method function named " + m_specific_options.m_training_method);
104 }
105
107 B2ERROR("m_max_epochs should be larger than 0 " << m_specific_options.m_max_epochs);
108 throw std::runtime_error("m_max_epochs should be larger than 0. The given value is " + std::to_string(
110 }
111
113 B2ERROR("m_random_seeds should be larger than 0 " << m_specific_options.m_random_seeds);
114 throw std::runtime_error("m_random_seeds should be larger than 0. The given value is " + std::to_string(
116 }
117
119 B2ERROR("m_test_rate should be larger than 0 " << m_specific_options.m_test_rate);
120 throw std::runtime_error("m_test_rate should be larger than 0. The given value is " + std::to_string(
122 }
123
125 B2ERROR("m_number_of_threads should be larger than 0. The given value is " << m_specific_options.m_number_of_threads);
126 throw std::runtime_error("m_number_of_threads should be larger than 0. The given value is " +
128 }
129
130 // set network parameters
131 fann_set_activation_function_hidden(ann, activationFunctions[m_specific_options.m_hidden_activiation_function]);
132 fann_set_activation_function_output(ann, activationFunctions[m_specific_options.m_output_activiation_function]);
133 fann_set_train_error_function(ann, errorFunctions[m_specific_options.m_error_function]);
134
135
136 double nTestingAndValidationEvents = numberOfEvents * m_specific_options.m_validation_fraction;
137 unsigned int nTestingEvents = int(nTestingAndValidationEvents * 0.5); // Number of events in the test sample.
138 unsigned int nValidationEvents = int(nTestingAndValidationEvents * 0.5);
139 unsigned int nTrainingEvents = numberOfEvents - nValidationEvents - nTestingEvents;
140
141 if (nTestingAndValidationEvents < 1) {
142 B2ERROR("m_validation_fraction should be a number between 0 and 1 (0 < x < 1). The given value is " <<
144 ". The total number of events is " << numberOfEvents << ". numberOfEvents * m_validation_fraction has to be larger than one");
145 throw std::runtime_error("m_validation_fraction should be a number between 0 and 1 (0 < x < 1). numberOfEvents * m_validation_fraction has to be larger than one");
146 }
147
148 if (nTrainingEvents < 1) {
149 B2ERROR("m_validation_fraction should be a number between 0 and 1 (0 < x < 1). The given value is " <<
151 ". The total number of events is " << numberOfEvents << ". numberOfEvents * (1 - m_validation_fraction) has to be larger than one");
152 throw std::runtime_error("m_validation_fraction should be a number between 0 and 1 (0 < x < 1). numberOfEvents * (1 - m_validation_fraction) has to be larger than one");
153 }
154
155 // training set
156 struct fann_train_data* train_data =
157 fann_create_train(nTrainingEvents, numberOfFeatures, 1);
158 for (unsigned iEvent = 0; iEvent < nTrainingEvents; ++iEvent) {
159 training_data.loadEvent(iEvent);
160 for (unsigned iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
161 train_data->input[iEvent][iFeature] = training_data.m_input[iFeature];
162 }
163 train_data->output[iEvent][0] = training_data.m_target;
164 }
165 // validation set
166 struct fann_train_data* valid_data =
167 fann_create_train(nValidationEvents, numberOfFeatures, 1);
168 for (unsigned iEvent = nTrainingEvents; iEvent < nTrainingEvents + nValidationEvents; ++iEvent) {
169 training_data.loadEvent(iEvent);
170 for (unsigned iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
171 valid_data->input[iEvent - nTrainingEvents][iFeature] = training_data.m_input[iFeature];
172 }
173 valid_data->output[iEvent - nTrainingEvents][0] = training_data.m_target;
174 }
175
176 // testing set
177 struct fann_train_data* test_data =
178 fann_create_train(nTestingEvents, numberOfFeatures, 1);
179 for (unsigned iEvent = nTrainingEvents + nValidationEvents; iEvent < numberOfEvents; ++iEvent) {
180 training_data.loadEvent(iEvent);
181 for (unsigned iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
182 test_data->input[iEvent - nTrainingEvents - nValidationEvents][iFeature] = training_data.m_input[iFeature];
183 }
184 test_data->output[iEvent - nTrainingEvents - nValidationEvents][0] = training_data.m_target;
185 }
186
187 struct fann_train_data* data = fann_create_train(numberOfEvents, numberOfFeatures, 1);
188 for (unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
189 training_data.loadEvent(iEvent);
190 for (unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
191 data->input[iEvent][iFeature] = training_data.m_input[iFeature];
192 }
193 data->output[iEvent][0] = training_data.m_target;
194 }
195
197 fann_set_input_scaling_params(ann, data, -1.0, 1.0);
198 }
199
201 fann_set_output_scaling_params(ann, data, -1.0, 1.0);
202 }
203
205 fann_scale_train(ann, data);
206 fann_scale_train(ann, train_data);
207 fann_scale_train(ann, valid_data);
208 fann_scale_train(ann, test_data);
209 }
210
211 struct fann* bestANN = nullptr;
212 double bestRMS = 999.;
213 std::vector<double> bestTrainLog = {};
214 std::vector<double> bestValidLog = {};
215
216 // repeat training several times with different random start weights
217 for (unsigned int iRun = 0; iRun < m_specific_options.m_random_seeds; ++iRun) {
218 double bestValid = 999.;
219 std::vector<double> trainLog = {};
220 std::vector<double> validLog = {};
221 trainLog.assign(m_specific_options.m_max_epochs, 0.);
222 validLog.assign(m_specific_options.m_max_epochs, 0.);
223 int breakEpoch = 0;
224 struct fann* iRunANN = nullptr;
225 fann_randomize_weights(ann, -0.1, 0.1);
226 for (unsigned int iEpoch = 1; iEpoch <= m_specific_options.m_max_epochs; ++iEpoch) {
227 double mse;
228#ifdef HAS_OPENMP
229 if (m_specific_options.m_training_method != "FANN_TRAIN_INCREMENTAL") {
230 mse = trainingMethods[m_specific_options.m_training_method](ann, train_data, m_specific_options.m_number_of_threads);
231 } else {mse = parallel_fann::train_epoch_incremental_mod(ann, train_data);}
232#else
233 fann_set_training_algorithm(ann, trainingMethods[m_specific_options.m_training_method]);
234 mse = fann_train_epoch(ann, train_data);
235#endif
236 trainLog[iEpoch - 1] = mse;
237 // evaluate validation set
238 fann_reset_MSE(ann);
239
240#ifdef HAS_OPENMP
241 double valid_mse = parallel_fann::test_data_parallel(ann, valid_data, m_specific_options.m_number_of_threads);
242#else
243 double valid_mse = fann_test_data(ann, valid_data);
244#endif
245
246 validLog[iEpoch - 1] = valid_mse;
247 // keep weights for lowest validation error
248 if (valid_mse < bestValid) {
249 bestValid = valid_mse;
250 iRunANN = fann_copy(ann);
251 }
252 // break when validation error increases
253 if (iEpoch > m_specific_options.m_test_rate && valid_mse > validLog[iEpoch - m_specific_options.m_test_rate]) {
255 B2INFO("Training stopped in iEpoch " << iEpoch);
256 B2INFO("Train error: " << mse << ", valid error: " << valid_mse <<
257 ", best valid: " << bestValid);
258 }
259 breakEpoch = iEpoch;
260 break;
261 }
262 // print current status
263 if (iEpoch == 1 || (iEpoch < 100 && iEpoch % 10 == 0) || iEpoch % 100 == 0) {
264 if (m_specific_options.m_verbose_mode) B2INFO("Epoch " << iEpoch << ": Train error = " << mse <<
265 ", valid error = " << valid_mse << ", best valid = " << bestValid);
266 }
267 }
268
269 // test trained network
270
271#ifdef HAS_OPENMP
272 double test_mse = parallel_fann::test_data_parallel(iRunANN, test_data, m_specific_options.m_number_of_threads);
273#else
274 double test_mse = fann_test_data(iRunANN, test_data);
275#endif
276
277 double RMS = sqrt(test_mse);
278
279 if (RMS < bestRMS) {
280 bestRMS = RMS;
281 bestANN = fann_copy(iRunANN);
282 fann_destroy(iRunANN);
283 bestTrainLog.assign(trainLog.begin(), trainLog.begin() + breakEpoch);
284 bestValidLog.assign(validLog.begin(), validLog.begin() + breakEpoch);
285 }
286 if (m_specific_options.m_verbose_mode) B2INFO("RMS on test samples: " << RMS << " (best: " << bestRMS << ")");
287 }
288
289 fann_destroy_train(data);
290 fann_destroy_train(train_data);
291 fann_destroy_train(valid_data);
292 fann_destroy_train(test_data);
293 fann_destroy(ann);
294
295 Weightfile weightfile;
296 std::string custom_weightfile = weightfile.generateFileName();
297
298 fann_save(bestANN, custom_weightfile.c_str());
299 fann_destroy(bestANN);
300
301 weightfile.addOptions(m_general_options);
302 weightfile.addOptions(m_specific_options);
303 weightfile.addFile("FANN_Weightfile", custom_weightfile);
304 weightfile.addVector("FANN_bestTrainLog", bestTrainLog);
305 weightfile.addVector("FANN_bestValidLog", bestValidLog);
306 weightfile.addSignalFraction(training_data.getSignalFraction());
307
308 return weightfile;
309
310 }
311
313 {
314 if (m_ann) {
315 fann_destroy(m_ann);
316 }
317 }
318
319 void FANNExpert::load(Weightfile& weightfile)
320 {
321
322 std::string custom_weightfile = weightfile.generateFileName();
323 weightfile.getFile("FANN_Weightfile", custom_weightfile);
324
325 if (m_ann) {
326 fann_destroy(m_ann);
327 }
328 m_ann = fann_create_from_file(custom_weightfile.c_str());
329
330 weightfile.getOptions(m_specific_options);
331 }
332
333 std::vector<float> FANNExpert::apply(Dataset& test_data) const
334 {
335
336 std::vector<fann_type> input(test_data.getNumberOfFeatures());
337 std::vector<float> probabilities(test_data.getNumberOfEvents());
338 for (unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
339 test_data.loadEvent(iEvent);
340 for (unsigned int iFeature = 0; iFeature < test_data.getNumberOfFeatures(); ++iFeature) {
341 input[iFeature] = test_data.m_input[iFeature];
342 }
343 if (m_specific_options.m_scale_features) fann_scale_input(m_ann, input.data());
344 probabilities[iEvent] = fann_run(m_ann, input.data())[0];
345 }
346 if (m_specific_options.m_scale_target) fann_descale_output(m_ann, probabilities.data());
347 return probabilities;
348 }
349
350 }
352}
Abstract base class of all Datasets given to the MVA interface The current event can always be access...
Definition: Dataset.h:33
virtual ~FANNExpert()
Destructor of FANN Expert.
Definition: FANN.cc:312
struct fann * m_ann
Pointer to FANN expert.
Definition: FANN.h:132
virtual std::vector< float > apply(Dataset &test_data) const override
Apply this expert onto a dataset.
Definition: FANN.cc:333
virtual void load(Weightfile &weightfile) override
Load the expert from a Weightfile.
Definition: FANN.cc:319
FANNOptions m_specific_options
Method specific options.
Definition: FANN.h:131
Options for the FANN MVA method.
Definition: FANN.h:29
double m_validation_fraction
Fraction of training sample used for validation in order to avoid overtraining.
Definition: FANN.h:69
bool m_scale_features
Scale features before training.
Definition: FANN.h:77
bool m_verbose_mode
Sets to report training status or not.
Definition: FANN.h:61
unsigned int m_random_seeds
Number of times the training is repeated with a new weight random seed.
Definition: FANN.h:70
std::string m_error_function
Loss function.
Definition: FANN.h:66
unsigned int m_number_of_threads
Number of threads for parallel training.
Definition: FANN.h:74
unsigned int m_test_rate
Error on validation is compared with the one before.
Definition: FANN.h:72
std::string m_hidden_activiation_function
Activation function in hidden layer.
Definition: FANN.h:64
bool m_scale_target
Scale target before training.
Definition: FANN.h:78
std::vector< unsigned int > getHiddenLayerNeurons(unsigned int nf) const
Returns the internal vector parameter with the number of hidden neurons per layer.
Definition: FANNOptions.cc:93
std::string m_training_method
Training method for back propagation.
Definition: FANN.h:67
std::string m_output_activiation_function
Activation function in output layer.
Definition: FANN.h:65
unsigned int m_max_epochs
Maximum number of epochs.
Definition: FANN.h:60
FANNTeacher(const GeneralOptions &general_options, const FANNOptions &specific_options)
Constructs a new teacher using the GeneralOptions and specific options of this training.
Definition: FANN.cc:26
FANNOptions m_specific_options
Method specific options.
Definition: FANN.h:102
virtual Weightfile train(Dataset &training_data) const override
Train a mva method using the given dataset returning a Weightfile.
Definition: FANN.cc:30
General options which are shared by all MVA trainings.
Definition: Options.h:62
Abstract base class of all Teachers Each MVA library has its own implementation of this class,...
Definition: Teacher.h:29
GeneralOptions m_general_options
GeneralOptions containing all shared options.
Definition: Teacher.h:49
The Weightfile class serializes all information about a training into an xml tree.
Definition: Weightfile.h:38
void addFile(const std::string &identifier, const std::string &custom_weightfile)
Add a file (mostly a weightfile from a MVA library) to our Weightfile.
Definition: Weightfile.cc:115
void addOptions(const Options &options)
Add an Option object to the xml tree.
Definition: Weightfile.cc:62
void getOptions(Options &options) const
Fills an Option object from the xml tree.
Definition: Weightfile.cc:67
void addSignalFraction(float signal_fraction)
Saves the signal fraction in the xml tree.
Definition: Weightfile.cc:95
void addVector(const std::string &identifier, const std::vector< T > &vector)
Add a vector to the xml tree.
Definition: Weightfile.h:125
std::string generateFileName(const std::string &suffix="")
Returns a temporary filename with the given suffix.
Definition: Weightfile.cc:105
void getFile(const std::string &identifier, const std::string &custom_weightfile)
Creates a file from our weightfile (mostly this will be a weightfile of an MVA library)
Definition: Weightfile.cc:138
double sqrt(double a)
sqrt for double
Definition: beamHelpers.h:28
Abstract base class for different kinds of events.