Belle II Software  release-08-01-10
NNWaveFitter.h
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #pragma once
10 #ifndef _SVD_RECONSTRUCTION_NNwAVEFITTER_H
11 #define _SVD_RECONSTRUCTION_NNWAVEFITTER_H
12 
13 #include <cmath>
14 #include <string>
15 #include <vector>
16 #include <map>
17 #include <memory>
18 #include <functional>
19 #include <svd/reconstruction/NNWaveFitTool.h>
20 #include <svd/simulation/SVDSimulationTools.h>
21 #include <Eigen/Dense>
22 
23 namespace Belle2 {
28  namespace SVD {
29 
30 // ==============================================================================
31 // The NNWaveFitter class - neural network fitter of APV25 waveforms
32 // ------------------------------------------------------------------------------
61  class NNWaveFitter {
62 
63  public:
64 
65 
70  typedef std::pair<double, double> nnBoundsType;
71 
76  NNWaveFitter(std::string xmlData = "");
77 
81  void setNetwrok(const std::string& xmlData);
82 
92  std::shared_ptr<nnFitterBinData> getFit(const apvSamples& samples, double tau);
93 
98  const NNWaveFitTool& getFitTool() const { return *m_fitTool; }
99 
104  const nnFitterBinData& getBinCenters() const { return m_binCenters; }
105 
109  const nnFitterBins& getBins() const { return m_bins; }
110 
117 
122 
127 
131  bool isValid() const { return m_isValid; }
132 
138  bool checkCoefficients(const std::string& dumpname, double tol = 1.0e-10);
139 
140  private:
141 
143  typedef std::function<double(double)> activationType;
146  [](double x) -> double { return std::max(double(0.0), x); };
149  [](double x) -> double { double e = std::exp(x); return e / (1.0 + e); };
150 
155  Eigen::VectorXd softmax(const Eigen::VectorXd& input)
156  {
157  auto output = input.array().unaryExpr(
158  [](double x)->double { return ((x > 0.0) ? exp(x) : 0.0); }
159  );
160  double norm = output.sum();
161  return output / norm;
162  }
163 
168  int readNetworkData(const std::string& xmlFileName);
169 
173  typedef std::map< size_t, std::pair< Eigen::MatrixXd, Eigen::VectorXd > >
175 
177  typedef std::vector<Eigen::VectorXd> layerStatesType;
178 
179  // Data members
180  bool m_isValid;
181  std::size_t m_nLayers;
185  std::vector<std::size_t> m_layerSizes;
193  std::shared_ptr<NNWaveFitTool> m_fitTool;
195  }; // class NNWaveFitter
196 
197  } // namespace SVD
199 } // namespace Belle2
200 
201 #endif
The class holds arrays of bins and bin centers, and a wave generator object containing information on...
Definition: NNWaveFitTool.h:91
The class uses a neural network to find a probability distribution of arrival times for a sextet of A...
Definition: NNWaveFitter.h:61
nnBoundsType getTimeShiftBounds() const
Get time shift bounds.
Definition: NNWaveFitter.h:126
NNWaveFitter(std::string xmlData="")
Constructor constructs the wavefitter from data in xml file.
std::size_t m_nLayers
number of NN layers, read from xml
Definition: NNWaveFitter.h:181
std::vector< Eigen::VectorXd > layerStatesType
Storage for internal states.
Definition: NNWaveFitter.h:177
void setNetwrok(const std::string &xmlData)
Set proper network definition file.
TauEncoder m_tauCoder
Tau encoder class instance to scale tau values.
Definition: NNWaveFitter.h:191
layerStatesType m_layerStates
vectors of layer states
Definition: NNWaveFitter.h:184
const nnFitterBinData & getBinCenters() const
Get bin times of the network output.
Definition: NNWaveFitter.h:104
int readNetworkData(const std::string &xmlFileName)
The method that actually reads the xml file.
Definition: NNWaveFitter.cc:36
nnBoundsType m_waveWidthBounds
Waveform width range of the network.
Definition: NNWaveFitter.h:189
activationType relu
Rectifier activation.
Definition: NNWaveFitter.h:145
const NNWaveFitTool & getFitTool() const
Get a handle to a NNWaveFit object.
Definition: NNWaveFitter.h:98
nnFitterBins m_bins
NN time bin boundaries.
Definition: NNWaveFitter.h:183
std::pair< double, double > nnBoundsType
Bounds type used to hold network parameter bounds used in training the network.
Definition: NNWaveFitter.h:70
bool isValid() const
Is this fitter working? Return false if the fitter was not initialized properly.
Definition: NNWaveFitter.h:131
activationType m_activation
Network activation function.
Definition: NNWaveFitter.h:187
std::function< double(double)> activationType
Activation functions type.
Definition: NNWaveFitter.h:143
nnBoundsType getWaveWidthBounds() const
Get width bounds Width bounds are minimum and maximum width used in training the network.
Definition: NNWaveFitter.h:116
std::vector< std::size_t > m_layerSizes
NN layer sizes.
Definition: NNWaveFitter.h:185
activationType sigmoid
Sigmoid activation.
Definition: NNWaveFitter.h:148
std::map< size_t, std::pair< Eigen::MatrixXd, Eigen::VectorXd > > networkWeightsType
We use map to store network layers since we don't know if we'll be reading them in correct order.
Definition: NNWaveFitter.h:174
networkWeightsType m_networkCoefs
NN weights and intercepts.
Definition: NNWaveFitter.h:186
bool checkCoefficients(const std::string &dumpname, double tol=1.0e-10)
Check NN data against a dump from Python.
const nnFitterBins & getBins() const
Get bins of netwrok output.
Definition: NNWaveFitter.h:109
std::shared_ptr< NNWaveFitTool > m_fitTool
FitterTool object allowing calculations on network fits.
Definition: NNWaveFitter.h:193
nnBoundsType m_amplitudeBounds
Amplitude range of the network.
Definition: NNWaveFitter.h:188
nnFitterBinData m_binCenters
centers of NN time bins
Definition: NNWaveFitter.h:182
bool m_isValid
true if fitter was properly initialized
Definition: NNWaveFitter.h:180
Eigen::VectorXd softmax(const Eigen::VectorXd &input)
Softmax function, normalization for the network's output layer.
Definition: NNWaveFitter.h:155
nnBoundsType getAmplitudeBounds() const
Get amplitude bounds.
Definition: NNWaveFitter.h:121
WaveformShape m_wave
Wave function used in training the network.
Definition: NNWaveFitter.h:192
nnBoundsType m_timeShiftBounds
Time shift range of the network.
Definition: NNWaveFitter.h:190
std::shared_ptr< nnFitterBinData > getFit(const apvSamples &samples, double tau)
Fitting method Send data and get rseult structure.
Encoder/decoder for neural network tau values.
std::array< apvSampleBaseType, nAPVSamples > apvSamples
vector od apvSample BaseType objects
std::vector< double > nnFitterBins
Vector of bin edges, nnFitterBinData.size() + 1.
Definition: NNWaveFitTool.h:33
std::function< double(double)> WaveformShape
WaveformShape type.
std::vector< double > nnFitterBinData
Vector of values defined for bins, such as bin times or bin probabilities.
Definition: NNWaveFitTool.h:30
Abstract base class for different kinds of events.