|
24 | 24 | HAS_JAX = find_spec("jax") is not None |
25 | 25 | HAS_XARRAY = find_spec("xarray") is not None |
26 | 26 | HAS_TENSORFLOW = find_spec("tensorflow") is not None |
27 | | -HAS_MLX = find_spec("mlx") is not None |
28 | 27 |
|
29 | 28 |
|
30 | 29 | def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911 |
@@ -139,17 +138,6 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 |
139 | 138 | return False |
140 | 139 | return comparator(orig.to_list(), new.to_list(), superset_obj) |
141 | 140 |
|
142 | | - if HAS_MLX: |
143 | | - import mlx.core as mx # type: ignore # noqa: PGH003 |
144 | | - |
145 | | - if isinstance(orig, mx.array): |
146 | | - if orig.dtype != new.dtype: |
147 | | - return False |
148 | | - if orig.shape != new.shape: |
149 | | - return False |
150 | | - # MLX allclose handles NaN comparison via equal_nan parameter |
151 | | - return bool(mx.allclose(orig, new, equal_nan=True)) |
152 | | - |
153 | 141 | if HAS_SQLALCHEMY: |
154 | 142 | import sqlalchemy # type: ignore # noqa: PGH003 |
155 | 143 |
|
@@ -235,6 +223,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 |
235 | 223 | return False |
236 | 224 | return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields) |
237 | 225 |
|
| 226 | + # Handle np.dtype instances (including numpy.dtypes.* classes like Float64DType, Int64DType, etc.) |
| 227 | + if isinstance(orig, np.dtype): |
| 228 | + return orig == new |
| 229 | + |
238 | 230 | if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix): |
239 | 231 | if orig.dtype != new.dtype: |
240 | 232 | return False |
|
0 commit comments