Belle II Software  release-08-01-10
hopfieldNetwork.cc
1 /**************************************************************************
2  * basf2 (Belle II Analysis Software Framework) *
3  * Author: The Belle II Collaboration *
4  * *
5  * See git log for contributors and copyright holders. *
6  * This file is licensed under LGPL-3.0, see LICENSE.md. *
7  **************************************************************************/
8 
9 #include <gtest/gtest.h>
10 
11 #include <tracking/trackFindingVXD/trackSetEvaluator/HopfieldNetwork.h>
12 #include <tracking/trackFindingVXD/trackSetEvaluator/OverlapResolverNodeInfo.h>
13 
14 #include <framework/logging/Logger.h>
15 
16 #include <stdlib.h>
17 
18 using namespace std;
19 using namespace Belle2;
20 
21 
23 class HopfieldNetworkTest : public ::testing::Test {
24 protected:
25 
29  vector <OverlapResolverNodeInfo> m_trackCandidateInfos;
30 
34  vector <OverlapResolverNodeInfo> m_qiTrackOverlap;
35 
36  unsigned int myTrueTracks = 10;
41  vector<OverlapResolverNodeInfo> getInput()
42  {
43  vector<OverlapResolverNodeInfo> trackCandidateInfos;
44  unsigned int myNTrackCands = 200;
45  unsigned int diff = myNTrackCands - myTrueTracks;
46  unsigned int nOverlaps = 6;
47 
48  //Create competitor IDs
49  vector <vector <unsigned short> > competitorIDMatrix(myNTrackCands);
50  for (unsigned int ii = 0; ii < myNTrackCands; ii++) {
51  for (unsigned int jj = 0; jj < nOverlaps; jj++) {
52  unsigned short bkgIndex = (rand() % diff) + myTrueTracks;
53  competitorIDMatrix[ii].push_back(bkgIndex);
54  competitorIDMatrix[bkgIndex].push_back(ii);
55  }
56  }
57 
58  //Make the actual OverlapResolverNodeInfos
59  for (unsigned int ii = 0; ii < myNTrackCands; ii++) {
60  float qualityIndicator = 0;
61  if (ii < myTrueTracks) {
62  qualityIndicator = static_cast<float>(rand() % 100) / 100.;
63  B2INFO("Track Index" << ii << ", TrueQI: " << qualityIndicator);
64  } else {
65  qualityIndicator = 1 / (static_cast<float>(rand() % 100) + 1.2);
66  B2INFO("Track Index" << ii << ", FakeQI: " << qualityIndicator);
67  }
68  trackCandidateInfos.emplace_back(qualityIndicator, ii, competitorIDMatrix[ii], 0.8);
69  }
70  return trackCandidateInfos;
71  }
72 };
73 
74 TEST_F(HopfieldNetworkTest, TestPerformance)
75 {
76  HopfieldNetwork hopfieldNetwork;
77  m_trackCandidateInfos = getInput();
78  bool finished = hopfieldNetwork.doHopfield(m_trackCandidateInfos);
79  int countCorrectTracksSurvived = 0;
80  int countWrongTracksSurvived = 0;
81  for (auto&& info : m_trackCandidateInfos) {
82  B2INFO("TrackIndex: " << info.trackIndex << ", Neuron Value: " << info.activityState);
83  if (info.trackIndex < myTrueTracks && info.activityState > 0.7) countCorrectTracksSurvived++;
84  if (info.trackIndex >= myTrueTracks && info.activityState > 0.7) countWrongTracksSurvived++;
85  }
86  B2INFO("Correct survivors: " << countCorrectTracksSurvived << ", FakeSurvivors: " << countWrongTracksSurvived);
87 
88  EXPECT_EQ(finished, true);
89 }
90 
91 //Let's compare the approach with the Scrooge approach
92 TEST_F(HopfieldNetworkTest, TestScrooge)
93 {
94  /*
95  m_trackCandidateInfos = getInput();
96  for (auto const && info : m_trackCandidateInfos) {
97  m_qiTrackOverlap.emplace_back({});
98  }
99  */
100 }
Hopfield Algorithm with number based inputs.
unsigned short doHopfield(std::vector< OverlapResolverNodeInfo > &overlapResolverNodeInfos, unsigned short nIterations=20)
Performance of the actual algorithm.
Test of HopfieldNetwork Class.
vector< OverlapResolverNodeInfo > m_trackCandidateInfos
Container on which the Hopfield Algorithm runs.
vector< OverlapResolverNodeInfo > getInput()
Create sample for test.
vector< OverlapResolverNodeInfo > m_qiTrackOverlap
Container on which the Greedy Algorithm (Scrooge) runs.
TEST_F(GlobalLabelTest, LargeNumberOfTimeDependentParameters)
Test large number of time-dep params for registration and retrieval.
Definition: globalLabel.cc:72
Abstract base class for different kinds of events.