Skip to content

Commit f1e4735

Browse files
author
Codeflash Bot
committed
comparator fix
1 parent 7bf6681 commit f1e4735

1 file changed

Lines changed: 4 additions & 12 deletions

File tree

codeflash/verification/comparator.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
HAS_JAX = find_spec("jax") is not None
2525
HAS_XARRAY = find_spec("xarray") is not None
2626
HAS_TENSORFLOW = find_spec("tensorflow") is not None
27-
HAS_MLX = find_spec("mlx") is not None
2827

2928

3029
def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911
@@ -139,17 +138,6 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
139138
return False
140139
return comparator(orig.to_list(), new.to_list(), superset_obj)
141140

142-
if HAS_MLX:
143-
import mlx.core as mx # type: ignore # noqa: PGH003
144-
145-
if isinstance(orig, mx.array):
146-
if orig.dtype != new.dtype:
147-
return False
148-
if orig.shape != new.shape:
149-
return False
150-
# MLX allclose handles NaN comparison via equal_nan parameter
151-
return bool(mx.allclose(orig, new, equal_nan=True))
152-
153141
if HAS_SQLALCHEMY:
154142
import sqlalchemy # type: ignore # noqa: PGH003
155143

@@ -235,6 +223,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
235223
return False
236224
return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields)
237225

226+
# Handle np.dtype instances (including numpy.dtypes.* classes like Float64DType, Int64DType, etc.)
227+
if isinstance(orig, np.dtype):
228+
return orig == new
229+
238230
if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix):
239231
if orig.dtype != new.dtype:
240232
return False

0 commit comments

Comments
 (0)