11 #include <gtest/gtest.h>
12 #include <tracking/trackFindingCDC/findlets/generic/TreeTraversal.h>
13 #include <tracking/trackFindingCDC/findlets/generic/WeightedTreeTraversal.h>
15 #include <tracking/trackFindingCDC/findlets/base/Findlet.h>
17 #include <tracking/trackFindingCDC/numerics/Weight.h>
18 #include <tracking/trackFindingCDC/numerics/WithWeight.h>
20 #include <tracking/trackFindingCDC/utilities/Functional.h>
21 #include <tracking/trackFindingCDC/utilities/Algorithms.h>
26 using namespace TrackFindingCDC;
30 using Result = std::vector<const int*>;
32 class AcceptAll :
public Findlet<const int* const, int*> {
34 void apply(
const std::vector<const int*>& currentPath __attribute__((unused)),
35 std::vector<int*>& nextStates __attribute__((unused)))
override
41 class AcceptHighWeight :
public Findlet<const int* const, WithWeight<int*> > {
43 void apply(
const std::vector<const int*>& currentPath __attribute__((unused)),
44 std::vector<WithWeight<int*>>& nextStates __attribute__((unused)))
47 erase_remove_if(nextStates, GetWeight() < 0.5);
53 TEST(TrackFindingCDCTest, Findlet_generic_TreeTraversal)
55 TreeTraversal<AcceptAll, State, Result> testTreeTraversal;
58 std::vector<int> states({1, 2, 3});
61 std::vector<Relation<int> > stateRelations;
62 stateRelations.push_back({&states[1], &states[2]});
63 stateRelations.push_back({&states[0], &states[2]});
64 stateRelations.push_back({&states[0], &states[1]});
65 std::sort(stateRelations.begin(), stateRelations.end());
68 std::vector<const int*> seedStates;
69 seedStates.push_back(&states[0]);
72 std::vector<std::vector<const int*>> results;
73 testTreeTraversal.apply(seedStates, stateRelations, results);
75 ASSERT_EQ(2, results.size());
77 ASSERT_EQ(3, results[0].size());
78 EXPECT_EQ(1, *results[0][0]);
79 EXPECT_EQ(2, *results[0][1]);
80 EXPECT_EQ(3, *results[0][2]);
82 ASSERT_EQ(2, results[1].size());
83 EXPECT_EQ(1, *results[1][0]);
84 EXPECT_EQ(3, *results[1][1]);
87 TEST(TrackFindingCDCTest, Findlet_generic_WeightedTreeTraversal)
89 WeightedTreeTraversal<AcceptHighWeight, State> testWeightedTreeTraversal;
92 std::vector<int> states({1, 2, 3});
95 std::vector<WeightedRelation<int> > stateRelations;
96 stateRelations.push_back({&states[1], 0.1, &states[2]});
97 stateRelations.push_back({&states[0], 1, &states[2]});
98 stateRelations.push_back({&states[0], 1, &states[1]});
99 std::sort(stateRelations.begin(), stateRelations.end());
102 std::vector<const int*> seedStates;
103 seedStates.push_back(&states[0]);
106 std::vector<std::vector<const int*>> results;
107 testWeightedTreeTraversal.apply(seedStates, stateRelations, results);
109 ASSERT_EQ(2, results.size());
111 ASSERT_EQ(2, results[0].size());
112 EXPECT_EQ(1, *results[0][0]);
113 EXPECT_EQ(2, *results[0][1]);
115 ASSERT_EQ(2, results[1].size());
116 EXPECT_EQ(1, *results[1][0]);
117 EXPECT_EQ(3, *results[1][1]);