Belle II Software development
CDCTriggerHoughMLP.h
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#pragma once
10
11#include <vector>
12#include <string>
13#include <array>
14#include <cstdint>
15
16#include <TObject.h>
17#include <TFile.h>
18#include "framework/logging/Logger.h"
19
20namespace Belle2 {
25
26 // Full set of parameters to describe a network
27 struct NeuroParametersHough : public TObject {
28 // Number of input nodes
29 size_t nInput;
30 // Number of output nodes
31 size_t nOutput;
32 // Number of nodes for each hidden layer
33 std::vector<size_t> nHidden;
34 // Output scale/range of each output node
35 std::vector<float> outputScale;
36
37 ClassDef(NeuroParametersHough, 1)
38 };
39
40 // Class to represent a network for the 3DNeuro Trigger (3DFinder input)
41 class CDCTriggerHoughMLP : public TObject {
42
43 public:
44 // Default constructor
45 CDCTriggerHoughMLP() = default;
46 // Constructor to set the network architecture
47 CDCTriggerHoughMLP(const NeuroParametersHough& neuroParametersHough);
48 // Default destructor
49 ~CDCTriggerHoughMLP() = default;
50
51 // Setter methods for the network
52 void setFloatWeights(const std::vector<float>& weights) { m_floatWeights = weights; }
53
54 // Getter methods for the network
55 size_t getNumberOfLayers() const { return m_nodes.size(); }
56 size_t getNumberOfNodes(const size_t layerIdx) const { return m_nodes[layerIdx]; }
57 size_t getNumberOfWeights() const;
58 const std::vector<float>& getFloatWeights() const { return m_floatWeights; }
59 const NeuroParametersHough& getNeuroParameters() const { return m_neuroParametersHough; }
60
61 // Save the current class instance to a root file
62 void saveMLPToFile(const std::string& fileName, const std::string& objName) const;
63 // Load MLP from file
64 template <typename T>
65 static T loadMLPFromFile(const std::string& fileName, const std::string& key)
66 {
67 TFile datafile(fileName.c_str(), "READ");
68 if (!datafile.IsOpen()) {
69 B2ERROR("Could not open file " << fileName);
70 throw std::runtime_error("Could not open file " + fileName);
71 }
72 T* network = dynamic_cast<T*>(datafile.Get(key.c_str()));
73 if (!network) {
74 throw std::runtime_error("File " + fileName + " does not contain key " + key + " of requested type");
75 }
76 T result = *network;
77 datafile.Close();
78 return result;
79 }
80
81 // Load config parameters from a plain text config file
82 static NeuroParametersHough loadConfigFromFile(const std::string& fileName);
83
84 private:
85 // Read a one dimensional array from a plain text config file
86 template <typename T>
87 static std::vector<T> readArray(const std::string& rawString);
88
89 // Network configuration
90 NeuroParametersHough m_neuroParametersHough;
91 // Number of nodes in each layer, not including bias nodes.
92 std::vector<size_t> m_nodes;
93 // Weights of the network.
94 std::vector<float> m_floatWeights;
95
96 ClassDef(CDCTriggerHoughMLP, 1);
97 };
98
99 class CDCTrigger3DHMLP : public CDCTriggerHoughMLP {
100 public:
101 using CDCTriggerHoughMLP::CDCTriggerHoughMLP;
102 static CDCTrigger3DHMLP loadFromFile(const std::string& fileName, const std::string& key)
103 {
104 return loadMLPFromFile<CDCTrigger3DHMLP>(fileName, key);
105 }
106 ClassDefOverride(CDCTrigger3DHMLP, 1);
107 };
108
109 class CDCTriggerDVTMLP : public CDCTriggerHoughMLP {
110 public:
111 using CDCTriggerHoughMLP::CDCTriggerHoughMLP;
112 static CDCTriggerDVTMLP loadFromFile(const std::string& fileName, const std::string& key)
113 {
114 return loadMLPFromFile<CDCTriggerDVTMLP>(fileName, key);
115 }
116 ClassDefOverride(CDCTriggerDVTMLP, 1);
117 };
118
119}
Abstract base class for different kinds of events.