Skip to content

Commit ce67f09

Browse files
authored
Merge pull request #1411 from codeflash-ai/comparator-nn-module
Comparator fix for `weakref`
2 parents 97531dc + 88d6e8b commit ce67f09

2 files changed

Lines changed: 1121 additions & 1 deletion

File tree

codeflash/verification/comparator.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import math
77
import re
88
import types
9+
import weakref
910
from collections import ChainMap, OrderedDict, deque
1011
from importlib.util import find_spec
1112
from typing import Any, Optional
@@ -93,7 +94,7 @@ def _get_wrapped_exception(exc: BaseException) -> Optional[BaseException]: # no
9394
return _extract_exception_from_message(str(exc))
9495

9596

96-
def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
97+
def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
9798
"""Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent."""
9899
try:
99100
# Handle exceptions specially - before type check to allow wrapper comparison
@@ -171,6 +172,17 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
171172
return True
172173
return math.isclose(orig, new)
173174

175+
# Handle weak references (e.g., found in torch.nn.LSTM/GRU modules)
176+
if isinstance(orig, weakref.ref):
177+
orig_referent = orig()
178+
new_referent = new()
179+
# Both dead refs are equal, otherwise compare referents
180+
if orig_referent is None and new_referent is None:
181+
return True
182+
if orig_referent is None or new_referent is None:
183+
return False
184+
return comparator(orig_referent, new_referent, superset_obj)
185+
174186
if HAS_JAX:
175187
import jax # type: ignore # noqa: PGH003
176188
import jax.numpy as jnp # type: ignore # noqa: PGH003

0 commit comments

Comments
 (0)