Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions python/sedona/geopandas/geodataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,12 @@ def __init__(
if isinstance(data, (GeoDataFrame, GeoSeries)):
assert dtype is None
assert not copy
self._anchor = data
self._col_label = index
super().__init__(data, index=index, dtype=dtype, copy=copy)
elif isinstance(data, (PandasOnSparkSeries, PandasOnSparkDataFrame)):
assert columns is None
assert dtype is None
assert not copy
if index is None:
internal = InternalFrame(spark_frame=data._internal.spark_frame)
object.__setattr__(self, "_internal_frame", internal)
super().__init__(data, index=index, dtype=dtype)
elif isinstance(data, SparkDataFrame):
assert columns is None
assert dtype is None
Expand All @@ -173,8 +170,13 @@ def __init__(
)
gdf = gpd.GeoDataFrame(df)
# convert each geometry column to wkb type
import shapely

for col in gdf.columns:
if isinstance(gdf[col], gpd.GeoSeries):
# It's possible we get a list, dict, pd.Series, gpd.GeoSeries, etc of shapely.Geometry objects.
if len(gdf[col]) > 0 and isinstance(
gdf[col].iloc[0], shapely.geometry.base.BaseGeometry
):
gdf[col] = gdf[col].apply(lambda geom: geom.wkb)
pdf = pd.DataFrame(gdf)
# initialize the parent class pyspark Dataframe with the pandas Series
Expand Down
77 changes: 68 additions & 9 deletions python/tests/geopandas/test_geodataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,14 @@
Point,
)

from sedona.geopandas import GeoDataFrame
from sedona.geopandas import GeoDataFrame, GeoSeries
from tests.test_base import TestBase
import pyspark.pandas as ps
import pandas as pd
import geopandas as gpd
import sedona.geopandas as sgpd
import pytest
from pandas.testing import assert_frame_equal


class TestDataframe(TestBase):
Expand All @@ -41,10 +46,52 @@ class TestDataframe(TestBase):
#
# def teardown_method(self):
# shutil.rmtree(self.tempdir)

def test_constructor(self):
df = GeoDataFrame([Point(x, x) for x in range(3)])
check_geodataframe(df)
@pytest.mark.parametrize(
"obj",
[
[Point(x, x) for x in range(3)],
{"geometry": [Point(x, x) for x in range(3)]},
pd.DataFrame([Point(x, x) for x in range(3)]),
gpd.GeoDataFrame([Point(x, x) for x in range(3)]),
pd.Series([Point(x, x) for x in range(3)]),
gpd.GeoSeries([Point(x, x) for x in range(3)]),
GeoSeries([Point(x, x) for x in range(3)]),
GeoDataFrame([Point(x, x) for x in range(3)]),
],
)
def test_constructor(self, obj):
sgpd_df = GeoDataFrame(obj)
check_geodataframe(sgpd_df)

def test_constructor_pandas_on_spark(self):
for obj in [
ps.DataFrame([Point(x, x) for x in range(3)]),
ps.Series([Point(x, x) for x in range(3)]),
]:
sgpd_df = GeoDataFrame(obj)
check_geodataframe(sgpd_df)

@pytest.mark.parametrize(
"obj",
[
[0, 1, 2],
["x", "y", "z"],
{"a": [0, 1, 2], 1: [4, 5, 6]},
{"a": ["x", "y", "z"], 1: ["a", "b", "c"]},
pd.Series([0, 1, 2]),
pd.Series(["x", "y", "z"]),
pd.DataFrame({"x": ["x", "y", "z"]}),
gpd.GeoDataFrame({"x": [0, 1, 2]}),
ps.DataFrame({"x": ["x", "y", "z"]}),
],
)
def test_non_geometry(self, obj):
pd_df = pd.DataFrame(obj)
# pd.DataFrame(obj) doesn't work correctly for pandas on spark DataFrame type, so we use to_pandas() method instead.
if isinstance(obj, ps.DataFrame):
pd_df = obj.to_pandas()
sgpd_df = sgpd.GeoDataFrame(obj)
assert_frame_equal(pd_df, sgpd_df.to_pandas())

def test_psdf(self):
# this is to make sure the spark session works with pandas on spark api
Expand Down Expand Up @@ -73,7 +120,10 @@ def test_type_single_geometry_column(self):

# Assert the geometry column has the correct type and is not nullable
geometry_field = schema["geometry1"]
assert geometry_field.dataType.typeName() == "geometrytype"
assert (
geometry_field.dataType.typeName() == "geometrytype"
or geometry_field.dataType.typeName() == "binary"
)
assert not geometry_field.nullable

# Assert non-geometry columns are present with correct types
Expand All @@ -97,16 +147,25 @@ def test_type_multiple_geometry_columns(self):
schema = df._internal.spark_frame.schema
# Assert both geometry columns have the correct type
geometry_field1 = schema["geometry1"]
assert geometry_field1.dataType.typeName() == "geometrytype"
assert (
geometry_field1.dataType.typeName() == "geometrytype"
or geometry_field1.dataType.typeName() == "binary"
)
assert not geometry_field1.nullable

geometry_field2 = schema["geometry2"]
assert geometry_field2.dataType.typeName() == "geometrytype"
assert (
geometry_field2.dataType.typeName() == "geometrytype"
or geometry_field2.dataType.typeName() == "binary"
)
assert not geometry_field2.nullable

# Check non-geometry column
attribute_field = schema["attribute"]
assert attribute_field.dataType.typeName() != "geometrytype"
assert (
attribute_field.dataType.typeName() != "geometrytype"
and attribute_field.dataType.typeName() != "binary"
)

def test_copy(self):
df = GeoDataFrame([Point(x, x) for x in range(3)], name="test_df")
Expand Down
Loading