Belle II Software  release-05-02-19
NNWaveFitter.h
1 /**************************************************************************
2  * BASF2 (Belle Analysis Framework 2) *
3  * Copyright(C) 2010 - Belle II Collaboration *
4  * *
5  * Author: The Belle II Collaboration *
6  * Contributors: Peter Kvasnicka *
7  * *
8  * This software is provided "as is" without any warranty. *
9  **************************************************************************/
10 
11 #pragma once
12 #ifndef _SVD_RECONSTRUCTION_NNwAVEFITTER_H
13 #define _SVD_RECONSTRUCTION_NNWAVEFITTER_H
14 
15 #include <cmath>
16 #include <string>
17 #include <vector>
18 #include <map>
19 #include <memory>
20 #include <functional>
21 #include <svd/reconstruction/NNWaveFitTool.h>
22 #include <svd/simulation/SVDSimulationTools.h>
23 #include <Eigen/Dense>
24 
25 namespace Belle2 {
30  namespace SVD {
31 
32 // ==============================================================================
33 // The NNWaveFitter class - neural network fitter of APV25 waveforms
34 // ------------------------------------------------------------------------------
63  class NNWaveFitter {
64 
65  public:
66 
67 
72  typedef std::pair<double, double> nnBoundsType;
73 
78  NNWaveFitter(std::string xmlData = "");
79 
83  void setNetwrok(const std::string& xmlData);
84 
94  std::shared_ptr<nnFitterBinData> getFit(const apvSamples& samples, double tau);
95 
100  const NNWaveFitTool& getFitTool() const { return *m_fitTool; }
101 
106  const nnFitterBinData& getBinCenters() const { return m_binCenters; }
107 
111  const nnFitterBins& getBins() const { return m_bins; }
112 
119 
124 
129 
133  bool isValid() const { return m_isValid; }
134 
140  bool checkCoefficients(const std::string& dumpname, double tol = 1.0e-10);
141 
142  private:
143 
145  typedef std::function<double(double)> activationType;
148  [](double x) -> double { return std::max(double(0.0), x); };
151  [](double x) -> double { double e = std::exp(x); return e / (1.0 + e); };
152 
157  Eigen::VectorXd softmax(const Eigen::VectorXd& input)
158  {
159  auto output = input.array().unaryExpr(
160  [](double x)->double { return ((x > 0.0) ? exp(x) : 0.0); }
161  );
162  double norm = output.sum();
163  return output / norm;
164  }
165 
170  int readNetworkData(const std::string& xmlFileName);
171 
175  typedef std::map< size_t, std::pair< Eigen::MatrixXd, Eigen::VectorXd > >
177 
179  typedef std::vector<Eigen::VectorXd> layerStatesType;
180 
181  // Data members
182  bool m_isValid;
183  std::size_t m_nLayers;
187  std::vector<std::size_t> m_layerSizes;
195  std::shared_ptr<NNWaveFitTool> m_fitTool;
197  }; // class NNWaveFitter
198 
199  } // namespace SVD
201 } // namespace Belle2
202 
203 #endif
Belle2::SVD::NNWaveFitter::getFitTool
const NNWaveFitTool & getFitTool() const
Get a handle to a NNWaveFit object.
Definition: NNWaveFitter.h:100
Belle2::SVD::NNWaveFitter::m_tauCoder
TauEncoder m_tauCoder
Tau encoder class instance to scale tau values.
Definition: NNWaveFitter.h:193
Belle2::SVD::waveFunction
std::function< double(double)> waveFunction
Wavefrom function type.
Definition: SVDSimulationTools.h:56
Belle2::SVD::TauEncoder
Encoder/decoder for neural network tau values.
Definition: SVDSimulationTools.h:162
Belle2::SVD::NNWaveFitter::m_activation
activationType m_activation
Network activation function.
Definition: NNWaveFitter.h:189
Belle2::SVD::NNWaveFitter::getBinCenters
const nnFitterBinData & getBinCenters() const
Get bin times of the network output.
Definition: NNWaveFitter.h:106
Belle2::SVD::NNWaveFitter::m_layerSizes
std::vector< std::size_t > m_layerSizes
NN layer sizes.
Definition: NNWaveFitter.h:187
Belle2::SVD::NNWaveFitter::m_networkCoefs
networkWeightsType m_networkCoefs
NN weights and intercepts.
Definition: NNWaveFitter.h:188
Belle2::SVD::NNWaveFitter::networkWeightsType
std::map< size_t, std::pair< Eigen::MatrixXd, Eigen::VectorXd > > networkWeightsType
We use map to store network layers since we don't know if we'll be reading them in correct order.
Definition: NNWaveFitter.h:176
Belle2::SVD::NNWaveFitter::nnBoundsType
std::pair< double, double > nnBoundsType
Bounds type used to hold network parameter bounds used in training the network.
Definition: NNWaveFitter.h:72
Belle2::SVD::NNWaveFitter::m_bins
nnFitterBins m_bins
NN time bin boundaries.
Definition: NNWaveFitter.h:185
Belle2::SVD::NNWaveFitter::softmax
Eigen::VectorXd softmax(const Eigen::VectorXd &input)
Softmax function, normalization for the network's output layer.
Definition: NNWaveFitter.h:157
Belle2::SVD::NNWaveFitter::readNetworkData
int readNetworkData(const std::string &xmlFileName)
The method that actually reads the xml file.
Definition: NNWaveFitter.cc:38
Belle2::SVD::NNWaveFitter::layerStatesType
std::vector< Eigen::VectorXd > layerStatesType
Storage for internal states.
Definition: NNWaveFitter.h:179
Belle2::SVD::NNWaveFitter
The class uses a neural network to find a probability distribution of arrival times for a sextet of A...
Definition: NNWaveFitter.h:63
Belle2::SVD::NNWaveFitter::getAmplitudeBounds
nnBoundsType getAmplitudeBounds() const
Get amplitude bounds.
Definition: NNWaveFitter.h:123
Belle2::SVD::NNWaveFitter::checkCoefficients
bool checkCoefficients(const std::string &dumpname, double tol=1.0e-10)
Check NN data against a dump from Python.
Definition: NNWaveFitter.cc:194
Belle2::SVD::NNWaveFitter::getWaveWidthBounds
nnBoundsType getWaveWidthBounds() const
Get width bounds Width bounds are minimum and maximum width used in training the network.
Definition: NNWaveFitter.h:118
Belle2::SVD::NNWaveFitter::m_wave
waveFunction m_wave
Wave function used in training the network.
Definition: NNWaveFitter.h:194
Belle2::SVD::nnFitterBins
std::vector< double > nnFitterBins
Vector of bin edges, nnFitterBinData.size() + 1.
Definition: NNWaveFitTool.h:35
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::SVD::apvSamples
std::array< apvSampleBaseType, nAPVSamples > apvSamples
vector od apvSample BaseType objects
Definition: SVDSimulationTools.h:41
Belle2::SVD::NNWaveFitter::m_layerStates
layerStatesType m_layerStates
vectors of layer states
Definition: NNWaveFitter.h:186
Belle2::SVD::NNWaveFitter::m_nLayers
std::size_t m_nLayers
number of NN layers, read from xml
Definition: NNWaveFitter.h:183
Belle2::SVD::NNWaveFitter::sigmoid
activationType sigmoid
Sigmoid activation.
Definition: NNWaveFitter.h:150
Belle2::SVD::NNWaveFitter::m_isValid
bool m_isValid
true if fitter was properly initialized
Definition: NNWaveFitter.h:182
Belle2::SVD::NNWaveFitter::setNetwrok
void setNetwrok(const std::string &xmlData)
Set proper network definition file.
Definition: NNWaveFitter.cc:253
Belle2::SVD::NNWaveFitter::m_binCenters
nnFitterBinData m_binCenters
centers of NN time bins
Definition: NNWaveFitter.h:184
Belle2::SVD::NNWaveFitter::getBins
const nnFitterBins & getBins() const
Get bins of netwrok output.
Definition: NNWaveFitter.h:111
Belle2::SVD::NNWaveFitter::getTimeShiftBounds
nnBoundsType getTimeShiftBounds() const
Get time shift bounds.
Definition: NNWaveFitter.h:128
Belle2::SVD::NNWaveFitter::activationType
std::function< double(double)> activationType
Activation functions type.
Definition: NNWaveFitter.h:145
Belle2::SVD::NNWaveFitTool
The class holds arrays of bins and bin centers, and a wave generator object containing information on...
Definition: NNWaveFitTool.h:93
Belle2::SVD::nnFitterBinData
std::vector< double > nnFitterBinData
Vector of values defined for bins, such as bin times or bin probabilities.
Definition: NNWaveFitTool.h:32
Belle2::SVD::NNWaveFitter::m_fitTool
std::shared_ptr< NNWaveFitTool > m_fitTool
FitterTool object allowing calculations on network fits.
Definition: NNWaveFitter.h:195
Belle2::SVD::NNWaveFitter::m_timeShiftBounds
nnBoundsType m_timeShiftBounds
Time shift range of the network.
Definition: NNWaveFitter.h:192
Belle2::SVD::NNWaveFitter::m_waveWidthBounds
nnBoundsType m_waveWidthBounds
Waveform width range of the network.
Definition: NNWaveFitter.h:191
Belle2::SVD::NNWaveFitter::relu
activationType relu
Rectifier activation.
Definition: NNWaveFitter.h:147
Belle2::SVD::NNWaveFitter::m_amplitudeBounds
nnBoundsType m_amplitudeBounds
Amplitude range of the network.
Definition: NNWaveFitter.h:190
Belle2::SVD::NNWaveFitter::NNWaveFitter
NNWaveFitter(std::string xmlData="")
Constructor constructs the wavefitter from data in xml file.
Definition: NNWaveFitter.cc:262
Belle2::SVD::NNWaveFitter::isValid
bool isValid() const
Is this fitter working? Return false if the fitter was not initialized properly.
Definition: NNWaveFitter.h:133
Belle2::SVD::NNWaveFitter::getFit
std::shared_ptr< nnFitterBinData > getFit(const apvSamples &samples, double tau)
Fitting method Send data and get rseult structure.
Definition: NNWaveFitter.cc:269