Skip to content

Commit fa2f67c

Browse files
CopilotRouthleck
andcommitted
Fix bm.for_loop jit parameter handling and remove unused parameters
Co-authored-by: Routhleck <88108241+Routhleck@users.noreply.github.com>
1 parent 4a72e16 commit fa2f67c

2 files changed

Lines changed: 50 additions & 13 deletions

File tree

brainpy/math/object_transform/controls.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numbers
1717
from typing import Union, Sequence, Any, Dict, Callable, Optional
1818

19+
import jax
1920
import jax.numpy as jnp
2021

2122
import brainstate
@@ -205,10 +206,8 @@ def for_loop(
205206
operands: Any,
206207
reverse: bool = False,
207208
unroll: int = 1,
208-
remat: bool = False,
209209
jit: Optional[bool] = None,
210210
progress_bar: bool = False,
211-
unroll_kwargs: Optional[Dict] = None,
212211
):
213212
"""``for-loop`` control flow with :py:class:`~.Variable`.
214213
@@ -266,10 +265,6 @@ def for_loop(
266265
If body function `body_func` receives multiple arguments,
267266
`operands` should be a tuple/list whose length is equal to the
268267
number of arguments.
269-
remat: bool
270-
Make ``fun`` recompute internal linearization points when differentiated.
271-
jit: bool
272-
Whether to just-in-time compile the function.
273268
reverse: bool
274269
Optional boolean specifying whether to run the scan iteration
275270
forward (the default) or in reverse, equivalent to reversing the leading
@@ -278,6 +273,8 @@ def for_loop(
278273
Optional positive int specifying, in the underlying operation of the
279274
scan primitive, how many scan iterations to unroll within a single
280275
iteration of a loop.
276+
jit: bool
277+
Whether to just-in-time compile the function. Set to ``False`` to disable JIT compilation.
281278
progress_bar: bool
282279
Whether we use the progress bar to report the running progress.
283280
@@ -296,8 +293,6 @@ def for_loop(
296293
.. deprecated:: 2.4.0
297294
No longer need to provide ``child_objs``. This function is capable of automatically
298295
collecting the children objects used in the target ``func``.
299-
unroll_kwargs: dict
300-
The keyword arguments without unrolling.
301296
302297
Returns::
303298
@@ -306,11 +301,21 @@ def for_loop(
306301
"""
307302
if not isinstance(operands, (tuple, list)):
308303
operands = (operands,)
309-
return brainstate.transform.for_loop(
310-
warp_to_no_state_input_output(body_fun),
311-
*operands, reverse=reverse, unroll=unroll,
312-
pbar=brainstate.transform.ProgressBar() if progress_bar else None,
313-
)
304+
305+
# Handle jit parameter
306+
if jit is False:
307+
with jax.disable_jit():
308+
return brainstate.transform.for_loop(
309+
warp_to_no_state_input_output(body_fun),
310+
*operands, reverse=reverse, unroll=unroll,
311+
pbar=brainstate.transform.ProgressBar() if progress_bar else None,
312+
)
313+
else:
314+
return brainstate.transform.for_loop(
315+
warp_to_no_state_input_output(body_fun),
316+
*operands, reverse=reverse, unroll=unroll,
317+
pbar=brainstate.transform.ProgressBar() if progress_bar else None,
318+
)
314319

315320

316321
def scan(

brainpy/math/object_transform/tests/test_controls.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,38 @@ def update(self):
7777
bm.for_loop(cls.step_run, indices)
7878
self.assertTrue(bm.allclose(cls.a, 10.))
7979

80+
def test_for_loop_jit_false(self):
81+
"""Test that jit=False disables JIT compilation"""
82+
a = bm.Variable(bm.zeros(1))
83+
call_count = {'count': 0}
84+
85+
def body(x):
86+
# This side effect should be visible when jit=False
87+
call_count['count'] += 1
88+
a.value += x
89+
return a.value
90+
91+
# Test with jit=False - should execute eagerly
92+
a.value = bm.zeros(1)
93+
call_count['count'] = 0
94+
result = bm.for_loop(body, operands=bm.arange(3), jit=False)
95+
# With jit=False, the function should be called 3 times
96+
self.assertEqual(call_count['count'], 3)
97+
self.assertTrue(bm.allclose(a.value, 3.))
98+
99+
def test_for_loop_jit_default(self):
100+
"""Test that default behavior (jit=None) allows JIT compilation"""
101+
a = bm.Variable(bm.zeros(1))
102+
103+
def body(x):
104+
a.value += x
105+
return a.value
106+
107+
# Test with default jit (None) - should work normally
108+
result = bm.for_loop(body, operands=bm.arange(3))
109+
self.assertTrue(bm.allclose(a.value, 3.))
110+
self.assertTrue(bm.allclose(result, bm.array([[0.], [1.], [3.]])))
111+
80112

81113
class TestScan(unittest.TestCase):
82114
def test1(self):

0 commit comments

Comments
 (0)