From 0e99113d3a5056706b84713edb47349183086c11 Mon Sep 17 00:00:00 2001 From: zhangfengcdt Date: Thu, 30 Apr 2026 08:31:23 -0700 Subject: [PATCH 1/6] [GH-2830] Add python wrappers and tests and add SRID-preservation tests for geography functions --- python/sedona/spark/sql/st_constructors.py | 117 +++++++++ python/tests/sql/test_geography.py | 244 ++++++++++++++++-- .../ConstructorsDataFrameAPITest.scala | 38 +++ .../geography/FunctionsDataFrameAPITest.scala | 69 ++++- .../PreserveSRIDGeographySuite.scala | 108 ++++++++ 5 files changed, 549 insertions(+), 27 deletions(-) create mode 100644 spark/common/src/test/scala/org/apache/sedona/sql/geography/PreserveSRIDGeographySuite.scala diff --git a/python/sedona/spark/sql/st_constructors.py b/python/sedona/spark/sql/st_constructors.py index fc5da936a10..5a9496bf41c 100644 --- a/python/sedona/spark/sql/st_constructors.py +++ b/python/sedona/spark/sql/st_constructors.py @@ -192,6 +192,123 @@ def ST_GeogFromWKT( return _call_constructor_function("ST_GeogFromWKT", args) +@validate_argument_types +def ST_GeogFromText( + wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None +) -> Column: + """Generate a geography column from a Well-Known Text (WKT) string column. + This is an alias of ST_GeogFromWKT. + + :param wkt: WKT string column to generate from. + :type wkt: ColumnOrName + :return: Geography column representing the WKT string. + :rtype: Column + """ + args = (wkt) if srid is None else (wkt, srid) + + return _call_constructor_function("ST_GeogFromText", args) + + +@validate_argument_types +def ST_GeogFromWKB( + wkb: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None +) -> Column: + """Generate a geography column from a Well-Known Binary (WKB) binary column. + + :param wkb: WKB binary column to generate from. + :type wkb: ColumnOrName + :return: Geography column representing the WKB binary. + :rtype: Column + """ + args = (wkb) if srid is None else (wkb, srid) + + return _call_constructor_function("ST_GeogFromWKB", args) + + +@validate_argument_types +def ST_GeogFromEWKB(wkb: ColumnOrName) -> Column: + """Generate a geography column from an OGC Extended Well-Known Binary (EWKB) binary column. + + :param wkb: EWKB binary column to generate from. + :type wkb: ColumnOrName + :return: Geography column representing the EWKB binary. + :rtype: Column + """ + return _call_constructor_function("ST_GeogFromEWKB", wkb) + + +@validate_argument_types +def ST_GeogFromEWKT(ewkt: ColumnOrName) -> Column: + """Generate a geography column from an OGC Extended Well-Known Text (EWKT) string column. + + :param ewkt: EWKT string column to generate from. + :type ewkt: ColumnOrName + :return: Geography column representing the EWKT string. + :rtype: Column + """ + return _call_constructor_function("ST_GeogFromEWKT", ewkt) + + +@validate_argument_types +def ST_GeogFromGeoHash( + geohash: ColumnOrName, precision: Optional[Union[ColumnOrName, int]] = None +) -> Column: + """Generate a geography column from a geohash column at a specified precision. + + :param geohash: Geohash string column to generate from. + :type geohash: ColumnOrName + :param precision: Geohash precision to use, either an integer or an integer column. + :type precision: Union[ColumnOrName, int] + :return: Geography column representing the supplied geohash and precision level. + :rtype: Column + """ + args = (geohash) if precision is None else (geohash, precision) + + return _call_constructor_function("ST_GeogFromGeoHash", args) + + +@validate_argument_types +def ST_GeogCollFromText( + wkt: ColumnOrName, srid: Optional[ColumnOrNameOrNumber] = None +) -> Column: + """Generate a GeometryCollection geography from a GeometryCollection WKT representation. + + :param wkt: GeometryCollection WKT string column to generate from. + :type wkt: ColumnOrName + :param srid: SRID for the geography. + :type srid: ColumnOrNameOrNumber + :return: GeometryCollection geography generated from the wkt column. + :rtype: Column + """ + args = (wkt) if srid is None else (wkt, srid) + + return _call_constructor_function("ST_GeogCollFromText", args) + + +@validate_argument_types +def ST_GeogToGeometry(geog: ColumnOrName) -> Column: + """Convert a geography column into a geometry column. + + :param geog: Geography column to convert. + :type geog: ColumnOrName + :return: Geometry column representing the geography. + :rtype: Column + """ + return _call_constructor_function("ST_GeogToGeometry", geog) + + +@validate_argument_types +def ST_GeomToGeography(geom: ColumnOrName) -> Column: + """Convert a geometry column into a geography column. + + :param geom: Geometry column to convert. + :type geom: ColumnOrName + :return: Geography column representing the geometry. + :rtype: Column + """ + return _call_constructor_function("ST_GeomToGeography", geom) + + @validate_argument_types def ST_GeomFromEWKT(ewkt: ColumnOrName) -> Column: """Generate a geometry column from a OGC Extended Well-Known Text (WKT) string column. diff --git a/python/tests/sql/test_geography.py b/python/tests/sql/test_geography.py index 0901c98d50e..6d5f71d9b09 100644 --- a/python/tests/sql/test_geography.py +++ b/python/tests/sql/test_geography.py @@ -15,31 +15,225 @@ # specific language governing permissions and limitations # under the License. -# from pyspark.sql.functions import expr -# from pyspark.sql.types import StructType -# from shapely.wkt import loads as wkt_loads -# from sedona.spark.core.geom.geography import Geography -# from sedona.spark.sql.types import GeographyType +from pyspark.sql.functions import col, lit +from sedona.spark.sql import st_constructors as stc +from sedona.spark.sql import st_functions as stf +from sedona.spark.sql import st_predicates as stp from tests.test_base import TestBase -class TestGeography(TestBase): - - def test_deserialize_geography(self): - """Test serialization and deserialization of Geography objects""" - # geog_df = self.spark.range(0, 10).withColumn( - # "geog", expr("ST_GeogFromWKT(CONCAT('POINT (', id, ' ', id + 1, ')'))") - # ) - # rows = geog_df.collect() - # assert len(rows) == 10 - # for row in rows: - # id = row["id"] - # geog = row["geog"] - # assert geog.geometry.wkt == f"POINT ({id} {id + 1})" - - def test_serialize_geography(self): - wkt = "MULTIPOLYGON (((10 10, 20 20, 20 10, 10 10)), ((-10 -10, -20 -20, -20 -10, -10 -10)))" - # geog = Geography(wkt_loads(wkt)) - # schema = StructType().add("geog", GeographyType()) - # returned_geog = self.spark.createDataFrame([(geog,)], schema).take(1)[0][0] - # assert geog.geometry.equals(returned_geog.geometry) +class TestGeographyConstructorsDataFrameAPI(TestBase): + """Exercise every ST_Geog* constructor through its typed Python wrapper.""" + + def test_st_geog_from_wkt(self): + df = self.spark.sql("SELECT 'POINT (1 2)' AS wkt").select( + stc.ST_GeogFromWKT(col("wkt"), lit(4326)).alias("g") + ) + ewkt = df.select(stf.ST_AsEWKT(col("g"))).first()[0] + assert ewkt == "SRID=4326; POINT (1 2)" + + def test_st_geog_from_wkt_no_srid(self): + df = self.spark.sql("SELECT 'POINT (1 2)' AS wkt").select( + stc.ST_GeogFromWKT(col("wkt")).alias("g") + ) + wkt = df.select(stf.ST_AsText(col("g"))).first()[0] + assert wkt == "POINT (1 2)" + + def test_st_geog_from_text(self): + df = self.spark.sql("SELECT 'POINT (3 4)' AS wkt").select( + stc.ST_GeogFromText(col("wkt"), lit(4326)).alias("g") + ) + ewkt = df.select(stf.ST_AsEWKT(col("g"))).first()[0] + assert ewkt == "SRID=4326; POINT (3 4)" + + def test_st_geog_from_wkb(self): + # WKB for POINT (10 15) in little-endian + wkb_bytes = bytes( + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 36, 64, 0, 0, 0, 0, 0, 0, 46, 64] + ) + df = self.spark.createDataFrame([(wkb_bytes,)], ["wkb"]).select( + stc.ST_GeogFromWKB(col("wkb")).alias("g") + ) + wkt = df.select(stf.ST_AsText(col("g"))).first()[0] + assert wkt == "POINT (10 15)" + + def test_st_geog_from_ewkb(self): + # EWKB for SRID=4326; LINESTRING (-2.1 -0.4, -1.5 -0.7) + ewkb_bytes = bytes( + [ + 1, 2, 0, 0, 32, 230, 16, 0, 0, 2, 0, 0, 0, + 0, 0, 0, 0, 132, 214, 0, 192, 0, 0, 0, 0, 128, 181, 214, 191, + 0, 0, 0, 96, 225, 239, 247, 191, 0, 0, 0, 128, 7, 93, 229, 191, + ] + ) + df = self.spark.createDataFrame([(ewkb_bytes,)], ["wkb"]).select( + stc.ST_GeogFromEWKB(col("wkb")).alias("g") + ) + ewkt = df.select(stf.ST_AsEWKT(col("g"))).first()[0] + assert ewkt.startswith("SRID=4326; LINESTRING") + + def test_st_geog_from_ewkt(self): + df = self.spark.sql("SELECT 'SRID=4269;POINT (5 6)' AS ewkt").select( + stc.ST_GeogFromEWKT(col("ewkt")).alias("g") + ) + ewkt = df.select(stf.ST_AsEWKT(col("g"))).first()[0] + assert ewkt == "SRID=4269; POINT (5 6)" + + def test_st_geog_from_geohash(self): + df = self.spark.sql("SELECT '9q9j8ue2v71y5zzy0s4q' AS geohash").select( + stc.ST_GeogFromGeoHash(col("geohash"), 4).alias("g") + ) + wkt = df.select(stf.ST_AsText(col("g"))).first()[0] + assert wkt.startswith("POLYGON") + + def test_st_geog_from_geohash_no_precision(self): + df = self.spark.sql("SELECT '9q9' AS geohash").select( + stc.ST_GeogFromGeoHash(col("geohash")).alias("g") + ) + wkt = df.select(stf.ST_AsText(col("g"))).first()[0] + assert wkt.startswith("POLYGON") + + def test_st_geogcoll_from_text(self): + wkt_in = ( + "GEOMETRYCOLLECTION (POINT (1 2), LINESTRING (0 0, 1 1))" + ) + df = self.spark.sql(f"SELECT '{wkt_in}' AS wkt").select( + stc.ST_GeogCollFromText(col("wkt"), lit(4326)).alias("g") + ) + ewkt = df.select(stf.ST_AsEWKT(col("g"))).first()[0] + assert ewkt.startswith("SRID=4326; GEOMETRYCOLLECTION") + + def test_st_geog_to_geometry(self): + df = ( + self.spark.sql("SELECT 'POINT (7 8)' AS wkt") + .select(stc.ST_GeogFromWKT(col("wkt"), lit(4326)).alias("g")) + .select(stc.ST_GeogToGeometry(col("g")).alias("geom")) + ) + wkt = df.select(stf.ST_AsText(col("geom"))).first()[0] + assert wkt == "POINT (7 8)" + + def test_st_geom_to_geography(self): + df = ( + self.spark.sql("SELECT 'POINT (9 10)' AS wkt") + .select(stc.ST_GeomFromWKT(col("wkt"), lit(4326)).alias("geom")) + .select(stc.ST_GeomToGeography(col("geom")).alias("g")) + ) + ewkt = df.select(stf.ST_AsEWKT(col("g"))).first()[0] + assert ewkt == "SRID=4326; POINT (9 10)" + + +class TestGeographyFunctionsDataFrameAPI(TestBase): + """Exercise dual-dispatch ST functions/predicates against Geography columns + via the typed Python DataFrame API.""" + + def _geog(self, wkt, srid=4326): + return stc.ST_GeogFromWKT(lit(wkt), lit(srid)) + + def test_st_distance(self): + df = self.spark.range(1).select( + stf.ST_Distance( + self._geog("POINT (0 0)"), self._geog("POINT (1 1)") + ).alias("d") + ) + d = df.first()[0] + assert 155000 < d < 160000 # ~157km on a sphere + + def test_st_length(self): + df = self.spark.range(1).select( + stf.ST_Length(self._geog("LINESTRING (0 0, 1 0)")).alias("l") + ) + l = df.first()[0] + assert 110000 < l < 112000 + + def test_st_length_of_point(self): + df = self.spark.range(1).select( + stf.ST_Length(self._geog("POINT (1 2)")).alias("l") + ) + assert df.first()[0] == 0.0 + + def test_st_area(self): + df = self.spark.range(1).select( + stf.ST_Area( + self._geog("POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))") + ).alias("a") + ) + a = df.first()[0] + # 1°×1° box near equator on R=6371008m sphere ≈ 1.2364e10 m² + assert 1.23e10 < a < 1.24e10 + + def test_st_centroid(self): + df = self.spark.range(1).select( + stf.ST_Centroid( + self._geog("POLYGON ((0 0, 2 0, 2 2, 0 2, 0 0))") + ).alias("c") + ) + wkt = df.select(stf.ST_AsText(col("c"))).first()[0] + assert wkt.startswith("POINT") + + def test_st_buffer(self): + df = self.spark.range(1).select( + stf.ST_Buffer(self._geog("POINT (0 0)"), lit(1000.0)).alias("b") + ) + wkt = df.select(stf.ST_AsText(col("b"))).first()[0] + assert wkt.startswith("POLYGON") + + def test_st_envelope(self): + df = self.spark.range(1).select( + stf.ST_Envelope( + self._geog("POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))") + ).alias("e") + ) + wkt = df.select(stf.ST_AsText(col("e"))).first()[0] + assert wkt.startswith("POLYGON") + + def test_st_npoints(self): + df = self.spark.range(1).select( + stf.ST_NPoints(self._geog("LINESTRING (0 0, 1 1, 2 2)")).alias("n") + ) + assert df.first()[0] == 3 + + def test_st_contains(self): + df = self.spark.range(1).select( + stp.ST_Contains( + self._geog("POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))"), + self._geog("POINT (0.5 0.5)"), + ).alias("r") + ) + assert df.first()[0] is True + + def test_st_within(self): + df = self.spark.range(1).select( + stp.ST_Within( + self._geog("POINT (0.5 0.5)"), + self._geog("POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))"), + ).alias("r") + ) + assert df.first()[0] is True + + def test_st_dwithin(self): + df = self.spark.range(1).select( + stp.ST_DWithin( + self._geog("POINT (0 0)"), + self._geog("POINT (0 1)"), + lit(200000.0), + ).alias("r") + ) + assert df.first()[0] is True + + def test_st_equals(self): + df = self.spark.range(1).select( + stp.ST_Equals( + self._geog("POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))"), + self._geog("POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))"), + ).alias("r") + ) + assert df.first()[0] is True + + def test_st_intersects(self): + df = self.spark.range(1).select( + stp.ST_Intersects( + self._geog("POLYGON ((0 0, 2 0, 2 2, 0 2, 0 0))"), + self._geog("POLYGON ((1 1, 3 1, 3 3, 1 3, 1 1))"), + ).alias("r") + ) + assert df.first()[0] is True diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/geography/ConstructorsDataFrameAPITest.scala b/spark/common/src/test/scala/org/apache/sedona/sql/geography/ConstructorsDataFrameAPITest.scala index 421275ba738..330f790e47e 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/geography/ConstructorsDataFrameAPITest.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/geography/ConstructorsDataFrameAPITest.scala @@ -138,4 +138,42 @@ class ConstructorsDataFrameAPITest extends TestBaseScala { assertEquals(wkt, geog.toString) } + it("passed ST_GeogFromText with srid") { + val df = sparkSession + .sql("SELECT 'POINT (0.0 1.0)' AS wkt") + .select(st_constructors.ST_GeogFromText(col("wkt"), lit(4326)).as("g")) + val geog = df.first().get(0).asInstanceOf[Geography] + assertEquals("POINT (0 1)", geog.toString(new PrecisionModel(PrecisionModel.FIXED))) + assertEquals(4326, geog.getSRID) + } + + it("passed ST_GeogFromText without srid") { + val df = sparkSession + .sql("SELECT 'POINT (3 4)' AS wkt") + .select(st_constructors.ST_GeogFromText(col("wkt")).as("g")) + val geog = df.first().get(0).asInstanceOf[Geography] + assertEquals("POINT (3 4)", geog.toString(new PrecisionModel(PrecisionModel.FIXED))) + assertEquals(0, geog.getSRID) + } + + it("passed ST_GeogCollFromText with srid") { + val wkt = "GEOMETRYCOLLECTION (POINT (1 2), LINESTRING (0 0, 1 1))" + val df = sparkSession + .sql(s"SELECT '$wkt' AS wkt") + .select(st_constructors.ST_GeogCollFromText(col("wkt"), lit(4326)).as("g")) + val geog = df.first().get(0).asInstanceOf[Geography] + assertEquals(4326, geog.getSRID) + assertEquals("GEOMETRYCOLLECTION", geog.toString.takeWhile(_ != ' ')) + } + + it("passed ST_GeogCollFromText without srid") { + val wkt = "GEOMETRYCOLLECTION (POINT (1 2), LINESTRING (0 0, 1 1))" + val df = sparkSession + .sql(s"SELECT '$wkt' AS wkt") + .select(st_constructors.ST_GeogCollFromText(col("wkt")).as("g")) + val geog = df.first().get(0).asInstanceOf[Geography] + assertEquals(0, geog.getSRID) + assertEquals("GEOMETRYCOLLECTION", geog.toString.takeWhile(_ != ' ')) + } + } diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/geography/FunctionsDataFrameAPITest.scala b/spark/common/src/test/scala/org/apache/sedona/sql/geography/FunctionsDataFrameAPITest.scala index 9d25f26bf5c..30420412ece 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/geography/FunctionsDataFrameAPITest.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/geography/FunctionsDataFrameAPITest.scala @@ -21,8 +21,8 @@ package org.apache.sedona.sql.geography import org.apache.sedona.common.S2Geography.Geography import org.apache.sedona.sql.TestBaseScala import org.apache.spark.sql.functions.{col, lit} -import org.apache.spark.sql.sedona_sql.expressions.{ST_Envelope, st_constructors, st_functions} -import org.junit.Assert.assertEquals +import org.apache.spark.sql.sedona_sql.expressions.{ST_Envelope, st_constructors, st_functions, st_predicates} +import org.junit.Assert.{assertEquals, assertTrue} class FunctionsDataFrameAPITest extends TestBaseScala { import sparkSession.implicits._ @@ -70,4 +70,69 @@ class FunctionsDataFrameAPITest extends TestBaseScala { assert(geoStr == wktExpected) } + it("Passed ST_Length via DataFrame API") { + val df = sparkSession + .sql("SELECT 'LINESTRING (0 0, 1 0)' AS wkt") + .select(st_constructors.ST_GeogFromWKT(col("wkt"), lit(4326)).as("g")) + .select(st_functions.ST_Length(col("g")).as("l")) + val len = df.first().getDouble(0) + // 1° along the equator on a sphere of radius 6371008 m ≈ 111195 m + assertEquals(111195.10, len, 1.0) + } + + it("Passed ST_Area via DataFrame API") { + val df = sparkSession + .sql("SELECT 'POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))' AS wkt") + .select(st_constructors.ST_GeogFromWKT(col("wkt"), lit(4326)).as("g")) + .select(st_functions.ST_Area(col("g")).as("a")) + val area = df.first().getDouble(0) + // 1°×1° box near equator on R=6371008m sphere ≈ 1.2364e10 m² + assertEquals(1.2364e10, area, 1e7) + } + + it("Passed ST_Centroid via DataFrame API") { + val df = sparkSession + .sql("SELECT 'POLYGON ((0 0, 2 0, 2 2, 0 2, 0 0))' AS wkt") + .select(st_constructors.ST_GeogFromWKT(col("wkt"), lit(4326)).as("g")) + .select(st_functions.ST_Centroid(col("g")).as("c")) + val centroid = df.first().get(0).asInstanceOf[Geography] + assertTrue(centroid.toString.startsWith("POINT")) + } + + it("Passed ST_NumGeometries via DataFrame API") { + val df = sparkSession + .sql("SELECT 'MULTIPOINT ((0 0), (1 1), (2 2))' AS wkt") + .select(st_constructors.ST_GeogFromWKT(col("wkt"), lit(4326)).as("g")) + .select(st_functions.ST_NumGeometries(col("g")).as("n")) + assertEquals(3, df.first().getInt(0)) + } + + it("Passed ST_GeometryType via DataFrame API") { + val df = sparkSession + .sql("SELECT 'POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))' AS wkt") + .select(st_constructors.ST_GeogFromWKT(col("wkt"), lit(4326)).as("g")) + .select(st_functions.ST_GeometryType(col("g")).as("t")) + assertEquals("ST_Polygon", df.first().getString(0)) + } + + it("Passed ST_AsText via DataFrame API") { + val df = sparkSession + .sql("SELECT 'POINT (1 2)' AS wkt") + .select(st_constructors.ST_GeogFromWKT(col("wkt"), lit(4326)).as("g")) + .select(st_functions.ST_AsText(col("g")).as("t")) + val txt = df.first().getString(0) + assertTrue(s"expected POINT prefix; got $txt", txt.startsWith("POINT")) + } + + it("Passed ST_Intersects via DataFrame API") { + val df = sparkSession + .sql("SELECT 'POLYGON ((0 0, 2 0, 2 2, 0 2, 0 0))' AS a, " + + "'POLYGON ((1 1, 3 1, 3 3, 1 3, 1 1))' AS b") + .select( + st_constructors.ST_GeogFromWKT(col("a"), lit(4326)).as("a"), + st_constructors.ST_GeogFromWKT(col("b"), lit(4326)).as("b")) + .select(st_predicates.ST_Intersects(col("a"), col("b")).as("r")) + assertTrue(df.first().getBoolean(0)) + } + } diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/geography/PreserveSRIDGeographySuite.scala b/spark/common/src/test/scala/org/apache/sedona/sql/geography/PreserveSRIDGeographySuite.scala new file mode 100644 index 00000000000..bcbd573e8d6 --- /dev/null +++ b/spark/common/src/test/scala/org/apache/sedona/sql/geography/PreserveSRIDGeographySuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.geography + +import org.apache.sedona.common.S2Geography.Geography +import org.apache.sedona.common.geography.Constructors +import org.apache.sedona.sql.TestBaseScala +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.sedona_sql.UDT.GeographyUDT +import org.apache.spark.sql.types.{StructField, StructType} +import org.locationtech.jts.geom.Geometry +import org.scalatest.prop.TableDrivenPropertyChecks + +/** + * Geography counterpart of [[org.apache.sedona.sql.PreserveSRIDSuite]]. Verifies that the + * Geography expression chain preserves the input SRID through the InferredExpression boundary. + * Geography→Geography functions (ST_Centroid, ST_Envelope, ST_Buffer, ST_GeomToGeography) are + * tested directly. Scalar/predicate functions on Geography inputs are wrapped in an identity + * passthrough — `IF(, geog1, geog1)` — so the surrounding expression returns a + * Geography whose SRID can be asserted; failure of such a row signals either an evaluation + * failure on the wrapped function or SRID being dropped somewhere in the chain. + */ +class PreserveSRIDGeographySuite extends TestBaseScala with TableDrivenPropertyChecks { + private var testDf: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + testDf = prepareTestDataFrame() + } + + describe("Preserve SRID (Geography)") { + val testCases = Table( + "test case", + // Direct Geography→Geography + ("ST_Centroid(geog1)", 1000), + ("ST_Envelope(geog1)", 1000), + ("ST_Envelope(geog1, true)", 1000), + ("ST_Envelope(geog1, false)", 1000), + ("ST_Buffer(geog1, 0)", 1000), + ("ST_Buffer(geog1, 100)", 1000), + ("ST_Buffer(geog1, 100, 'quad_segs=8')", 1000), + // Cross-type boundaries + ("ST_GeomToGeography(ST_GeomFromWKT('POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))', 1000))", 1000), + ("ST_GeogToGeometry(geog1)", 1000), + // Predicates wrapped in identity passthrough + ("IF(ST_Intersects(geog1, geog2), geog1, geog1)", 1000), + ("IF(ST_Within(geog1, geog1), geog1, geog1)", 1000), + ("IF(ST_DWithin(geog1, geog2, 1000.0), geog1, geog1)", 1000), + ("IF(ST_Contains(geog1, geog1), geog1, geog1)", 1000), + ("IF(ST_Equals(geog1, geog1), geog1, geog1)", 1000), + // Scalar/string functions wrapped in identity passthrough + ("IF(ST_Length(geog3) >= 0, geog1, geog1)", 1000), + ("IF(ST_Area(geog1) >= 0, geog1, geog1)", 1000), + ("IF(ST_Distance(geog1, geog2) >= 0, geog1, geog1)", 1000), + ("IF(ST_NPoints(geog1) > 0, geog1, geog1)", 1000), + ("IF(ST_NumGeometries(geog1) > 0, geog1, geog1)", 1000), + ("IF(ST_GeometryType(geog1) IS NOT NULL, geog1, geog1)", 1000), + ("IF(ST_AsText(geog1) IS NOT NULL, geog1, geog1)", 1000), + ("IF(ST_AsEWKT(geog1) IS NOT NULL, geog1, geog1)", 1000)) + + forAll(testCases) { case (expression: String, srid: Int) => + it(s"$expression") { + testDf.selectExpr(expression).collect().foreach { row => + val value = row.getAs[AnyRef](0) + value match { + case geog: Geography => assert(geog.getSRID == srid) + case geom: Geometry => assert(geom.getSRID == srid) + case _ => fail(s"Unexpected result: $value") + } + } + } + } + } + + private def prepareTestDataFrame(): DataFrame = { + import scala.collection.JavaConverters._ + + val schema = StructType( + Seq( + StructField("geog1", GeographyUDT()), + StructField("geog2", GeographyUDT()), + StructField("geog3", GeographyUDT()))) + val geog1 = + Constructors.geogFromWKT("POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))", 1000) + val geog2 = + Constructors.geogFromWKT("MULTILINESTRING ((0 0, 0 1), (0 1, 1 1))", 1000) + val geog3 = + Constructors.geogFromWKT("LINESTRING (0 0, 0 1, 1 1, 1 0)", 1000) + val rows = Seq(Row(geog1, geog2, geog3)) + sparkSession.createDataFrame(rows.asJava, schema) + } +} From 335b6fc25a99eaacedd454fa7dc59afd347b4f88 Mon Sep 17 00:00:00 2001 From: zhangfengcdt Date: Thu, 30 Apr 2026 08:51:35 -0700 Subject: [PATCH 2/6] fix pre-commit lint error --- python/tests/sql/test_geography.py | 74 ++++++++++++++++++++++-------- 1 file changed, 56 insertions(+), 18 deletions(-) diff --git a/python/tests/sql/test_geography.py b/python/tests/sql/test_geography.py index 6d5f71d9b09..d8cdc41d632 100644 --- a/python/tests/sql/test_geography.py +++ b/python/tests/sql/test_geography.py @@ -61,9 +61,51 @@ def test_st_geog_from_ewkb(self): # EWKB for SRID=4326; LINESTRING (-2.1 -0.4, -1.5 -0.7) ewkb_bytes = bytes( [ - 1, 2, 0, 0, 32, 230, 16, 0, 0, 2, 0, 0, 0, - 0, 0, 0, 0, 132, 214, 0, 192, 0, 0, 0, 0, 128, 181, 214, 191, - 0, 0, 0, 96, 225, 239, 247, 191, 0, 0, 0, 128, 7, 93, 229, 191, + 1, + 2, + 0, + 0, + 32, + 230, + 16, + 0, + 0, + 2, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 132, + 214, + 0, + 192, + 0, + 0, + 0, + 0, + 128, + 181, + 214, + 191, + 0, + 0, + 0, + 96, + 225, + 239, + 247, + 191, + 0, + 0, + 0, + 128, + 7, + 93, + 229, + 191, ] ) df = self.spark.createDataFrame([(ewkb_bytes,)], ["wkb"]).select( @@ -94,9 +136,7 @@ def test_st_geog_from_geohash_no_precision(self): assert wkt.startswith("POLYGON") def test_st_geogcoll_from_text(self): - wkt_in = ( - "GEOMETRYCOLLECTION (POINT (1 2), LINESTRING (0 0, 1 1))" - ) + wkt_in = "GEOMETRYCOLLECTION (POINT (1 2), LINESTRING (0 0, 1 1))" df = self.spark.sql(f"SELECT '{wkt_in}' AS wkt").select( stc.ST_GeogCollFromText(col("wkt"), lit(4326)).alias("g") ) @@ -131,9 +171,9 @@ def _geog(self, wkt, srid=4326): def test_st_distance(self): df = self.spark.range(1).select( - stf.ST_Distance( - self._geog("POINT (0 0)"), self._geog("POINT (1 1)") - ).alias("d") + stf.ST_Distance(self._geog("POINT (0 0)"), self._geog("POINT (1 1)")).alias( + "d" + ) ) d = df.first()[0] assert 155000 < d < 160000 # ~157km on a sphere @@ -153,9 +193,7 @@ def test_st_length_of_point(self): def test_st_area(self): df = self.spark.range(1).select( - stf.ST_Area( - self._geog("POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))") - ).alias("a") + stf.ST_Area(self._geog("POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))")).alias("a") ) a = df.first()[0] # 1°×1° box near equator on R=6371008m sphere ≈ 1.2364e10 m² @@ -163,9 +201,9 @@ def test_st_area(self): def test_st_centroid(self): df = self.spark.range(1).select( - stf.ST_Centroid( - self._geog("POLYGON ((0 0, 2 0, 2 2, 0 2, 0 0))") - ).alias("c") + stf.ST_Centroid(self._geog("POLYGON ((0 0, 2 0, 2 2, 0 2, 0 0))")).alias( + "c" + ) ) wkt = df.select(stf.ST_AsText(col("c"))).first()[0] assert wkt.startswith("POINT") @@ -179,9 +217,9 @@ def test_st_buffer(self): def test_st_envelope(self): df = self.spark.range(1).select( - stf.ST_Envelope( - self._geog("POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))") - ).alias("e") + stf.ST_Envelope(self._geog("POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))")).alias( + "e" + ) ) wkt = df.select(stf.ST_AsText(col("e"))).first()[0] assert wkt.startswith("POLYGON") From 7bc7eb4c8f3cce8c874f81f7d9cf2d9c9a9a95a7 Mon Sep 17 00:00:00 2001 From: zhangfengcdt Date: Thu, 30 Apr 2026 09:55:13 -0700 Subject: [PATCH 3/6] fix the tests --- .../PreserveSRIDGeographySuite.scala | 56 ++++++++++--------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/geography/PreserveSRIDGeographySuite.scala b/spark/common/src/test/scala/org/apache/sedona/sql/geography/PreserveSRIDGeographySuite.scala index bcbd573e8d6..1bd029bd7e4 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/geography/PreserveSRIDGeographySuite.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/geography/PreserveSRIDGeographySuite.scala @@ -45,34 +45,38 @@ class PreserveSRIDGeographySuite extends TestBaseScala with TableDrivenPropertyC } describe("Preserve SRID (Geography)") { + // Fixture SRID is 4326 because Geography ST_Buffer round-trips through + // FunctionsGeoTools.bufferSpheroid, which resolves the SRID via the EPSG authority + // — only real EPSG codes are valid there. val testCases = Table( "test case", - // Direct Geography→Geography - ("ST_Centroid(geog1)", 1000), - ("ST_Envelope(geog1)", 1000), - ("ST_Envelope(geog1, true)", 1000), - ("ST_Envelope(geog1, false)", 1000), - ("ST_Buffer(geog1, 0)", 1000), - ("ST_Buffer(geog1, 100)", 1000), - ("ST_Buffer(geog1, 100, 'quad_segs=8')", 1000), - // Cross-type boundaries + // Direct Geography→Geography. Note: 1-arg ST_Envelope is geometry-only; the + // splitAtAntiMeridian (2-arg) form is the Geography path. + ("ST_Centroid(geog1)", 4326), + ("ST_Envelope(geog1, true)", 4326), + ("ST_Envelope(geog1, false)", 4326), + ("ST_Buffer(geog1, 0)", 4326), + ("ST_Buffer(geog1, 100)", 4326), + ("ST_Buffer(geog1, 100, 'quad_segs=8')", 4326), + // Cross-type boundaries. The literal SRID here exercises that any int survives the + // Geometry↔Geography boundary (no CRS resolution is involved on this code path). ("ST_GeomToGeography(ST_GeomFromWKT('POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))', 1000))", 1000), - ("ST_GeogToGeometry(geog1)", 1000), + ("ST_GeogToGeometry(geog1)", 4326), // Predicates wrapped in identity passthrough - ("IF(ST_Intersects(geog1, geog2), geog1, geog1)", 1000), - ("IF(ST_Within(geog1, geog1), geog1, geog1)", 1000), - ("IF(ST_DWithin(geog1, geog2, 1000.0), geog1, geog1)", 1000), - ("IF(ST_Contains(geog1, geog1), geog1, geog1)", 1000), - ("IF(ST_Equals(geog1, geog1), geog1, geog1)", 1000), + ("IF(ST_Intersects(geog1, geog2), geog1, geog1)", 4326), + ("IF(ST_Within(geog1, geog1), geog1, geog1)", 4326), + ("IF(ST_DWithin(geog1, geog2, 1000.0), geog1, geog1)", 4326), + ("IF(ST_Contains(geog1, geog1), geog1, geog1)", 4326), + ("IF(ST_Equals(geog1, geog1), geog1, geog1)", 4326), // Scalar/string functions wrapped in identity passthrough - ("IF(ST_Length(geog3) >= 0, geog1, geog1)", 1000), - ("IF(ST_Area(geog1) >= 0, geog1, geog1)", 1000), - ("IF(ST_Distance(geog1, geog2) >= 0, geog1, geog1)", 1000), - ("IF(ST_NPoints(geog1) > 0, geog1, geog1)", 1000), - ("IF(ST_NumGeometries(geog1) > 0, geog1, geog1)", 1000), - ("IF(ST_GeometryType(geog1) IS NOT NULL, geog1, geog1)", 1000), - ("IF(ST_AsText(geog1) IS NOT NULL, geog1, geog1)", 1000), - ("IF(ST_AsEWKT(geog1) IS NOT NULL, geog1, geog1)", 1000)) + ("IF(ST_Length(geog3) >= 0, geog1, geog1)", 4326), + ("IF(ST_Area(geog1) >= 0, geog1, geog1)", 4326), + ("IF(ST_Distance(geog1, geog2) >= 0, geog1, geog1)", 4326), + ("IF(ST_NPoints(geog1) > 0, geog1, geog1)", 4326), + ("IF(ST_NumGeometries(geog1) > 0, geog1, geog1)", 4326), + ("IF(ST_GeometryType(geog1) IS NOT NULL, geog1, geog1)", 4326), + ("IF(ST_AsText(geog1) IS NOT NULL, geog1, geog1)", 4326), + ("IF(ST_AsEWKT(geog1) IS NOT NULL, geog1, geog1)", 4326)) forAll(testCases) { case (expression: String, srid: Int) => it(s"$expression") { @@ -97,11 +101,11 @@ class PreserveSRIDGeographySuite extends TestBaseScala with TableDrivenPropertyC StructField("geog2", GeographyUDT()), StructField("geog3", GeographyUDT()))) val geog1 = - Constructors.geogFromWKT("POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))", 1000) + Constructors.geogFromWKT("POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))", 4326) val geog2 = - Constructors.geogFromWKT("MULTILINESTRING ((0 0, 0 1), (0 1, 1 1))", 1000) + Constructors.geogFromWKT("MULTILINESTRING ((0 0, 0 1), (0 1, 1 1))", 4326) val geog3 = - Constructors.geogFromWKT("LINESTRING (0 0, 0 1, 1 1, 1 0)", 1000) + Constructors.geogFromWKT("LINESTRING (0 0, 0 1, 1 1, 1 0)", 4326) val rows = Seq(Row(geog1, geog2, geog3)) sparkSession.createDataFrame(rows.asJava, schema) } From 2724a963b903612ddfd4ac90454432deee384290 Mon Sep 17 00:00:00 2001 From: zhangfengcdt Date: Thu, 30 Apr 2026 09:59:35 -0700 Subject: [PATCH 4/6] fix the python tests --- python/tests/sql/test_geography.py | 42 ++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/python/tests/sql/test_geography.py b/python/tests/sql/test_geography.py index d8cdc41d632..883f1583e42 100644 --- a/python/tests/sql/test_geography.py +++ b/python/tests/sql/test_geography.py @@ -15,13 +15,22 @@ # specific language governing permissions and limitations # under the License. -from pyspark.sql.functions import col, lit +import re + +from pyspark.sql.functions import col, expr, lit from sedona.spark.sql import st_constructors as stc from sedona.spark.sql import st_functions as stf from sedona.spark.sql import st_predicates as stp from tests.test_base import TestBase +def _parse_point_xy(wkt): + """Extract (x, y) from a 'POINT (x y)' string.""" + m = re.match(r"\s*POINT\s*\(\s*(-?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?)\s+(-?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?)\s*\)\s*$", wkt) + assert m is not None, f"unparseable POINT WKT: {wkt!r}" + return float(m.group(1)), float(m.group(2)) + + class TestGeographyConstructorsDataFrameAPI(TestBase): """Exercise every ST_Geog* constructor through its typed Python wrapper.""" @@ -37,7 +46,11 @@ def test_st_geog_from_wkt_no_srid(self): stc.ST_GeogFromWKT(col("wkt")).alias("g") ) wkt = df.select(stf.ST_AsText(col("g"))).first()[0] - assert wkt == "POINT (1 2)" + # S2 round-trip introduces sub-nanometer floating-point drift; allow a loose + # tolerance instead of comparing the WKT string verbatim. + x, y = _parse_point_xy(wkt) + assert abs(x - 1.0) < 1e-9 + assert abs(y - 2.0) < 1e-9 def test_st_geog_from_text(self): df = self.spark.sql("SELECT 'POINT (3 4)' AS wkt").select( @@ -150,7 +163,11 @@ def test_st_geog_to_geometry(self): .select(stc.ST_GeogToGeometry(col("g")).alias("geom")) ) wkt = df.select(stf.ST_AsText(col("geom"))).first()[0] - assert wkt == "POINT (7 8)" + # S2 round-trip introduces sub-nanometer floating-point drift on the geography + # → geometry conversion path; compare numerically with a loose tolerance. + x, y = _parse_point_xy(wkt) + assert abs(x - 7.0) < 1e-9 + assert abs(y - 8.0) < 1e-9 def test_st_geom_to_geography(self): df = ( @@ -209,17 +226,26 @@ def test_st_centroid(self): assert wkt.startswith("POINT") def test_st_buffer(self): + # The Python `stf.ST_Buffer` wrapper defaults `useSpheroid=False` which dispatches + # to the 3-arg `(geom, buf, useSpheroid)` overload; Geography rejects any boolean + # `useSpheroid` argument because Geography is always spheroidal. Pass + # `useSpheroid=None` so the wrapper falls through to the 2-arg form, which is + # what Geography supports. df = self.spark.range(1).select( - stf.ST_Buffer(self._geog("POINT (0 0)"), lit(1000.0)).alias("b") + stf.ST_Buffer(self._geog("POINT (0 0)"), lit(1000.0), useSpheroid=None).alias( + "b" + ) ) wkt = df.select(stf.ST_AsText(col("b"))).first()[0] assert wkt.startswith("POLYGON") def test_st_envelope(self): - df = self.spark.range(1).select( - stf.ST_Envelope(self._geog("POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))")).alias( - "e" - ) + # Geography ST_Envelope is the 2-arg `splitAtAntiMeridian` form; the 1-arg form + # is geometry-only. Use a SQL expression to invoke the 2-arg overload. + df = ( + self.spark.range(1) + .select(self._geog("POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))").alias("g")) + .select(expr("ST_Envelope(g, true)").alias("e")) ) wkt = df.select(stf.ST_AsText(col("e"))).first()[0] assert wkt.startswith("POLYGON") From d5f2b08aea3d6c7b34a97ac049ff8d860ad627f1 Mon Sep 17 00:00:00 2001 From: zhangfengcdt Date: Thu, 30 Apr 2026 10:10:54 -0700 Subject: [PATCH 5/6] fix pre-commit lint errors --- python/tests/sql/test_geography.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/python/tests/sql/test_geography.py b/python/tests/sql/test_geography.py index 883f1583e42..f71e938c82b 100644 --- a/python/tests/sql/test_geography.py +++ b/python/tests/sql/test_geography.py @@ -26,8 +26,11 @@ def _parse_point_xy(wkt): """Extract (x, y) from a 'POINT (x y)' string.""" - m = re.match(r"\s*POINT\s*\(\s*(-?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?)\s+(-?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?)\s*\)\s*$", wkt) - assert m is not None, f"unparseable POINT WKT: {wkt!r}" + m = re.match( + r"\s*POINT\s*\(\s*(-?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?)\s+(-?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?)\s*\)\s*$", + wkt, + ) + assert m is not None, f"unparsable POINT WKT: {wkt!r}" return float(m.group(1)), float(m.group(2)) @@ -232,9 +235,9 @@ def test_st_buffer(self): # `useSpheroid=None` so the wrapper falls through to the 2-arg form, which is # what Geography supports. df = self.spark.range(1).select( - stf.ST_Buffer(self._geog("POINT (0 0)"), lit(1000.0), useSpheroid=None).alias( - "b" - ) + stf.ST_Buffer( + self._geog("POINT (0 0)"), lit(1000.0), useSpheroid=None + ).alias("b") ) wkt = df.select(stf.ST_AsText(col("b"))).first()[0] assert wkt.startswith("POLYGON") From 45fdbd6468c171bd82df3f56307e0aa907f4e9c2 Mon Sep 17 00:00:00 2001 From: zhangfengcdt Date: Sat, 2 May 2026 19:56:29 -0700 Subject: [PATCH 6/6] add scala tests --- .../geography/FunctionsDataFrameAPITest.scala | 82 +++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/geography/FunctionsDataFrameAPITest.scala b/spark/common/src/test/scala/org/apache/sedona/sql/geography/FunctionsDataFrameAPITest.scala index 30420412ece..47fc7719c0f 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/geography/FunctionsDataFrameAPITest.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/geography/FunctionsDataFrameAPITest.scala @@ -135,4 +135,86 @@ class FunctionsDataFrameAPITest extends TestBaseScala { assertTrue(df.first().getBoolean(0)) } + it("Passed ST_Distance via DataFrame API") { + val df = sparkSession + .sql("SELECT 'POINT (0 0)' AS a, 'POINT (1 1)' AS b") + .select( + st_constructors.ST_GeogFromWKT(col("a"), lit(4326)).as("a"), + st_constructors.ST_GeogFromWKT(col("b"), lit(4326)).as("b")) + .select(st_functions.ST_Distance(col("a"), col("b")).as("d")) + val d = df.first().getDouble(0) + // ~157km on a sphere; matches the Python TestGeographyFunctionsDataFrameAPI bound. + assertTrue(s"expected 155000 < d < 160000; got $d", d > 155000 && d < 160000) + } + + it("Passed ST_Length on POINT via DataFrame API") { + val df = sparkSession + .sql("SELECT 'POINT (1 2)' AS wkt") + .select(st_constructors.ST_GeogFromWKT(col("wkt"), lit(4326)).as("g")) + .select(st_functions.ST_Length(col("g")).as("l")) + assertEquals(0.0, df.first().getDouble(0), 0.0) + } + + it("Passed ST_Buffer via DataFrame API") { + val df = sparkSession + .sql("SELECT 'POINT (0 0)' AS wkt") + .select(st_constructors.ST_GeogFromWKT(col("wkt"), lit(4326)).as("g")) + .select(st_functions.ST_Buffer(col("g"), lit(1000.0)).as("b")) + .select(st_functions.ST_AsText(col("b")).as("t")) + val txt = df.first().getString(0) + assertTrue(s"expected POLYGON prefix; got $txt", txt.startsWith("POLYGON")) + } + + it("Passed ST_NPoints via DataFrame API") { + val df = sparkSession + .sql("SELECT 'LINESTRING (0 0, 1 1, 2 2)' AS wkt") + .select(st_constructors.ST_GeogFromWKT(col("wkt"), lit(4326)).as("g")) + .select(st_functions.ST_NPoints(col("g")).as("n")) + assertEquals(3, df.first().getInt(0)) + } + + it("Passed ST_Contains via DataFrame API") { + val df = sparkSession + .sql("SELECT 'POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))' AS a, " + + "'POINT (0.5 0.5)' AS b") + .select( + st_constructors.ST_GeogFromWKT(col("a"), lit(4326)).as("a"), + st_constructors.ST_GeogFromWKT(col("b"), lit(4326)).as("b")) + .select(st_predicates.ST_Contains(col("a"), col("b")).as("r")) + assertTrue(df.first().getBoolean(0)) + } + + it("Passed ST_Within via DataFrame API") { + val df = sparkSession + .sql("SELECT 'POINT (0.5 0.5)' AS a, " + + "'POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))' AS b") + .select( + st_constructors.ST_GeogFromWKT(col("a"), lit(4326)).as("a"), + st_constructors.ST_GeogFromWKT(col("b"), lit(4326)).as("b")) + .select(st_predicates.ST_Within(col("a"), col("b")).as("r")) + assertTrue(df.first().getBoolean(0)) + } + + it("Passed ST_DWithin via DataFrame API") { + // 1° of latitude ≈ 111 km, well within 200 km. + val df = sparkSession + .sql("SELECT 'POINT (0 0)' AS a, 'POINT (0 1)' AS b") + .select( + st_constructors.ST_GeogFromWKT(col("a"), lit(4326)).as("a"), + st_constructors.ST_GeogFromWKT(col("b"), lit(4326)).as("b")) + .select(st_predicates.ST_DWithin(col("a"), col("b"), lit(200000.0)).as("r")) + assertTrue(df.first().getBoolean(0)) + } + + it("Passed ST_Equals via DataFrame API") { + val df = sparkSession + .sql("SELECT 'POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))' AS a, " + + "'POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))' AS b") + .select( + st_constructors.ST_GeogFromWKT(col("a"), lit(4326)).as("a"), + st_constructors.ST_GeogFromWKT(col("b"), lit(4326)).as("b")) + .select(st_predicates.ST_Equals(col("a"), col("b")).as("r")) + assertTrue(df.first().getBoolean(0)) + } + }