Skip to content

Commit 12bd117

Browse files
committed
Refactor Base to BaseArray and update usages
Renamed the Base class in ndarray.py to BaseArray and updated all references in related modules, including Array and Variable. Also fixed the while_loop wrapper in controls.py to correctly unpack operands for the body and condition functions. Uncommented bm.clear_buffer_memory() calls in test_aligns.py to ensure buffer memory is cleared after tests.
1 parent 6f30df5 commit 12bd117

4 files changed

Lines changed: 13 additions & 13 deletions

File tree

brainpy/_src/dyn/projections/tests/test_aligns.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def update(self):
7474
bp.visualize.raster_plot(indices * bm.dt, spks, show=True)
7575

7676
plt.close()
77-
# bm.clear_buffer_memory()
77+
bm.clear_buffer_memory()
7878

7979

8080
def test_ProjAlignPostMg2():
@@ -150,7 +150,7 @@ def update(self):
150150
bp.visualize.raster_plot(indices * bm.dt, spks, show=True)
151151

152152
plt.close()
153-
# bm.clear_buffer_memory()
153+
bm.clear_buffer_memory()
154154

155155

156156
def test_ProjAlignPost1():
@@ -185,7 +185,7 @@ def update(self, input):
185185
indices = bm.arange(400)
186186
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
187187
bp.visualize.raster_plot(indices, spks, show=True)
188-
# bm.clear_buffer_memory()
188+
bm.clear_buffer_memory()
189189
plt.close()
190190

191191

@@ -244,7 +244,7 @@ def update(self, inp):
244244
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
245245
bp.visualize.raster_plot(indices, spks, show=True)
246246

247-
# bm.clear_buffer_memory()
247+
bm.clear_buffer_memory()
248248
plt.close()
249249

250250

@@ -280,7 +280,7 @@ def update(self, input):
280280
indices = bm.arange(400)
281281
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
282282
bp.visualize.raster_plot(indices, spks, show=True)
283-
# bm.clear_buffer_memory()
283+
bm.clear_buffer_memory()
284284
plt.close()
285285

286286

@@ -338,7 +338,7 @@ def update(self, inp):
338338
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
339339
bp.visualize.raster_plot(indices, spks, show=True)
340340

341-
# bm.clear_buffer_memory()
341+
bm.clear_buffer_memory()
342342
plt.close()
343343

344344

@@ -396,7 +396,7 @@ def update(self, inp):
396396
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
397397
bp.visualize.raster_plot(indices, spks, show=True)
398398

399-
# bm.clear_buffer_memory()
399+
bm.clear_buffer_memory()
400400
plt.close()
401401

402402

@@ -437,4 +437,4 @@ def update(self, input):
437437
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices, progress_bar=True)
438438
bp.visualize.raster_plot(indices, spks, show=True)
439439
plt.close()
440-
# bm.clear_buffer_memory()
440+
bm.clear_buffer_memory()

brainpy/_src/math/ndarray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _get_dtype(v):
6363
dtype = canonicalize_dtype(type(v))
6464
return dtype
6565

66-
class Base:
66+
class BaseArray:
6767
@property
6868
def sharding(self):
6969
return self._value.sharding
@@ -1488,7 +1488,7 @@ def double(self): return jnp.asarray(self.value, dtype=jnp.float64)
14881488

14891489

14901490
@register_pytree_node_class
1491-
class Array(Base):
1491+
class Array(BaseArray):
14921492
"""Multiple-dimensional array in BrainPy.
14931493
14941494
Compared to ``jax.Array``, :py:class:`~.Array` has the following advantages:

brainpy/_src/math/object_transform/controls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,4 +381,4 @@ def while_loop(
381381
"""
382382
if not isinstance(operands, (tuple, list)):
383383
operands = (operands,)
384-
return brainstate.transform.while_loop(body_fun, cond_fun, *operands)
384+
return brainstate.transform.while_loop(lambda x: body_fun(*x), cond_fun(*x), operands)

brainpy/_src/math/object_transform/variables.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from jax.dtypes import canonicalize_dtype
77
from jax.tree_util import register_pytree_node_class
88

9-
from brainpy._src.math.ndarray import Array, Base
9+
from brainpy._src.math.ndarray import Array, BaseArray
1010
from brainpy._src.math.sharding import BATCH_AXIS
1111
from brainpy.errors import MathError
1212

@@ -224,7 +224,7 @@ def __add__(self, other: dict):
224224

225225

226226
@register_pytree_node_class
227-
class Variable(brainstate.State, Base):
227+
class Variable(brainstate.State, BaseArray):
228228
"""The pointer to specify the dynamical variable.
229229
230230
Initializing an instance of ``Variable`` by two ways:

0 commit comments

Comments
 (0)