9#include <gtest/gtest.h>
11#include <tracking/trackFindingVXD/trackSetEvaluator/HopfieldNetwork.h>
12#include <tracking/trackFindingVXD/trackSetEvaluator/OverlapResolverNodeInfo.h>
14#include <framework/logging/Logger.h>
43 vector<OverlapResolverNodeInfo> trackCandidateInfos;
44 unsigned int myNTrackCands = 200;
46 unsigned int nOverlaps = 6;
49 vector <vector <unsigned short> > competitorIDMatrix(myNTrackCands);
50 for (
unsigned int ii = 0; ii < myNTrackCands; ii++) {
51 for (
unsigned int jj = 0; jj < nOverlaps; jj++) {
52 unsigned short bkgIndex = (rand() % diff) +
myTrueTracks;
53 competitorIDMatrix[ii].push_back(bkgIndex);
54 competitorIDMatrix[bkgIndex].push_back(ii);
59 for (
unsigned int ii = 0; ii < myNTrackCands; ii++) {
60 float qualityIndicator = 0;
62 qualityIndicator =
static_cast<float>(rand() % 100) / 100.;
63 B2INFO(
"Track Index" << ii <<
", TrueQI: " << qualityIndicator);
65 qualityIndicator = 1 / (
static_cast<float>(rand() % 100) + 1.2);
66 B2INFO(
"Track Index" << ii <<
", FakeQI: " << qualityIndicator);
68 trackCandidateInfos.emplace_back(qualityIndicator, ii, competitorIDMatrix[ii], 0.8);
70 return trackCandidateInfos;
77 m_trackCandidateInfos = getInput();
78 bool finished = hopfieldNetwork.
doHopfield(m_trackCandidateInfos);
79 int countCorrectTracksSurvived = 0;
80 int countWrongTracksSurvived = 0;
81 for (
auto&& info : m_trackCandidateInfos) {
82 B2INFO(
"TrackIndex: " << info.trackIndex <<
", Neuron Value: " << info.activityState);
83 if (info.trackIndex < myTrueTracks && info.activityState > 0.7) countCorrectTracksSurvived++;
84 if (info.trackIndex >= myTrueTracks && info.activityState > 0.7) countWrongTracksSurvived++;
86 B2INFO(
"Correct survivors: " << countCorrectTracksSurvived <<
", FakeSurvivors: " << countWrongTracksSurvived);
88 EXPECT_EQ(finished,
true);
Hopfield Algorithm with number based inputs.
unsigned short doHopfield(std::vector< OverlapResolverNodeInfo > &overlapResolverNodeInfos, unsigned short nIterations=20)
Performance of the actual algorithm.
Test of HopfieldNetwork Class.
vector< OverlapResolverNodeInfo > m_trackCandidateInfos
Container on which the Hopfield Algorithm runs.
unsigned int myTrueTracks
Number of true tracks.
vector< OverlapResolverNodeInfo > m_qiTrackOverlap
Container on which the Greedy Algorithm (Scrooge) runs.
vector< OverlapResolverNodeInfo > getInput()
Create sample for test.
Abstract base class for different kinds of events.