Skip to content

Commit 08da6a7

Browse files
authored
[Enhancement] Support some spatial functions (#48695)
Support for ST_Intersects, ST_Disjoint, ST_Touches sql functions.
1 parent 2b90e70 commit 08da6a7

File tree

13 files changed

+2533
-25
lines changed

13 files changed

+2533
-25
lines changed

be/src/geo/geo_types.cpp

Lines changed: 600 additions & 0 deletions
Large diffs are not rendered by default.

be/src/geo/geo_types.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ class GeoShape {
6060
virtual std::string as_wkt() const = 0;
6161

6262
virtual bool contains(const GeoShape* rhs) const { return false; }
63+
64+
virtual bool disjoint(const GeoShape* rhs) const { return false; }
65+
66+
virtual bool intersects(const GeoShape* rhs) const { return false; }
67+
68+
virtual bool touches(const GeoShape* rhs) const { return false; }
69+
6370
virtual std::string to_string() const { return ""; }
6471
static std::string as_binary(GeoShape* rhs);
6572

@@ -82,6 +89,10 @@ class GeoPoint : public GeoShape {
8289

8390
GeoCoordinateList to_coords() const;
8491

92+
bool intersects(const GeoShape* rhs) const override;
93+
bool disjoint(const GeoShape* rhs) const override;
94+
bool touches(const GeoShape* rhs) const override;
95+
8596
GeoShapeType type() const override { return GEO_SHAPE_POINT; }
8697

8798
const S2Point* point() const { return _point.get(); }
@@ -119,6 +130,10 @@ class GeoLine : public GeoShape {
119130

120131
GeoCoordinateList to_coords() const;
121132

133+
bool intersects(const GeoShape* rhs) const override;
134+
bool disjoint(const GeoShape* rhs) const override;
135+
bool touches(const GeoShape* rhs) const override;
136+
122137
GeoShapeType type() const override { return GEO_SHAPE_LINE_STRING; }
123138
const S2Polyline* polyline() const { return _polyline.get(); }
124139

@@ -148,7 +163,14 @@ class GeoPolygon : public GeoShape {
148163
GeoShapeType type() const override { return GEO_SHAPE_POLYGON; }
149164
const S2Polygon* polygon() const { return _polygon.get(); }
150165

166+
bool intersects(const GeoShape* rhs) const override;
167+
bool disjoint(const GeoShape* rhs) const override;
168+
bool touches(const GeoShape* rhs) const override;
151169
bool contains(const GeoShape* rhs) const override;
170+
171+
bool polygon_touch_point(const S2Polygon* polygon, const S2Point* point) const;
172+
bool polygon_touch_polygon(const S2Polygon* polygon1, const S2Polygon* polygon2) const;
173+
152174
std::string as_wkt() const override;
153175

154176
int numLoops() const;
@@ -174,6 +196,11 @@ class GeoCircle : public GeoShape {
174196

175197
GeoShapeType type() const override { return GEO_SHAPE_CIRCLE; }
176198

199+
const S2Cap* circle() const { return _cap.get(); }
200+
201+
bool intersects(const GeoShape* rhs) const override;
202+
bool disjoint(const GeoShape* rhs) const override;
203+
bool touches(const GeoShape* rhs) const override;
177204
bool contains(const GeoShape* rhs) const override;
178205
std::string as_wkt() const override;
179206

be/src/vec/functions/functions_geo.cpp

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -626,11 +626,13 @@ struct StCircle {
626626
}
627627
};
628628

629-
struct StContains {
629+
template <typename Func>
630+
struct StRelationFunction {
630631
static constexpr auto NEED_CONTEXT = true;
631-
static constexpr auto NAME = "st_contains";
632+
static constexpr auto NAME = Func::NAME;
632633
static const size_t NUM_ARGS = 2;
633634
using Type = DataTypeUInt8;
635+
634636
static Status execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
635637
size_t result) {
636638
DCHECK_EQ(arguments.size(), 2);
@@ -642,8 +644,7 @@ struct StContains {
642644

643645
const auto size = std::max(left_column->size(), right_column->size());
644646

645-
auto res = ColumnUInt8::create();
646-
res->reserve(size);
647+
auto res = ColumnUInt8::create(size, 0);
647648
auto null_map = ColumnUInt8::create(size, 0);
648649
auto& null_map_data = null_map->get_data();
649650

@@ -660,55 +661,50 @@ struct StContains {
660661
}
661662

662663
static void loop_do(StringRef& lhs_value, StringRef& rhs_value,
663-
std::vector<std::shared_ptr<GeoShape>>& shapes, int& i,
664+
std::vector<std::unique_ptr<GeoShape>>& shapes,
664665
ColumnUInt8::MutablePtr& res, NullMap& null_map, int row) {
665666
StringRef* strs[2] = {&lhs_value, &rhs_value};
666-
for (i = 0; i < 2; ++i) {
667-
shapes[i] =
668-
std::shared_ptr<GeoShape>(GeoShape::from_encoded(strs[i]->data, strs[i]->size));
669-
if (shapes[i] == nullptr) {
667+
for (int i = 0; i < 2; ++i) {
668+
std::unique_ptr<GeoShape> shape(GeoShape::from_encoded(strs[i]->data, strs[i]->size));
669+
shapes[i] = std::move(shape);
670+
if (!shapes[i]) {
670671
null_map[row] = 1;
671-
res->insert_default();
672672
break;
673673
}
674674
}
675-
676-
if (i == 2) {
677-
auto contains_value = shapes[0]->contains(shapes[1].get());
678-
res->insert_data(const_cast<const char*>((char*)&contains_value), 0);
675+
if (shapes[0] && shapes[1]) {
676+
auto relation_value = Func::evaluate(shapes[0].get(), shapes[1].get());
677+
res->get_data()[row] = relation_value;
679678
}
680679
}
681680

682681
static void const_vector(const ColumnPtr& left_column, const ColumnPtr& right_column,
683682
ColumnUInt8::MutablePtr& res, NullMap& null_map, const size_t size) {
684-
int i;
685683
auto lhs_value = left_column->get_data_at(0);
686-
std::vector<std::shared_ptr<GeoShape>> shapes = {nullptr, nullptr};
684+
std::vector<std::unique_ptr<GeoShape>> shapes(2);
687685
for (int row = 0; row < size; ++row) {
688686
auto rhs_value = right_column->get_data_at(row);
689-
loop_do(lhs_value, rhs_value, shapes, i, res, null_map, row);
687+
loop_do(lhs_value, rhs_value, shapes, res, null_map, row);
690688
}
691689
}
692690

693691
static void vector_const(const ColumnPtr& left_column, const ColumnPtr& right_column,
694692
ColumnUInt8::MutablePtr& res, NullMap& null_map, const size_t size) {
695-
int i;
696693
auto rhs_value = right_column->get_data_at(0);
697-
std::vector<std::shared_ptr<GeoShape>> shapes = {nullptr, nullptr};
694+
std::vector<std::unique_ptr<GeoShape>> shapes(2);
698695
for (int row = 0; row < size; ++row) {
699696
auto lhs_value = left_column->get_data_at(row);
700-
loop_do(lhs_value, rhs_value, shapes, i, res, null_map, row);
697+
loop_do(lhs_value, rhs_value, shapes, res, null_map, row);
701698
}
702699
}
703700

704701
static void vector_vector(const ColumnPtr& left_column, const ColumnPtr& right_column,
705702
ColumnUInt8::MutablePtr& res, NullMap& null_map, const size_t size) {
706-
int i;
707-
std::vector<std::shared_ptr<GeoShape>> shapes = {nullptr, nullptr};
703+
std::vector<std::unique_ptr<GeoShape>> shapes(2);
708704
for (int row = 0; row < size; ++row) {
709705
auto lhs_value = left_column->get_data_at(row);
710706
auto rhs_value = right_column->get_data_at(row);
711-
loop_do(lhs_value, rhs_value, shapes, i, res, null_map, row);
707+
loop_do(lhs_value, rhs_value, shapes, res, null_map, row);
712708
}
713709
}
714710

@@ -719,7 +715,27 @@ struct StContains {
719715
static Status close(FunctionContext* context, FunctionContext::FunctionStateScope scope) {
720716
return Status::OK();
721717
}
722-
}; // namespace doris::vectorized
718+
};
719+
720+
struct StContainsFunc {
721+
static constexpr auto NAME = "st_contains";
722+
static bool evaluate(GeoShape* shape1, GeoShape* shape2) { return shape1->contains(shape2); }
723+
};
724+
725+
struct StIntersectsFunc {
726+
static constexpr auto NAME = "st_intersects";
727+
static bool evaluate(GeoShape* shape1, GeoShape* shape2) { return shape1->intersects(shape2); }
728+
};
729+
730+
struct StDisjointFunc {
731+
static constexpr auto NAME = "st_disjoint";
732+
static bool evaluate(GeoShape* shape1, GeoShape* shape2) { return shape1->disjoint(shape2); }
733+
};
734+
735+
struct StTouchesFunc {
736+
static constexpr auto NAME = "st_touches";
737+
static bool evaluate(GeoShape* shape1, GeoShape* shape2) { return shape1->touches(shape2); }
738+
};
723739

724740
struct StGeometryFromText {
725741
static constexpr auto NAME = "st_geometryfromtext";
@@ -914,7 +930,10 @@ void register_function_geo(SimpleFunctionFactory& factory) {
914930
factory.register_function<GeoFunction<StAngleSphere>>();
915931
factory.register_function<GeoFunction<StAngle>>();
916932
factory.register_function<GeoFunction<StAzimuth>>();
917-
factory.register_function<GeoFunction<StContains>>();
933+
factory.register_function<GeoFunction<StRelationFunction<StContainsFunc>>>();
934+
factory.register_function<GeoFunction<StRelationFunction<StIntersectsFunc>>>();
935+
factory.register_function<GeoFunction<StRelationFunction<StDisjointFunc>>>();
936+
factory.register_function<GeoFunction<StRelationFunction<StTouchesFunc>>>();
918937
factory.register_function<GeoFunction<StCircle>>();
919938
factory.register_function<GeoFunction<StGeoFromText<StGeometryFromText>>>();
920939
factory.register_function<GeoFunction<StGeoFromText<StGeomFromText>>>();

0 commit comments

Comments
 (0)