Skip to content

Commit c6cc85c

Browse files
Fixed spurious failure of test/test_adjoint.py::test_implicit
1 parent 0a59c9d commit c6cc85c

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

diffrax/_adjoint.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,14 @@ def _solve(inputs):
429429
)
430430

431431

432+
# Unwrap jaxtyping decorator during tests, so that these are global functions.
433+
# This is needed to ensure `optx.implicit_jvp` is happy.
434+
if _vf.__globals__["__name__"].startswith("jaxtyping"):
435+
_vf = _vf.__wrapped__ # pyright: ignore[reportFunctionMemberAccess]
436+
if _solve.__globals__["__name__"].startswith("jaxtyping"):
437+
_solve = _solve.__wrapped__ # pyright: ignore[reportFunctionMemberAccess]
438+
439+
432440
def _frozenset(x: Union[object, Iterable[object]]) -> frozenset[object]:
433441
try:
434442
iter_x = iter(x) # pyright: ignore

0 commit comments

Comments
 (0)