11 #include <mva/methods/FANN.h>
13 #include <framework/logging/Logger.h>
16 #include <parallel_fann.hpp>
28 FANNTeacher::FANNTeacher(
const GeneralOptions& general_options,
const FANNOptions& specific_options) : Teacher(general_options),
29 m_specific_options(specific_options) { }
35 unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
36 unsigned int numberOfEvents = training_data.getNumberOfEvents();
39 unsigned int number_of_layers = 2;
40 for (
unsigned int hiddenLayer : hiddenLayers) {
41 if (hiddenLayer > 0) {
46 auto layers = std::unique_ptr<unsigned int[]>(
new unsigned int[number_of_layers]);
47 layers[0] = numberOfFeatures;
48 for (
unsigned int i = 0; i < hiddenLayers.size(); ++i) {
49 if (hiddenLayers[i] > 0) {
50 layers[i + 1] = hiddenLayers[i];
53 layers[number_of_layers - 1] = 1;
55 struct fann* ann = fann_create_standard_array(number_of_layers, layers.get());
57 std::map<std::string, enum fann_activationfunc_enum> activationFunctions;
59 for (
auto& name : FANN_ACTIVATIONFUNC_NAMES) {
60 activationFunctions[name] = fann_activationfunc_enum(i);
65 typedef float (*FnPtr)(
struct fann * ann,
struct fann_train_data * data,
const unsigned int threadnumb);
66 std::map<std::string, FnPtr> trainingMethods;
67 trainingMethods[
"FANN_TRAIN_RPROP"] = parallel_fann::train_epoch_irpropm_parallel;
68 trainingMethods[
"FANN_TRAIN_BATCH"] = parallel_fann::train_epoch_batch_parallel;
69 trainingMethods[
"FANN_TRAIN_QUICKPROP"] = parallel_fann::train_epoch_quickprop_parallel;
70 trainingMethods[
"FANN_TRAIN_SARPROP"] = parallel_fann::train_epoch_sarprop_parallel;
71 trainingMethods[
"FANN_TRAIN_INCREMENTAL"] =
nullptr;
73 std::map<std::string, enum fann_train_enum> trainingMethods;
75 for (
auto& name : FANN_TRAIN_NAMES) {
76 trainingMethods[name] = fann_train_enum(i);
81 std::map<std::string, enum fann_errorfunc_enum> errorFunctions;
83 for (
auto& name : FANN_ERRORFUNC_NAMES) {
84 errorFunctions[name] = fann_errorfunc_enum(i);
110 throw std::runtime_error(
"m_max_epochs should be larger than 0. The given value is " + std::to_string(
116 throw std::runtime_error(
"m_random_seeds should be larger than 0. The given value is " + std::to_string(
122 throw std::runtime_error(
"m_test_rate should be larger than 0. The given value is " + std::to_string(
128 throw std::runtime_error(
"m_number_of_threads should be larger than 0. The given value is " +
139 unsigned int nTestingEvents = int(nTestingAndValidationEvents * 0.5);
140 unsigned int nValidationEvents = int(nTestingAndValidationEvents * 0.5);
141 unsigned int nTrainingEvents = numberOfEvents - nValidationEvents - nTestingEvents;
143 if (nTestingAndValidationEvents < 1) {
144 B2ERROR(
"m_validation_fraction should be a number between 0 and 1 (0 < x < 1). The given value is " <<
146 ". The total number of events is " << numberOfEvents <<
". numberOfEvents * m_validation_fraction has to be larger than one");
147 throw std::runtime_error(
"m_validation_fraction should be a number between 0 and 1 (0 < x < 1). numberOfEvents * m_validation_fraction has to be larger than one");
150 if (nTrainingEvents < 1) {
151 B2ERROR(
"m_validation_fraction should be a number between 0 and 1 (0 < x < 1). The given value is " <<
153 ". The total number of events is " << numberOfEvents <<
". numberOfEvents * (1 - m_validation_fraction) has to be larger than one");
154 throw std::runtime_error(
"m_validation_fraction should be a number between 0 and 1 (0 < x < 1). numberOfEvents * (1 - m_validation_fraction) has to be larger than one");
158 struct fann_train_data* train_data =
159 fann_create_train(nTrainingEvents, numberOfFeatures, 1);
160 for (
unsigned iEvent = 0; iEvent < nTrainingEvents; ++iEvent) {
161 training_data.loadEvent(iEvent);
162 for (
unsigned iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
163 train_data->input[iEvent][iFeature] = training_data.m_input[iFeature];
165 train_data->output[iEvent][0] = training_data.m_target;
168 struct fann_train_data* valid_data =
169 fann_create_train(nValidationEvents, numberOfFeatures, 1);
170 for (
unsigned iEvent = nTrainingEvents; iEvent < nTrainingEvents + nValidationEvents; ++iEvent) {
171 training_data.loadEvent(iEvent);
172 for (
unsigned iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
173 valid_data->input[iEvent - nTrainingEvents][iFeature] = training_data.m_input[iFeature];
175 valid_data->output[iEvent - nTrainingEvents][0] = training_data.m_target;
179 struct fann_train_data* test_data =
180 fann_create_train(nTestingEvents, numberOfFeatures, 1);
181 for (
unsigned iEvent = nTrainingEvents + nValidationEvents; iEvent < numberOfEvents; ++iEvent) {
182 training_data.loadEvent(iEvent);
183 for (
unsigned iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
184 test_data->input[iEvent - nTrainingEvents - nValidationEvents][iFeature] = training_data.m_input[iFeature];
186 test_data->output[iEvent - nTrainingEvents - nValidationEvents][0] = training_data.m_target;
189 struct fann_train_data* data = fann_create_train(numberOfEvents, numberOfFeatures, 1);
190 for (
unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
191 training_data.loadEvent(iEvent);
192 for (
unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
193 data->input[iEvent][iFeature] = training_data.m_input[iFeature];
195 data->output[iEvent][0] = training_data.m_target;
199 fann_set_input_scaling_params(ann, data, -1.0, 1.0);
203 fann_set_output_scaling_params(ann, data, -1.0, 1.0);
207 fann_scale_train(ann, data);
208 fann_scale_train(ann, train_data);
209 fann_scale_train(ann, valid_data);
210 fann_scale_train(ann, test_data);
213 struct fann* bestANN =
nullptr;
214 double bestRMS = 999.;
215 std::vector<double> bestTrainLog = {};
216 std::vector<double> bestValidLog = {};
220 double bestValid = 999.;
221 std::vector<double> trainLog = {};
222 std::vector<double> validLog = {};
226 struct fann* iRunANN =
nullptr;
227 fann_randomize_weights(ann, -0.1, 0.1);
233 }
else {mse = parallel_fann::train_epoch_incremental_mod(ann, train_data);}
236 mse = fann_train_epoch(ann, train_data);
238 trainLog[iEpoch - 1] = mse;
245 double valid_mse = fann_test_data(ann, valid_data);
248 validLog[iEpoch - 1] = valid_mse;
250 if (valid_mse < bestValid) {
251 bestValid = valid_mse;
252 iRunANN = fann_copy(ann);
257 B2INFO(
"Training stopped in iEpoch " << iEpoch);
258 B2INFO(
"Train error: " << mse <<
", valid error: " << valid_mse <<
259 ", best valid: " << bestValid);
265 if (iEpoch == 1 || (iEpoch < 100 && iEpoch % 10 == 0) || iEpoch % 100 == 0) {
267 ", valid error = " << valid_mse <<
", best valid = " << bestValid);
276 double test_mse = fann_test_data(iRunANN, test_data);
279 double RMS = sqrt(test_mse);
283 bestANN = fann_copy(iRunANN);
284 fann_destroy(iRunANN);
285 bestTrainLog.assign(trainLog.begin(), trainLog.begin() + breakEpoch);
286 bestValidLog.assign(validLog.begin(), validLog.begin() + breakEpoch);
291 fann_destroy_train(data);
292 fann_destroy_train(train_data);
293 fann_destroy_train(valid_data);
294 fann_destroy_train(test_data);
297 Weightfile weightfile;
298 std::string custom_weightfile = weightfile.generateFileName();
300 fann_save(bestANN, custom_weightfile.c_str());
301 fann_destroy(bestANN);
305 weightfile.addFile(
"FANN_Weightfile", custom_weightfile);
306 weightfile.addVector(
"FANN_bestTrainLog", bestTrainLog);
307 weightfile.addVector(
"FANN_bestValidLog", bestValidLog);
308 weightfile.addSignalFraction(training_data.getSignalFraction());
324 std::string custom_weightfile = weightfile.generateFileName();
325 weightfile.getFile(
"FANN_Weightfile", custom_weightfile);
330 m_ann = fann_create_from_file(custom_weightfile.c_str());
338 std::vector<fann_type> input(test_data.getNumberOfFeatures());
339 std::vector<float> probabilities(test_data.getNumberOfEvents());
340 for (
unsigned int iEvent = 0; iEvent < test_data.getNumberOfEvents(); ++iEvent) {
341 test_data.loadEvent(iEvent);
342 for (
unsigned int iFeature = 0; iFeature < test_data.getNumberOfFeatures(); ++iFeature) {
343 input[iFeature] = test_data.m_input[iFeature];
346 probabilities[iEvent] = fann_run(
m_ann, input.data())[0];
349 return probabilities;