2121 Point ,
2222)
2323
24- from sedona .geopandas import GeoDataFrame
24+ from sedona .geopandas import GeoDataFrame , GeoSeries
2525from tests .test_base import TestBase
2626import pyspark .pandas as ps
27+ import pandas as pd
28+ import geopandas as gpd
29+ import sedona .geopandas as sgpd
30+ import pytest
31+ from pandas .testing import assert_frame_equal
2732
2833
2934class TestDataframe (TestBase ):
@@ -41,10 +46,52 @@ class TestDataframe(TestBase):
4146 #
4247 # def teardown_method(self):
4348 # shutil.rmtree(self.tempdir)
44-
45- def test_constructor (self ):
46- df = GeoDataFrame ([Point (x , x ) for x in range (3 )])
47- check_geodataframe (df )
49+ @pytest .mark .parametrize (
50+ "obj" ,
51+ [
52+ [Point (x , x ) for x in range (3 )],
53+ {"geometry" : [Point (x , x ) for x in range (3 )]},
54+ pd .DataFrame ([Point (x , x ) for x in range (3 )]),
55+ gpd .GeoDataFrame ([Point (x , x ) for x in range (3 )]),
56+ pd .Series ([Point (x , x ) for x in range (3 )]),
57+ gpd .GeoSeries ([Point (x , x ) for x in range (3 )]),
58+ GeoSeries ([Point (x , x ) for x in range (3 )]),
59+ GeoDataFrame ([Point (x , x ) for x in range (3 )]),
60+ ],
61+ )
62+ def test_constructor (self , obj ):
63+ sgpd_df = GeoDataFrame (obj )
64+ check_geodataframe (sgpd_df )
65+
66+ def test_constructor_pandas_on_spark (self ):
67+ for obj in [
68+ ps .DataFrame ([Point (x , x ) for x in range (3 )]),
69+ ps .Series ([Point (x , x ) for x in range (3 )]),
70+ ]:
71+ sgpd_df = GeoDataFrame (obj )
72+ check_geodataframe (sgpd_df )
73+
74+ @pytest .mark .parametrize (
75+ "obj" ,
76+ [
77+ [0 , 1 , 2 ],
78+ ["x" , "y" , "z" ],
79+ {"a" : [0 , 1 , 2 ], 1 : [4 , 5 , 6 ]},
80+ {"a" : ["x" , "y" , "z" ], 1 : ["a" , "b" , "c" ]},
81+ pd .Series ([0 , 1 , 2 ]),
82+ pd .Series (["x" , "y" , "z" ]),
83+ pd .DataFrame ({"x" : ["x" , "y" , "z" ]}),
84+ gpd .GeoDataFrame ({"x" : [0 , 1 , 2 ]}),
85+ ps .DataFrame ({"x" : ["x" , "y" , "z" ]}),
86+ ],
87+ )
88+ def test_non_geometry (self , obj ):
89+ pd_df = pd .DataFrame (obj )
90+ # pd.DataFrame(obj) doesn't work correctly for pandas on spark DataFrame type, so we use to_pandas() method instead.
91+ if isinstance (obj , ps .DataFrame ):
92+ pd_df = obj .to_pandas ()
93+ sgpd_df = sgpd .GeoDataFrame (obj )
94+ assert_frame_equal (pd_df , sgpd_df .to_pandas ())
4895
4996 def test_psdf (self ):
5097 # this is to make sure the spark session works with pandas on spark api
@@ -73,7 +120,10 @@ def test_type_single_geometry_column(self):
73120
74121 # Assert the geometry column has the correct type and is not nullable
75122 geometry_field = schema ["geometry1" ]
76- assert geometry_field .dataType .typeName () == "geometrytype"
123+ assert (
124+ geometry_field .dataType .typeName () == "geometrytype"
125+ or geometry_field .dataType .typeName () == "binary"
126+ )
77127 assert not geometry_field .nullable
78128
79129 # Assert non-geometry columns are present with correct types
@@ -97,16 +147,25 @@ def test_type_multiple_geometry_columns(self):
97147 schema = df ._internal .spark_frame .schema
98148 # Assert both geometry columns have the correct type
99149 geometry_field1 = schema ["geometry1" ]
100- assert geometry_field1 .dataType .typeName () == "geometrytype"
150+ assert (
151+ geometry_field1 .dataType .typeName () == "geometrytype"
152+ or geometry_field1 .dataType .typeName () == "binary"
153+ )
101154 assert not geometry_field1 .nullable
102155
103156 geometry_field2 = schema ["geometry2" ]
104- assert geometry_field2 .dataType .typeName () == "geometrytype"
157+ assert (
158+ geometry_field2 .dataType .typeName () == "geometrytype"
159+ or geometry_field2 .dataType .typeName () == "binary"
160+ )
105161 assert not geometry_field2 .nullable
106162
107163 # Check non-geometry column
108164 attribute_field = schema ["attribute" ]
109- assert attribute_field .dataType .typeName () != "geometrytype"
165+ assert (
166+ attribute_field .dataType .typeName () != "geometrytype"
167+ and attribute_field .dataType .typeName () != "binary"
168+ )
110169
111170 def test_copy (self ):
112171 df = GeoDataFrame ([Point (x , x ) for x in range (3 )], name = "test_df" )
0 commit comments