Belle II Software  release-05-01-25
hopfieldNetwork.cc
1 /**************************************************************************
2  * BASF2 (Belle Analysis Framework 2) *
3  * Copyright(C) 2015 - Belle II Collaboration *
4  * *
5  * Author: The Belle II Collaboration *
6  * Contributors: Martin Heck *
7  * *
8  * This software is provided "as is" without any warranty. *
9  **************************************************************************/
10 
11 #include <gtest/gtest.h>
12 
13 #include <tracking/trackFindingVXD/trackSetEvaluator/HopfieldNetwork.h>
14 #include <tracking/trackFindingVXD/trackSetEvaluator/OverlapResolverNodeInfo.h>
15 
16 #include <framework/logging/Logger.h>
17 
18 #include <stdlib.h>
19 
20 using namespace std;
21 using namespace Belle2;
22 
23 
25 class HopfieldNetworkTest : public ::testing::Test {
26 protected:
27 
31  vector <OverlapResolverNodeInfo> m_trackCandidateInfos;
32 
36  vector <OverlapResolverNodeInfo> m_qiTrackOverlap;
37 
38  unsigned int myTrueTracks = 10;
43  vector<OverlapResolverNodeInfo> getInput()
44  {
45  vector<OverlapResolverNodeInfo> trackCandidateInfos;
46  unsigned int myNTrackCands = 200;
47  unsigned int diff = myNTrackCands - myTrueTracks;
48  unsigned int nOverlaps = 6;
49 
50  //Create competitor IDs
51  vector <vector <unsigned short> > competitorIDMatrix(myNTrackCands);
52  for (unsigned int ii = 0; ii < myNTrackCands; ii++) {
53  for (unsigned int jj = 0; jj < nOverlaps; jj++) {
54  unsigned short bkgIndex = (rand() % diff) + myTrueTracks;
55  competitorIDMatrix[ii].push_back(bkgIndex);
56  competitorIDMatrix[bkgIndex].push_back(ii);
57  }
58  }
59 
60  //Make the actual OverlapResolverNodeInfos
61  for (unsigned int ii = 0; ii < myNTrackCands; ii++) {
62  float qualityIndicator = 0;
63  if (ii < myTrueTracks) {
64  qualityIndicator = static_cast<float>(rand() % 100) / 100.;
65  B2INFO("Track Index" << ii << ", TrueQI: " << qualityIndicator);
66  } else {
67  qualityIndicator = 1 / (static_cast<float>(rand() % 100) + 1.2);
68  B2INFO("Track Index" << ii << ", FakeQI: " << qualityIndicator);
69  }
70  trackCandidateInfos.emplace_back(qualityIndicator, ii, competitorIDMatrix[ii], 0.8);
71  }
72  return trackCandidateInfos;
73  }
74 };
75 
76 TEST_F(HopfieldNetworkTest, TestPerformance)
77 {
78  HopfieldNetwork hopfieldNetwork;
79  m_trackCandidateInfos = getInput();
80  bool finished = hopfieldNetwork.doHopfield(m_trackCandidateInfos);
81  int countCorrectTracksSurvived = 0;
82  int countWrongTracksSurvived = 0;
83  for (auto && info : m_trackCandidateInfos) {
84  B2INFO("TrackIndex: " << info.trackIndex << ", Neuron Value: " << info.activityState);
85  if (info.trackIndex < myTrueTracks && info.activityState > 0.7) countCorrectTracksSurvived++;
86  if (info.trackIndex >= myTrueTracks && info.activityState > 0.7) countWrongTracksSurvived++;
87  }
88  B2INFO("Correct survivors: " << countCorrectTracksSurvived << ", FakeSurvivors: " << countWrongTracksSurvived);
89 
90  EXPECT_EQ(finished, true);
91 }
92 
93 //Let's compare the approach with the Scrooge approach
94 TEST_F(HopfieldNetworkTest, TestScrooge)
95 {
96  /*
97  m_trackCandidateInfos = getInput();
98  for (auto const && info : m_trackCandidateInfos) {
99  m_qiTrackOverlap.emplace_back({});
100  }
101  */
102 }
HopfieldNetworkTest::m_qiTrackOverlap
vector< OverlapResolverNodeInfo > m_qiTrackOverlap
Container on which the Greedy Algorithm (Scrooge) runs.
Definition: hopfieldNetwork.cc:36
Belle2
Abstract base class for different kinds of events.
Definition: MillepedeAlgorithm.h:19
Belle2::HopfieldNetwork
Hopfield Algorithm with number based inputs.
Definition: HopfieldNetwork.h:46
Belle2::TEST_F
TEST_F(GlobalLabelTest, LargeNumberOfTimeDependentParameters)
Test large number of time-dep params for registration and retrieval.
Definition: globalLabel.cc:65
Belle2::HopfieldNetwork::doHopfield
unsigned short doHopfield(std::vector< OverlapResolverNodeInfo > &overlapResolverNodeInfos, unsigned short nIterations=20)
Performance of the actual algorithm.
Definition: HopfieldNetwork.cc:21
HopfieldNetworkTest::getInput
vector< OverlapResolverNodeInfo > getInput()
Create sample for test.
Definition: hopfieldNetwork.cc:43
HopfieldNetworkTest::m_trackCandidateInfos
vector< OverlapResolverNodeInfo > m_trackCandidateInfos
Container on which the Hopfield Algorithm runs.
Definition: hopfieldNetwork.cc:31
HopfieldNetworkTest
Test of HopfieldNetwork Class.
Definition: hopfieldNetwork.cc:25