33 unsigned int numberOfFeatures = training_data.getNumberOfFeatures();
34 unsigned int numberOfEvents = training_data.getNumberOfEvents();
36 std::vector<unsigned int> hiddenLayers =
m_specific_options.getHiddenLayerNeurons(numberOfFeatures);
37 unsigned int number_of_layers = 2;
38 for (
unsigned int hiddenLayer : hiddenLayers) {
39 if (hiddenLayer > 0) {
44 auto layers = std::unique_ptr<unsigned int[]>(
new unsigned int[number_of_layers]);
45 layers[0] = numberOfFeatures;
46 for (
unsigned int i = 0; i < hiddenLayers.size(); ++i) {
47 if (hiddenLayers[i] > 0) {
48 layers[i + 1] = hiddenLayers[i];
51 layers[number_of_layers - 1] = 1;
53 struct fann* ann = fann_create_standard_array(number_of_layers, layers.get());
55 std::map<std::string, enum fann_activationfunc_enum> activationFunctions;
57 for (
auto& name : FANN_ACTIVATIONFUNC_NAMES) {
58 activationFunctions[name] = fann_activationfunc_enum(i);
63 typedef float (*FnPtr)(
struct fann * ann,
struct fann_train_data * data,
const unsigned int threadnumb);
64 std::map<std::string, FnPtr> trainingMethods;
65 trainingMethods[
"FANN_TRAIN_RPROP"] = parallel_fann::train_epoch_irpropm_parallel;
66 trainingMethods[
"FANN_TRAIN_BATCH"] = parallel_fann::train_epoch_batch_parallel;
67 trainingMethods[
"FANN_TRAIN_QUICKPROP"] = parallel_fann::train_epoch_quickprop_parallel;
68 trainingMethods[
"FANN_TRAIN_SARPROP"] = parallel_fann::train_epoch_sarprop_parallel;
69 trainingMethods[
"FANN_TRAIN_INCREMENTAL"] =
nullptr;
71 std::map<std::string, enum fann_train_enum> trainingMethods;
73 for (
auto& name : FANN_TRAIN_NAMES) {
74 trainingMethods[name] = fann_train_enum(i);
79 std::map<std::string, enum fann_errorfunc_enum> errorFunctions;
81 for (
auto& name : FANN_ERRORFUNC_NAMES) {
82 errorFunctions[name] = fann_errorfunc_enum(i);
86 if (activationFunctions.find(
m_specific_options.m_hidden_activiation_function) == activationFunctions.end()) {
87 B2ERROR(
"Coulnd't find activation function named " <<
m_specific_options.m_hidden_activiation_function);
88 throw std::runtime_error(
"Coulnd't find activation function named " +
m_specific_options.m_hidden_activiation_function);
91 if (activationFunctions.find(
m_specific_options.m_output_activiation_function) == activationFunctions.end()) {
92 B2ERROR(
"Coulnd't find activation function named " <<
m_specific_options.m_output_activiation_function);
93 throw std::runtime_error(
"Coulnd't find activation function named " +
m_specific_options.m_output_activiation_function);
96 if (errorFunctions.find(
m_specific_options.m_error_function) == errorFunctions.end()) {
97 B2ERROR(
"Coulnd't find training method function named " <<
m_specific_options.m_error_function);
98 throw std::runtime_error(
"Coulnd't find training method function named " +
m_specific_options.m_error_function);
101 if (trainingMethods.find(
m_specific_options.m_training_method) == trainingMethods.end()) {
102 B2ERROR(
"Coulnd't find training method function named " <<
m_specific_options.m_training_method);
103 throw std::runtime_error(
"Coulnd't find training method function named " +
m_specific_options.m_training_method);
108 throw std::runtime_error(
"m_max_epochs should be larger than 0. The given value is " + std::to_string(
113 B2ERROR(
"m_random_seeds should be larger than 0 " <<
m_specific_options.m_random_seeds);
114 throw std::runtime_error(
"m_random_seeds should be larger than 0. The given value is " + std::to_string(
120 throw std::runtime_error(
"m_test_rate should be larger than 0. The given value is " + std::to_string(
125 B2ERROR(
"m_number_of_threads should be larger than 0. The given value is " <<
m_specific_options.m_number_of_threads);
126 throw std::runtime_error(
"m_number_of_threads should be larger than 0. The given value is " +
131 fann_set_activation_function_hidden(ann, activationFunctions[
m_specific_options.m_hidden_activiation_function]);
132 fann_set_activation_function_output(ann, activationFunctions[
m_specific_options.m_output_activiation_function]);
133 fann_set_train_error_function(ann, errorFunctions[
m_specific_options.m_error_function]);
136 double nTestingAndValidationEvents = numberOfEvents *
m_specific_options.m_validation_fraction;
137 unsigned int nTestingEvents = int(nTestingAndValidationEvents * 0.5);
138 unsigned int nValidationEvents = int(nTestingAndValidationEvents * 0.5);
139 unsigned int nTrainingEvents = numberOfEvents - nValidationEvents - nTestingEvents;
141 if (nTestingAndValidationEvents < 1) {
142 B2ERROR(
"m_validation_fraction should be a number between 0 and 1 (0 < x < 1). The given value is " <<
144 ". The total number of events is " << numberOfEvents <<
". numberOfEvents * m_validation_fraction has to be larger than one");
145 throw std::runtime_error(
"m_validation_fraction should be a number between 0 and 1 (0 < x < 1). numberOfEvents * m_validation_fraction has to be larger than one");
148 if (nTrainingEvents < 1) {
149 B2ERROR(
"m_validation_fraction should be a number between 0 and 1 (0 < x < 1). The given value is " <<
151 ". The total number of events is " << numberOfEvents <<
". numberOfEvents * (1 - m_validation_fraction) has to be larger than one");
152 throw std::runtime_error(
"m_validation_fraction should be a number between 0 and 1 (0 < x < 1). numberOfEvents * (1 - m_validation_fraction) has to be larger than one");
156 struct fann_train_data* train_data =
157 fann_create_train(nTrainingEvents, numberOfFeatures, 1);
158 for (
unsigned iEvent = 0; iEvent < nTrainingEvents; ++iEvent) {
159 training_data.loadEvent(iEvent);
160 for (
unsigned iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
161 train_data->input[iEvent][iFeature] = training_data.m_input[iFeature];
163 train_data->output[iEvent][0] = training_data.m_target;
166 struct fann_train_data* valid_data =
167 fann_create_train(nValidationEvents, numberOfFeatures, 1);
168 for (
unsigned iEvent = nTrainingEvents; iEvent < nTrainingEvents + nValidationEvents; ++iEvent) {
169 training_data.loadEvent(iEvent);
170 for (
unsigned iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
171 valid_data->input[iEvent - nTrainingEvents][iFeature] = training_data.m_input[iFeature];
173 valid_data->output[iEvent - nTrainingEvents][0] = training_data.m_target;
177 struct fann_train_data* test_data =
178 fann_create_train(nTestingEvents, numberOfFeatures, 1);
179 for (
unsigned iEvent = nTrainingEvents + nValidationEvents; iEvent < numberOfEvents; ++iEvent) {
180 training_data.loadEvent(iEvent);
181 for (
unsigned iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
182 test_data->input[iEvent - nTrainingEvents - nValidationEvents][iFeature] = training_data.m_input[iFeature];
184 test_data->output[iEvent - nTrainingEvents - nValidationEvents][0] = training_data.m_target;
187 struct fann_train_data* data = fann_create_train(numberOfEvents, numberOfFeatures, 1);
188 for (
unsigned int iEvent = 0; iEvent < numberOfEvents; ++iEvent) {
189 training_data.loadEvent(iEvent);
190 for (
unsigned int iFeature = 0; iFeature < numberOfFeatures; ++iFeature) {
191 data->input[iEvent][iFeature] = training_data.m_input[iFeature];
193 data->output[iEvent][0] = training_data.m_target;
197 fann_set_input_scaling_params(ann, data, -1.0, 1.0);
201 fann_set_output_scaling_params(ann, data, -1.0, 1.0);
205 fann_scale_train(ann, data);
206 fann_scale_train(ann, train_data);
207 fann_scale_train(ann, valid_data);
208 fann_scale_train(ann, test_data);
211 struct fann* bestANN =
nullptr;
212 double bestRMS = 999.;
213 std::vector<double> bestTrainLog = {};
214 std::vector<double> bestValidLog = {};
218 double bestValid = 999.;
219 std::vector<double> trainLog = {};
220 std::vector<double> validLog = {};
224 struct fann* iRunANN =
nullptr;
225 fann_randomize_weights(ann, -0.1, 0.1);
231 }
else {mse = parallel_fann::train_epoch_incremental_mod(ann, train_data);}
233 fann_set_training_algorithm(ann, trainingMethods[
m_specific_options.m_training_method]);
234 mse = fann_train_epoch(ann, train_data);
236 trainLog[iEpoch - 1] = mse;
241 double valid_mse = parallel_fann::test_data_parallel(ann, valid_data,
m_specific_options.m_number_of_threads);
243 double valid_mse = fann_test_data(ann, valid_data);
246 validLog[iEpoch - 1] = valid_mse;
248 if (valid_mse < bestValid) {
249 bestValid = valid_mse;
250 iRunANN = fann_copy(ann);
255 B2INFO(
"Training stopped in iEpoch " << iEpoch);
256 B2INFO(
"Train error: " << mse <<
", valid error: " << valid_mse <<
257 ", best valid: " << bestValid);
263 if (iEpoch == 1 || (iEpoch < 100 && iEpoch % 10 == 0) || iEpoch % 100 == 0) {
264 if (
m_specific_options.m_verbose_mode) B2INFO(
"Epoch " << iEpoch <<
": Train error = " << mse <<
265 ", valid error = " << valid_mse <<
", best valid = " << bestValid);
272 double test_mse = parallel_fann::test_data_parallel(iRunANN, test_data,
m_specific_options.m_number_of_threads);
274 double test_mse = fann_test_data(iRunANN, test_data);
277 double RMS =
sqrt(test_mse);
281 bestANN = fann_copy(iRunANN);
282 fann_destroy(iRunANN);
283 bestTrainLog.assign(trainLog.begin(), trainLog.begin() + breakEpoch);
284 bestValidLog.assign(validLog.begin(), validLog.begin() + breakEpoch);
286 if (
m_specific_options.m_verbose_mode) B2INFO(
"RMS on test samples: " << RMS <<
" (best: " << bestRMS <<
")");
289 fann_destroy_train(data);
290 fann_destroy_train(train_data);
291 fann_destroy_train(valid_data);
292 fann_destroy_train(test_data);
298 fann_save(bestANN, custom_weightfile.c_str());
299 fann_destroy(bestANN);
303 weightfile.
addFile(
"FANN_Weightfile", custom_weightfile);
304 weightfile.
addVector(
"FANN_bestTrainLog", bestTrainLog);
305 weightfile.
addVector(
"FANN_bestValidLog", bestValidLog);