@@ -802,6 +802,137 @@ def test_pandas():
802802 assert comparator (filtered1 , filtered2 )
803803
804804
805+ def test_pyarrow ():
806+ try :
807+ import pyarrow as pa
808+ except ImportError :
809+ pytest .skip ()
810+
811+ # Test PyArrow Table
812+ table1 = pa .table ({"a" : [1 , 2 , 3 ], "b" : [4 , 5 , 6 ]})
813+ table2 = pa .table ({"a" : [1 , 2 , 3 ], "b" : [4 , 5 , 6 ]})
814+ table3 = pa .table ({"a" : [1 , 2 , 3 ], "b" : [4 , 5 , 7 ]})
815+ table4 = pa .table ({"a" : [1 , 2 , 3 , 4 ], "b" : [4 , 5 , 6 , 7 ]})
816+ table5 = pa .table ({"a" : [1 , 2 , 3 ], "c" : [4 , 5 , 6 ]}) # different column name
817+
818+ assert comparator (table1 , table2 )
819+ assert not comparator (table1 , table3 )
820+ assert not comparator (table1 , table4 )
821+ assert not comparator (table1 , table5 )
822+
823+ # Test PyArrow RecordBatch
824+ batch1 = pa .RecordBatch .from_pydict ({"x" : [1 , 2 ], "y" : [3.0 , 4.0 ]})
825+ batch2 = pa .RecordBatch .from_pydict ({"x" : [1 , 2 ], "y" : [3.0 , 4.0 ]})
826+ batch3 = pa .RecordBatch .from_pydict ({"x" : [1 , 2 ], "y" : [3.0 , 5.0 ]})
827+ batch4 = pa .RecordBatch .from_pydict ({"x" : [1 , 2 , 3 ], "y" : [3.0 , 4.0 , 5.0 ]})
828+
829+ assert comparator (batch1 , batch2 )
830+ assert not comparator (batch1 , batch3 )
831+ assert not comparator (batch1 , batch4 )
832+
833+ # Test PyArrow Array
834+ arr1 = pa .array ([1 , 2 , 3 ])
835+ arr2 = pa .array ([1 , 2 , 3 ])
836+ arr3 = pa .array ([1 , 2 , 4 ])
837+ arr4 = pa .array ([1 , 2 , 3 , 4 ])
838+ arr5 = pa .array ([1.0 , 2.0 , 3.0 ]) # different type
839+
840+ assert comparator (arr1 , arr2 )
841+ assert not comparator (arr1 , arr3 )
842+ assert not comparator (arr1 , arr4 )
843+ assert not comparator (arr1 , arr5 )
844+
845+ # Test PyArrow Array with nulls
846+ arr_null1 = pa .array ([1 , None , 3 ])
847+ arr_null2 = pa .array ([1 , None , 3 ])
848+ arr_null3 = pa .array ([1 , 2 , 3 ])
849+
850+ assert comparator (arr_null1 , arr_null2 )
851+ assert not comparator (arr_null1 , arr_null3 )
852+
853+ # Test PyArrow ChunkedArray
854+ chunked1 = pa .chunked_array ([[1 , 2 ], [3 , 4 ]])
855+ chunked2 = pa .chunked_array ([[1 , 2 ], [3 , 4 ]])
856+ chunked3 = pa .chunked_array ([[1 , 2 ], [3 , 5 ]])
857+ chunked4 = pa .chunked_array ([[1 , 2 , 3 ], [4 , 5 ]])
858+
859+ assert comparator (chunked1 , chunked2 )
860+ assert not comparator (chunked1 , chunked3 )
861+ assert not comparator (chunked1 , chunked4 )
862+
863+ # Test PyArrow Scalar
864+ scalar1 = pa .scalar (42 )
865+ scalar2 = pa .scalar (42 )
866+ scalar3 = pa .scalar (43 )
867+ scalar4 = pa .scalar (42.0 ) # different type
868+
869+ assert comparator (scalar1 , scalar2 )
870+ assert not comparator (scalar1 , scalar3 )
871+ assert not comparator (scalar1 , scalar4 )
872+
873+ # Test null scalars
874+ null_scalar1 = pa .scalar (None , type = pa .int64 ())
875+ null_scalar2 = pa .scalar (None , type = pa .int64 ())
876+ null_scalar3 = pa .scalar (None , type = pa .float64 ())
877+
878+ assert comparator (null_scalar1 , null_scalar2 )
879+ assert not comparator (null_scalar1 , null_scalar3 )
880+
881+ # Test PyArrow Schema
882+ schema1 = pa .schema ([("a" , pa .int64 ()), ("b" , pa .float64 ())])
883+ schema2 = pa .schema ([("a" , pa .int64 ()), ("b" , pa .float64 ())])
884+ schema3 = pa .schema ([("a" , pa .int64 ()), ("c" , pa .float64 ())])
885+ schema4 = pa .schema ([("a" , pa .int32 ()), ("b" , pa .float64 ())])
886+
887+ assert comparator (schema1 , schema2 )
888+ assert not comparator (schema1 , schema3 )
889+ assert not comparator (schema1 , schema4 )
890+
891+ # Test PyArrow Field
892+ field1 = pa .field ("name" , pa .int64 ())
893+ field2 = pa .field ("name" , pa .int64 ())
894+ field3 = pa .field ("other" , pa .int64 ())
895+ field4 = pa .field ("name" , pa .float64 ())
896+
897+ assert comparator (field1 , field2 )
898+ assert not comparator (field1 , field3 )
899+ assert not comparator (field1 , field4 )
900+
901+ # Test PyArrow DataType
902+ type1 = pa .int64 ()
903+ type2 = pa .int64 ()
904+ type3 = pa .int32 ()
905+ type4 = pa .float64 ()
906+
907+ assert comparator (type1 , type2 )
908+ assert not comparator (type1 , type3 )
909+ assert not comparator (type1 , type4 )
910+
911+ # Test string arrays
912+ str_arr1 = pa .array (["hello" , "world" ])
913+ str_arr2 = pa .array (["hello" , "world" ])
914+ str_arr3 = pa .array (["hello" , "there" ])
915+
916+ assert comparator (str_arr1 , str_arr2 )
917+ assert not comparator (str_arr1 , str_arr3 )
918+
919+ # Test nested types (struct)
920+ struct_arr1 = pa .array ([{"x" : 1 , "y" : 2 }, {"x" : 3 , "y" : 4 }])
921+ struct_arr2 = pa .array ([{"x" : 1 , "y" : 2 }, {"x" : 3 , "y" : 4 }])
922+ struct_arr3 = pa .array ([{"x" : 1 , "y" : 2 }, {"x" : 3 , "y" : 5 }])
923+
924+ assert comparator (struct_arr1 , struct_arr2 )
925+ assert not comparator (struct_arr1 , struct_arr3 )
926+
927+ # Test list arrays
928+ list_arr1 = pa .array ([[1 , 2 ], [3 , 4 , 5 ]])
929+ list_arr2 = pa .array ([[1 , 2 ], [3 , 4 , 5 ]])
930+ list_arr3 = pa .array ([[1 , 2 ], [3 , 4 , 6 ]])
931+
932+ assert comparator (list_arr1 , list_arr2 )
933+ assert not comparator (list_arr1 , list_arr3 )
934+
935+
805936def test_pyrsistent ():
806937 try :
807938 from pyrsistent import PBag , PClass , PRecord , field , pdeque , pmap , pset , pvector # type: ignore
@@ -2795,7 +2926,6 @@ def test_torch_runtime_error_wrapping():
27952926 class TorchRuntimeError (Exception ):
27962927 """Mock TorchRuntimeError for testing."""
27972928
2798-
27992929 # Monkey-patch the __module__ to match torch._dynamo.exc
28002930 TorchRuntimeError .__module__ = "torch._dynamo.exc"
28012931
0 commit comments