Belle II Software development
NNFitterTest.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <svd/simulation/SVDSimulationTools.h>
10#include <svd/reconstruction/NNWaveFitter.h>
11#include <svd/reconstruction/NNWaveFitTool.h>
12#include <gtest/gtest.h>
13#include <iostream>
14#include <fstream>
15#include <sstream>
16#include <string>
17#include <tuple>
18
19using namespace std;
20
21namespace Belle2 {
26 namespace SVD {
27
35 TEST(NNTimeFitter, DISABLED_CompareNetworkCoefficient)
36 {
37 // Create an instance of the NN fitter
38 NNWaveFitter fitter("svd/data/SVDTimeNet.xml");
39 EXPECT_TRUE(fitter.checkCoefficients("svd/data/classifier.txt", 1.0e-6));
40 }
41
53 TEST(NNTimeFitter, DISABLED_CompareFits)
54 {
55 const size_t max_lines = 100; // maximum number of lines to be read
56
57 // Create an instance of the NN fitter and the fitter tool.
58 NNWaveFitter fitter("SVDTimeNet_6samples");
59 auto fitTool = fitter.getFitTool();
60 size_t nProbs = fitTool.getBinCenters().size();
61
62 ifstream infile("svd/data/test_sample.csv");
63
64 // Read the rows one by one and compare results
65 string line;
66 getline(infile, line);
67
68 for (size_t i_line = 0; i_line < max_lines; i_line++) {
69
70 getline(infile, line);
71 if (line.size() < 10) break;
72 istringstream sline(line);
73
74 // Parse header. We want the dimennsion of the probability array.
75 // not needed
76 string cell;
77 getline(sline, cell, ','); // index
78 getline(sline, cell, ','); // test
79
80 // true values. Read from the file, though not used.
81 getline(sline, cell, ',');
82 [[maybe_unused]] double true_amp = stod(cell);
83 getline(sline, cell, ',');
84 [[maybe_unused]] double true_t0 = stod(cell);
85 getline(sline, cell, ',');
86 [[maybe_unused]] double width = stod(cell);
87 getline(sline, cell, ',');
88 [[maybe_unused]] double noise = stod(cell);
89
90 // normalized samples
91 apvSamples normedSamples;
92 for (size_t iSample = 0; iSample < nAPVSamples; ++iSample) {
93 getline(sline, cell, ',');
94 normedSamples[iSample] = stod(cell);
95 }
96
97 // not needed
98 getline(sline, cell, ',');
99 getline(sline, cell, ',');
100 getline(sline, cell, ',');
101
102 // probabilities
103 nnFitterBinData ProbsPy(nProbs);
104 for (size_t iSample = 0; iSample < nProbs; ++iSample) {
105 getline(sline, cell, ',');
106 ProbsPy[iSample] = stod(cell);
107 }
108
109 // fit results
110 getline(sline, cell, ',');
111 double fitPy_amp = stod(cell);
112 getline(sline, cell, ',');
113 double fitPy_ampSigma = stod(cell);
114 getline(sline, cell, ',');
115 [[maybe_unused]] double fitPy_chi2 = stod(cell);
116 getline(sline, cell, ',');
117 double fitPy_t0 = stod(cell);
118 getline(sline, cell, ',');
119 double fitPy_t0Sigma = stod(cell);
120
121 // now do the Cpp fit
122 const shared_ptr<nnFitterBinData> ProbsCpp = fitter.getFit(normedSamples, width);
123 for (size_t iBin = 0; iBin < nProbs; ++iBin)
124 EXPECT_NEAR((*ProbsCpp)[iBin], ProbsPy[iBin], 5.0e-3);
125
126 double t0_cpp, t0_err_cpp;
127 tie(t0_cpp, t0_err_cpp) = fitTool.getTimeShift(*ProbsCpp);
128 EXPECT_NEAR(t0_cpp, fitPy_t0, 5);
129 EXPECT_NEAR(t0_err_cpp, fitPy_t0Sigma, 2);
130
131 double amp_cpp, amp_err_cpp, chi2_cpp;
132 tie(amp_cpp, amp_err_cpp, chi2_cpp) = fitTool.getAmplitudeChi2(normedSamples, t0_cpp, width);
133 EXPECT_NEAR(amp_cpp, fitPy_amp, 1.0);
134 EXPECT_NEAR(amp_err_cpp, fitPy_ampSigma, 0.1);
135 // FIXME: This is calculated slightly differently in Python, and it shows.
136 // EXPECT_NEAR(chi2_cpp, fitPy_chi2, 1);
137 }
138 }
139
140 } // namespace SVD
142} // namespace Belle2
The class uses a neural network to find a probability distribution of arrival times for a sextet of A...
Definition: NNWaveFitter.h:61
std::array< apvSampleBaseType, nAPVSamples > apvSamples
vector od apvSample BaseType objects
const std::size_t nAPVSamples
Number of APV samples.
std::vector< double > nnFitterBinData
Vector of values defined for bins, such as bin times or bin probabilities.
Definition: NNWaveFitTool.h:30
Abstract base class for different kinds of events.
STL namespace.