Skip to content

Commit 3dd19c6

Browse files
authored
Merge pull request #1387 from codeflash-ai/pyarrow-comparator
feat: add PyArrow support to comparator
2 parents ce67f09 + dd0e83d commit 3dd19c6

4 files changed

Lines changed: 311 additions & 6 deletions

File tree

codeflash/verification/comparator.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
HAS_XARRAY = find_spec("xarray") is not None
2727
HAS_TENSORFLOW = find_spec("tensorflow") is not None
2828
HAS_NUMBA = find_spec("numba") is not None
29+
HAS_PYARROW = find_spec("pyarrow") is not None
2930

3031
# Pattern to match pytest temp directories: /tmp/pytest-of-<user>/pytest-<N>/
3132
# These paths vary between test runs but are logically equivalent
@@ -354,13 +355,57 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
354355
return False
355356
return (orig != new).nnz == 0
356357

358+
if HAS_PYARROW:
359+
import pyarrow as pa # type: ignore # noqa: PGH003
360+
361+
if isinstance(orig, pa.Table):
362+
if orig.schema != new.schema:
363+
return False
364+
if orig.num_rows != new.num_rows:
365+
return False
366+
return bool(orig.equals(new))
367+
368+
if isinstance(orig, pa.RecordBatch):
369+
if orig.schema != new.schema:
370+
return False
371+
if orig.num_rows != new.num_rows:
372+
return False
373+
return bool(orig.equals(new))
374+
375+
if isinstance(orig, pa.ChunkedArray):
376+
if orig.type != new.type:
377+
return False
378+
if len(orig) != len(new):
379+
return False
380+
return bool(orig.equals(new))
381+
382+
if isinstance(orig, pa.Array):
383+
if orig.type != new.type:
384+
return False
385+
if len(orig) != len(new):
386+
return False
387+
return bool(orig.equals(new))
388+
389+
if isinstance(orig, pa.Scalar):
390+
if orig.type != new.type:
391+
return False
392+
# Handle null scalars
393+
if not orig.is_valid and not new.is_valid:
394+
return True
395+
if not orig.is_valid or not new.is_valid:
396+
return False
397+
return bool(orig.equals(new))
398+
399+
if isinstance(orig, (pa.Schema, pa.Field, pa.DataType)):
400+
return bool(orig.equals(new))
401+
357402
if HAS_PANDAS:
358403
import pandas # noqa: ICN001
359404

360405
if isinstance(
361406
orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray)
362407
):
363-
return orig.equals(new)
408+
return bool(orig.equals(new))
364409

365410
if isinstance(orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period)):
366411
return orig == new
@@ -407,10 +452,10 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
407452
return orig == new
408453

409454
if HAS_NUMBA:
410-
import numba # type: ignore # noqa: PGH003
411-
from numba.core.dispatcher import Dispatcher # type: ignore # noqa: PGH003
412-
from numba.typed import Dict as NumbaDict # type: ignore # noqa: PGH003
413-
from numba.typed import List as NumbaList # type: ignore # noqa: PGH003
455+
import numba
456+
from numba.core.dispatcher import Dispatcher
457+
from numba.typed import Dict as NumbaDict
458+
from numba.typed import List as NumbaList
414459

415460
# Handle numba typed List
416461
if isinstance(orig, NumbaList):

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ tests = [
8787
"jax>=0.4.30",
8888
"numpy>=2.0.2",
8989
"pandas>=2.3.3",
90+
"pyarrow>=15.0.0",
9091
"pyrsistent>=0.20.0",
9192
"scipy>=1.13.1",
9293
"torch>=2.8.0",

tests/test_comparator.py

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,137 @@ def test_pandas():
802802
assert comparator(filtered1, filtered2)
803803

804804

805+
def test_pyarrow():
806+
try:
807+
import pyarrow as pa
808+
except ImportError:
809+
pytest.skip()
810+
811+
# Test PyArrow Table
812+
table1 = pa.table({"a": [1, 2, 3], "b": [4, 5, 6]})
813+
table2 = pa.table({"a": [1, 2, 3], "b": [4, 5, 6]})
814+
table3 = pa.table({"a": [1, 2, 3], "b": [4, 5, 7]})
815+
table4 = pa.table({"a": [1, 2, 3, 4], "b": [4, 5, 6, 7]})
816+
table5 = pa.table({"a": [1, 2, 3], "c": [4, 5, 6]}) # different column name
817+
818+
assert comparator(table1, table2)
819+
assert not comparator(table1, table3)
820+
assert not comparator(table1, table4)
821+
assert not comparator(table1, table5)
822+
823+
# Test PyArrow RecordBatch
824+
batch1 = pa.RecordBatch.from_pydict({"x": [1, 2], "y": [3.0, 4.0]})
825+
batch2 = pa.RecordBatch.from_pydict({"x": [1, 2], "y": [3.0, 4.0]})
826+
batch3 = pa.RecordBatch.from_pydict({"x": [1, 2], "y": [3.0, 5.0]})
827+
batch4 = pa.RecordBatch.from_pydict({"x": [1, 2, 3], "y": [3.0, 4.0, 5.0]})
828+
829+
assert comparator(batch1, batch2)
830+
assert not comparator(batch1, batch3)
831+
assert not comparator(batch1, batch4)
832+
833+
# Test PyArrow Array
834+
arr1 = pa.array([1, 2, 3])
835+
arr2 = pa.array([1, 2, 3])
836+
arr3 = pa.array([1, 2, 4])
837+
arr4 = pa.array([1, 2, 3, 4])
838+
arr5 = pa.array([1.0, 2.0, 3.0]) # different type
839+
840+
assert comparator(arr1, arr2)
841+
assert not comparator(arr1, arr3)
842+
assert not comparator(arr1, arr4)
843+
assert not comparator(arr1, arr5)
844+
845+
# Test PyArrow Array with nulls
846+
arr_null1 = pa.array([1, None, 3])
847+
arr_null2 = pa.array([1, None, 3])
848+
arr_null3 = pa.array([1, 2, 3])
849+
850+
assert comparator(arr_null1, arr_null2)
851+
assert not comparator(arr_null1, arr_null3)
852+
853+
# Test PyArrow ChunkedArray
854+
chunked1 = pa.chunked_array([[1, 2], [3, 4]])
855+
chunked2 = pa.chunked_array([[1, 2], [3, 4]])
856+
chunked3 = pa.chunked_array([[1, 2], [3, 5]])
857+
chunked4 = pa.chunked_array([[1, 2, 3], [4, 5]])
858+
859+
assert comparator(chunked1, chunked2)
860+
assert not comparator(chunked1, chunked3)
861+
assert not comparator(chunked1, chunked4)
862+
863+
# Test PyArrow Scalar
864+
scalar1 = pa.scalar(42)
865+
scalar2 = pa.scalar(42)
866+
scalar3 = pa.scalar(43)
867+
scalar4 = pa.scalar(42.0) # different type
868+
869+
assert comparator(scalar1, scalar2)
870+
assert not comparator(scalar1, scalar3)
871+
assert not comparator(scalar1, scalar4)
872+
873+
# Test null scalars
874+
null_scalar1 = pa.scalar(None, type=pa.int64())
875+
null_scalar2 = pa.scalar(None, type=pa.int64())
876+
null_scalar3 = pa.scalar(None, type=pa.float64())
877+
878+
assert comparator(null_scalar1, null_scalar2)
879+
assert not comparator(null_scalar1, null_scalar3)
880+
881+
# Test PyArrow Schema
882+
schema1 = pa.schema([("a", pa.int64()), ("b", pa.float64())])
883+
schema2 = pa.schema([("a", pa.int64()), ("b", pa.float64())])
884+
schema3 = pa.schema([("a", pa.int64()), ("c", pa.float64())])
885+
schema4 = pa.schema([("a", pa.int32()), ("b", pa.float64())])
886+
887+
assert comparator(schema1, schema2)
888+
assert not comparator(schema1, schema3)
889+
assert not comparator(schema1, schema4)
890+
891+
# Test PyArrow Field
892+
field1 = pa.field("name", pa.int64())
893+
field2 = pa.field("name", pa.int64())
894+
field3 = pa.field("other", pa.int64())
895+
field4 = pa.field("name", pa.float64())
896+
897+
assert comparator(field1, field2)
898+
assert not comparator(field1, field3)
899+
assert not comparator(field1, field4)
900+
901+
# Test PyArrow DataType
902+
type1 = pa.int64()
903+
type2 = pa.int64()
904+
type3 = pa.int32()
905+
type4 = pa.float64()
906+
907+
assert comparator(type1, type2)
908+
assert not comparator(type1, type3)
909+
assert not comparator(type1, type4)
910+
911+
# Test string arrays
912+
str_arr1 = pa.array(["hello", "world"])
913+
str_arr2 = pa.array(["hello", "world"])
914+
str_arr3 = pa.array(["hello", "there"])
915+
916+
assert comparator(str_arr1, str_arr2)
917+
assert not comparator(str_arr1, str_arr3)
918+
919+
# Test nested types (struct)
920+
struct_arr1 = pa.array([{"x": 1, "y": 2}, {"x": 3, "y": 4}])
921+
struct_arr2 = pa.array([{"x": 1, "y": 2}, {"x": 3, "y": 4}])
922+
struct_arr3 = pa.array([{"x": 1, "y": 2}, {"x": 3, "y": 5}])
923+
924+
assert comparator(struct_arr1, struct_arr2)
925+
assert not comparator(struct_arr1, struct_arr3)
926+
927+
# Test list arrays
928+
list_arr1 = pa.array([[1, 2], [3, 4, 5]])
929+
list_arr2 = pa.array([[1, 2], [3, 4, 5]])
930+
list_arr3 = pa.array([[1, 2], [3, 4, 6]])
931+
932+
assert comparator(list_arr1, list_arr2)
933+
assert not comparator(list_arr1, list_arr3)
934+
935+
805936
def test_pyrsistent():
806937
try:
807938
from pyrsistent import PBag, PClass, PRecord, field, pdeque, pmap, pset, pvector # type: ignore
@@ -2795,7 +2926,6 @@ def test_torch_runtime_error_wrapping():
27952926
class TorchRuntimeError(Exception):
27962927
"""Mock TorchRuntimeError for testing."""
27972928

2798-
27992929
# Monkey-patch the __module__ to match torch._dynamo.exc
28002930
TorchRuntimeError.__module__ = "torch._dynamo.exc"
28012931

0 commit comments

Comments
 (0)