Skip to content

Commit 6f30df5

Browse files
committed
Comment out bm.clear_buffer_memory() in test_aligns.py
All calls to bm.clear_buffer_memory() in test_aligns.py have been commented out, likely to prevent buffer clearing during test execution or debugging. Additionally, the unused reduce_axes argument and its documentation were removed from grad() in autograd.py for code simplification.
1 parent f107719 commit 6f30df5

2 files changed

Lines changed: 8 additions & 19 deletions

File tree

brainpy/_src/dyn/projections/tests/test_aligns.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def update(self):
7474
bp.visualize.raster_plot(indices * bm.dt, spks, show=True)
7575

7676
plt.close()
77-
bm.clear_buffer_memory()
77+
# bm.clear_buffer_memory()
7878

7979

8080
def test_ProjAlignPostMg2():
@@ -150,7 +150,7 @@ def update(self):
150150
bp.visualize.raster_plot(indices * bm.dt, spks, show=True)
151151

152152
plt.close()
153-
bm.clear_buffer_memory()
153+
# bm.clear_buffer_memory()
154154

155155

156156
def test_ProjAlignPost1():
@@ -185,7 +185,7 @@ def update(self, input):
185185
indices = bm.arange(400)
186186
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
187187
bp.visualize.raster_plot(indices, spks, show=True)
188-
bm.clear_buffer_memory()
188+
# bm.clear_buffer_memory()
189189
plt.close()
190190

191191

@@ -244,7 +244,7 @@ def update(self, inp):
244244
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
245245
bp.visualize.raster_plot(indices, spks, show=True)
246246

247-
bm.clear_buffer_memory()
247+
# bm.clear_buffer_memory()
248248
plt.close()
249249

250250

@@ -280,7 +280,7 @@ def update(self, input):
280280
indices = bm.arange(400)
281281
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
282282
bp.visualize.raster_plot(indices, spks, show=True)
283-
bm.clear_buffer_memory()
283+
# bm.clear_buffer_memory()
284284
plt.close()
285285

286286

@@ -338,7 +338,7 @@ def update(self, inp):
338338
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
339339
bp.visualize.raster_plot(indices, spks, show=True)
340340

341-
bm.clear_buffer_memory()
341+
# bm.clear_buffer_memory()
342342
plt.close()
343343

344344

@@ -396,7 +396,7 @@ def update(self, inp):
396396
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
397397
bp.visualize.raster_plot(indices, spks, show=True)
398398

399-
bm.clear_buffer_memory()
399+
# bm.clear_buffer_memory()
400400
plt.close()
401401

402402

@@ -437,4 +437,4 @@ def update(self, input):
437437
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices, progress_bar=True)
438438
bp.visualize.raster_plot(indices, spks, show=True)
439439
plt.close()
440-
bm.clear_buffer_memory()
440+
# bm.clear_buffer_memory()

brainpy/_src/math/object_transform/autograd.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ def grad(
2121
argnums: Optional[Union[int, Sequence[int]]] = None,
2222
holomorphic: Optional[bool] = False,
2323
allow_int: Optional[bool] = False,
24-
reduce_axes: Optional[Sequence[str]] = (),
2524
has_aux: Optional[bool] = None,
2625
return_value: Optional[bool] = False,
2726
) -> Union[Callable, Callable[..., Callable]]:
@@ -122,14 +121,6 @@ def grad(
122121
Whether to allow differentiating with
123122
respect to integer valued inputs. The gradient of an integer input will
124123
have a trivial vector-space dtype (float0). Default False.
125-
reduce_axes: optional, tuple of int
126-
tuple of axis names. If an axis is listed here, and
127-
``fun`` implicitly broadcasts a value over that axis, the backward pass
128-
will perform a ``psum`` of the corresponding gradient. Otherwise, the
129-
gradient will be per-example over named axes. For example, if ``'batch'``
130-
is a named batch axis, ``grad(f, reduce_axes=('batch',))`` will create a
131-
function that computes the total gradient while ``grad(f)`` will create
132-
one that computes the per-example gradient.
133124
134125
Returns
135126
-------
@@ -147,7 +138,6 @@ def grad(
147138
argnums=argnums,
148139
holomorphic=holomorphic,
149140
allow_int=allow_int,
150-
reduce_axes=reduce_axes,
151141
has_aux=has_aux,
152142
return_value=return_value)
153143
else:
@@ -157,7 +147,6 @@ def grad(
157147
argnums=argnums,
158148
holomorphic=holomorphic,
159149
allow_int=allow_int,
160-
reduce_axes=reduce_axes,
161150
has_aux=has_aux,
162151
return_value=return_value
163152
)

0 commit comments

Comments
 (0)