|
6 | 6 | import math |
7 | 7 | import re |
8 | 8 | import types |
| 9 | +import weakref |
9 | 10 | from collections import ChainMap, OrderedDict, deque |
10 | 11 | from importlib.util import find_spec |
11 | 12 | from typing import Any, Optional |
@@ -93,7 +94,7 @@ def _get_wrapped_exception(exc: BaseException) -> Optional[BaseException]: # no |
93 | 94 | return _extract_exception_from_message(str(exc)) |
94 | 95 |
|
95 | 96 |
|
96 | | -def comparator(orig: Any, new: Any, superset_obj=False) -> bool: |
| 97 | +def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: |
97 | 98 | """Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent.""" |
98 | 99 | try: |
99 | 100 | # Handle exceptions specially - before type check to allow wrapper comparison |
@@ -171,6 +172,17 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: |
171 | 172 | return True |
172 | 173 | return math.isclose(orig, new) |
173 | 174 |
|
| 175 | + # Handle weak references (e.g., found in torch.nn.LSTM/GRU modules) |
| 176 | + if isinstance(orig, weakref.ref): |
| 177 | + orig_referent = orig() |
| 178 | + new_referent = new() |
| 179 | + # Both dead refs are equal, otherwise compare referents |
| 180 | + if orig_referent is None and new_referent is None: |
| 181 | + return True |
| 182 | + if orig_referent is None or new_referent is None: |
| 183 | + return False |
| 184 | + return comparator(orig_referent, new_referent, superset_obj) |
| 185 | + |
174 | 186 | if HAS_JAX: |
175 | 187 | import jax # type: ignore # noqa: PGH003 |
176 | 188 | import jax.numpy as jnp # type: ignore # noqa: PGH003 |
|
0 commit comments