Skip to content

Commit 8aba927

Browse files
committed
Avoid eager optional comparator imports
1 parent d7f95a5 commit 8aba927

2 files changed

Lines changed: 84 additions & 50 deletions

File tree

codeflash/verification/comparator.py

Lines changed: 62 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import _thread
24
import array
35
import ast
@@ -15,6 +17,8 @@
1517
import weakref
1618
import xml.etree.ElementTree as ET
1719
from collections import ChainMap, OrderedDict, deque
20+
from functools import cache
21+
from importlib import import_module
1822
from importlib.util import find_spec
1923
from typing import Any, Optional
2024

@@ -35,33 +39,6 @@
3539
HAS_NUMBA = find_spec("numba") is not None
3640
HAS_PYARROW = find_spec("pyarrow") is not None
3741

38-
if HAS_NUMPY:
39-
import numpy as np
40-
if HAS_SCIPY:
41-
import scipy # type: ignore # noqa: PGH003
42-
if HAS_JAX:
43-
import jax # type: ignore # noqa: PGH003
44-
import jax.numpy as jnp # type: ignore # noqa: PGH003
45-
if HAS_XARRAY:
46-
import xarray # type: ignore # noqa: PGH003
47-
if HAS_TENSORFLOW:
48-
import tensorflow as tf # type: ignore # noqa: PGH003
49-
if HAS_SQLALCHEMY:
50-
import sqlalchemy # type: ignore # noqa: PGH003
51-
if HAS_PYARROW:
52-
import pyarrow as pa # type: ignore # noqa: PGH003
53-
if HAS_PANDAS:
54-
import pandas # noqa: ICN001
55-
if HAS_TORCH:
56-
import torch # type: ignore # noqa: PGH003
57-
if HAS_NUMBA:
58-
import numba # type: ignore # noqa: PGH003
59-
from numba.core.dispatcher import Dispatcher # type: ignore # noqa: PGH003
60-
from numba.typed import Dict as NumbaDict # type: ignore # noqa: PGH003
61-
from numba.typed import List as NumbaList # type: ignore # noqa: PGH003
62-
if HAS_PYRSISTENT:
63-
import pyrsistent # type: ignore # noqa: PGH003
64-
6542
# Pattern to match pytest temp directories: /tmp/pytest-of-<user>/pytest-<N>/
6643
# These paths vary between test runs but are logically equivalent
6744
PYTEST_TEMP_PATH_PATTERN = re.compile(r"/tmp/pytest-of-[^/]+/pytest-\d+/") # noqa: S108
@@ -117,6 +94,18 @@
11794
)
11895

11996

97+
@cache
98+
def _optional_module(module_name: str) -> Optional[Any]:
99+
if find_spec(module_name) is None:
100+
return None
101+
return import_module(module_name)
102+
103+
104+
def _object_module_matches(obj: Any, module_prefix: str) -> bool:
105+
module_name = type(obj).__module__
106+
return module_name == module_prefix or module_name.startswith(f"{module_prefix}.")
107+
108+
120109
def _normalize_temp_path(path: str) -> str:
121110
"""Normalize temporary file paths by replacing session-specific components.
122111
@@ -133,7 +122,7 @@ def _is_temp_path(s: str) -> bool:
133122
return PYTEST_TEMP_PATH_PATTERN.search(s) is not None or PYTHON_TEMPFILE_PATTERN.search(s) is not None
134123

135124

136-
def _extract_exception_from_message(msg: str) -> Optional[BaseException]: # noqa: FA100
125+
def _extract_exception_from_message(msg: str) -> Optional[BaseException]:
137126
"""Try to extract a wrapped exception type from an error message.
138127
139128
Looks for patterns like "got ExceptionType('..." that indicate a wrapped exception.
@@ -153,7 +142,7 @@ def _extract_exception_from_message(msg: str) -> Optional[BaseException]: # noq
153142
return None
154143

155144

156-
def _get_wrapped_exception(exc: BaseException) -> Optional[BaseException]: # noqa: FA100
145+
def _get_wrapped_exception(exc: BaseException) -> Optional[BaseException]:
157146
"""Get the wrapped exception if this is a simple wrapper.
158147
159148
Returns the inner exception if:
@@ -272,22 +261,26 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
272261
return False
273262
return comparator(orig_referent, new_referent, superset_obj)
274263

275-
if HAS_JAX:
264+
if HAS_JAX and (_object_module_matches(orig, "jax") or _object_module_matches(orig, "jaxlib")):
265+
jax = _optional_module("jax")
266+
jnp = _optional_module("jax.numpy")
276267
# Handle JAX arrays first to avoid boolean context errors in other conditions
277-
if isinstance(orig, jax.Array):
268+
if jax is not None and jnp is not None and isinstance(orig, jax.Array):
278269
if orig.dtype != new.dtype:
279270
return False
280271
if orig.shape != new.shape:
281272
return False
282273
return bool(jnp.allclose(orig, new, equal_nan=True))
283274

284275
# Handle xarray objects before numpy to avoid boolean context errors
285-
if HAS_XARRAY:
286-
if isinstance(orig, (xarray.Dataset, xarray.DataArray)):
276+
if HAS_XARRAY and _object_module_matches(orig, "xarray"):
277+
xarray = _optional_module("xarray")
278+
if xarray is not None and isinstance(orig, (xarray.Dataset, xarray.DataArray)):
287279
return orig.identical(new)
288280

289281
# Handle TensorFlow objects early to avoid boolean context errors
290-
if HAS_TENSORFLOW:
282+
if HAS_TENSORFLOW and _object_module_matches(orig, "tensorflow"):
283+
tf = _optional_module("tensorflow")
291284
if isinstance(orig, tf.Tensor):
292285
if orig.dtype != new.dtype:
293286
return False
@@ -313,8 +306,8 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
313306
if not comparator(orig.dense_shape.numpy(), new.dense_shape.numpy(), superset_obj):
314307
return False
315308
return comparator(orig.indices.numpy(), new.indices.numpy(), superset_obj) and comparator(
316-
orig.values.numpy(), # noqa: PD011
317-
new.values.numpy(), # noqa: PD011
309+
orig.values.numpy(),
310+
new.values.numpy(),
318311
superset_obj,
319312
)
320313

@@ -325,7 +318,8 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
325318
return False
326319
return comparator(orig.to_list(), new.to_list(), superset_obj)
327320

328-
if HAS_SQLALCHEMY:
321+
if HAS_SQLALCHEMY and (hasattr(orig, "_sa_instance_state") or _object_module_matches(orig, "sqlalchemy")):
322+
sqlalchemy = _optional_module("sqlalchemy")
329323
try:
330324
insp = sqlalchemy.inspection.inspect(orig)
331325
insp = sqlalchemy.inspection.inspect(new)
@@ -342,7 +336,10 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
342336
pass
343337

344338
# scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it
345-
if isinstance(orig, dict) and not (HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix)):
339+
scipy_sparse = (
340+
_optional_module("scipy.sparse") if HAS_SCIPY and _object_module_matches(orig, "scipy") else None
341+
)
342+
if isinstance(orig, dict) and not (scipy_sparse is not None and isinstance(orig, scipy_sparse.spmatrix)):
346343
if superset_obj:
347344
return all(k in new and comparator(v, new[k], superset_obj) for k, v in orig.items())
348345
if len(orig) != len(new):
@@ -366,7 +363,8 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
366363
if isinstance(orig, _DICT_ITEMS_TYPE):
367364
return comparator(dict(orig), dict(new), superset_obj)
368365

369-
if HAS_NUMPY:
366+
np = _optional_module("numpy") if HAS_NUMPY and _object_module_matches(orig, "numpy") else None
367+
if np is not None:
370368
if isinstance(orig, (np.datetime64, np.timedelta64)):
371369
# Handle NaT (Not a Time) - numpy's equivalent of NaN for datetime
372370
if np.isnat(orig) and np.isnat(new):
@@ -420,14 +418,15 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
420418
new_state = new.get_state(legacy=False)
421419
return comparator(orig_state, new_state, superset_obj)
422420

423-
if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix):
421+
if scipy_sparse is not None and isinstance(orig, scipy_sparse.spmatrix):
424422
if orig.dtype != new.dtype:
425423
return False
426424
if orig.get_shape() != new.get_shape():
427425
return False
428426
return (orig != new).nnz == 0
429427

430-
if HAS_PYARROW:
428+
if HAS_PYARROW and _object_module_matches(orig, "pyarrow"):
429+
pa = _optional_module("pyarrow")
431430
if isinstance(orig, pa.Table):
432431
if orig.schema != new.schema:
433432
return False
@@ -469,7 +468,8 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
469468
if isinstance(orig, (pa.Schema, pa.Field, pa.DataType)):
470469
return bool(orig.equals(new))
471470

472-
if HAS_PANDAS:
471+
if HAS_PANDAS and _object_module_matches(orig, "pandas"):
472+
pandas = _optional_module("pandas")
473473
if isinstance(
474474
orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray)
475475
):
@@ -489,17 +489,18 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
489489

490490
# This should be at the end of all numpy checking
491491
try:
492-
if HAS_NUMPY and np.isnan(orig):
492+
if np is not None and np.isnan(orig):
493493
return np.isnan(new)
494494
except Exception:
495495
pass
496496
try:
497-
if HAS_NUMPY and np.isinf(orig):
497+
if np is not None and np.isinf(orig):
498498
return np.isinf(new)
499499
except Exception:
500500
pass
501501

502-
if HAS_TORCH:
502+
if HAS_TORCH and _object_module_matches(orig, "torch"):
503+
torch = _optional_module("torch")
503504
if isinstance(orig, torch.Tensor):
504505
if orig.dtype != new.dtype:
505506
return False
@@ -517,15 +518,23 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
517518
if isinstance(orig, torch.device):
518519
return orig == new
519520

520-
if HAS_NUMBA:
521+
if HAS_NUMBA and _object_module_matches(orig, "numba"):
522+
numba = _optional_module("numba")
523+
numba_dispatcher = _optional_module("numba.core.dispatcher")
524+
numba_typed = _optional_module("numba.typed")
525+
if numba is None or numba_dispatcher is None or numba_typed is None:
526+
return False
527+
dispatcher = numba_dispatcher.Dispatcher
528+
numba_dict = numba_typed.Dict
529+
numba_list = numba_typed.List
521530
# Handle numba typed List
522-
if isinstance(orig, NumbaList):
531+
if isinstance(orig, numba_list):
523532
if len(orig) != len(new):
524533
return False
525534
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))
526535

527536
# Handle numba typed Dict
528-
if isinstance(orig, NumbaDict):
537+
if isinstance(orig, numba_dict):
529538
if superset_obj:
530539
# Allow new dict to have more keys, but all orig keys must exist with equal values
531540
return all(key in new and comparator(orig[key], new[key], superset_obj) for key in orig)
@@ -543,12 +552,15 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
543552
return orig == new
544553

545554
# Handle numba JIT-compiled functions (CPUDispatcher, etc.)
546-
if isinstance(orig, Dispatcher):
555+
if isinstance(orig, dispatcher):
547556
# Compare by identity of the underlying Python function
548557
# Two JIT functions are equal if they wrap the same Python function
549558
return orig.py_func is new.py_func
550559

551-
if HAS_PYRSISTENT:
560+
if HAS_PYRSISTENT and _object_module_matches(orig, "pyrsistent"):
561+
pyrsistent = _optional_module("pyrsistent")
562+
if pyrsistent is None:
563+
return False
552564
if isinstance(
553565
orig,
554566
(

tests/test_comparator.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import datetime
66
import decimal
77
import re
8+
import subprocess
89
import sys
910
import uuid
1011
import weakref
@@ -29,6 +30,27 @@
2930
from codeflash.verification.equivalence import compare_test_results
3031

3132

33+
def test_comparator_import_does_not_load_optional_numeric_modules() -> None:
34+
result = subprocess.run(
35+
[
36+
sys.executable,
37+
"-c",
38+
(
39+
"import sys; "
40+
"import codeflash.verification.comparator; "
41+
"loaded = {'numpy', 'pandas', 'xarray', 'numexpr'} & set(sys.modules); "
42+
"print(','.join(sorted(loaded))); "
43+
"raise SystemExit(bool(loaded))"
44+
),
45+
],
46+
capture_output=True,
47+
text=True,
48+
encoding="utf-8",
49+
check=False,
50+
)
51+
assert result.returncode == 0, result.stdout + result.stderr
52+
53+
3254
def test_basic_python_objects() -> None:
3355
a = 5
3456
b = 5

0 commit comments

Comments
 (0)