Skip to content

Commit b31f019

Browse files
feat: add ST_KNN spatial join operator (in-memory, planar)
Adds a K-nearest neighbor spatial join to duckdb-spatial: - ST_KNN(geom1, geom2, k) — scalar stub recognized by the optimizer - Optimizer rewrites JOIN ON ST_KNN(...) into SPATIAL_KNN_JOIN - Hjaltason-Samet priority queue KNN on FlatRTree - Exact distance refinement with adaptive overfetch (2x, retry 8x) - INNER and LEFT JOIN support - JOO child swap handling via build_child_idx Stripped for review (follow-up PRs): - No out-of-core / spill-to-disk (tree always in memory) - No spheroid support (planar distance only) - No partitioning (single R-tree) The FlatRTree is compact (~24 bytes/entry), so in-memory handles up to ~100M entries (2.4GB tree) on commodity hardware.
1 parent 324000b commit b31f019

9 files changed

Lines changed: 1185 additions & 0 deletions

src/spatial/modules/main/spatial_functions_scalar.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "duckdb/planner/expression/bound_constant_expression.hpp"
1919

2020
#include "spatial/util/distance_extract.hpp"
21+
#include "spatial/util/knn_extract.hpp"
2122
#include "spatial/spatial_settings.hpp"
2223

2324
// Extra
@@ -9462,6 +9463,76 @@ constexpr const char *ST_X::NAME;
94629463
constexpr const char *ST_Y::NAME;
94639464
constexpr const char *ST_Z::NAME;
94649465

9466+
//======================================================================================================================
9467+
// ST_KNN
9468+
//======================================================================================================================
9469+
9470+
struct ST_KNN {
9471+
9472+
class BindData final : public FunctionData {
9473+
public:
9474+
int32_t k;
9475+
bool is_constant = false;
9476+
9477+
BindData(int32_t k_p) : k(k_p), is_constant(true) {
9478+
}
9479+
9480+
unique_ptr<FunctionData> Copy() const override {
9481+
return make_uniq<BindData>(k);
9482+
}
9483+
9484+
bool Equals(const FunctionData &other) const override {
9485+
auto &other_data = other.Cast<BindData>();
9486+
return is_constant == other_data.is_constant && k == other_data.k;
9487+
}
9488+
};
9489+
9490+
static unique_ptr<FunctionData> Bind3(ClientContext &context, ScalarFunction &bound_function,
9491+
vector<unique_ptr<Expression>> &arguments) {
9492+
// ST_KNN(geom1, geom2, k)
9493+
if (arguments[2]->IsFoldable()) {
9494+
const auto k_expr = ExpressionExecutor::EvaluateScalar(context, *arguments[2]);
9495+
const auto k_value = k_expr.GetValue<int32_t>();
9496+
if (k_value < 1) {
9497+
throw InvalidInputException("ST_KNN: k must be >= 1, got %d", k_value);
9498+
}
9499+
Function::EraseArgument(bound_function, arguments, 2);
9500+
return make_uniq<BindData>(k_value);
9501+
}
9502+
throw InvalidInputException("ST_KNN: k must be a constant expression");
9503+
}
9504+
9505+
static void Execute(DataChunk &args, ExpressionState &state, Vector &result) {
9506+
throw InvalidInputException("ST_KNN cannot be used outside of a JOIN ON clause");
9507+
}
9508+
9509+
static void Register(ExtensionLoader &loader) {
9510+
FunctionBuilder::RegisterScalar(loader, "ST_KNN", [](ScalarFunctionBuilder &func) {
9511+
// ST_KNN(geom1, geom2, k)
9512+
func.AddVariant([](ScalarFunctionVariantBuilder &variant) {
9513+
variant.AddParameter("geom1", LogicalType::GEOMETRY());
9514+
variant.AddParameter("geom2", LogicalType::GEOMETRY());
9515+
variant.AddParameter("k", LogicalType::INTEGER);
9516+
variant.SetReturnType(LogicalType::BOOLEAN);
9517+
variant.SetBind(Bind3);
9518+
variant.SetFunction(Execute);
9519+
});
9520+
func.SetDescription(R"(
9521+
K-nearest neighbor spatial join predicate.
9522+
Finds the k nearest geometries from geom2 for each geom1.
9523+
Must be used in a JOIN ON clause.
9524+
)");
9525+
func.SetExample(R"(
9526+
SELECT a.id, b.id
9527+
FROM table_a a
9528+
JOIN table_b b ON ST_KNN(a.geom, b.geom, 5);
9529+
)");
9530+
func.SetTag("ext", "spatial");
9531+
func.SetTag("category", "relation");
9532+
});
9533+
}
9534+
};
9535+
94659536
} // namespace
94669537

94679538
// Helper to access the constant distance from the bind data
@@ -9476,6 +9547,19 @@ bool ST_DWithinHelper::TryGetConstDistance(const unique_ptr<FunctionData> &bind_
94769547
return false;
94779548
}
94789549

9550+
// Helper to access the constant k from the bind data
9551+
bool ST_KNNHelper::TryGetConstK(const unique_ptr<FunctionData> &bind_data, int32_t &result) {
9552+
if (bind_data) {
9553+
const auto &data = bind_data->Cast<ST_KNN::BindData>();
9554+
if (data.is_constant) {
9555+
result = data.k;
9556+
return true;
9557+
}
9558+
}
9559+
return false;
9560+
}
9561+
9562+
94799563
//######################################################################################################################
94809564
// Register
94819565
//######################################################################################################################
@@ -9528,6 +9612,7 @@ void RegisterSpatialScalarFunctions(ExtensionLoader &loader) {
95289612
ST_InterpolatePoint::Register(loader);
95299613
ST_Intersects::Register(loader);
95309614
ST_Intersects_Extent::Register(loader);
9615+
ST_KNN::Register(loader);
95319616
ST_IsClosed::Register(loader);
95329617
ST_IsEmpty::Register(loader);
95339618
ST_Length::Register(loader);

src/spatial/operators/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,7 @@ set(EXTENSION_SOURCES
33
${CMAKE_CURRENT_SOURCE_DIR}/spatial_join_logical.cpp
44
${CMAKE_CURRENT_SOURCE_DIR}/spatial_join_physical.cpp
55
${CMAKE_CURRENT_SOURCE_DIR}/spatial_join_optimizer.cpp
6+
${CMAKE_CURRENT_SOURCE_DIR}/spatial_knn_join_logical.cpp
7+
${CMAKE_CURRENT_SOURCE_DIR}/spatial_knn_join_physical.cpp
68
${CMAKE_CURRENT_SOURCE_DIR}/spatial_operator_extension.cpp
79
PARENT_SCOPE)

src/spatial/operators/spatial_join_optimizer.cpp

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "spatial_join_optimizer.hpp"
22
#include "spatial_join_logical.hpp"
3+
#include "spatial/operators/spatial_knn_join_logical.hpp"
34
#include "spatial/util/distance_extract.hpp"
5+
#include "spatial/util/knn_extract.hpp"
46
#include "spatial/spatial_types.hpp"
57

68
#include "duckdb/main/database.hpp"
@@ -319,7 +321,168 @@ static void TrySwapAnyJoin(OptimizerExtensionInput &input, unique_ptr<LogicalOpe
319321
plan = std::move(spatial_join);
320322
}
321323

324+
//======================================================================================================================
325+
// KNN Join Detection
326+
//======================================================================================================================
327+
328+
static bool IsKNNJoinPredicate(const unique_ptr<Expression> &expr, const unordered_set<idx_t> &left_bindings,
329+
const unordered_set<idx_t> &right_bindings, bool &needs_flipping, int32_t &k_value) {
330+
331+
const auto total_side = JoinSide::GetJoinSide(*expr, left_bindings, right_bindings);
332+
if (total_side != JoinSide::BOTH) {
333+
return false;
334+
}
335+
336+
if (expr->type != ExpressionType::BOUND_FUNCTION) {
337+
return false;
338+
}
339+
340+
auto &func = expr->Cast<BoundFunctionExpression>();
341+
342+
if (!StringUtil::CIEquals(func.function.name, "ST_KNN")) {
343+
return false;
344+
}
345+
346+
// After bind, ST_KNN has 2 args (k was folded into bind_data)
347+
if (func.children.size() != 2) {
348+
return false;
349+
}
350+
351+
if (func.return_type != LogicalType::BOOLEAN) {
352+
return false;
353+
}
354+
355+
const auto left_side = JoinSide::GetJoinSide(*func.children[0], left_bindings, right_bindings);
356+
const auto right_side = JoinSide::GetJoinSide(*func.children[1], left_bindings, right_bindings);
357+
358+
if (left_side == JoinSide::BOTH || right_side == JoinSide::BOTH) {
359+
return false;
360+
}
361+
362+
needs_flipping = (left_side == JoinSide::RIGHT);
363+
364+
if (!ST_KNNHelper::TryGetConstK(func.bind_info, k_value)) {
365+
return false;
366+
}
367+
368+
return true;
369+
}
370+
371+
static bool TrySwapKNNAnyJoin(OptimizerExtensionInput &input, unique_ptr<LogicalOperator> &plan) {
372+
auto &op = *plan;
373+
374+
if (op.type != LogicalOperatorType::LOGICAL_ANY_JOIN) {
375+
return false;
376+
}
377+
378+
auto &any_join = op.Cast<LogicalAnyJoin>();
379+
380+
// KNN supports INNER and LEFT joins.
381+
// RIGHT may appear when JOO swaps a LEFT join — we convert it back to LEFT after child swap.
382+
// Explicit RIGHT joins from the user are not supported (would need build-side outer emit).
383+
if (any_join.join_type != JoinType::INNER && any_join.join_type != JoinType::LEFT &&
384+
any_join.join_type != JoinType::RIGHT) {
385+
return false;
386+
}
387+
388+
auto &left_child = any_join.children[0];
389+
auto &right_child = any_join.children[1];
390+
unordered_set<idx_t> left_bindings;
391+
unordered_set<idx_t> right_bindings;
392+
LogicalJoin::GetTableReferences(*left_child, left_bindings);
393+
LogicalJoin::GetTableReferences(*right_child, right_bindings);
394+
395+
// Split the join condition by AND
396+
vector<unique_ptr<Expression>> expressions;
397+
expressions.push_back(any_join.condition->Copy());
398+
LogicalFilter::SplitPredicates(expressions);
399+
400+
unique_ptr<Expression> knn_pred_expr = nullptr;
401+
vector<unique_ptr<Expression>> extra_predicates;
402+
int32_t k_value = 1;
403+
404+
for (auto &expr : expressions) {
405+
bool unused_flip = false;
406+
int32_t k_tmp = 0;
407+
if (!knn_pred_expr && IsKNNJoinPredicate(expr, left_bindings, right_bindings, unused_flip, k_tmp)) {
408+
knn_pred_expr = std::move(expr);
409+
k_value = k_tmp;
410+
} else if (expr) {
411+
extra_predicates.push_back(std::move(expr));
412+
}
413+
}
414+
415+
if (!knn_pred_expr) {
416+
return false;
417+
}
418+
419+
// For non-INNER joins, extra predicates can't be pushed as filters safely
420+
if (!extra_predicates.empty() && any_join.join_type != JoinType::INNER) {
421+
return false;
422+
}
423+
424+
// ST_KNN(probe_geom, build_geom, k): arg0 = probe side, arg1 = build side.
425+
// DuckDB sinks children[1], so we need the build side at children[1].
426+
// Determine which child the build key (second arg) currently references.
427+
auto &knn_func = knn_pred_expr->Cast<BoundFunctionExpression>();
428+
auto build_key_side = JoinSide::GetJoinSide(*knn_func.children[1], left_bindings, right_bindings);
429+
bool needs_child_swap = (build_key_side == JoinSide::LEFT);
430+
431+
432+
// When JOO swaps children, LEFT becomes RIGHT. Convert back to LEFT after our swap.
433+
// If RIGHT wasn't produced by JOO swap, reject it — we don't support explicit RIGHT.
434+
auto effective_join_type = any_join.join_type;
435+
if (effective_join_type == JoinType::RIGHT) {
436+
if (needs_child_swap) {
437+
effective_join_type = JoinType::LEFT;
438+
} else {
439+
// Explicit RIGHT JOIN from user — not supported for KNN
440+
return false;
441+
}
442+
}
443+
444+
auto knn_join = make_uniq<LogicalSpatialKNNJoin>(effective_join_type);
445+
knn_join->spatial_predicate = std::move(knn_pred_expr);
446+
knn_join->k = k_value;
447+
knn_join->children = std::move(any_join.children);
448+
knn_join->expressions = std::move(any_join.expressions);
449+
450+
if (needs_child_swap) {
451+
std::swap(knn_join->children[0], knn_join->children[1]);
452+
knn_join->left_projection_map = std::move(any_join.right_projection_map);
453+
knn_join->right_projection_map = std::move(any_join.left_projection_map);
454+
} else {
455+
knn_join->left_projection_map = std::move(any_join.left_projection_map);
456+
knn_join->right_projection_map = std::move(any_join.right_projection_map);
457+
}
458+
459+
// build_child_idx is always 1 after potential swap
460+
knn_join->build_child_idx = 1;
461+
462+
knn_join->join_stats = std::move(any_join.join_stats);
463+
knn_join->has_estimated_cardinality = any_join.has_estimated_cardinality;
464+
knn_join->estimated_cardinality = any_join.estimated_cardinality;
465+
466+
if (!extra_predicates.empty()) {
467+
// Wrap KNN join in a filter for extra AND predicates (INNER only)
468+
auto filter = make_uniq<LogicalFilter>();
469+
filter->expressions = std::move(extra_predicates);
470+
filter->children.push_back(std::move(knn_join));
471+
plan = std::move(filter);
472+
} else {
473+
plan = std::move(knn_join);
474+
}
475+
return true;
476+
}
477+
478+
//======================================================================================================================
479+
322480
static void InsertSpatialJoin(OptimizerExtensionInput &input, unique_ptr<LogicalOperator> &plan) {
481+
// Try KNN first (more specific)
482+
if (TrySwapKNNAnyJoin(input, plan)) {
483+
return;
484+
}
485+
323486
if (TrySwapComparisonJoin(input, plan)) {
324487
return;
325488
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#include "spatial/operators/spatial_knn_join_logical.hpp"
2+
#include "spatial/operators/spatial_knn_join_physical.hpp"
3+
4+
#include "duckdb/planner/expression/bound_function_expression.hpp"
5+
#include "duckdb/execution/column_binding_resolver.hpp"
6+
#include "duckdb/common/serializer/serializer.hpp"
7+
#include "duckdb/common/serializer/deserializer.hpp"
8+
9+
namespace duckdb {
10+
11+
LogicalSpatialKNNJoin::LogicalSpatialKNNJoin(JoinType join_type_p) : join_type(join_type_p) {
12+
}
13+
14+
vector<ColumnBinding> LogicalSpatialKNNJoin::GetColumnBindings() {
15+
auto left_bindings = MapBindings(children[0]->GetColumnBindings(), left_projection_map);
16+
auto right_bindings = MapBindings(children[1]->GetColumnBindings(), right_projection_map);
17+
left_bindings.insert(left_bindings.end(), right_bindings.begin(), right_bindings.end());
18+
return left_bindings;
19+
}
20+
21+
void LogicalSpatialKNNJoin::ResolveColumnBindings(ColumnBindingResolver &res, vector<ColumnBinding> &bindings) {
22+
auto &cond = spatial_predicate->Cast<BoundFunctionExpression>();
23+
24+
// After the optimizer, children are always in correct order:
25+
// children[0] = probe (arg0), children[1] = build (arg1).
26+
res.VisitOperator(*children[0]);
27+
res.VisitExpression(&cond.children[0]);
28+
res.VisitOperator(*children[1]);
29+
res.VisitExpression(&cond.children[1]);
30+
31+
bindings = GetColumnBindings();
32+
}
33+
34+
void LogicalSpatialKNNJoin::ResolveTypes() {
35+
types = MapTypes(children[0]->types, left_projection_map);
36+
auto right_types = MapTypes(children[1]->types, right_projection_map);
37+
types.insert(types.end(), right_types.begin(), right_types.end());
38+
}
39+
40+
PhysicalOperator &LogicalSpatialKNNJoin::CreatePlan(ClientContext &context, PhysicalPlanGenerator &generator) {
41+
auto &left = generator.CreatePlan(*children[0]);
42+
auto &right = generator.CreatePlan(*children[1]);
43+
44+
return generator.Make<PhysicalSpatialKNNJoin>(*this, left, right, std::move(spatial_predicate), join_type,
45+
estimated_cardinality, k);
46+
}
47+
48+
void LogicalSpatialKNNJoin::Serialize(Serializer &writer) const {
49+
LogicalExtensionOperator::Serialize(writer);
50+
writer.WritePropertyWithDefault(300, "operator_type", string(OPERATOR_TYPE_NAME));
51+
writer.WritePropertyWithDefault<JoinType>(400, "join_type", join_type, JoinType::INNER);
52+
writer.WritePropertyWithDefault<vector<idx_t>>(402, "left_projection_map", left_projection_map);
53+
writer.WritePropertyWithDefault<vector<idx_t>>(403, "right_projection_map", right_projection_map);
54+
writer.WritePropertyWithDefault<unique_ptr<Expression>>(404, "spatial_predicate", spatial_predicate);
55+
writer.WritePropertyWithDefault<int32_t>(405, "k", k, 1);
56+
writer.WritePropertyWithDefault<idx_t>(406, "build_child_idx", build_child_idx, static_cast<idx_t>(1));
57+
}
58+
59+
unique_ptr<LogicalExtensionOperator> LogicalSpatialKNNJoin::Deserialize(Deserializer &reader) {
60+
auto join_type = reader.ReadPropertyWithExplicitDefault<JoinType>(400, "join_type", JoinType::INNER);
61+
auto left_projection_map = reader.ReadPropertyWithDefault<vector<idx_t>>(402, "left_projection_map");
62+
auto right_projection_map = reader.ReadPropertyWithDefault<vector<idx_t>>(403, "right_projection_map");
63+
auto spatial_predicate = reader.ReadPropertyWithDefault<unique_ptr<Expression>>(404, "spatial_predicate");
64+
auto k = reader.ReadPropertyWithExplicitDefault<int32_t>(405, "k", 1);
65+
auto build_child_idx = reader.ReadPropertyWithExplicitDefault<idx_t>(406, "build_child_idx", static_cast<idx_t>(1));
66+
67+
auto result = make_uniq<LogicalSpatialKNNJoin>(join_type);
68+
result->left_projection_map = std::move(left_projection_map);
69+
result->right_projection_map = std::move(right_projection_map);
70+
result->spatial_predicate = std::move(spatial_predicate);
71+
result->k = k;
72+
result->build_child_idx = build_child_idx;
73+
74+
return std::move(result);
75+
}
76+
77+
} // namespace duckdb

0 commit comments

Comments
 (0)