Belle II Software development
WeightedFastHoughTree.h
1/**************************************************************************
2 * basf2 (Belle II Analysis Software Framework) *
3 * Author: The Belle II Collaboration *
4 * *
5 * See git log for contributors and copyright holders. *
6 * This file is licensed under LGPL-3.0, see LICENSE.md. *
7 **************************************************************************/
8#pragma once
9
10#include <tracking/trackFindingCDC/hough/trees/DynTree.h>
11#include <tracking/trackFindingCDC/hough/baseelements/WithWeightedItems.h>
12#include <tracking/trackFindingCDC/hough/baseelements/WithSharedMark.h>
13
14#include <vector>
15#include <memory>
16#include <cassert>
17#include <cfloat>
18#include <cmath>
19
20namespace Belle2 {
25 namespace TrackFindingCDC {
26
28 template<class T, class ADomain, class ADomainDivsion>
30 public DynTree< WithWeightedItems<ADomain, T>, ADomainDivsion> {
31 private:
34
35 public:
37 WeightedParititioningDynTree(ADomain topDomain, ADomainDivsion domainDivsion) :
38 Super(WithWeightedItems<ADomain, T>(std::move(topDomain)), std::move(domainDivsion))
39 {}
40 };
41
48 template<class T, class ADomain, class ADomainDivsion>
50 public WeightedParititioningDynTree<WithSharedMark<T>, ADomain, ADomainDivsion> {
51
52 private:
54 using Super = WeightedParititioningDynTree<WithSharedMark<T>, ADomain, ADomainDivsion>;
55
56 public:
58 using WeightedParititioningDynTree<WithSharedMark<T>, ADomain, ADomainDivsion>::WeightedParititioningDynTree;
59
61 using Node = typename Super::Node;
62
63 public:
65 template<class Ts>
66 void seed(const Ts& items)
67 {
68 this->fell();
69 Node& topNode = this->getTopNode();
70 for (auto&& item : items) {
71 m_marks.push_back(false);
72 bool& markOfItem = m_marks.back();
73 Weight weight = DBL_MAX;
74 topNode.insert(WithSharedMark<T>(T(item), &markOfItem), weight);
75 }
76 }
77
79 template <class AItemInDomainMeasure>
80 std::vector<std::pair<ADomain, std::vector<T>>>
81 findHeavyLeavesDisjoint(AItemInDomainMeasure& weightItemInDomain,
82 int maxLevel,
83 double minWeight)
84 {
85 auto skipLowWeightNode = [minWeight](const Node * node) {
86 return not(node->getWeight() >= minWeight);
87 };
88 return findLeavesDisjoint(weightItemInDomain, maxLevel, skipLowWeightNode);
89 }
90
92 template <class AItemInDomainMeasure, class ASkipNodePredicate>
93 std::vector<std::pair<ADomain, std::vector<T>>>
94 findLeavesDisjoint(AItemInDomainMeasure& weightItemInDomain,
95 int maxLevel,
96 ASkipNodePredicate& skipNode)
97 {
98 std::vector<std::pair<ADomain, std::vector<T> > > found;
99 auto isLeaf = [&found, &skipNode, maxLevel](Node * node) {
100 // Skip the expansion and the filling of the children
101 if (skipNode(node)) {
102 return true;
103 }
104
105 // Node is a leaf at the maximum level
106 // Save its content
107 // Do not walk children
108 if (node->getLevel() >= maxLevel) {
109 const ADomain* domain = node;
110 found.emplace_back(*domain, std::vector<T>(node->begin(), node->end()));
111 for (WithSharedMark<T>& markableItem : *node) {
112 markableItem.mark();
113 }
114 return true;
115 }
116
117 // Else to node has enough weight and is not at the lowest level
118 // Signal that it is not a leaf
119 // Continue to create and fill children.
120 return false;
121 };
122 fillWalk(weightItemInDomain, isLeaf);
123 return found;
124 }
125
133 template <class AItemInDomainMeasure>
134 std::vector<std::pair<ADomain, std::vector<T>>>
135 findHeaviestLeafRepeated(AItemInDomainMeasure& weightItemInDomain,
136 int maxLevel,
137 const Weight minWeight = NAN)
138 {
139 auto skipLowWeightNode = [minWeight](const Node * node) {
140 return not(node->getWeight() >= minWeight);
141 };
142 return findHeaviestLeafRepeated(weightItemInDomain, maxLevel, skipLowWeightNode);
143 }
144
152 template <class AItemInDomainMeasure, class ASkipNodePredicate>
153 std::vector<std::pair<ADomain, std::vector<T>>>
154 findHeaviestLeafRepeated(AItemInDomainMeasure& weightItemInDomain,
155 int maxLevel,
156 ASkipNodePredicate& skipNode)
157 {
158 std::vector<std::pair<ADomain, std::vector<T> > > found;
159 Node* node = findHeaviestLeaf(weightItemInDomain, maxLevel, skipNode);
160 while (node) {
161 const ADomain* domain = node;
162 found.emplace_back(*domain, std::vector<T>(node->begin(), node->end()));
163 for (WithSharedMark<T>& markableItem : *node) {
164 markableItem.mark();
165 }
166 node = findHeaviestLeaf(weightItemInDomain, maxLevel, skipNode);
167 }
168 return found;
169 }
170
176 template <class AItemInDomainMeasure, class ASkipNodePredicate>
177 std::unique_ptr<std::pair<ADomain, std::vector<T>>>
178 findHeaviestLeafSingle(AItemInDomainMeasure& weightItemInDomain,
179 int maxLevel,
180 ASkipNodePredicate& skipNode)
181 {
182 using Result = std::pair<ADomain, std::vector<T> >;
183 std::unique_ptr<Result> found = nullptr;
184 Node* node = findHeaviestLeaf(weightItemInDomain, maxLevel, skipNode);
185 if (node) {
186 const ADomain* domain = node;
187 found.reset(new Result(*domain, std::vector<T>(node->begin(), node->end())));
188 for (WithSharedMark<T>& markableItem : *node) {
189 markableItem.mark();
190 }
191 }
192 return found;
193 }
194
200 template <class AItemInDomainMeasure, class ASkipNodePredicate>
201 Node* findHeaviestLeaf(AItemInDomainMeasure& weightItemInDomain,
202 int maxLevel,
203 ASkipNodePredicate& skipNode)
204 {
205 Node* heaviestNode = nullptr;
206 Weight heighestWeigth = NAN;
207 auto isLeaf = [&heaviestNode, &heighestWeigth, maxLevel, &skipNode](Node * node) {
208 // Skip the expansion and the filling of the children
209 if (skipNode(node)) {
210 return true;
211 }
212
213 Weight nodeWeight = node->getWeight();
214 // Skip the expansion and filling of the children if the node has not enough weight
215 if (not std::isnan(heighestWeigth) and not(nodeWeight > heighestWeigth)) {
216 return true;
217 }
218
219 // Node is a leaf at the maximum level and is heavier than everything seen before.
220 // Save its content
221 // Do not walk children
222 if (node->getLevel() >= maxLevel) {
223 heaviestNode = node;
224 heighestWeigth = nodeWeight;
225 return true;
226 }
227 return false;
228 };
229 fillWalk(weightItemInDomain, isLeaf);
230 return heaviestNode;
231 }
232
233 public:
238 template<class AItemInDomainMeasure, class AIsLeafPredicate>
239 void fillWalk(AItemInDomainMeasure& weightItemInDomain,
240 AIsLeafPredicate& isLeaf)
241 {
242 auto walker = [&weightItemInDomain, &isLeaf](Node * node) {
243 // Check if node is a leaf
244 // Do not create children in this case
245 if (isLeaf(node)) {
246 // Do not walk children.
247 return false;
248 }
249
250 // Node is not a leaf.
251 // Check if it has children.
252 // If children have not been created, create and fill them.
253 typename Node::Children* children = node->getChildren();
254 if (not children) {
255 node->createChildren();
256 children = node->getChildren();
257 for (Node& childNode : *children) {
258 assert(childNode.getChildren() == nullptr);
259 assert(childNode.size() == 0);
260 auto measure =
261 [&childNode, &weightItemInDomain](WithSharedMark<T>& markableItem) -> Weight {
262 // Weighting function should not see the mark, but only the item itself.
263 T & item(markableItem);
264 return weightItemInDomain(item, &childNode);
265 };
266 childNode.insert(*node, measure);
267 }
268 }
269 // Continue to walk the children.
270 return true;
271 };
272 walkHeighWeightFirst(walker);
273 }
274
276 template<class ATreeWalker>
277 void walkHeighWeightFirst(ATreeWalker& walker)
278 {
279 auto priority = [](Node * node) -> float {
281 auto isMarked = [](const WithSharedMark<T>& markableItem) -> bool {
282 return markableItem.isMarked();
283 };
284 node->eraseIf(isMarked);
285 return node->getWeight();
286 };
287
288 this->walk(walker, priority);
289 }
290
292 void fell()
293 {
294 this->getTopNode().clear();
295 m_marks.clear();
296 Super::fell();
297 }
298
300 void raze()
301 {
302 this->fell();
303 Super::raze();
304 m_marks.shrink_to_fit();
305 }
306
307 private:
309 std::deque<bool> m_marks;
310 // Note: Have to use a deque here because std::vector<bool> is special
311 // std::vector<bool> m_marks;
312 };
313 }
315}
Class for a node in the tree.
Definition: DynTree.h:48
This is the base class for all hough trees.
Definition: DynTree.h:35
void fell()
Fell to tree meaning deleting all child nodes from the tree. Keeps the top node.
Definition: DynTree.h:334
void raze()
Like fell but also releases all memory the tree has acquired during long execution.
Definition: DynTree.h:349
Node & getTopNode()
Getter for the top node of the tree.
Definition: DynTree.h:237
void walk(AWalker &walker)
Forward walk to the top node.
Definition: DynTree.h:316
Dynamic tree structure with weighted items in each node which are markable through out the tree.
void fell()
Fell to tree meaning deleting all child nodes from the tree. Keeps the top node.
void seed(const Ts &items)
Take the item set and insert them into the top node of the hough space.
Node * findHeaviestLeaf(AItemInDomainMeasure &weightItemInDomain, int maxLevel, ASkipNodePredicate &skipNode)
Go through all children until the maxLevel is reached and find the leaf with the highest weight.
std::vector< std::pair< ADomain, std::vector< T > > > findLeavesDisjoint(AItemInDomainMeasure &weightItemInDomain, int maxLevel, ASkipNodePredicate &skipNode)
Find all children node at maximum level and add them to the result list. Skip nodes if skipNode retur...
void raze()
Like fell but also releases all memory the tree has acquired during long executions.
std::deque< bool > m_marks
Memory of the used marks of the items.
std::unique_ptr< std::pair< ADomain, std::vector< T > > > findHeaviestLeafSingle(AItemInDomainMeasure &weightItemInDomain, int maxLevel, ASkipNodePredicate &skipNode)
Go through all children until the maxLevel is reached and find the leaf with the highest weight.
void fillWalk(AItemInDomainMeasure &weightItemInDomain, AIsLeafPredicate &isLeaf)
Walk through the children and fill them if necessary until isLeaf returns true.
void walkHeighWeightFirst(ATreeWalker &walker)
Walk the tree investigating the heaviest children with priority.
std::vector< std::pair< ADomain, std::vector< T > > > findHeavyLeavesDisjoint(AItemInDomainMeasure &weightItemInDomain, int maxLevel, double minWeight)
Find all children node at maximum level and add them to the result list. Skip nodes if their weight i...
std::vector< std::pair< ADomain, std::vector< T > > > findHeaviestLeafRepeated(AItemInDomainMeasure &weightItemInDomain, int maxLevel, const Weight minWeight=NAN)
Go through all children until maxLevel is reached and find the heaviest leaves.
std::vector< std::pair< ADomain, std::vector< T > > > findHeaviestLeafRepeated(AItemInDomainMeasure &weightItemInDomain, int maxLevel, ASkipNodePredicate &skipNode)
Go through all children until maxLevel is reached and find the heaviest leaves.
typename Super::Node Node
Type of the node in the tree.
Type of tree for partitioning the hough space.
WeightedParititioningDynTree(ADomain topDomain, ADomainDivsion domainDivsion)
Constructor attaching a vector of the weighted items to the top most node domain.
Mixin class to attach a mark that is shared among many instances.
A mixin class to attach a set of weighted items to a class.
Abstract base class for different kinds of events.
STL namespace.