Belle II Software development
NNWaveFitter.h
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#pragma once
10#ifndef _SVD_RECONSTRUCTION_NNwAVEFITTER_H
11#define _SVD_RECONSTRUCTION_NNWAVEFITTER_H
12
13#include <cmath>
14#include <string>
15#include <vector>
16#include <map>
17#include <memory>
18#include <functional>
19#include <svd/reconstruction/NNWaveFitTool.h>
20#include <svd/simulation/SVDSimulationTools.h>
21#include <Eigen/Dense>
22
23namespace Belle2 {
28 namespace SVD {
29
30// ==============================================================================
31// The NNWaveFitter class - neural network fitter of APV25 waveforms
32// ------------------------------------------------------------------------------
62
63 public:
64
65
70 typedef std::pair<double, double> nnBoundsType;
71
76 NNWaveFitter(std::string xmlData = "");
77
81 void setNetwrok(const std::string& xmlData);
82
92 std::shared_ptr<nnFitterBinData> getFit(const apvSamples& samples, double tau);
93
98 const NNWaveFitTool& getFitTool() const { return *m_fitTool; }
99
104 const nnFitterBinData& getBinCenters() const { return m_binCenters; }
105
109 const nnFitterBins& getBins() const { return m_bins; }
110
117
122
127
131 bool isValid() const { return m_isValid; }
132
138 bool checkCoefficients(const std::string& dumpname, double tol = 1.0e-10);
139
140 private:
141
143 typedef std::function<double(double)> activationType;
146 [](double x) -> double { return std::max(double(0.0), x); };
149 [](double x) -> double { double e = std::exp(x); return e / (1.0 + e); };
150
155 Eigen::VectorXd softmax(const Eigen::VectorXd& input)
156 {
157 auto output = input.array().unaryExpr(
158 [](double x)->double { return ((x > 0.0) ? exp(x) : 0.0); }
159 );
160 double norm = output.sum();
161 return output / norm;
162 }
163
168 int readNetworkData(const std::string& xmlFileName);
169
173 typedef std::map< size_t, std::pair< Eigen::MatrixXd, Eigen::VectorXd > >
175
177 typedef std::vector<Eigen::VectorXd> layerStatesType;
178
179 // Data members
181 std::size_t m_nLayers;
185 std::vector<std::size_t> m_layerSizes;
193 std::shared_ptr<NNWaveFitTool> m_fitTool;
195 }; // class NNWaveFitter
196
197 } // namespace SVD
199} // namespace Belle2
200
201#endif
The class holds arrays of bins and bin centers, and a wave generator object containing information on...
Definition: NNWaveFitTool.h:91
The class uses a neural network to find a probability distribution of arrival times for a sextet of A...
Definition: NNWaveFitter.h:61
nnBoundsType getTimeShiftBounds() const
Get time shift bounds.
Definition: NNWaveFitter.h:126
std::size_t m_nLayers
number of NN layers, read from xml
Definition: NNWaveFitter.h:181
std::vector< Eigen::VectorXd > layerStatesType
Storage for internal states.
Definition: NNWaveFitter.h:177
void setNetwrok(const std::string &xmlData)
Set proper network definition file.
TauEncoder m_tauCoder
Tau encoder class instance to scale tau values.
Definition: NNWaveFitter.h:191
layerStatesType m_layerStates
vectors of layer states
Definition: NNWaveFitter.h:184
const NNWaveFitTool & getFitTool() const
Get a handle to a NNWaveFit object.
Definition: NNWaveFitter.h:98
int readNetworkData(const std::string &xmlFileName)
The method that actually reads the xml file.
Definition: NNWaveFitter.cc:36
nnBoundsType m_waveWidthBounds
Waveform width range of the network.
Definition: NNWaveFitter.h:189
activationType relu
Rectifier activation.
Definition: NNWaveFitter.h:145
nnFitterBins m_bins
NN time bin boundaries.
Definition: NNWaveFitter.h:183
std::pair< double, double > nnBoundsType
Bounds type used to hold network parameter bounds used in training the network.
Definition: NNWaveFitter.h:70
bool isValid() const
Is this fitter working? Return false if the fitter was not initialized properly.
Definition: NNWaveFitter.h:131
activationType m_activation
Network activation function.
Definition: NNWaveFitter.h:187
std::function< double(double)> activationType
Activation functions type.
Definition: NNWaveFitter.h:143
nnBoundsType getWaveWidthBounds() const
Get width bounds Width bounds are minimum and maximum width used in training the network.
Definition: NNWaveFitter.h:116
std::vector< std::size_t > m_layerSizes
NN layer sizes.
Definition: NNWaveFitter.h:185
activationType sigmoid
Sigmoid activation.
Definition: NNWaveFitter.h:148
std::map< size_t, std::pair< Eigen::MatrixXd, Eigen::VectorXd > > networkWeightsType
We use map to store network layers since we don't know if we'll be reading them in correct order.
Definition: NNWaveFitter.h:174
networkWeightsType m_networkCoefs
NN weights and intercepts.
Definition: NNWaveFitter.h:186
bool checkCoefficients(const std::string &dumpname, double tol=1.0e-10)
Check NN data against a dump from Python.
const nnFitterBins & getBins() const
Get bins of netwrok output.
Definition: NNWaveFitter.h:109
std::shared_ptr< NNWaveFitTool > m_fitTool
FitterTool object allowing calculations on network fits.
Definition: NNWaveFitter.h:193
nnBoundsType m_amplitudeBounds
Amplitude range of the network.
Definition: NNWaveFitter.h:188
const nnFitterBinData & getBinCenters() const
Get bin times of the network output.
Definition: NNWaveFitter.h:104
nnFitterBinData m_binCenters
centers of NN time bins
Definition: NNWaveFitter.h:182
bool m_isValid
true if fitter was properly initialized
Definition: NNWaveFitter.h:180
Eigen::VectorXd softmax(const Eigen::VectorXd &input)
Softmax function, normalization for the network's output layer.
Definition: NNWaveFitter.h:155
nnBoundsType getAmplitudeBounds() const
Get amplitude bounds.
Definition: NNWaveFitter.h:121
WaveformShape m_wave
Wave function used in training the network.
Definition: NNWaveFitter.h:192
nnBoundsType m_timeShiftBounds
Time shift range of the network.
Definition: NNWaveFitter.h:190
std::shared_ptr< nnFitterBinData > getFit(const apvSamples &samples, double tau)
Fitting method Send data and get rseult structure.
Encoder/decoder for neural network tau values.
std::array< apvSampleBaseType, nAPVSamples > apvSamples
vector od apvSample BaseType objects
std::vector< double > nnFitterBins
Vector of bin edges, nnFitterBinData.size() + 1.
Definition: NNWaveFitTool.h:33
std::function< double(double)> WaveformShape
WaveformShape type.
std::vector< double > nnFitterBinData
Vector of values defined for bins, such as bin times or bin probabilities.
Definition: NNWaveFitTool.h:30
Abstract base class for different kinds of events.