@@ -626,11 +626,13 @@ struct StCircle {
626626 }
627627};
628628
629- struct StContains {
629+ template <typename Func>
630+ struct StRelationFunction {
630631 static constexpr auto NEED_CONTEXT = true ;
631- static constexpr auto NAME = " st_contains " ;
632+ static constexpr auto NAME = Func::NAME ;
632633 static const size_t NUM_ARGS = 2 ;
633634 using Type = DataTypeUInt8;
635+
634636 static Status execute (FunctionContext* context, Block& block, const ColumnNumbers& arguments,
635637 size_t result) {
636638 DCHECK_EQ (arguments.size (), 2 );
@@ -642,8 +644,7 @@ struct StContains {
642644
643645 const auto size = std::max (left_column->size (), right_column->size ());
644646
645- auto res = ColumnUInt8::create ();
646- res->reserve (size);
647+ auto res = ColumnUInt8::create (size, 0 );
647648 auto null_map = ColumnUInt8::create (size, 0 );
648649 auto & null_map_data = null_map->get_data ();
649650
@@ -660,55 +661,50 @@ struct StContains {
660661 }
661662
662663 static void loop_do (StringRef& lhs_value, StringRef& rhs_value,
663- std::vector<std::shared_ptr <GeoShape>>& shapes, int & i ,
664+ std::vector<std::unique_ptr <GeoShape>>& shapes,
664665 ColumnUInt8::MutablePtr& res, NullMap& null_map, int row) {
665666 StringRef* strs[2 ] = {&lhs_value, &rhs_value};
666- for (i = 0 ; i < 2 ; ++i) {
667- shapes [i] =
668- std::shared_ptr<GeoShape>( GeoShape::from_encoded (strs[i]-> data , strs[i]-> size ) );
669- if (shapes[i] == nullptr ) {
667+ for (int i = 0 ; i < 2 ; ++i) {
668+ std::unique_ptr<GeoShape> shape ( GeoShape::from_encoded (strs [i]-> data , strs[i]-> size ));
669+ shapes[i] = std::move (shape );
670+ if (! shapes[i]) {
670671 null_map[row] = 1 ;
671- res->insert_default ();
672672 break ;
673673 }
674674 }
675-
676- if (i == 2 ) {
677- auto contains_value = shapes[0 ]->contains (shapes[1 ].get ());
678- res->insert_data (const_cast <const char *>((char *)&contains_value), 0 );
675+ if (shapes[0 ] && shapes[1 ]) {
676+ auto relation_value = Func::evaluate (shapes[0 ].get (), shapes[1 ].get ());
677+ res->get_data ()[row] = relation_value;
679678 }
680679 }
681680
682681 static void const_vector (const ColumnPtr& left_column, const ColumnPtr& right_column,
683682 ColumnUInt8::MutablePtr& res, NullMap& null_map, const size_t size) {
684- int i;
685683 auto lhs_value = left_column->get_data_at (0 );
686- std::vector<std::shared_ptr <GeoShape>> shapes = { nullptr , nullptr } ;
684+ std::vector<std::unique_ptr <GeoShape>> shapes ( 2 ) ;
687685 for (int row = 0 ; row < size; ++row) {
688686 auto rhs_value = right_column->get_data_at (row);
689- loop_do (lhs_value, rhs_value, shapes, i, res, null_map, row);
687+ loop_do (lhs_value, rhs_value, shapes, res, null_map, row);
690688 }
691689 }
692690
693691 static void vector_const (const ColumnPtr& left_column, const ColumnPtr& right_column,
694692 ColumnUInt8::MutablePtr& res, NullMap& null_map, const size_t size) {
695- int i;
696693 auto rhs_value = right_column->get_data_at (0 );
697- std::vector<std::shared_ptr <GeoShape>> shapes = { nullptr , nullptr } ;
694+ std::vector<std::unique_ptr <GeoShape>> shapes ( 2 ) ;
698695 for (int row = 0 ; row < size; ++row) {
699696 auto lhs_value = left_column->get_data_at (row);
700- loop_do (lhs_value, rhs_value, shapes, i, res, null_map, row);
697+ loop_do (lhs_value, rhs_value, shapes, res, null_map, row);
701698 }
702699 }
703700
704701 static void vector_vector (const ColumnPtr& left_column, const ColumnPtr& right_column,
705702 ColumnUInt8::MutablePtr& res, NullMap& null_map, const size_t size) {
706- int i;
707- std::vector<std::shared_ptr<GeoShape>> shapes = {nullptr , nullptr };
703+ std::vector<std::unique_ptr<GeoShape>> shapes (2 );
708704 for (int row = 0 ; row < size; ++row) {
709705 auto lhs_value = left_column->get_data_at (row);
710706 auto rhs_value = right_column->get_data_at (row);
711- loop_do (lhs_value, rhs_value, shapes, i, res, null_map, row);
707+ loop_do (lhs_value, rhs_value, shapes, res, null_map, row);
712708 }
713709 }
714710
@@ -719,7 +715,27 @@ struct StContains {
719715 static Status close (FunctionContext* context, FunctionContext::FunctionStateScope scope) {
720716 return Status::OK ();
721717 }
722- }; // namespace doris::vectorized
718+ };
719+
720+ struct StContainsFunc {
721+ static constexpr auto NAME = " st_contains" ;
722+ static bool evaluate (GeoShape* shape1, GeoShape* shape2) { return shape1->contains (shape2); }
723+ };
724+
725+ struct StIntersectsFunc {
726+ static constexpr auto NAME = " st_intersects" ;
727+ static bool evaluate (GeoShape* shape1, GeoShape* shape2) { return shape1->intersects (shape2); }
728+ };
729+
730+ struct StDisjointFunc {
731+ static constexpr auto NAME = " st_disjoint" ;
732+ static bool evaluate (GeoShape* shape1, GeoShape* shape2) { return shape1->disjoint (shape2); }
733+ };
734+
735+ struct StTouchesFunc {
736+ static constexpr auto NAME = " st_touches" ;
737+ static bool evaluate (GeoShape* shape1, GeoShape* shape2) { return shape1->touches (shape2); }
738+ };
723739
724740struct StGeometryFromText {
725741 static constexpr auto NAME = " st_geometryfromtext" ;
@@ -914,7 +930,10 @@ void register_function_geo(SimpleFunctionFactory& factory) {
914930 factory.register_function <GeoFunction<StAngleSphere>>();
915931 factory.register_function <GeoFunction<StAngle>>();
916932 factory.register_function <GeoFunction<StAzimuth>>();
917- factory.register_function <GeoFunction<StContains>>();
933+ factory.register_function <GeoFunction<StRelationFunction<StContainsFunc>>>();
934+ factory.register_function <GeoFunction<StRelationFunction<StIntersectsFunc>>>();
935+ factory.register_function <GeoFunction<StRelationFunction<StDisjointFunc>>>();
936+ factory.register_function <GeoFunction<StRelationFunction<StTouchesFunc>>>();
918937 factory.register_function <GeoFunction<StCircle>>();
919938 factory.register_function <GeoFunction<StGeoFromText<StGeometryFromText>>>();
920939 factory.register_function <GeoFunction<StGeoFromText<StGeomFromText>>>();
0 commit comments