Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/spatial/geometry/bbox.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,40 @@ struct Box {
return (min + max) / 2;
}

// Minimum squared distance from a point to the nearest edge of this box.
// Returns 0 if the point is inside the box.
VALUE_TYPE MinDistanceSquared(const V &point) const {
VALUE_TYPE dx = 0, dy = 0;
if (point.x < min.x) {
dx = min.x - point.x;
} else if (point.x > max.x) {
dx = point.x - max.x;
}
if (point.y < min.y) {
dy = min.y - point.y;
} else if (point.y > max.y) {
dy = point.y - max.y;
}
return dx * dx + dy * dy;
}

// Minimum squared distance between two boxes.
// Returns 0 if the boxes overlap.
VALUE_TYPE MinDistanceSquared(const Box &other) const {
VALUE_TYPE dx = 0, dy = 0;
if (other.max.x < min.x) {
dx = min.x - other.max.x;
} else if (other.min.x > max.x) {
dx = other.min.x - max.x;
}
if (other.max.y < min.y) {
dy = min.y - other.max.y;
} else if (other.min.y > max.y) {
dy = other.min.y - max.y;
}
return dx * dx + dy * dy;
}

bool operator==(const Box &other) const {
return min == other.min && max == other.max;
}
Expand Down
87 changes: 86 additions & 1 deletion src/spatial/modules/main/spatial_functions_scalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "duckdb/planner/expression/bound_constant_expression.hpp"

#include "spatial/util/distance_extract.hpp"
#include "spatial/util/knn_extract.hpp"
#include "spatial/spatial_settings.hpp"

// Extra
Expand Down Expand Up @@ -5784,7 +5785,7 @@ struct ST_Distance_Sphere {
unique_ptr<FunctionData> Copy() const override {
auto copy = make_uniq<BindData>();
copy->always_xy = always_xy;
return copy;
return std::move(copy);
}
bool Equals(const FunctionData &other) const override {
auto &other_bind = other.Cast<BindData>();
Expand Down Expand Up @@ -9462,6 +9463,76 @@ constexpr const char *ST_X::NAME;
constexpr const char *ST_Y::NAME;
constexpr const char *ST_Z::NAME;

//======================================================================================================================
// ST_KNN
//======================================================================================================================

struct ST_KNN {

class BindData final : public FunctionData {
public:
int32_t k;
bool is_constant = false;

BindData(int32_t k_p) : k(k_p), is_constant(true) {
}

unique_ptr<FunctionData> Copy() const override {
return make_uniq<BindData>(k);
}

bool Equals(const FunctionData &other) const override {
auto &other_data = other.Cast<BindData>();
return is_constant == other_data.is_constant && k == other_data.k;
}
};

static unique_ptr<FunctionData> Bind3(ClientContext &context, ScalarFunction &bound_function,
vector<unique_ptr<Expression>> &arguments) {
// ST_KNN(geom1, geom2, k)
if (arguments[2]->IsFoldable()) {
const auto k_expr = ExpressionExecutor::EvaluateScalar(context, *arguments[2]);
const auto k_value = k_expr.GetValue<int32_t>();
if (k_value < 1) {
throw InvalidInputException("ST_KNN: k must be >= 1, got %d", k_value);
}
Function::EraseArgument(bound_function, arguments, 2);
return make_uniq<BindData>(k_value);
}
throw InvalidInputException("ST_KNN: k must be a constant expression");
}

static void Execute(DataChunk &args, ExpressionState &state, Vector &result) {
throw InvalidInputException("ST_KNN cannot be used outside of a JOIN ON clause");
}

static void Register(ExtensionLoader &loader) {
FunctionBuilder::RegisterScalar(loader, "ST_KNN", [](ScalarFunctionBuilder &func) {
// ST_KNN(geom1, geom2, k)
func.AddVariant([](ScalarFunctionVariantBuilder &variant) {
variant.AddParameter("geom1", LogicalType::GEOMETRY());
variant.AddParameter("geom2", LogicalType::GEOMETRY());
variant.AddParameter("k", LogicalType::INTEGER);
variant.SetReturnType(LogicalType::BOOLEAN);
variant.SetBind(Bind3);
variant.SetFunction(Execute);
});
func.SetDescription(R"(
K-nearest neighbor spatial join predicate.
Finds the k nearest geometries from geom2 for each geom1.
Must be used in a JOIN ON clause.
)");
func.SetExample(R"(
SELECT a.id, b.id
FROM table_a a
JOIN table_b b ON ST_KNN(a.geom, b.geom, 5);
)");
func.SetTag("ext", "spatial");
func.SetTag("category", "relation");
});
}
};

} // namespace

// Helper to access the constant distance from the bind data
Expand All @@ -9476,6 +9547,19 @@ bool ST_DWithinHelper::TryGetConstDistance(const unique_ptr<FunctionData> &bind_
return false;
}

// Helper to access the constant k from the bind data
bool ST_KNNHelper::TryGetConstK(const unique_ptr<FunctionData> &bind_data, int32_t &result) {
if (bind_data) {
const auto &data = bind_data->Cast<ST_KNN::BindData>();
if (data.is_constant) {
result = data.k;
return true;
}
}
return false;
}


//######################################################################################################################
// Register
//######################################################################################################################
Expand Down Expand Up @@ -9528,6 +9612,7 @@ void RegisterSpatialScalarFunctions(ExtensionLoader &loader) {
ST_InterpolatePoint::Register(loader);
ST_Intersects::Register(loader);
ST_Intersects_Extent::Register(loader);
ST_KNN::Register(loader);
ST_IsClosed::Register(loader);
ST_IsEmpty::Register(loader);
ST_Length::Register(loader);
Expand Down
2 changes: 2 additions & 0 deletions src/spatial/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,7 @@ set(EXTENSION_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/spatial_join_logical.cpp
${CMAKE_CURRENT_SOURCE_DIR}/spatial_join_physical.cpp
${CMAKE_CURRENT_SOURCE_DIR}/spatial_join_optimizer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/spatial_knn_join_logical.cpp
${CMAKE_CURRENT_SOURCE_DIR}/spatial_knn_join_physical.cpp
${CMAKE_CURRENT_SOURCE_DIR}/spatial_operator_extension.cpp
PARENT_SCOPE)
Loading