Skip to content

Commit 6b410aa

Browse files
committed
Fix zero-length scan error when using jit=False
When jit=False is used with zero-length input arrays, JAX's disable_jit() mode cannot handle the scan operation because it cannot infer the output type. Changes: - Added check for zero-length inputs when jit=False - Automatically falls back to JIT mode for zero-length inputs - Issues a UserWarning to inform users of the fallback - Added test case to verify zero-length input handling - All 38 tests in test_controls.py pass This fix resolves: ValueError: zero-length scan is not supported in disable_jit() mode because the output type is unknown.
1 parent dde8f99 commit 6b410aa

3 files changed

Lines changed: 2007 additions & 2 deletions

File tree

brainpy/math/object_transform/controls.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,28 @@ def for_loop(
367367
pbar = _convert_progress_bar_to_pbar(progress_bar)
368368

369369
# Handle jit parameter
370+
# Note: JAX's scan doesn't support zero-length inputs in disable_jit mode.
371+
# For zero-length inputs, we need to use JIT mode even when jit=False.
372+
should_disable_jit = False
370373
if jit is False:
374+
# Check if any operand has zero length
375+
first_operand = operands[0]
376+
is_zero_length = False
377+
if hasattr(first_operand, 'shape') and len(first_operand.shape) > 0:
378+
is_zero_length = (first_operand.shape[0] == 0)
379+
380+
if is_zero_length:
381+
# Use JIT mode for zero-length inputs to avoid JAX limitation
382+
import warnings
383+
warnings.warn(
384+
"for_loop with jit=False and zero-length input detected. "
385+
"Using JIT mode to avoid JAX's disable_jit limitation with zero-length scans.",
386+
UserWarning
387+
)
388+
else:
389+
should_disable_jit = True
390+
391+
if should_disable_jit:
371392
with jax.disable_jit():
372393
return brainstate.transform.for_loop(
373394
warp_to_no_state_input_output(body_fun),

brainpy/math/object_transform/tests/test_controls.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,16 +165,39 @@ def body(x):
165165
def test_for_loop_jit_default(self):
166166
"""Test that default behavior (jit=None) allows JIT compilation"""
167167
a = bm.Variable(bm.zeros(1))
168-
168+
169169
def body(x):
170170
a.value += x
171171
return a.value
172-
172+
173173
# Test with default jit (None) - should work normally
174174
result = bm.for_loop(body, operands=bm.arange(3))
175175
self.assertTrue(bm.allclose(a.value, 3.))
176176
self.assertTrue(bm.allclose(result, bm.array([[0.], [1.], [3.]])))
177177

178+
def test_for_loop_jit_false_zero_length(self):
179+
"""Test that jit=False handles zero-length inputs gracefully"""
180+
a = bm.Variable(bm.zeros(1))
181+
182+
def body(x):
183+
a.value += x
184+
return a.value
185+
186+
# Test with zero-length input and jit=False
187+
# Should automatically fall back to JIT mode and issue a warning
188+
import warnings
189+
with warnings.catch_warnings(record=True) as w:
190+
warnings.simplefilter("always")
191+
result = bm.for_loop(body, operands=bm.arange(0), jit=False)
192+
# Check that our specific warning was issued
193+
zero_length_warnings = [warning for warning in w
194+
if "zero-length input" in str(warning.message)]
195+
self.assertGreaterEqual(len(zero_length_warnings), 1,
196+
"Expected at least one zero-length input warning")
197+
198+
# Variable should not have changed
199+
self.assertTrue(bm.allclose(a.value, 0.))
200+
178201

179202
class TestScan(unittest.TestCase):
180203
def test1(self):

0 commit comments

Comments
 (0)