Skip to content

Commit 847e798

Browse files
committed
Eliminate sqrt from distance comparisons, use ternary for branchless codegen
1 parent 607417e commit 847e798

4 files changed

Lines changed: 46 additions & 28 deletions

File tree

include/dbscan/algo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ int DBSCAN(intT n, floatT* PF, double epsilon, intT minPts, bool* coreFlagOut, i
5757
intT count = 0;
5858
auto isCore = [&] (pointT *p) {
5959
if(count >= minPts) return true;
60-
if(p->distSqr(P[i]) <= epsSqr) {//todo sqrt opt
60+
if(p->distSqr(P[i]) <= epsSqr) {
6161
count ++;}
6262
return false;};
6363
G->nghPointMap(P[i].coordinate(), isCore);

include/dbscan/coreBccp.h

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,18 @@
2929
#include "pbbs/parallel.h"
3030
#include "pbbs/utils.h"
3131

32+
// r holds squared distance; using distSqr and nodeDistanceSqr avoids sqrt in hot path
3233
template<class nodeT, class objT>
3334
inline void compBcpCoreHSerial(nodeT* n1, nodeT* n2, floatT* r, intT* coreFlag, objT* P) {
34-
if (n1->nodeDistance(n2) > *r) return;
35+
if (n1->nodeDistanceSqr(n2) > *r) return;
3536

3637
if (n1->isLeaf() && n2->isLeaf()) {//basecase
3738
for (intT i=0; i<n1->size(); ++i) {
3839
for (intT j=0; j<n2->size(); ++j) {
3940
auto pi = n1->getItem(i);
4041
auto pj = n2->getItem(j);
4142
if (coreFlag[pi - P] && coreFlag[pj - P]) {
42-
floatT dist = pi->dist(*pj);
43+
floatT dist = pi->distSqr(*pj);
4344
r[0] = min(r[0], dist);
4445
}
4546
}
@@ -78,30 +79,31 @@ inline void compBcpCoreHSerial(nodeT* n1, nodeT* n2, floatT* r, intT* coreFlag,
7879

7980
template<class nodeT, class objT>
8081
inline void compBcpCoreHBase(nodeT* n1, nodeT* n2, floatT* r, intT* coreFlag, objT* P) {
81-
if (n1->nodeDistance(n2) > *r) return;
82+
if (n1->nodeDistanceSqr(n2) > *r) return;
8283

8384
if (n1->isLeaf() && n2->isLeaf()) {//basecase
8485
for (intT i=0; i<n1->size(); ++i) {
8586
for (intT j=0; j<n2->size(); ++j) {
8687
auto pi = n1->getItem(i);
8788
auto pj = n2->getItem(j);
8889
if (coreFlag[pi - P] && coreFlag[pj - P]) {
89-
floatT dist = pi->dist(*pj);
90+
floatT dist = pi->distSqr(*pj);
9091
utils::writeMin(r, dist);
9192
}
9293
}
9394
}
94-
} else {//recursive, todo consider call order, might help
95+
} else {//recursive
9596
if (n1->isLeaf()) {
96-
if (n1->nodeDistance(n2->L()) < n1->nodeDistance(n2->R())) {
97+
// nodeDistanceSqr avoids sqrt; monotonicity preserves ordering
98+
if (n1->nodeDistanceSqr(n2->L()) < n1->nodeDistanceSqr(n2->R())) {
9799
compBcpCoreH(n1, n2->L(), r, coreFlag, P);
98100
compBcpCoreH(n1, n2->R(), r, coreFlag, P);
99101
} else {
100102
compBcpCoreH(n1, n2->R(), r, coreFlag, P);
101103
compBcpCoreH(n1, n2->L(), r, coreFlag, P);
102104
}
103105
} else if (n2->isLeaf()) {
104-
if (n2->nodeDistance(n1->L()) < n2->nodeDistance(n1->R())) {
106+
if (n2->nodeDistanceSqr(n1->L()) < n2->nodeDistanceSqr(n1->R())) {
105107
compBcpCoreH(n2, n1->L(), r, coreFlag, P);
106108
compBcpCoreH(n2, n1->R(), r, coreFlag, P);
107109
} else {
@@ -115,7 +117,7 @@ inline void compBcpCoreHBase(nodeT* n1, nodeT* n2, floatT* r, intT* coreFlag, ob
115117
ordering[2] = make_pair(n2->L(), n1->R());
116118
ordering[3] = make_pair(n2->R(), n1->R());
117119
auto bbd = [&](pair<nodeT*,nodeT*> p1, pair<nodeT*,nodeT*> p2) {
118-
return p1.first->nodeDistance(p1.second) < p2.first->nodeDistance(p2.second);};
120+
return p1.first->nodeDistanceSqr(p1.second) < p2.first->nodeDistanceSqr(p2.second);};
119121
quickSortSerial(ordering, 4, bbd);
120122
for (intT o=0; o<4; ++o) {
121123
compBcpCoreH(ordering[o].first, ordering[o].second, r, coreFlag, P);}
@@ -125,21 +127,21 @@ inline void compBcpCoreHBase(nodeT* n1, nodeT* n2, floatT* r, intT* coreFlag, ob
125127

126128
template<class nodeT, class objT>
127129
inline void compBcpCoreH(nodeT* n1, nodeT* n2, floatT* r, intT* coreFlag, objT* P) {
128-
if (n1->nodeDistance(n2) > *r) return;
130+
if (n1->nodeDistanceSqr(n2) > *r) return;
129131

130132
if ((n1->isLeaf() && n2->isLeaf()) || (n1->size()+n2->size() < 2000)) {
131133
return compBcpCoreHBase(n1, n2, r, coreFlag, P);
132-
} else {//recursive, todo consider call order, might help
134+
} else {//recursive
133135
if (n1->isLeaf()) {
134-
if (n1->nodeDistance(n2->L()) < n1->nodeDistance(n2->R())) {
136+
if (n1->nodeDistanceSqr(n2->L()) < n1->nodeDistanceSqr(n2->R())) {
135137
par_do([&](){compBcpCoreH(n1, n2->L(), r, coreFlag, P);},
136138
[&](){compBcpCoreH(n1, n2->R(), r, coreFlag, P);});
137139
} else {
138140
par_do([&](){compBcpCoreH(n1, n2->R(), r, coreFlag, P);},
139141
[&](){compBcpCoreH(n1, n2->L(), r, coreFlag, P);});
140142
}
141143
} else if (n2->isLeaf()) {
142-
if (n2->nodeDistance(n1->L()) < n2->nodeDistance(n1->R())) {
144+
if (n2->nodeDistanceSqr(n1->L()) < n2->nodeDistanceSqr(n1->R())) {
143145
par_do([&](){compBcpCoreH(n2, n1->L(), r, coreFlag, P);},
144146
[&](){compBcpCoreH(n2, n1->R(), r, coreFlag, P);});
145147
} else {
@@ -153,7 +155,7 @@ inline void compBcpCoreH(nodeT* n1, nodeT* n2, floatT* r, intT* coreFlag, objT*
153155
ordering[2] = make_pair(n2->L(), n1->R());
154156
ordering[3] = make_pair(n2->R(), n1->R());
155157
auto bbd = [&](pair<nodeT*,nodeT*> p1, pair<nodeT*,nodeT*> p2) {
156-
return p1.first->nodeDistance(p1.second) < p2.first->nodeDistance(p2.second);};
158+
return p1.first->nodeDistanceSqr(p1.second) < p2.first->nodeDistanceSqr(p2.second);};
157159
quickSortSerial(ordering, 4, bbd);
158160
parallel_for (0, 4, [&](intT o) {
159161
compBcpCoreH(ordering[o].first, ordering[o].second, r, coreFlag, P);}, 1);
@@ -179,11 +181,11 @@ inline bool hasEdge(intT n1, intT n2, intT* coreFlag, objT* P, floatT epsilon, c
179181

180182
if (!trees[n1])
181183
trees[n1] = new treeT(cells[n1].getItem(), cells[n1].size(), false);//todo allocation, parallel
182-
if (!trees[n2])
184+
if (!trees[n2])
183185
trees[n2] = new treeT(cells[n2].getItem(), cells[n2].size(), false);//todo allocation, parallel
184186
floatT r = floatMax();
185187
compBcpCoreH(trees[n1]->rootNode(), trees[n2]->rootNode(), &r, coreFlag, P);
186-
return r <= epsilon;
188+
return r <= epsilon * epsilon; // r holds squared distance now
187189
}
188190

189191
#endif

include/dbscan/kdNode.h

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,20 @@ class kdNode {
264264
return 0; // intersect
265265
}
266266

267+
// squared bb distance — avoids sqrt, safe to compare against rSqr thresholds
268+
// ternary instead of std::max lets compiler emit maxss/fmax without NaN concerns
269+
inline floatT nodeDistanceSqr(nodeT* n2) {
270+
floatT rsqr = 0;
271+
for (int d = 0; d < dim; ++d) {
272+
floatT gapA = pMin[d] - n2->pMax[d];
273+
floatT gapB = n2->pMin[d] - pMax[d];
274+
floatT gap = gapA > gapB ? gapA : gapB;
275+
gap = gap > 0 ? gap : 0;
276+
rsqr += gap * gap;
277+
}
278+
return rsqr;
279+
}
280+
267281
//return the far bb distance between n1 and n2
268282
inline floatT nodeFarDistance(nodeT* n2) {
269283
floatT result = 0;
@@ -296,44 +310,45 @@ class kdNode {
296310
}
297311

298312
//vecT need to be vector<objT*>
313+
// rSqr is the squared query radius; use distSqr to avoid sqrt per point
299314
template<class vecT>
300-
void rangeNeighbor(pointT queryPt, floatT r, pointT pMin1, pointT pMax1, vecT* accum) {
315+
void rangeNeighbor(pointT queryPt, floatT rSqr, pointT pMin1, pointT pMax1, vecT* accum) {
301316
int relation = boxCompare(pMin1, pMax1, pMin, pMax);
302317
if (relation == boxInclude) {
303318
for(intT i=0; i<n; ++i) {
304-
if (items[i]->getCoordObj()->dist(queryPt) <= r)
319+
if (items[i]->getCoordObj()->distSqr(queryPt) <= rSqr)
305320
accum->push_back(items[i]);
306321
}
307322
} else if (relation == boxOverlap) {
308323
if (isLeaf()) {
309324
for(intT i=0; i<n; ++i) {
310-
if (items[i]->getCoordObj()->dist(queryPt) <= r &&
325+
if (items[i]->getCoordObj()->distSqr(queryPt) <= rSqr &&
311326
itemInBox(pMin1, pMax1, items[i])) accum->push_back(items[i]);
312327
}
313328
} else {
314-
left->rangeNeighbor(queryPt, r, pMin1, pMax1, accum);
315-
right->rangeNeighbor(queryPt, r, pMin1, pMax1, accum);}
329+
left->rangeNeighbor(queryPt, rSqr, pMin1, pMax1, accum);
330+
right->rangeNeighbor(queryPt, rSqr, pMin1, pMax1, accum);}
316331
}
317332
}
318333

319334
template<class func, class func2>
320-
void rangeNeighbor(pointT queryPt, floatT r, pointT pMin1, pointT pMax1, func term, func2 doTerm) {
335+
void rangeNeighbor(pointT queryPt, floatT rSqr, pointT pMin1, pointT pMax1, func term, func2 doTerm) {
321336
if (term()) return;
322337
int relation = boxCompare(pMin1, pMax1, pMin, pMax);
323338
if (relation == boxInclude) {
324339
for(intT i=0; i<n; ++i) {
325-
if (items[i]->getCoordObj()->dist(queryPt) <= r &&
340+
if (items[i]->getCoordObj()->distSqr(queryPt) <= rSqr &&
326341
doTerm(items[i])) break;
327342
}
328343
} else if (relation == boxOverlap) {
329344
if (isLeaf()) {
330345
for(intT i=0; i<n; ++i) {
331-
if (items[i]->getCoordObj()->dist(queryPt) <= r &&
346+
if (items[i]->getCoordObj()->distSqr(queryPt) <= rSqr &&
332347
doTerm(items[i])) break;
333348
}
334349
} else {
335-
left->rangeNeighbor(queryPt, r, pMin1, pMax1, term, doTerm);
336-
right->rangeNeighbor(queryPt, r, pMin1, pMax1, term, doTerm);}
350+
left->rangeNeighbor(queryPt, rSqr, pMin1, pMax1, term, doTerm);
351+
right->rangeNeighbor(queryPt, rSqr, pMin1, pMax1, term, doTerm);}
337352
}
338353
}
339354

include/dbscan/kdTree.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,16 @@ class kdTree {
8484
queryPt.updateX(i, center[i]);
8585
pMin1.updateX(i, center[i]-r);
8686
pMax1.updateX(i, center[i]+r);}
87+
floatT rSqr = r * r;
8788
if(cache) {
8889
if(!accum) accum = new vecT();
89-
root->rangeNeighbor(queryPt, r, pMin1, pMax1, accum);
90+
root->rangeNeighbor(queryPt, rSqr, pMin1, pMax1, accum);
9091
for (auto accum_i : *accum) {
9192
if(doTerm(accum_i)) break;
9293
}
9394
return accum;
9495
} else {
95-
root->rangeNeighbor(queryPt, r, pMin1, pMax1, term, doTerm);
96+
root->rangeNeighbor(queryPt, rSqr, pMin1, pMax1, term, doTerm);
9697
return NULL;
9798
}
9899
}

0 commit comments

Comments
 (0)