Skip to content

Commit d799f50

Browse files
authored
[GH-2004] Geopandas.GeoSeries: Implement Test Framework (#2005)
* Fix small nit in series __repr__() * Add test_non_geom_fails() * test_constructor on all different geometry types * Change Series.area return type to pd.Series to match gpd behavior and add area tests * Fix GeoSeries.to_pandas() and fix refactor tests * pre-commit * Test if sgpd_res equals sedona result and gpd result * Remove run_sedona_sql test * Rename test_geoseries.py to test_match_geopandas_series.py * Make area( return ps.Series instead of pd.Series * Add new test_geoseries to mimic the scala tests * Use smaller tests for test_geoseries and hard-code expected results * Remove check_less_precise for version compatibility
1 parent 0cc1521 commit d799f50

3 files changed

Lines changed: 292 additions & 104 deletions

File tree

python/sedona/geopandas/geoseries.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def __repr__(self) -> str:
5050
Return a string representation of the GeoSeries in WKT format.
5151
"""
5252
try:
53-
pandas_series = self.to_geopandas()
54-
return gpd.GeoSeries(pandas_series).__repr__()
53+
gpd_series = self.to_geopandas()
54+
return gpd_series.__repr__()
5555

5656
except Exception as e:
5757
# Fallback to parent's representation if conversion fails
@@ -176,7 +176,7 @@ def _process_geometry_column(
176176
A GeoSeries with the operation applied to the geometry column.
177177
"""
178178
# Find the first column with BinaryType or GeometryType
179-
first_col = self.get_first_geometry_column()
179+
first_col = self.get_first_geometry_column() # TODO: fixme
180180

181181
if first_col:
182182
data_type = self._internal.spark_frame.schema[first_col].dataType
@@ -230,9 +230,16 @@ def to_geopandas(self) -> gpd.GeoSeries:
230230
return self._to_geopandas()
231231

232232
def _to_geopandas(self) -> gpd.GeoSeries:
233-
return gpd.GeoSeries(
234-
self._to_internal_pandas().map(lambda wkb: shapely.wkb.loads(bytes(wkb)))
235-
)
233+
pd_series = self._to_internal_pandas()
234+
try:
235+
return gpd.GeoSeries(
236+
pd_series.map(lambda wkb: shapely.wkb.loads(bytes(wkb)))
237+
)
238+
except Exception as e:
239+
return gpd.GeoSeries(pd_series)
240+
241+
def to_spark_pandas(self) -> pspd.Series:
242+
return pspd.Series(self._to_internal_pandas())
236243

237244
@property
238245
def geometry(self) -> "GeoSeries":
@@ -274,7 +281,7 @@ def copy(self, deep=False):
274281
return self
275282

276283
@property
277-
def area(self) -> "GeoSeries":
284+
def area(self) -> pspd.Series:
278285
"""
279286
Returns a Series containing the area of each geometry in the GeoSeries expressed in the units of the CRS.
280287
@@ -295,7 +302,7 @@ def area(self) -> "GeoSeries":
295302
1 4.0
296303
dtype: float64
297304
"""
298-
return self._process_geometry_column("ST_Area", rename="area")
305+
return self._process_geometry_column("ST_Area", rename="area").to_spark_pandas()
299306

300307
@property
301308
def crs(self):
@@ -521,7 +528,7 @@ def buffer(
521528
mitre_limit=5.0,
522529
single_sided=False,
523530
**kwargs,
524-
):
531+
) -> "GeoSeries":
525532
"""
526533
Returns a GeoSeries of geometries representing all points within a given distance of each geometric object.
527534

python/tests/geopandas/test_geoseries.py

Lines changed: 41 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -14,109 +14,55 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
import os
18-
import shutil
19-
import tempfile
20-
from geopandas.testing import assert_geoseries_equal
2117

22-
from shapely.geometry import (
23-
Point,
24-
Polygon,
25-
)
26-
27-
from sedona.geopandas import GeoSeries
18+
import pandas as pd
19+
import geopandas as gpd
20+
import sedona.geopandas as sgpd
2821
from tests.test_base import TestBase
29-
import pyspark.pandas as ps
22+
from shapely import wkt
23+
from shapely.geometry import Point, LineString, Polygon, GeometryCollection
24+
from pandas.testing import assert_series_equal
3025

3126

32-
class TestSeries(TestBase):
27+
class TestGeoSeries(TestBase):
3328
def setup_method(self):
34-
self.tempdir = tempfile.mkdtemp()
35-
self.t1 = Polygon([(0, 0), (1, 0), (1, 1)])
36-
self.t2 = Polygon([(0, 0), (1, 1), (0, 1)])
37-
self.sq = Polygon([(0, 0), (1, 0), (1, 1), (0, 1)])
38-
self.g1 = GeoSeries([self.t1, self.t2])
39-
self.g2 = GeoSeries([self.sq, self.t1])
40-
self.g3 = GeoSeries([self.t1, self.t2], crs="epsg:4326")
41-
self.g4 = GeoSeries([self.t2, self.t1])
42-
43-
def teardown_method(self):
44-
shutil.rmtree(self.tempdir)
45-
46-
def test_constructor(self):
47-
s = GeoSeries([Point(x, x) for x in range(3)])
48-
check_geoseries_equal(s, s)
49-
50-
def test_psdf(self):
51-
# this is to make sure the spark session works with pandas on spark api
52-
psdf = ps.DataFrame(
53-
{
54-
"a": [1, 2, 3, 4, 5, 6],
55-
"b": [100, 200, 300, 400, 500, 600],
56-
"c": ["one", "two", "three", "four", "five", "six"],
57-
},
58-
index=[10, 20, 30, 40, 50, 60],
29+
self.geoseries = sgpd.GeoSeries(
30+
[
31+
Point(2.3, -1),
32+
LineString([(0.5, 0), (0, -3)]),
33+
Polygon([(-1, -1), (-0.3, 5), (1, 1.2)]),
34+
GeometryCollection(
35+
[
36+
Point(2.3, -1),
37+
LineString([(0.5, 0), (0, -3)]),
38+
Polygon([(-1, -1), (-0.3, 5), (1, 1.2)]),
39+
]
40+
),
41+
]
5942
)
60-
assert psdf.count().count() == 3
61-
62-
def test_internal_st_function(self):
63-
# this is to make sure the spark session works with internal sedona udfs
64-
baseDf = self.spark.sql(
65-
"SELECT ST_GeomFromWKT('POLYGON ((50 50 1, 50 80 2, 80 80 3, 80 50 2, 50 50 1))') as geom"
66-
)
67-
actual = baseDf.selectExpr("ST_AsText(ST_Expand(geom, 10))").first()[0]
68-
expected = "POLYGON Z((40 40 -9, 40 90 -9, 90 90 13, 90 40 13, 40 40 -9))"
69-
assert expected == actual
70-
71-
def test_type(self):
72-
assert type(self.g1) is GeoSeries
73-
assert type(self.g2) is GeoSeries
74-
assert type(self.g3) is GeoSeries
75-
assert type(self.g4) is GeoSeries
7643

77-
def test_copy(self):
78-
gc = self.g3.copy()
79-
assert type(gc) is GeoSeries
80-
assert self.g3.name == gc.name
44+
def check_sgpd_equals_gpd(self, actual: sgpd.GeoSeries, expected: gpd.GeoSeries):
45+
assert isinstance(actual, sgpd.GeoSeries)
46+
assert isinstance(expected, gpd.GeoSeries)
47+
assert len(actual) == len(expected)
48+
sgpd_result = actual.to_geopandas()
49+
for a, e in zip(sgpd_result, expected):
50+
self.assert_geometry_almost_equal(a, e)
8151

8252
def test_area(self):
83-
area = self.g1.area
84-
assert area is not None
85-
assert type(area) is GeoSeries
86-
assert area.count() == 2
53+
result = self.geoseries.area.to_pandas()
54+
expected = pd.Series([0.0, 0.0, 5.23, 5.23])
55+
assert result.count() > 0
56+
assert_series_equal(result, expected)
8757

8858
def test_buffer(self):
89-
buffer = self.g1.buffer(0.2)
90-
assert buffer is not None
91-
assert type(buffer) is GeoSeries
92-
assert buffer.count() == 2
93-
94-
def test_buffer_then_area(self):
95-
area = self.g1.buffer(0.2).area
96-
assert area is not None
97-
assert type(area) is GeoSeries
98-
assert area.count() == 2
99-
100-
def test_buffer_then_geoparquet(self):
101-
temp_file_path = os.path.join(
102-
self.tempdir, next(tempfile._get_candidate_names()) + ".parquet"
103-
)
104-
self.g1.buffer(0.2).to_parquet(temp_file_path)
105-
assert os.path.exists(temp_file_path)
106-
107-
108-
# -----------------------------------------------------------------------------
109-
# # Utils
110-
# -----------------------------------------------------------------------------
111-
112-
113-
def check_geoseries_equal(s1, s2):
114-
assert isinstance(s1, GeoSeries)
115-
assert isinstance(s1.geometry, GeoSeries)
116-
assert isinstance(s2, GeoSeries)
117-
assert isinstance(s2.geometry, GeoSeries)
118-
if isinstance(s1, GeoSeries):
119-
s1 = s1.to_geopandas()
120-
if isinstance(s2, GeoSeries):
121-
s2 = s2.to_geopandas()
122-
assert_geoseries_equal(s1, s2)
59+
result = self.geoseries.buffer(1)
60+
expected = [
61+
"POLYGON ((3.300000000000000 -1.000000000000000, 3.280785280403230 -1.195090322016128, 3.223879532511287 -1.382683432365090, 3.131469612302545 -1.555570233019602, 3.007106781186547 -1.707106781186547, 2.855570233019602 -1.831469612302545, 2.682683432365089 -1.923879532511287, 2.495090322016128 -1.980785280403230, 2.300000000000000 -2.000000000000000, 2.104909677983872 -1.980785280403230, 1.917316567634910 -1.923879532511287, 1.744429766980398 -1.831469612302545, 1.592893218813452 -1.707106781186547, 1.468530387697454 -1.555570233019602, 1.376120467488713 -1.382683432365090, 1.319214719596769 -1.195090322016129, 1.300000000000000 -1.000000000000000, 1.319214719596769 -0.804909677983872, 1.376120467488713 -0.617316567634910, 1.468530387697454 -0.444429766980398, 1.592893218813452 -0.292893218813453, 1.744429766980398 -0.168530387697455, 1.917316567634910 -0.076120467488713, 2.104909677983871 -0.019214719596770, 2.300000000000000 0.000000000000000, 2.495090322016128 -0.019214719596770, 2.682683432365090 -0.076120467488713, 2.855570233019602 -0.168530387697455, 3.007106781186547 -0.292893218813452, 3.131469612302545 -0.444429766980398, 3.223879532511286 -0.617316567634910, 3.280785280403230 -0.804909677983871, 3.300000000000000 -1.000000000000000))",
62+
"POLYGON ((0.986393923832144 -3.164398987305357, 0.935367989801224 -3.353676015097457, 0.848396388482656 -3.529361471973156, 0.728821389740875 -3.684703864350261, 0.581238193719096 -3.813733471206735, 0.411318339874827 -3.911491757111723, 0.225591752899151 -3.974221925961374, 0.031195801372873 -3.999513292546280, -0.164398987305357 -3.986393923832144, -0.353676015097457 -3.935367989801224, -0.529361471973156 -3.848396388482656, -0.684703864350260 -3.728821389740875, -0.813733471206735 -3.581238193719097, -0.911491757111723 -3.411318339874827, -0.974221925961374 -3.225591752899151, -0.999513292546279 -3.031195801372874, -0.986393923832144 -2.835601012694643, -0.486393923832144 0.164398987305357, -0.435367989801224 0.353676015097458, -0.348396388482656 0.529361471973156, -0.228821389740875 0.684703864350260, -0.081238193719096 0.813733471206735, 0.088681660125173 0.911491757111723, 0.274408247100849 0.974221925961374, 0.468804198627127 0.999513292546279, 0.664398987305357 0.986393923832144, 0.853676015097457 0.935367989801224, 1.029361471973156 0.848396388482656, 1.184703864350260 0.728821389740875, 1.313733471206735 0.581238193719096, 1.411491757111723 0.411318339874827, 1.474221925961374 0.225591752899151, 1.499513292546280 0.031195801372873, 1.486393923832144 -0.164398987305357, 0.986393923832144 -3.164398987305357))",
63+
"POLYGON ((-0.260059926604056 -1.672672793996312, -0.403493516968407 -1.802608257932399, -0.569270104475049 -1.902480890158382, -0.751180291696993 -1.968549819451744, -0.942410374326119 -1.998340340272165, -1.135797558140999 -1.990736606370705, -1.324098251632999 -1.946023426395157, -1.500259385009482 -1.865875595977814, -1.657682592935656 -1.753295165887471, -1.790471365675451 -1.612498995956065, -1.893651911234561 -1.448760806607280, -1.963359455800552 -1.268213644171327, -1.996983004332570 -1.077620158927971, -1.993263139087243 -0.884119300439822, -1.293263139087243 5.115880699560178, -1.252729137381052 5.303820984767603, -1.176977926029782 5.480530662139786, -1.068809614934931 5.639477736894415, -0.932222597700009 5.774786800970082, -0.772265752785876 5.881456214877171, -0.594851813959648 5.955542991081357, -0.406538808715662 5.994308544787506, -0.214287643700274 5.996319924510972, -0.025204797887634 5.961502780493132, 0.153720365261017 5.891144113007211, 0.315873956515097 5.787844698964485, 0.455262040354176 5.655422955350244, 0.566732198133767 5.498773793134933, 0.646163984953356 5.323687679062990, 1.946163984953356 1.523687679062990, 1.993263731568509 1.315875621036525, 1.995265095723606 1.102802318781350, 1.952077207005038 0.894142203137658, 1.865660978573300 0.699369327572194, 1.739940073395944 0.527327206003688, -0.260059926604056 -1.672672793996312))",
64+
"POLYGON ((-0.844303230213814 -1.983056850984667, -0.942410374326119 -1.998340340272165, -1.135797558140999 -1.990736606370705, -1.324098251632999 -1.946023426395157, -1.500259385009482 -1.865875595977814, -1.657682592935656 -1.753295165887471, -1.790471365675451 -1.612498995956065, -1.893651911234561 -1.448760806607280, -1.963359455800552 -1.268213644171327, -1.996983004332570 -1.077620158927971, -1.993263139087243 -0.884119300439822, -1.293263139087243 5.115880699560178, -1.252729137381052 5.303820984767603, -1.176977926029782 5.480530662139786, -1.068809614934931 5.639477736894415, -0.932222597700009 5.774786800970082, -0.772265752785876 5.881456214877171, -0.594851813959648 5.955542991081357, -0.406538808715662 5.994308544787506, -0.214287643700274 5.996319924510972, -0.025204797887634 5.961502780493132, 0.153720365261017 5.891144113007211, 0.315873956515097 5.787844698964485, 0.455262040354176 5.655422955350244, 0.566732198133767 5.498773793134933, 0.646163984953356 5.323687679062990, 1.946163984953356 1.523687679062990, 1.993263731568509 1.315875621036525, 1.995265095723606 1.102802318781350, 1.952077207005038 0.894142203137658, 1.865660978573300 0.699369327572194, 1.739940073395944 0.527327206003688, 1.471895863976614 0.232478575642425, 1.474221925961374 0.225591752899151, 1.499513292546280 0.031195801372873, 1.486393923832144 -0.164398987305357, 1.426669391220515 -0.522746182975131, 1.468530387697454 -0.444429766980398, 1.592893218813452 -0.292893218813453, 1.744429766980398 -0.168530387697455, 1.917316567634910 -0.076120467488713, 2.104909677983871 -0.019214719596770, 2.300000000000000 0.000000000000000, 2.495090322016128 -0.019214719596770, 2.682683432365090 -0.076120467488713, 2.855570233019602 -0.168530387697455, 3.007106781186547 -0.292893218813452, 3.131469612302545 -0.444429766980398, 3.223879532511286 -0.617316567634910, 3.280785280403230 -0.804909677983871, 3.300000000000000 -1.000000000000000, 3.280785280403230 -1.195090322016128, 3.223879532511287 -1.382683432365090, 3.131469612302545 -1.555570233019602, 3.007106781186547 -1.707106781186547, 2.855570233019602 -1.831469612302545, 2.682683432365089 -1.923879532511287, 2.495090322016128 -1.980785280403230, 2.300000000000000 -2.000000000000000, 2.104909677983872 -1.980785280403230, 1.917316567634910 -1.923879532511287, 1.744429766980398 -1.831469612302545, 1.592893218813452 -1.707106781186547, 1.468530387697454 -1.555570233019602, 1.376120467488713 -1.382683432365090, 1.319214719596769 -1.195090322016129, 1.317505079406277 -1.177732053860557, 0.986393923832144 -3.164398987305357, 0.935367989801224 -3.353676015097457, 0.848396388482656 -3.529361471973156, 0.728821389740875 -3.684703864350261, 0.581238193719096 -3.813733471206735, 0.411318339874827 -3.911491757111723, 0.225591752899151 -3.974221925961374, 0.031195801372873 -3.999513292546280, -0.164398987305357 -3.986393923832144, -0.353676015097457 -3.935367989801224, -0.529361471973156 -3.848396388482656, -0.684703864350260 -3.728821389740875, -0.813733471206735 -3.581238193719097, -0.911491757111723 -3.411318339874827, -0.974221925961374 -3.225591752899151, -0.999513292546279 -3.031195801372874, -0.986393923832144 -2.835601012694643, -0.844303230213814 -1.983056850984667))",
65+
]
66+
expected = gpd.GeoSeries([wkt.loads(wkt_str) for wkt_str in expected])
67+
assert result.count() > 0
68+
self.check_sgpd_equals_gpd(result, expected)

0 commit comments

Comments
 (0)