diff --git a/README.md b/README.md index 2a2aaed..f32fcff 100644 --- a/README.md +++ b/README.md @@ -107,6 +107,10 @@ print(vmul(jnp.arange(4.0))) # Array([ 1., 4., 7., 10.], dtype=float32) ### `bounded_while_loop` +If you know a loop will terminate before `max_steps` steps then `bounded_while_loop` is better than a normal while loop. +`bounded_while_loop` uses `jax.lax.scan` under the hood, enabling both forward-mode and backward-mode differentiation. +Speed and efficiency are maintained by evaluating the termination condition and, when satisfied, switching to a no-op function so that, under `jax.jit`, the remaining scan steps execute a cheap no-op instead of re-running `cond_fn`/`body_fn`, even though all steps are still part of the compiled program. + Simple loop over a scalar: ```python