1+ from __future__ import annotations
2+
13import _thread
24import array
35import ast
1517import weakref
1618import xml .etree .ElementTree as ET
1719from collections import ChainMap , OrderedDict , deque
20+ from functools import cache
21+ from importlib import import_module
1822from importlib .util import find_spec
1923from typing import Any , Optional
2024
3539HAS_NUMBA = find_spec ("numba" ) is not None
3640HAS_PYARROW = find_spec ("pyarrow" ) is not None
3741
38- if HAS_NUMPY :
39- import numpy as np
40- if HAS_SCIPY :
41- import scipy # type: ignore # noqa: PGH003
42- if HAS_JAX :
43- import jax # type: ignore # noqa: PGH003
44- import jax .numpy as jnp # type: ignore # noqa: PGH003
45- if HAS_XARRAY :
46- import xarray # type: ignore # noqa: PGH003
47- if HAS_TENSORFLOW :
48- import tensorflow as tf # type: ignore # noqa: PGH003
49- if HAS_SQLALCHEMY :
50- import sqlalchemy # type: ignore # noqa: PGH003
51- if HAS_PYARROW :
52- import pyarrow as pa # type: ignore # noqa: PGH003
53- if HAS_PANDAS :
54- import pandas # noqa: ICN001
55- if HAS_TORCH :
56- import torch # type: ignore # noqa: PGH003
57- if HAS_NUMBA :
58- import numba # type: ignore # noqa: PGH003
59- from numba .core .dispatcher import Dispatcher # type: ignore # noqa: PGH003
60- from numba .typed import Dict as NumbaDict # type: ignore # noqa: PGH003
61- from numba .typed import List as NumbaList # type: ignore # noqa: PGH003
62- if HAS_PYRSISTENT :
63- import pyrsistent # type: ignore # noqa: PGH003
64-
6542# Pattern to match pytest temp directories: /tmp/pytest-of-<user>/pytest-<N>/
6643# These paths vary between test runs but are logically equivalent
6744PYTEST_TEMP_PATH_PATTERN = re .compile (r"/tmp/pytest-of-[^/]+/pytest-\d+/" ) # noqa: S108
11794)
11895
11996
97+ @cache
98+ def _optional_module (module_name : str ) -> Optional [Any ]:
99+ if find_spec (module_name ) is None :
100+ return None
101+ return import_module (module_name )
102+
103+
104+ def _object_module_matches (obj : Any , module_prefix : str ) -> bool :
105+ module_name = type (obj ).__module__
106+ return module_name == module_prefix or module_name .startswith (f"{ module_prefix } ." )
107+
108+
120109def _normalize_temp_path (path : str ) -> str :
121110 """Normalize temporary file paths by replacing session-specific components.
122111
@@ -133,7 +122,7 @@ def _is_temp_path(s: str) -> bool:
133122 return PYTEST_TEMP_PATH_PATTERN .search (s ) is not None or PYTHON_TEMPFILE_PATTERN .search (s ) is not None
134123
135124
136- def _extract_exception_from_message (msg : str ) -> Optional [BaseException ]: # noqa: FA100
125+ def _extract_exception_from_message (msg : str ) -> Optional [BaseException ]:
137126 """Try to extract a wrapped exception type from an error message.
138127
139128 Looks for patterns like "got ExceptionType('..." that indicate a wrapped exception.
@@ -153,7 +142,7 @@ def _extract_exception_from_message(msg: str) -> Optional[BaseException]: # noq
153142 return None
154143
155144
156- def _get_wrapped_exception (exc : BaseException ) -> Optional [BaseException ]: # noqa: FA100
145+ def _get_wrapped_exception (exc : BaseException ) -> Optional [BaseException ]:
157146 """Get the wrapped exception if this is a simple wrapper.
158147
159148 Returns the inner exception if:
@@ -272,22 +261,26 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
272261 return False
273262 return comparator (orig_referent , new_referent , superset_obj )
274263
275- if HAS_JAX :
264+ if HAS_JAX and (_object_module_matches (orig , "jax" ) or _object_module_matches (orig , "jaxlib" )):
265+ jax = _optional_module ("jax" )
266+ jnp = _optional_module ("jax.numpy" )
276267 # Handle JAX arrays first to avoid boolean context errors in other conditions
277- if isinstance (orig , jax .Array ):
268+ if jax is not None and jnp is not None and isinstance (orig , jax .Array ):
278269 if orig .dtype != new .dtype :
279270 return False
280271 if orig .shape != new .shape :
281272 return False
282273 return bool (jnp .allclose (orig , new , equal_nan = True ))
283274
284275 # Handle xarray objects before numpy to avoid boolean context errors
285- if HAS_XARRAY :
286- if isinstance (orig , (xarray .Dataset , xarray .DataArray )):
276+ if HAS_XARRAY and _object_module_matches (orig , "xarray" ):
277+ xarray = _optional_module ("xarray" )
278+ if xarray is not None and isinstance (orig , (xarray .Dataset , xarray .DataArray )):
287279 return orig .identical (new )
288280
289281 # Handle TensorFlow objects early to avoid boolean context errors
290- if HAS_TENSORFLOW :
282+ if HAS_TENSORFLOW and _object_module_matches (orig , "tensorflow" ):
283+ tf = _optional_module ("tensorflow" )
291284 if isinstance (orig , tf .Tensor ):
292285 if orig .dtype != new .dtype :
293286 return False
@@ -313,8 +306,8 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
313306 if not comparator (orig .dense_shape .numpy (), new .dense_shape .numpy (), superset_obj ):
314307 return False
315308 return comparator (orig .indices .numpy (), new .indices .numpy (), superset_obj ) and comparator (
316- orig .values .numpy (), # noqa: PD011
317- new .values .numpy (), # noqa: PD011
309+ orig .values .numpy (),
310+ new .values .numpy (),
318311 superset_obj ,
319312 )
320313
@@ -325,7 +318,8 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
325318 return False
326319 return comparator (orig .to_list (), new .to_list (), superset_obj )
327320
328- if HAS_SQLALCHEMY :
321+ if HAS_SQLALCHEMY and (hasattr (orig , "_sa_instance_state" ) or _object_module_matches (orig , "sqlalchemy" )):
322+ sqlalchemy = _optional_module ("sqlalchemy" )
329323 try :
330324 insp = sqlalchemy .inspection .inspect (orig )
331325 insp = sqlalchemy .inspection .inspect (new )
@@ -342,7 +336,10 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
342336 pass
343337
344338 # scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it
345- if isinstance (orig , dict ) and not (HAS_SCIPY and isinstance (orig , scipy .sparse .spmatrix )):
339+ scipy_sparse = (
340+ _optional_module ("scipy.sparse" ) if HAS_SCIPY and _object_module_matches (orig , "scipy" ) else None
341+ )
342+ if isinstance (orig , dict ) and not (scipy_sparse is not None and isinstance (orig , scipy_sparse .spmatrix )):
346343 if superset_obj :
347344 return all (k in new and comparator (v , new [k ], superset_obj ) for k , v in orig .items ())
348345 if len (orig ) != len (new ):
@@ -366,7 +363,8 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
366363 if isinstance (orig , _DICT_ITEMS_TYPE ):
367364 return comparator (dict (orig ), dict (new ), superset_obj )
368365
369- if HAS_NUMPY :
366+ np = _optional_module ("numpy" ) if HAS_NUMPY and _object_module_matches (orig , "numpy" ) else None
367+ if np is not None :
370368 if isinstance (orig , (np .datetime64 , np .timedelta64 )):
371369 # Handle NaT (Not a Time) - numpy's equivalent of NaN for datetime
372370 if np .isnat (orig ) and np .isnat (new ):
@@ -420,14 +418,15 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
420418 new_state = new .get_state (legacy = False )
421419 return comparator (orig_state , new_state , superset_obj )
422420
423- if HAS_SCIPY and isinstance (orig , scipy . sparse .spmatrix ):
421+ if scipy_sparse is not None and isinstance (orig , scipy_sparse .spmatrix ):
424422 if orig .dtype != new .dtype :
425423 return False
426424 if orig .get_shape () != new .get_shape ():
427425 return False
428426 return (orig != new ).nnz == 0
429427
430- if HAS_PYARROW :
428+ if HAS_PYARROW and _object_module_matches (orig , "pyarrow" ):
429+ pa = _optional_module ("pyarrow" )
431430 if isinstance (orig , pa .Table ):
432431 if orig .schema != new .schema :
433432 return False
@@ -469,7 +468,8 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
469468 if isinstance (orig , (pa .Schema , pa .Field , pa .DataType )):
470469 return bool (orig .equals (new ))
471470
472- if HAS_PANDAS :
471+ if HAS_PANDAS and _object_module_matches (orig , "pandas" ):
472+ pandas = _optional_module ("pandas" )
473473 if isinstance (
474474 orig , (pandas .DataFrame , pandas .Series , pandas .Index , pandas .Categorical , pandas .arrays .SparseArray )
475475 ):
@@ -489,17 +489,18 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
489489
490490 # This should be at the end of all numpy checking
491491 try :
492- if HAS_NUMPY and np .isnan (orig ):
492+ if np is not None and np .isnan (orig ):
493493 return np .isnan (new )
494494 except Exception :
495495 pass
496496 try :
497- if HAS_NUMPY and np .isinf (orig ):
497+ if np is not None and np .isinf (orig ):
498498 return np .isinf (new )
499499 except Exception :
500500 pass
501501
502- if HAS_TORCH :
502+ if HAS_TORCH and _object_module_matches (orig , "torch" ):
503+ torch = _optional_module ("torch" )
503504 if isinstance (orig , torch .Tensor ):
504505 if orig .dtype != new .dtype :
505506 return False
@@ -517,15 +518,23 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
517518 if isinstance (orig , torch .device ):
518519 return orig == new
519520
520- if HAS_NUMBA :
521+ if HAS_NUMBA and _object_module_matches (orig , "numba" ):
522+ numba = _optional_module ("numba" )
523+ numba_dispatcher = _optional_module ("numba.core.dispatcher" )
524+ numba_typed = _optional_module ("numba.typed" )
525+ if numba is None or numba_dispatcher is None or numba_typed is None :
526+ return False
527+ dispatcher = numba_dispatcher .Dispatcher
528+ numba_dict = numba_typed .Dict
529+ numba_list = numba_typed .List
521530 # Handle numba typed List
522- if isinstance (orig , NumbaList ):
531+ if isinstance (orig , numba_list ):
523532 if len (orig ) != len (new ):
524533 return False
525534 return all (comparator (elem1 , elem2 , superset_obj ) for elem1 , elem2 in zip (orig , new ))
526535
527536 # Handle numba typed Dict
528- if isinstance (orig , NumbaDict ):
537+ if isinstance (orig , numba_dict ):
529538 if superset_obj :
530539 # Allow new dict to have more keys, but all orig keys must exist with equal values
531540 return all (key in new and comparator (orig [key ], new [key ], superset_obj ) for key in orig )
@@ -543,12 +552,15 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
543552 return orig == new
544553
545554 # Handle numba JIT-compiled functions (CPUDispatcher, etc.)
546- if isinstance (orig , Dispatcher ):
555+ if isinstance (orig , dispatcher ):
547556 # Compare by identity of the underlying Python function
548557 # Two JIT functions are equal if they wrap the same Python function
549558 return orig .py_func is new .py_func
550559
551- if HAS_PYRSISTENT :
560+ if HAS_PYRSISTENT and _object_module_matches (orig , "pyrsistent" ):
561+ pyrsistent = _optional_module ("pyrsistent" )
562+ if pyrsistent is None :
563+ return False
552564 if isinstance (
553565 orig ,
554566 (
0 commit comments