Skip to content

Commit a8ef525

Browse files
committed
Merge branch 'update' of https://github.com/brainpy/BrainPy into update
# Conflicts: # brainpy/_src/math/object_transform/controls.py
2 parents b287b73 + 23248e6 commit a8ef525

6 files changed

Lines changed: 30 additions & 18 deletions

File tree

brainpy/_src/dyn/projections/tests/test_aligns.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def update(self):
7474
bp.visualize.raster_plot(indices * bm.dt, spks, show=True)
7575

7676
plt.close()
77-
# bm.clear_buffer_memory()
77+
bm.clear_buffer_memory()
7878

7979

8080
def test_ProjAlignPostMg2():
@@ -150,7 +150,7 @@ def update(self):
150150
bp.visualize.raster_plot(indices * bm.dt, spks, show=True)
151151

152152
plt.close()
153-
# bm.clear_buffer_memory()
153+
bm.clear_buffer_memory()
154154

155155

156156
def test_ProjAlignPost1():
@@ -185,7 +185,7 @@ def update(self, input):
185185
indices = bm.arange(400)
186186
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
187187
bp.visualize.raster_plot(indices, spks, show=True)
188-
# bm.clear_buffer_memory()
188+
bm.clear_buffer_memory()
189189
plt.close()
190190

191191

@@ -244,7 +244,7 @@ def update(self, inp):
244244
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
245245
bp.visualize.raster_plot(indices, spks, show=True)
246246

247-
# bm.clear_buffer_memory()
247+
bm.clear_buffer_memory()
248248
plt.close()
249249

250250

@@ -280,7 +280,7 @@ def update(self, input):
280280
indices = bm.arange(400)
281281
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
282282
bp.visualize.raster_plot(indices, spks, show=True)
283-
# bm.clear_buffer_memory()
283+
bm.clear_buffer_memory()
284284
plt.close()
285285

286286

@@ -338,7 +338,7 @@ def update(self, inp):
338338
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
339339
bp.visualize.raster_plot(indices, spks, show=True)
340340

341-
# bm.clear_buffer_memory()
341+
bm.clear_buffer_memory()
342342
plt.close()
343343

344344

@@ -396,7 +396,7 @@ def update(self, inp):
396396
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
397397
bp.visualize.raster_plot(indices, spks, show=True)
398398

399-
# bm.clear_buffer_memory()
399+
bm.clear_buffer_memory()
400400
plt.close()
401401

402402

@@ -437,4 +437,4 @@ def update(self, input):
437437
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices, progress_bar=True)
438438
bp.visualize.raster_plot(indices, spks, show=True)
439439
plt.close()
440-
# bm.clear_buffer_memory()
440+
bm.clear_buffer_memory()

brainpy/_src/math/delayvars.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .compat_numpy import broadcast_to, expand_dims, concatenate
1515
from .environment import get_dt, get_float
1616
from .interoperability import as_jax
17-
from .ndarray import ndarray, Array
17+
from .ndarray import ndarray, BaseArray
1818
from .object_transform.base import BrainPyObject
1919
from .object_transform.controls import cond
2020
from .object_transform.variables import Variable
@@ -29,7 +29,7 @@
2929

3030

3131
def _as_jax_array(arr):
32-
return arr.value if isinstance(arr, Array) else arr
32+
return arr.value if isinstance(arr, BaseArray) else arr
3333

3434

3535
class AbstractDelay(BrainPyObject):
@@ -129,8 +129,8 @@ def __init__(
129129
super(TimeDelay, self).__init__(name=name)
130130

131131
# shape
132-
if not isinstance(delay_target, (jnp.ndarray, Array)):
133-
raise ValueError(f'Must be an instance of Array or jax.numpy.ndarray. But we got {type(delay_target)}')
132+
if not isinstance(delay_target, (jnp.ndarray, BaseArray)):
133+
raise ValueError(f'Must be an instance of BaseArray or jax.numpy.ndarray. But we got {type(delay_target)}')
134134

135135
# delay_len
136136
self.t0 = t0
@@ -453,7 +453,7 @@ def retrieve(self, delay_len, *indices):
453453
# the delay data
454454
return self.data[indices]
455455

456-
def update(self, value: Union[numbers.Number, Array, jax.Array] = None):
456+
def update(self, value: Union[numbers.Number, BaseArray, jax.Array] = None):
457457
"""Update delay variable with the new data.
458458
459459
Parameters

brainpy/_src/math/ndarray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _get_dtype(v):
6363
dtype = canonicalize_dtype(type(v))
6464
return dtype
6565

66-
class Base:
66+
class BaseArray:
6767
@property
6868
def sharding(self):
6969
return self._value.sharding
@@ -1488,7 +1488,7 @@ def double(self): return jnp.asarray(self.value, dtype=jnp.float64)
14881488

14891489

14901490
@register_pytree_node_class
1491-
class Array(Base):
1491+
class Array(BaseArray):
14921492
"""Multiple-dimensional array in BrainPy.
14931493
14941494
Compared to ``jax.Array``, :py:class:`~.Array` has the following advantages:

brainpy/_src/math/object_transform/controls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,4 +381,4 @@ def while_loop(
381381
"""
382382
if not isinstance(operands, (tuple, list)):
383383
operands = (operands,)
384-
return brainstate.transform.while_loop(body_fun, cond_fun, *operands)
384+
return brainstate.transform.while_loop(lambda x: body_fun(*x), lambda x: cond_fun(*x), operands)

brainpy/_src/math/object_transform/variables.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from jax.dtypes import canonicalize_dtype
77
from jax.tree_util import register_pytree_node_class
88

9-
from brainpy._src.math.ndarray import Array, Base
9+
from brainpy._src.math.ndarray import Array, BaseArray
1010
from brainpy._src.math.sharding import BATCH_AXIS
1111
from brainpy.errors import MathError
1212

@@ -224,7 +224,7 @@ def __add__(self, other: dict):
224224

225225

226226
@register_pytree_node_class
227-
class Variable(brainstate.State, Base):
227+
class Variable(brainstate.State, BaseArray):
228228
"""The pointer to specify the dynamical variable.
229229
230230
Initializing an instance of ``Variable`` by two ways:

brainpy/_src/math/random.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,8 +546,20 @@ def seed(self, seed_or_key=None, seed=None):
546546
key = seed_or_key
547547
self._value = key
548548

549+
550+
def _ensure_value_exists(self):
551+
"""Ensure that the random state has a valid value, regenerate if needed."""
552+
if not isinstance(self._value, np.ndarray):
553+
with jax.ensure_compile_time_eval():
554+
if not isinstance(self._value, jax.core.Tracer):
555+
if self._value.is_deleted():
556+
seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
557+
self._value = seed_or_key
558+
559+
549560
@property
550561
def value(self):
562+
self._ensure_value_exists()
551563
record_state_value_read(self)
552564
return self._read_value()
553565

0 commit comments

Comments
 (0)