From 8ea773163a7eb8d8c4573106a5b2ad8b7716f964 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Tue, 3 Mar 2026 11:06:50 -0500 Subject: [PATCH 1/3] Update README.md --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 2a2aaed..b18841e 100644 --- a/README.md +++ b/README.md @@ -107,6 +107,10 @@ print(vmul(jnp.arange(4.0))) # Array([ 1., 4., 7., 10.], dtype=float32) ### `bounded_while_loop` +If you know a loop will terminate before `n` steps then `bounded_while_loop` is better than a normal while loop. +`bounded_while_loop` uses `jax.lax.scan` under the hood, enabling both forward-mode and backward-mode differentiation. +Speed and efficiency are maintained by evaluating the termination condition and, when satisfied, switching to a no-op function that jax compiles away under jit for all remaining steps. + Simple loop over a scalar: ```python From 76a52463b2ea252f06f5c56202c5649992b21838 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Tue, 3 Mar 2026 14:55:59 -0500 Subject: [PATCH 2/3] Update README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Nathaniel Starkman --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b18841e..754e045 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ print(vmul(jnp.arange(4.0))) # Array([ 1., 4., 7., 10.], dtype=float32) ### `bounded_while_loop` -If you know a loop will terminate before `n` steps then `bounded_while_loop` is better than a normal while loop. +If you know a loop will terminate before `max_steps` steps then `bounded_while_loop` is better than a normal while loop. `bounded_while_loop` uses `jax.lax.scan` under the hood, enabling both forward-mode and backward-mode differentiation. Speed and efficiency are maintained by evaluating the termination condition and, when satisfied, switching to a no-op function that jax compiles away under jit for all remaining steps. From 794817bf191894e75d5ca48c7ee4b3fa24d5de10 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Tue, 3 Mar 2026 14:56:26 -0500 Subject: [PATCH 3/3] Update README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Nathaniel Starkman --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 754e045..f32fcff 100644 --- a/README.md +++ b/README.md @@ -109,7 +109,7 @@ print(vmul(jnp.arange(4.0))) # Array([ 1., 4., 7., 10.], dtype=float32) If you know a loop will terminate before `max_steps` steps then `bounded_while_loop` is better than a normal while loop. `bounded_while_loop` uses `jax.lax.scan` under the hood, enabling both forward-mode and backward-mode differentiation. -Speed and efficiency are maintained by evaluating the termination condition and, when satisfied, switching to a no-op function that jax compiles away under jit for all remaining steps. +Speed and efficiency are maintained by evaluating the termination condition and, when satisfied, switching to a no-op function so that, under `jax.jit`, the remaining scan steps execute a cheap no-op instead of re-running `cond_fn`/`body_fn`, even though all steps are still part of the compiled program. Simple loop over a scalar: