11 #include <gtest/gtest.h>
13 #include <tracking/trackFindingVXD/trackSetEvaluator/HopfieldNetwork.h>
14 #include <tracking/trackFindingVXD/trackSetEvaluator/OverlapResolverNodeInfo.h>
16 #include <framework/logging/Logger.h>
38 unsigned int myTrueTracks = 10;
45 vector<OverlapResolverNodeInfo> trackCandidateInfos;
46 unsigned int myNTrackCands = 200;
47 unsigned int diff = myNTrackCands - myTrueTracks;
48 unsigned int nOverlaps = 6;
51 vector <vector <unsigned short> > competitorIDMatrix(myNTrackCands);
52 for (
unsigned int ii = 0; ii < myNTrackCands; ii++) {
53 for (
unsigned int jj = 0; jj < nOverlaps; jj++) {
54 unsigned short bkgIndex = (rand() % diff) + myTrueTracks;
55 competitorIDMatrix[ii].push_back(bkgIndex);
56 competitorIDMatrix[bkgIndex].push_back(ii);
61 for (
unsigned int ii = 0; ii < myNTrackCands; ii++) {
62 float qualityIndicator = 0;
63 if (ii < myTrueTracks) {
64 qualityIndicator =
static_cast<float>(rand() % 100) / 100.;
65 B2INFO(
"Track Index" << ii <<
", TrueQI: " << qualityIndicator);
67 qualityIndicator = 1 / (
static_cast<float>(rand() % 100) + 1.2);
68 B2INFO(
"Track Index" << ii <<
", FakeQI: " << qualityIndicator);
70 trackCandidateInfos.emplace_back(qualityIndicator, ii, competitorIDMatrix[ii], 0.8);
72 return trackCandidateInfos;
79 m_trackCandidateInfos = getInput();
80 bool finished = hopfieldNetwork.
doHopfield(m_trackCandidateInfos);
81 int countCorrectTracksSurvived = 0;
82 int countWrongTracksSurvived = 0;
83 for (
auto && info : m_trackCandidateInfos) {
84 B2INFO(
"TrackIndex: " << info.trackIndex <<
", Neuron Value: " << info.activityState);
85 if (info.trackIndex < myTrueTracks && info.activityState > 0.7) countCorrectTracksSurvived++;
86 if (info.trackIndex >= myTrueTracks && info.activityState > 0.7) countWrongTracksSurvived++;
88 B2INFO(
"Correct survivors: " << countCorrectTracksSurvived <<
", FakeSurvivors: " << countWrongTracksSurvived);
90 EXPECT_EQ(finished,
true);