We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0a59c9d commit c6cc85cCopy full SHA for c6cc85c
1 file changed
diffrax/_adjoint.py
@@ -429,6 +429,14 @@ def _solve(inputs):
429
)
430
431
432
+# Unwrap jaxtyping decorator during tests, so that these are global functions.
433
+# This is needed to ensure `optx.implicit_jvp` is happy.
434
+if _vf.__globals__["__name__"].startswith("jaxtyping"):
435
+ _vf = _vf.__wrapped__ # pyright: ignore[reportFunctionMemberAccess]
436
+if _solve.__globals__["__name__"].startswith("jaxtyping"):
437
+ _solve = _solve.__wrapped__ # pyright: ignore[reportFunctionMemberAccess]
438
+
439
440
def _frozenset(x: Union[object, Iterable[object]]) -> frozenset[object]:
441
try:
442
iter_x = iter(x) # pyright: ignore
0 commit comments