Belle II Software  release-08-01-10
NNFitterTest.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <svd/simulation/SVDSimulationTools.h>
10 #include <svd/reconstruction/NNWaveFitter.h>
11 #include <svd/reconstruction/NNWaveFitTool.h>
12 #include <gtest/gtest.h>
13 #include <iostream>
14 #include <fstream>
15 #include <sstream>
16 #include <string>
17 #include <tuple>
18 
19 using namespace std;
20 
21 namespace Belle2 {
26  namespace SVD {
27 
35  TEST(NNTimeFitter, DISABLED_CompareNetworkCoefficient)
36  {
37  // Create an instance of the NN fitter
38  NNWaveFitter fitter("svd/data/SVDTimeNet.xml");
39  EXPECT_TRUE(fitter.checkCoefficients("svd/data/classifier.txt", 1.0e-6));
40  }
41 
53  TEST(NNTimeFitter, DISABLED_CompareFits)
54  {
55  const size_t max_lines = 100; // maximum number of lines to be read
56 
57  // Create an instance of the NN fitter and the fitter tool.
58  NNWaveFitter fitter("SVDTimeNet_6samples");
59  auto fitTool = fitter.getFitTool();
60  size_t nProbs = fitTool.getBinCenters().size();
61 
62  ifstream infile("svd/data/test_sample.csv");
63 
64  // Read the rows one by one and compare results
65  string line;
66  getline(infile, line);
67 
68  for (size_t i_line = 0; i_line < max_lines; i_line++) {
69 
70  getline(infile, line);
71  if (line.size() < 10) break;
72  istringstream sline(line);
73 
74  // Parse header. We want the dimennsion of the probability array.
75  // not needed
76  string cell;
77  getline(sline, cell, ','); // index
78  getline(sline, cell, ','); // test
79 
80  // true values. Read from the file, though not used.
81  getline(sline, cell, ',');
82  [[maybe_unused]] double true_amp = stod(cell);
83  getline(sline, cell, ',');
84  [[maybe_unused]] double true_t0 = stod(cell);
85  getline(sline, cell, ',');
86  [[maybe_unused]] double width = stod(cell);
87  getline(sline, cell, ',');
88  [[maybe_unused]] double noise = stod(cell);
89 
90  // normalized samples
91  apvSamples normedSamples;
92  for (size_t iSample = 0; iSample < nAPVSamples; ++iSample) {
93  getline(sline, cell, ',');
94  normedSamples[iSample] = stod(cell);
95  }
96 
97  // not needed
98  getline(sline, cell, ',');
99  getline(sline, cell, ',');
100  getline(sline, cell, ',');
101 
102  // probabilities
103  nnFitterBinData ProbsPy(nProbs);
104  for (size_t iSample = 0; iSample < nProbs; ++iSample) {
105  getline(sline, cell, ',');
106  ProbsPy[iSample] = stod(cell);
107  }
108 
109  // fit results
110  getline(sline, cell, ',');
111  double fitPy_amp = stod(cell);
112  getline(sline, cell, ',');
113  double fitPy_ampSigma = stod(cell);
114  getline(sline, cell, ',');
115  [[maybe_unused]] double fitPy_chi2 = stod(cell);
116  getline(sline, cell, ',');
117  double fitPy_t0 = stod(cell);
118  getline(sline, cell, ',');
119  double fitPy_t0Sigma = stod(cell);
120 
121  // now do the Cpp fit
122  const shared_ptr<nnFitterBinData> ProbsCpp = fitter.getFit(normedSamples, width);
123  for (size_t iBin = 0; iBin < nProbs; ++iBin)
124  EXPECT_NEAR((*ProbsCpp)[iBin], ProbsPy[iBin], 5.0e-3);
125 
126  double t0_cpp, t0_err_cpp;
127  tie(t0_cpp, t0_err_cpp) = fitTool.getTimeShift(*ProbsCpp);
128  EXPECT_NEAR(t0_cpp, fitPy_t0, 5);
129  EXPECT_NEAR(t0_err_cpp, fitPy_t0Sigma, 2);
130 
131  double amp_cpp, amp_err_cpp, chi2_cpp;
132  tie(amp_cpp, amp_err_cpp, chi2_cpp) = fitTool.getAmplitudeChi2(normedSamples, t0_cpp, width);
133  EXPECT_NEAR(amp_cpp, fitPy_amp, 1.0);
134  EXPECT_NEAR(amp_err_cpp, fitPy_ampSigma, 0.1);
135  // FIXME: This is calculated slightly differently in Python, and it shows.
136  // EXPECT_NEAR(chi2_cpp, fitPy_chi2, 1);
137  }
138  }
139 
140  } // namespace SVD
142 } // namespace Belle2
The class uses a neural network to find a probability distribution of arrival times for a sextet of A...
Definition: NNWaveFitter.h:61
TEST(TestgetDetectorRegion, TestgetDetectorRegion)
Test Constructors.
std::array< apvSampleBaseType, nAPVSamples > apvSamples
vector od apvSample BaseType objects
const std::size_t nAPVSamples
Number of APV samples.
std::vector< double > nnFitterBinData
Vector of values defined for bins, such as bin times or bin probabilities.
Definition: NNWaveFitTool.h:30
Abstract base class for different kinds of events.