Skip to content

Commit 6990c52

Browse files
committed
Refactor STDP weight update logic and improve error handling, requiring brainevent>=0.0.4
1 parent 6e3874f commit 6990c52

6 files changed

Lines changed: 61 additions & 341 deletions

File tree

brainpy/_src/dnn/linear.py

Lines changed: 12 additions & 315 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import jax.numpy as jnp
99
import numpy as np
1010

11+
from brainevent._plasticity_dense import dense_on_pre, dense_on_post
12+
from brainevent._plasticity_csr import csr_on_pre, csr2csc_on_post
1113
from brainpy import math as bm
1214
from brainpy._src import connect, initialize as init
1315
from brainpy._src.context import share
@@ -207,10 +209,8 @@ def stdp_update(
207209
w_min: numbers.Number = None,
208210
w_max: numbers.Number = None
209211
):
210-
if isinstance(self.W, float):
211-
raise ValueError(f'Cannot update the weight of a constant node.')
212212
if not isinstance(self.W, bm.Variable):
213-
self.tracing_variable('W', self.W, self.W.shape)
213+
raise ValueError(f'When using STDP to update synaptic weights, the weight must be a variable.')
214214
if on_pre is not None:
215215
spike = on_pre['spike']
216216
trace = on_pre['trace']
@@ -235,124 +235,6 @@ def update(self, x):
235235
return x
236236

237237

238-
if False:
239-
240-
# @numba.njit(nogil=True, fastmath=True, parallel=False)
241-
# def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w):
242-
# out_w[:] = weight
243-
# for i in numba.prange(spike.shape[0]):
244-
# if spike[i]:
245-
# out_w[:, i] = np.clip(out_w[:, i] + trace, w_min, w_max)
246-
247-
@ti.kernel
248-
def _dense_on_post(
249-
old_w: ti.types.ndarray(ndim=2),
250-
post_spike: ti.types.ndarray(ndim=1),
251-
pre_trace: ti.types.ndarray(ndim=1),
252-
w_min: ti.types.ndarray(ndim=1),
253-
w_max: ti.types.ndarray(ndim=1),
254-
out_w: ti.types.ndarray(ndim=2)
255-
):
256-
w_min0 = w_min[0]
257-
w_max0 = w_max[0]
258-
num_pre, num_post = out_w.shape
259-
260-
for i, j in ti.ndrange(num_pre, num_post):
261-
if post_spike[j]:
262-
new_value = out_w[i, j] + pre_trace[i]
263-
if new_value < w_min0:
264-
out_w[i, j] = w_min0
265-
elif new_value > w_max0:
266-
out_w[i, j] = w_max0
267-
else:
268-
out_w[i, j] = new_value
269-
else:
270-
out_w[i, j] = old_w[i, j]
271-
272-
273-
dense_on_post_prim = bti.XLACustomOp(cpu_kernel=_dense_on_post, gpu_kernel=_dense_on_post)
274-
275-
276-
# @numba.njit(nogil=True, fastmath=True, parallel=False)
277-
# def _cpu_dense_on_pre(weight, spike, trace, w_min, w_max, out_w):
278-
# out_w[:] = weight
279-
# for i in numba.prange(spike.shape[0]):
280-
# if spike[i]:
281-
# out_w[i] = np.clip(out_w[i] + trace, w_min, w_max)
282-
283-
@ti.kernel
284-
def _dense_on_pre(
285-
old_w: ti.types.ndarray(ndim=2),
286-
pre_spike: ti.types.ndarray(ndim=1),
287-
post_trace: ti.types.ndarray(ndim=1),
288-
w_min: ti.types.ndarray(ndim=1),
289-
w_max: ti.types.ndarray(ndim=1),
290-
out_w: ti.types.ndarray(ndim=2)
291-
):
292-
w_min0 = w_min[0]
293-
w_max0 = w_max[0]
294-
num_pre, num_post = out_w.shape
295-
296-
for i, j in ti.ndrange(num_pre, num_post):
297-
if pre_spike[i]:
298-
new_value = out_w[i, j] + post_trace[j]
299-
if new_value < w_min0:
300-
out_w[i, j] = w_min0
301-
elif new_value > w_max0:
302-
out_w[i, j] = w_max0
303-
else:
304-
out_w[i, j] = new_value
305-
else:
306-
out_w[i, j] = old_w[i, j]
307-
308-
309-
dense_on_pre_prim = bti.XLACustomOp(cpu_kernel=_dense_on_pre, gpu_kernel=_dense_on_pre)
310-
311-
else:
312-
dense_on_pre_prim = None
313-
dense_on_post_prim = None
314-
315-
316-
def dense_on_pre(weight, spike, trace, w_min, w_max):
317-
if dense_on_pre_prim is None:
318-
raise_braintaichi_not_found()
319-
320-
if w_min is None:
321-
w_min = -np.inf
322-
if w_max is None:
323-
w_max = np.inf
324-
w_min = jnp.atleast_1d(w_min)
325-
w_max = jnp.atleast_1d(w_max)
326-
327-
weight = bm.as_jax(weight)
328-
spike = bm.as_jax(spike)
329-
trace = bm.as_jax(trace)
330-
w_min = bm.as_jax(w_min)
331-
w_max = bm.as_jax(w_max)
332-
return dense_on_pre_prim(weight, spike, trace, w_min, w_max,
333-
outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0]
334-
335-
336-
def dense_on_post(weight, spike, trace, w_min, w_max):
337-
if dense_on_post_prim is None:
338-
raise_braintaichi_not_found()
339-
340-
if w_min is None:
341-
w_min = -np.inf
342-
if w_max is None:
343-
w_max = np.inf
344-
w_min = jnp.atleast_1d(w_min)
345-
w_max = jnp.atleast_1d(w_max)
346-
347-
weight = bm.as_jax(weight)
348-
spike = bm.as_jax(spike)
349-
trace = bm.as_jax(trace)
350-
w_min = bm.as_jax(w_min)
351-
w_max = bm.as_jax(w_max)
352-
return dense_on_post_prim(weight, spike, trace, w_min, w_max,
353-
outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0]
354-
355-
356238
class AllToAll(Layer, SupportSTDP):
357239
"""Synaptic matrix multiplication with All2All connections.
358240
@@ -608,8 +490,10 @@ def stdp_update(
608490
if on_pre is not None: # update on presynaptic spike
609491
spike = on_pre['spike']
610492
trace = on_pre['trace']
611-
self.weight.value = csr_on_pre_update(self.weight.value, self.indices, self.indptr, spike, trace, w_min,
612-
w_max)
493+
self.weight.value = csr_on_pre(
494+
self.weight.value, self.indices, self.indptr, spike, trace, w_min, w_max,
495+
shape=(spike.shape[0], trace.shape[0]),
496+
)
613497
if on_post is not None: # update on postsynaptic spike
614498
if not hasattr(self, '_pre_ids'):
615499
with jax.ensure_compile_time_eval():
@@ -618,8 +502,11 @@ def stdp_update(
618502
)
619503
spike = on_post['spike']
620504
trace = on_post['trace']
621-
self.weight.value = csc_on_post_update(self.weight.value, self._pre_ids, self._post_indptr,
622-
self.w_indices, spike, trace, w_min, w_max)
505+
self.weight.value = csr2csc_on_post(
506+
self.weight.value, self._pre_ids, self._post_indptr,
507+
self.w_indices, spike, trace, w_min, w_max,
508+
shape=(trace.shape[0], spike.shape[0]),
509+
)
623510

624511

625512
class CSRLinear(_CSRLayer):
@@ -722,196 +609,6 @@ def _batch_csrmv(self, x):
722609
transpose=self.transpose)
723610

724611

725-
if False:
726-
@ti.kernel
727-
def _csr_on_pre_update(
728-
old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
729-
indices: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
730-
indptr: ti.types.ndarray(ndim=1), # vector with shape of (num_pre + 1)
731-
spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,)
732-
trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,)
733-
w_min: ti.types.ndarray(ndim=1), # scalar
734-
w_max: ti.types.ndarray(ndim=1), # scalar
735-
out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn)
736-
):
737-
w_min0 = w_min[0]
738-
w_max0 = w_max[0]
739-
num_pre = spike.shape[0]
740-
for i_pre in range(num_pre):
741-
if spike[i_pre]:
742-
for i_syn in range(indptr[i_pre], indptr[i_pre + 1]):
743-
out_w[i_syn] = min(max(old_w[i_syn] + trace[indices[i_syn]], w_min0), w_max0)
744-
else:
745-
for i_syn in range(indptr[i_pre], indptr[i_pre + 1]):
746-
out_w[i_syn] = old_w[i_syn]
747-
748-
749-
csr_on_pre_update_prim = bti.XLACustomOp(cpu_kernel=_csr_on_pre_update, gpu_kernel=_csr_on_pre_update)
750-
751-
752-
@ti.kernel
753-
def _coo_on_pre_update(
754-
old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
755-
pre_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
756-
post_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
757-
pre_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,)
758-
post_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,)
759-
w_min: ti.types.ndarray(ndim=1), # scalar
760-
w_max: ti.types.ndarray(ndim=1), # scalar
761-
out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn)
762-
):
763-
w_min0 = w_min[0]
764-
w_max0 = w_max[0]
765-
num_syn = old_w.shape[0]
766-
for i_syn in range(num_syn):
767-
if pre_spike[pre_ids[i_syn]]: # pre spike
768-
out_w[i_syn] = min(max(old_w[i_syn] + post_trace[post_ids[i_syn]], w_min0), w_max0)
769-
else:
770-
out_w[i_syn] = old_w[i_syn]
771-
772-
773-
coo_on_pre_update_prim = bti.XLACustomOp(cpu_kernel=_coo_on_pre_update, gpu_kernel=_coo_on_pre_update)
774-
775-
776-
@ti.kernel
777-
def _coo_on_post_update(
778-
old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
779-
pre_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
780-
post_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
781-
post_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,)
782-
pre_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,)
783-
w_min: ti.types.ndarray(ndim=1), # scalar
784-
w_max: ti.types.ndarray(ndim=1), # scalar
785-
out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn)
786-
):
787-
w_min0 = w_min[0]
788-
w_max0 = w_max[0]
789-
num_syn = old_w.shape[0]
790-
for i_syn in range(num_syn):
791-
if post_spike[post_ids[i_syn]]: # pre spike
792-
out_w[i_syn] = min(max(old_w[i_syn] + pre_trace[pre_ids[i_syn]], w_min0), w_max0)
793-
else:
794-
out_w[i_syn] = old_w[i_syn]
795-
796-
797-
coo_on_post_update_prim = bti.XLACustomOp(cpu_kernel=_coo_on_post_update, gpu_kernel=_coo_on_post_update)
798-
799-
800-
# @numba.njit(nogil=True, fastmath=True, parallel=False)
801-
# def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w):
802-
# out_w[:] = w
803-
# w_min = w_min[()]
804-
# w_max = w_max[()]
805-
# for i in numba.prange(spike.shape[0]): # post id
806-
# if spike[i]:
807-
# for k in range(indptr[i], indptr[i + 1]):
808-
# j = post_ids[k] # pre id
809-
# l = w_ids[k] # syn id
810-
# out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max)
811-
812-
@ti.kernel
813-
def _csc_on_post_update(
814-
old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
815-
indices: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
816-
indptr: ti.types.ndarray(ndim=1), # vector with shape of (num_post + 1)
817-
w_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
818-
post_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_post,)
819-
pre_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,)
820-
w_min: ti.types.ndarray(ndim=1), # scalar
821-
w_max: ti.types.ndarray(ndim=1), # scalar
822-
out_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn)
823-
):
824-
w_min0 = w_min[0]
825-
w_max0 = w_max[0]
826-
num_post = post_spike.shape[0]
827-
for i_post in range(num_post):
828-
if post_spike[i_post]:
829-
for k in range(indptr[i_post], indptr[i_post + 1]):
830-
i_syn = w_ids[k] # syn id
831-
out_w[i_syn] = min(max(old_w[i_syn] + pre_trace[indices[k]], w_min0), w_max0)
832-
else:
833-
for k in range(indptr[i_post], indptr[i_post + 1]):
834-
i_syn = w_ids[k] # syn id
835-
out_w[i_syn] = old_w[i_syn]
836-
837-
838-
csc_on_post_update_prim = bti.XLACustomOp(cpu_kernel=_csc_on_post_update, gpu_kernel=_csc_on_post_update)
839-
840-
841-
else:
842-
csr_on_pre_update_prim = None
843-
coo_on_pre_update_prim = None
844-
csc_on_post_update_prim = None
845-
846-
847-
def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None):
848-
if csr_on_pre_update_prim is None:
849-
raise_braintaichi_not_found()
850-
851-
if w_min is None:
852-
w_min = -np.inf
853-
if w_max is None:
854-
w_max = np.inf
855-
w_min = jnp.atleast_1d(w_min)
856-
w_max = jnp.atleast_1d(w_max)
857-
858-
w = bm.as_jax(w)
859-
indices = bm.as_jax(indices)
860-
indptr = bm.as_jax(indptr)
861-
spike = bm.as_jax(spike)
862-
trace = bm.as_jax(trace)
863-
w_min = bm.as_jax(w_min)
864-
w_max = bm.as_jax(w_max)
865-
return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max,
866-
outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]
867-
868-
869-
def coo_on_pre_update(w, pre_ids, post_ids, spike, trace, w_min=None, w_max=None):
870-
if coo_on_pre_update_prim is None:
871-
raise_braintaichi_not_found()
872-
873-
if w_min is None:
874-
w_min = -np.inf
875-
if w_max is None:
876-
w_max = np.inf
877-
w_min = jnp.atleast_1d(w_min)
878-
w_max = jnp.atleast_1d(w_max)
879-
880-
w = bm.as_jax(w)
881-
pre_ids = bm.as_jax(pre_ids)
882-
post_ids = bm.as_jax(post_ids)
883-
spike = bm.as_jax(spike)
884-
trace = bm.as_jax(trace)
885-
w_min = bm.as_jax(w_min)
886-
w_max = bm.as_jax(w_max)
887-
888-
return coo_on_pre_update_prim(w, pre_ids, post_ids, spike, trace, w_min, w_max,
889-
outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]
890-
891-
892-
def csc_on_post_update(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min=None, w_max=None):
893-
if csc_on_post_update_prim is None:
894-
raise_braintaichi_not_found()
895-
896-
if w_min is None:
897-
w_min = -np.inf
898-
if w_max is None:
899-
w_max = np.inf
900-
w_min = jnp.atleast_1d(w_min)
901-
w_max = jnp.atleast_1d(w_max)
902-
903-
w = bm.as_jax(w)
904-
post_ids = bm.as_jax(post_ids)
905-
indptr = bm.as_jax(indptr)
906-
w_ids = bm.as_jax(w_ids)
907-
post_spike = bm.as_jax(post_spike)
908-
pre_trace = bm.as_jax(pre_trace)
909-
w_min = bm.as_jax(w_min)
910-
w_max = bm.as_jax(w_max)
911-
return csc_on_post_update_prim(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min, w_max,
912-
outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0]
913-
914-
915612
class CSCLinear(Layer):
916613
r"""Synaptic matrix multiplication with CSC sparse computation.
917614

brainpy/_src/dyn/projections/plasticity.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -209,15 +209,21 @@ def update(self):
209209
# post spikes
210210
if not hasattr(self.refs['post'], 'spike'):
211211
raise AttributeError(f'{self} needs a "spike" variable for the post-synaptic neuron group.')
212-
post_spike = self.refs['post'].spike
212+
post_spike = self.refs['post'].spike.value
213213

214214
# weight updates
215-
Apost = self.refs['post_trace'].g
216-
self.comm.stdp_update(on_pre={"spike": pre_spike, "trace": -Apost * self.A2}, w_min=self.W_min,
217-
w_max=self.W_max)
218-
Apre = self.refs['pre_trace'].g
219-
self.comm.stdp_update(on_post={"spike": post_spike, "trace": Apre * self.A1}, w_min=self.W_min,
220-
w_max=self.W_max)
215+
Apost = self.refs['post_trace'].g.value
216+
self.comm.stdp_update(
217+
on_pre={"spike": bm.as_jax(pre_spike), "trace": bm.as_jax(-Apost * self.A2)},
218+
w_min=bm.as_jax(self.W_min),
219+
w_max=bm.as_jax(self.W_max),
220+
)
221+
Apre = self.refs['pre_trace'].g.value
222+
self.comm.stdp_update(
223+
on_post={"spike": bm.as_jax(post_spike), "trace": bm.as_jax(Apre * self.A1)},
224+
w_min=bm.as_jax(self.W_min),
225+
w_max=bm.as_jax(self.W_max),
226+
)
221227

222228
# synaptic currents
223229
current = self.comm(x)

0 commit comments

Comments
 (0)