Skip to content

Commit 8dad989

Browse files
committed
Fix random state regeneration logic in RandomState
Updated the _ensure_value_exists method to properly check the type of self._value and only regenerate the random state if it is not a numpy array, not a JAX Tracer, and has been deleted. This prevents unnecessary regeneration and ensures correct handling of different value types.
1 parent 8d924dd commit 8dad989

1 file changed

Lines changed: 6 additions & 3 deletions

File tree

brainpy/_src/math/random.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -546,12 +546,15 @@ def seed(self, seed_or_key=None, seed=None):
546546
key = seed_or_key
547547
self._value = key
548548

549+
549550
def _ensure_value_exists(self):
550551
"""Ensure that the random state has a valid value, regenerate if needed."""
551-
if not isinstance(self._value, np.ndarray) and self._value.is_deleted():
552+
if not isinstance(self._value, np.ndarray):
552553
with jax.ensure_compile_time_eval():
553-
seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
554-
self._value = seed_or_key
554+
if not isinstance(self._value, jax.core.Tracer):
555+
if self._value.is_deleted():
556+
seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
557+
self._value = seed_or_key
555558

556559

557560
@property

0 commit comments

Comments
 (0)