Belle II Software development
CDCTriggerHoughMLP.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include "trg/cdc/dataobjects/CDCTriggerHoughMLP.h"
10
11#include <fstream>
12#include <string>
13#include <vector>
14#include <algorithm>
15
16#include "framework/logging/Logger.h"
17
18using namespace Belle2;
19
20CDCTriggerHoughMLP::CDCTriggerHoughMLP(const NeuroParametersHough& neuroParametersHough):
21 m_neuroParametersHough(neuroParametersHough)
22{
23 std::vector<size_t> nHidden = m_neuroParametersHough.nHidden;
24 m_nodes = {static_cast<size_t>(m_neuroParametersHough.nInput)};
25 for (size_t hiddenLayerIdx = 0; hiddenLayerIdx < nHidden.size(); ++hiddenLayerIdx) {
26 m_nodes.push_back(nHidden[hiddenLayerIdx]);
27 }
28 m_nodes.push_back(m_neuroParametersHough.nOutput);
29 m_floatWeights.assign(getNumberOfWeights(), 0.0f);
30}
31
32// Get the number of weights
33size_t CDCTriggerHoughMLP::getNumberOfWeights() const
34{
35 size_t nWeights = 0;
36 size_t nLayers = getNumberOfLayers();
37 for (size_t i = 1; i < nLayers; ++i) {
38 // +1 for bias node
39 nWeights += (m_nodes[i - 1] + 1) * m_nodes[i];
40 }
41 return nWeights;
42}
43
44// Save the current class instance to a root file
45void CDCTriggerHoughMLP::saveMLPToFile(const std::string& fileName, const std::string& objName) const
46{
47 B2INFO(std::string("Saving network to file ") + fileName + ", object " + objName);
48 TFile datafile(fileName.c_str(), "UPDATE");
49 this->Write(objName.c_str(), TObject::kSingleKey | TObject::kOverwrite);
50 datafile.Close();
51}
52
53// STATIC: Load config parameters from a plain text config file
54NeuroParametersHough CDCTriggerHoughMLP::loadConfigFromFile(const std::string& fileName)
55{
56 NeuroParametersHough neuroParametersHough;
57 std::ifstream configFile(fileName);
58 if (!configFile.is_open()) {
59 B2ERROR("Could not open configuration file: " + fileName);
60 exit(EXIT_FAILURE);
61 }
62 std::string completeLine;
63 while (std::getline(configFile, completeLine)) {
64 std::size_t hashtag = completeLine.find('#'); // Remove comments
65 std::string line = completeLine.substr(0, hashtag);
66 std::string configParameter;
67 std::string parameterValue;
68 if (line.length() < 3 || line.find('=') == std::string::npos) {
69 continue;
70 }
71 line.erase(std::remove(line.begin(), line.end(), ' '), line.end()); // Remove whitspaces
72 size_t equalPosition = line.find('=');
73 configParameter = line.substr(0, equalPosition);
74 parameterValue = line.substr((equalPosition + 1), line.length() - equalPosition - 1);
75 if (configParameter == "nInput") {
76 neuroParametersHough.nInput = std::stoull(parameterValue);
77 } else if (configParameter == "nOutput") {
78 neuroParametersHough.nOutput = std::stoull(parameterValue);
79 } else if (configParameter == "nHidden") {
80 neuroParametersHough.nHidden = readArray<size_t>(parameterValue);
81 } else if (configParameter == "outputScale") {
82 neuroParametersHough.outputScale = readArray<float>(parameterValue);
83 } else {
84 B2WARNING("Unknown config parameter: " + configParameter);
85 }
86 }
87 return neuroParametersHough;
88}
89
90// PRIVATE STATIC: Read a one dimensional array from a plain text config file
91template <typename T>
92std::vector<T> CDCTriggerHoughMLP::readArray(const std::string& rawString)
93{
94 std::vector<T> configVector;
95 std::string strippedString = rawString.substr(1, rawString.size() - 2); // Strip brackets
96 std::stringstream strippedStream(strippedString);
97 std::string entry;
98 while (std::getline(strippedStream, entry, ',')) {
99 configVector.push_back(static_cast<T>(std::stod(entry)));
100 }
101 return configVector;
102}
Abstract base class for different kinds of events.