Belle II Software prerelease-11-00-00a
GRLMLP.h
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#ifndef GRLMLP_H
10#define GRLMLP_H
11
12#include <TObject.h>
13#include <framework/logging/Logger.h>
14
15namespace Belle2 {
21 class GRLMLP : public TObject {
22
23 // weights etc. are set only by the trainer
24 friend class GRLNeuroTrainerModule;
25
26 public:
28 GRLMLP();
29
31 GRLMLP(std::vector<unsigned short>& nodes);
32
34 ~GRLMLP() { }
35
37 bool is_trained() const { return m_trained; }
39 unsigned get_number_of_layers() const { return m_n_nodes.size(); }
41 unsigned get_number_of_nodes_layer(unsigned i_layer) const { return m_n_nodes[i_layer]; }
43 unsigned get_number_of_weights() const { return m_weights.size(); }
45 unsigned n_weights_cal() const;
47 unsigned n_bias_cal() const;
49 std::vector<float> get_weights() const { return m_weights; }
51 std::vector<float> get_bias() const { return m_bias; }
53 void set_weights(std::vector<float>& weights) { m_weights = weights; }
55 void set_bias(std::vector<float>& bias) { m_bias = bias; }
56
58 std::vector<float> get_nn_thres() const { return m_nn_thres; }
60 void set_nn_thres(const std::vector<float>& nn_thres) { m_nn_thres = nn_thres; }
61
63 std::vector<int> get_total_bit_bias() const { return m_total_bit_bias; }
64 std::vector<int> get_int_bit_bias() const { return m_int_bit_bias; }
65 std::vector<bool> get_is_signed_bias() const { return m_is_signed_bias; }
66 std::vector<int> get_rounding_bias() const { return m_rounding_bias; }
67 std::vector<int> get_saturation_bias() const { return m_saturation_bias; }
68 std::vector<int> get_total_bit_accum() const { return m_total_bit_accum; }
69 std::vector<int> get_int_bit_accum() const { return m_int_bit_accum; }
70 std::vector<bool> get_is_signed_accum() const { return m_is_signed_accum; }
71 std::vector<int> get_rounding_accum() const { return m_rounding_accum; }
72 std::vector<int> get_saturation_accum() const { return m_saturation_accum; }
73 std::vector<int> get_total_bit_weight() const { return m_total_bit_weight; }
74 std::vector<int> get_int_bit_weight() const { return m_int_bit_weight; }
75 std::vector<bool> get_is_signed_weight() const { return m_is_signed_weight; }
76 std::vector<int> get_rounding_weight() const { return m_rounding_weight; }
77 std::vector<int> get_saturation_weight() const { return m_saturation_weight; }
78 std::vector<int> get_total_bit_relu() const { return m_total_bit_relu; }
79 std::vector<int> get_int_bit_relu() const { return m_int_bit_relu; }
80 std::vector<bool> get_is_signed_relu() const { return m_is_signed_relu; }
81 std::vector<int> get_rounding_relu() const { return m_rounding_relu; }
82 std::vector<int> get_saturation_relu() const { return m_saturation_relu; }
83 std::vector<int> get_total_bit() const { return m_total_bit; }
84 std::vector<int> get_int_bit() const { return m_int_bit; }
85 std::vector<bool> get_is_signed() const { return m_is_signed; }
86 std::vector<int> get_rounding() const { return m_rounding; }
87 std::vector<int> get_saturation() const { return m_saturation; }
88 std::vector<std::vector<int>> get_W_input() const { return m_W_input; }
89 std::vector<std::vector<int>> get_I_input() const { return m_I_input; }
90
92 void set_total_bit_bias(const std::vector<int>& i) { m_total_bit_bias = i; }
93 void set_int_bit_bias(const std::vector<int>& i) { m_int_bit_bias = i; }
94 void set_is_signed_bias(const std::vector<bool>& i) { m_is_signed_bias = i; }
95 void set_rounding_bias(const std::vector<int>& i) { m_rounding_bias = i; }
96 void set_saturation_bias(const std::vector<int>& i) { m_saturation_bias = i; }
97 void set_total_bit_accum(const std::vector<int>& i) { m_total_bit_accum = i; }
98 void set_int_bit_accum(const std::vector<int>& i) { m_int_bit_accum = i; }
99 void set_is_signed_accum(const std::vector<bool>& i) { m_is_signed_accum = i; }
100 void set_rounding_accum(const std::vector<int>& i) { m_rounding_accum = i; }
101 void set_saturation_accum(const std::vector<int>& i) { m_saturation_accum = i; }
102 void set_total_bit_weight(const std::vector<int>& i) { m_total_bit_weight = i; }
103 void set_int_bit_weight(const std::vector<int>& i) { m_int_bit_weight = i; }
104 void set_is_signed_weight(const std::vector<bool>& i) { m_is_signed_weight = i; }
105 void set_rounding_weight(const std::vector<int>& i) { m_rounding_weight = i; }
106 void set_saturation_weight(const std::vector<int>& i) { m_saturation_weight = i; }
107 void set_total_bit_relu(const std::vector<int>& i) { m_total_bit_relu = i; }
108 void set_int_bit_relu(const std::vector<int>& i) { m_int_bit_relu = i; }
109 void set_is_signed_relu(const std::vector<bool>& i) { m_is_signed_relu = i; }
110 void set_rounding_relu(const std::vector<int>& i) { m_rounding_relu = i; }
111 void set_saturation_relu(const std::vector<int>& i) { m_saturation_relu = i; }
112 void set_total_bit(const std::vector<int>& i) { m_total_bit = i; }
113 void set_int_bit(const std::vector<int>& i) { m_int_bit = i; }
114 void set_is_signed(const std::vector<bool>& i) { m_is_signed = i; }
115 void set_rounding(const std::vector<int>& i) { m_rounding = i; }
116 void set_saturation(const std::vector<int>& i) { m_saturation = i; }
117 void set_W_input(const std::vector<std::vector<int>>& i) { m_W_input = i; }
118 void set_I_input(const std::vector<std::vector<int>>& i) { m_I_input = i; }
119
121 void Trained(bool trained) { m_trained = trained; }
122
123 private:
125 std::vector<unsigned short> m_n_nodes;
127 std::vector<float> m_nn_thres;
129 std::vector<float> m_weights;
131 std::vector<float> m_bias;
133 std::vector<int> m_total_bit_bias;
134 std::vector<int> m_int_bit_bias;
135 std::vector<bool> m_is_signed_bias;
136 std::vector<int> m_rounding_bias;
137 std::vector<int> m_saturation_bias;
138 std::vector<int> m_total_bit_accum;
139 std::vector<int> m_int_bit_accum;
140 std::vector<bool> m_is_signed_accum;
141 std::vector<int> m_rounding_accum;
142 std::vector<int> m_saturation_accum;
143 std::vector<int> m_total_bit_weight;
144 std::vector<int> m_int_bit_weight;
145 std::vector<bool> m_is_signed_weight;
146 std::vector<int> m_rounding_weight;
147 std::vector<int> m_saturation_weight;
148 std::vector<int> m_total_bit_relu;
149 std::vector<int> m_int_bit_relu;
150 std::vector<bool> m_is_signed_relu;
151 std::vector<int> m_rounding_relu;
152 std::vector<int> m_saturation_relu;
153 std::vector<int> m_total_bit;
154 std::vector<int> m_int_bit;
155 std::vector<bool> m_is_signed;
156 std::vector<int> m_rounding;
157 std::vector<int> m_saturation;
158 std::vector<std::vector<int>> m_W_input;
159 std::vector<std::vector<int>> m_I_input;
160
164
167 };
168
169}
170#endif
unsigned get_number_of_layers() const
get number of layers
Definition GRLMLP.h:39
std::vector< float > m_weights
Weights of the network.
Definition GRLMLP.h:129
GRLMLP()
default constructor.
Definition GRLMLP.cc:14
ClassDef(GRLMLP, 6)
Needed to make the ROOT object storable.
unsigned get_number_of_weights() const
get number of weights from length of weights vector
Definition GRLMLP.h:43
std::vector< int > m_total_bit_bias
bit width etc.
Definition GRLMLP.h:133
std::vector< unsigned short > m_n_nodes
Number of nodes in each layer, not including bias nodes.
Definition GRLMLP.h:125
void set_total_bit_bias(const std::vector< int > &i)
set bit width etc.
Definition GRLMLP.h:92
void set_weights(std::vector< float > &weights)
set weights vector
Definition GRLMLP.h:53
void Trained(bool trained)
check if weights are default values or set by some trainer
Definition GRLMLP.h:121
std::vector< float > get_bias() const
get bias vector
Definition GRLMLP.h:51
unsigned n_weights_cal() const
calculate number of weights from number of nodes
Definition GRLMLP.cc:30
std::vector< float > get_nn_thres() const
get output threshold vector
Definition GRLMLP.h:58
unsigned n_bias_cal() const
calculate number of weights from number of nodes
Definition GRLMLP.cc:43
std::vector< float > get_weights() const
get weights vector
Definition GRLMLP.h:49
std::vector< float > m_bias
bias of the network.
Definition GRLMLP.h:131
bool is_trained() const
check if weights are default values or set by some trainer
Definition GRLMLP.h:37
std::vector< float > m_nn_thres
threshold of output
Definition GRLMLP.h:127
unsigned get_number_of_nodes_layer(unsigned i_layer) const
get number of nodes in a layer
Definition GRLMLP.h:41
void set_bias(std::vector< float > &bias)
set bias vector
Definition GRLMLP.h:55
std::vector< int > get_total_bit_bias() const
get bit width etc.
Definition GRLMLP.h:63
~GRLMLP()
destructor, empty because we don't allocate memory anywhere.
Definition GRLMLP.h:34
bool m_trained
Indicator whether the weights are just default values or have been set by some trainer (set to true w...
Definition GRLMLP.h:163
void set_nn_thres(const std::vector< float > &nn_thres)
set output threshold vector
Definition GRLMLP.h:60
Abstract base class for different kinds of events.