1616import numbers
1717from typing import Union , Sequence , Any , Dict , Callable , Optional
1818
19+ import jax
1920import jax .numpy as jnp
2021
2122import brainstate
@@ -205,10 +206,8 @@ def for_loop(
205206 operands : Any ,
206207 reverse : bool = False ,
207208 unroll : int = 1 ,
208- remat : bool = False ,
209209 jit : Optional [bool ] = None ,
210210 progress_bar : bool = False ,
211- unroll_kwargs : Optional [Dict ] = None ,
212211):
213212 """``for-loop`` control flow with :py:class:`~.Variable`.
214213
@@ -266,10 +265,6 @@ def for_loop(
266265 If body function `body_func` receives multiple arguments,
267266 `operands` should be a tuple/list whose length is equal to the
268267 number of arguments.
269- remat: bool
270- Make ``fun`` recompute internal linearization points when differentiated.
271- jit: bool
272- Whether to just-in-time compile the function.
273268 reverse: bool
274269 Optional boolean specifying whether to run the scan iteration
275270 forward (the default) or in reverse, equivalent to reversing the leading
@@ -278,6 +273,8 @@ def for_loop(
278273 Optional positive int specifying, in the underlying operation of the
279274 scan primitive, how many scan iterations to unroll within a single
280275 iteration of a loop.
276+ jit: bool
277+ Whether to just-in-time compile the function. Set to ``False`` to disable JIT compilation.
281278 progress_bar: bool
282279 Whether we use the progress bar to report the running progress.
283280
@@ -296,8 +293,6 @@ def for_loop(
296293 .. deprecated:: 2.4.0
297294 No longer need to provide ``child_objs``. This function is capable of automatically
298295 collecting the children objects used in the target ``func``.
299- unroll_kwargs: dict
300- The keyword arguments without unrolling.
301296
302297 Returns::
303298
@@ -306,11 +301,21 @@ def for_loop(
306301 """
307302 if not isinstance (operands , (tuple , list )):
308303 operands = (operands ,)
309- return brainstate .transform .for_loop (
310- warp_to_no_state_input_output (body_fun ),
311- * operands , reverse = reverse , unroll = unroll ,
312- pbar = brainstate .transform .ProgressBar () if progress_bar else None ,
313- )
304+
305+ # Handle jit parameter
306+ if jit is False :
307+ with jax .disable_jit ():
308+ return brainstate .transform .for_loop (
309+ warp_to_no_state_input_output (body_fun ),
310+ * operands , reverse = reverse , unroll = unroll ,
311+ pbar = brainstate .transform .ProgressBar () if progress_bar else None ,
312+ )
313+ else :
314+ return brainstate .transform .for_loop (
315+ warp_to_no_state_input_output (body_fun ),
316+ * operands , reverse = reverse , unroll = unroll ,
317+ pbar = brainstate .transform .ProgressBar () if progress_bar else None ,
318+ )
314319
315320
316321def scan (
0 commit comments