Belle II Software development
hopfieldNetwork.cc
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8
9#include <gtest/gtest.h>
10
11#include <tracking/trackFindingVXD/trackSetEvaluator/HopfieldNetwork.h>
12#include <tracking/trackFindingVXD/trackSetEvaluator/OverlapResolverNodeInfo.h>
13
14#include <framework/logging/Logger.h>
15
16#include <stdlib.h>
17
18using namespace std;
19using namespace Belle2;
20
21
23class HopfieldNetworkTest : public ::testing::Test {
24protected:
25
29 vector <OverlapResolverNodeInfo> m_trackCandidateInfos;
30
34 vector <OverlapResolverNodeInfo> m_qiTrackOverlap;
35
36 unsigned int myTrueTracks = 10;
41 vector<OverlapResolverNodeInfo> getInput()
42 {
43 vector<OverlapResolverNodeInfo> trackCandidateInfos;
44 unsigned int myNTrackCands = 200;
45 unsigned int diff = myNTrackCands - myTrueTracks;
46 unsigned int nOverlaps = 6;
47
48 //Create competitor IDs
49 vector <vector <unsigned short> > competitorIDMatrix(myNTrackCands);
50 for (unsigned int ii = 0; ii < myNTrackCands; ii++) {
51 for (unsigned int jj = 0; jj < nOverlaps; jj++) {
52 unsigned short bkgIndex = (rand() % diff) + myTrueTracks;
53 competitorIDMatrix[ii].push_back(bkgIndex);
54 competitorIDMatrix[bkgIndex].push_back(ii);
55 }
56 }
57
58 //Make the actual OverlapResolverNodeInfos
59 for (unsigned int ii = 0; ii < myNTrackCands; ii++) {
60 float qualityIndicator = 0;
61 if (ii < myTrueTracks) {
62 qualityIndicator = static_cast<float>(rand() % 100) / 100.;
63 B2INFO("Track Index" << ii << ", TrueQI: " << qualityIndicator);
64 } else {
65 qualityIndicator = 1 / (static_cast<float>(rand() % 100) + 1.2);
66 B2INFO("Track Index" << ii << ", FakeQI: " << qualityIndicator);
67 }
68 trackCandidateInfos.emplace_back(qualityIndicator, ii, competitorIDMatrix[ii], 0.8);
69 }
70 return trackCandidateInfos;
71 }
72};
73
74TEST_F(HopfieldNetworkTest, TestPerformance)
75{
76 HopfieldNetwork hopfieldNetwork;
77 m_trackCandidateInfos = getInput();
78 bool finished = hopfieldNetwork.doHopfield(m_trackCandidateInfos);
79 int countCorrectTracksSurvived = 0;
80 int countWrongTracksSurvived = 0;
81 for (auto&& info : m_trackCandidateInfos) {
82 B2INFO("TrackIndex: " << info.trackIndex << ", Neuron Value: " << info.activityState);
83 if (info.trackIndex < myTrueTracks && info.activityState > 0.7) countCorrectTracksSurvived++;
84 if (info.trackIndex >= myTrueTracks && info.activityState > 0.7) countWrongTracksSurvived++;
85 }
86 B2INFO("Correct survivors: " << countCorrectTracksSurvived << ", FakeSurvivors: " << countWrongTracksSurvived);
87
88 EXPECT_EQ(finished, true);
89}
90
91//Let's compare the approach with the Scrooge approach
92TEST_F(HopfieldNetworkTest, TestScrooge)
93{
94 /*
95 m_trackCandidateInfos = getInput();
96 for (auto const && info : m_trackCandidateInfos) {
97 m_qiTrackOverlap.emplace_back({});
98 }
99 */
100}
Hopfield Algorithm with number based inputs.
unsigned short doHopfield(std::vector< OverlapResolverNodeInfo > &overlapResolverNodeInfos, unsigned short nIterations=20)
Performance of the actual algorithm.
Test of HopfieldNetwork Class.
vector< OverlapResolverNodeInfo > m_trackCandidateInfos
Container on which the Hopfield Algorithm runs.
unsigned int myTrueTracks
Number of true tracks.
vector< OverlapResolverNodeInfo > m_qiTrackOverlap
Container on which the Greedy Algorithm (Scrooge) runs.
vector< OverlapResolverNodeInfo > getInput()
Create sample for test.
Abstract base class for different kinds of events.
STL namespace.