We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 2ffd64d + a9c5c43 commit 8f18b55Copy full SHA for 8f18b55
1 file changed
diffrax/misc/ad.py
@@ -29,7 +29,7 @@ def is_perturbed(x: Any) -> bool:
29
30
31
def nondifferentiable_input(x: PyTree, name: str) -> None:
32
- if any(is_perturbed(xi) for xi in jax.tree_leaves(x)):
+ if any(is_perturbed(xi) for xi in jtu.tree_leaves(x)):
33
raise ValueError(f"Cannot differentiate {name}.")
34
35
0 commit comments