Skip to content

Commit 70967cb

Browse files
authored
[GH-2007] Geopandas.Dataframe: Fix constructor for pandas-on-pyspark and Sedona Geopandas input types (#2008)
* Fix small constructor bug * Fix condition for converting to wkb * Fix constructor to not error on sgpd and pspd inputs * Add constructor tests for all input types, including non-geometry * pre-commit reformat * Change to BaseGeometry for shapely compatibilty * pre-commit fmt * Remove empty lst and dct test cases since diff spark versions handle differently
1 parent b6c6421 commit 70967cb

2 files changed

Lines changed: 76 additions & 15 deletions

File tree

python/sedona/geopandas/geodataframe.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,15 +141,12 @@ def __init__(
141141
if isinstance(data, (GeoDataFrame, GeoSeries)):
142142
assert dtype is None
143143
assert not copy
144-
self._anchor = data
145-
self._col_label = index
144+
super().__init__(data, index=index, dtype=dtype, copy=copy)
146145
elif isinstance(data, (PandasOnSparkSeries, PandasOnSparkDataFrame)):
147146
assert columns is None
148147
assert dtype is None
149148
assert not copy
150-
if index is None:
151-
internal = InternalFrame(spark_frame=data._internal.spark_frame)
152-
object.__setattr__(self, "_internal_frame", internal)
149+
super().__init__(data, index=index, dtype=dtype)
153150
elif isinstance(data, SparkDataFrame):
154151
assert columns is None
155152
assert dtype is None
@@ -173,8 +170,13 @@ def __init__(
173170
)
174171
gdf = gpd.GeoDataFrame(df)
175172
# convert each geometry column to wkb type
173+
import shapely
174+
176175
for col in gdf.columns:
177-
if isinstance(gdf[col], gpd.GeoSeries):
176+
# It's possible we get a list, dict, pd.Series, gpd.GeoSeries, etc of shapely.Geometry objects.
177+
if len(gdf[col]) > 0 and isinstance(
178+
gdf[col].iloc[0], shapely.geometry.base.BaseGeometry
179+
):
178180
gdf[col] = gdf[col].apply(lambda geom: geom.wkb)
179181
pdf = pd.DataFrame(gdf)
180182
# initialize the parent class pyspark Dataframe with the pandas Series

python/tests/geopandas/test_geodataframe.py

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,14 @@
2121
Point,
2222
)
2323

24-
from sedona.geopandas import GeoDataFrame
24+
from sedona.geopandas import GeoDataFrame, GeoSeries
2525
from tests.test_base import TestBase
2626
import pyspark.pandas as ps
27+
import pandas as pd
28+
import geopandas as gpd
29+
import sedona.geopandas as sgpd
30+
import pytest
31+
from pandas.testing import assert_frame_equal
2732

2833

2934
class TestDataframe(TestBase):
@@ -41,10 +46,52 @@ class TestDataframe(TestBase):
4146
#
4247
# def teardown_method(self):
4348
# shutil.rmtree(self.tempdir)
44-
45-
def test_constructor(self):
46-
df = GeoDataFrame([Point(x, x) for x in range(3)])
47-
check_geodataframe(df)
49+
@pytest.mark.parametrize(
50+
"obj",
51+
[
52+
[Point(x, x) for x in range(3)],
53+
{"geometry": [Point(x, x) for x in range(3)]},
54+
pd.DataFrame([Point(x, x) for x in range(3)]),
55+
gpd.GeoDataFrame([Point(x, x) for x in range(3)]),
56+
pd.Series([Point(x, x) for x in range(3)]),
57+
gpd.GeoSeries([Point(x, x) for x in range(3)]),
58+
GeoSeries([Point(x, x) for x in range(3)]),
59+
GeoDataFrame([Point(x, x) for x in range(3)]),
60+
],
61+
)
62+
def test_constructor(self, obj):
63+
sgpd_df = GeoDataFrame(obj)
64+
check_geodataframe(sgpd_df)
65+
66+
def test_constructor_pandas_on_spark(self):
67+
for obj in [
68+
ps.DataFrame([Point(x, x) for x in range(3)]),
69+
ps.Series([Point(x, x) for x in range(3)]),
70+
]:
71+
sgpd_df = GeoDataFrame(obj)
72+
check_geodataframe(sgpd_df)
73+
74+
@pytest.mark.parametrize(
75+
"obj",
76+
[
77+
[0, 1, 2],
78+
["x", "y", "z"],
79+
{"a": [0, 1, 2], 1: [4, 5, 6]},
80+
{"a": ["x", "y", "z"], 1: ["a", "b", "c"]},
81+
pd.Series([0, 1, 2]),
82+
pd.Series(["x", "y", "z"]),
83+
pd.DataFrame({"x": ["x", "y", "z"]}),
84+
gpd.GeoDataFrame({"x": [0, 1, 2]}),
85+
ps.DataFrame({"x": ["x", "y", "z"]}),
86+
],
87+
)
88+
def test_non_geometry(self, obj):
89+
pd_df = pd.DataFrame(obj)
90+
# pd.DataFrame(obj) doesn't work correctly for pandas on spark DataFrame type, so we use to_pandas() method instead.
91+
if isinstance(obj, ps.DataFrame):
92+
pd_df = obj.to_pandas()
93+
sgpd_df = sgpd.GeoDataFrame(obj)
94+
assert_frame_equal(pd_df, sgpd_df.to_pandas())
4895

4996
def test_psdf(self):
5097
# this is to make sure the spark session works with pandas on spark api
@@ -73,7 +120,10 @@ def test_type_single_geometry_column(self):
73120

74121
# Assert the geometry column has the correct type and is not nullable
75122
geometry_field = schema["geometry1"]
76-
assert geometry_field.dataType.typeName() == "geometrytype"
123+
assert (
124+
geometry_field.dataType.typeName() == "geometrytype"
125+
or geometry_field.dataType.typeName() == "binary"
126+
)
77127
assert not geometry_field.nullable
78128

79129
# Assert non-geometry columns are present with correct types
@@ -97,16 +147,25 @@ def test_type_multiple_geometry_columns(self):
97147
schema = df._internal.spark_frame.schema
98148
# Assert both geometry columns have the correct type
99149
geometry_field1 = schema["geometry1"]
100-
assert geometry_field1.dataType.typeName() == "geometrytype"
150+
assert (
151+
geometry_field1.dataType.typeName() == "geometrytype"
152+
or geometry_field1.dataType.typeName() == "binary"
153+
)
101154
assert not geometry_field1.nullable
102155

103156
geometry_field2 = schema["geometry2"]
104-
assert geometry_field2.dataType.typeName() == "geometrytype"
157+
assert (
158+
geometry_field2.dataType.typeName() == "geometrytype"
159+
or geometry_field2.dataType.typeName() == "binary"
160+
)
105161
assert not geometry_field2.nullable
106162

107163
# Check non-geometry column
108164
attribute_field = schema["attribute"]
109-
assert attribute_field.dataType.typeName() != "geometrytype"
165+
assert (
166+
attribute_field.dataType.typeName() != "geometrytype"
167+
and attribute_field.dataType.typeName() != "binary"
168+
)
110169

111170
def test_copy(self):
112171
df = GeoDataFrame([Point(x, x) for x in range(3)], name="test_df")

0 commit comments

Comments
 (0)