88import jax .numpy as jnp
99import numpy as np
1010
11+ from brainevent ._plasticity_dense import dense_on_pre , dense_on_post
12+ from brainevent ._plasticity_csr import csr_on_pre , csr2csc_on_post
1113from brainpy import math as bm
1214from brainpy ._src import connect , initialize as init
1315from brainpy ._src .context import share
@@ -207,10 +209,8 @@ def stdp_update(
207209 w_min : numbers .Number = None ,
208210 w_max : numbers .Number = None
209211 ):
210- if isinstance (self .W , float ):
211- raise ValueError (f'Cannot update the weight of a constant node.' )
212212 if not isinstance (self .W , bm .Variable ):
213- self . tracing_variable ( 'W' , self . W , self . W . shape )
213+ raise ValueError ( f'When using STDP to update synaptic weights, the weight must be a variable.' )
214214 if on_pre is not None :
215215 spike = on_pre ['spike' ]
216216 trace = on_pre ['trace' ]
@@ -235,124 +235,6 @@ def update(self, x):
235235 return x
236236
237237
238- if False :
239-
240- # @numba.njit(nogil=True, fastmath=True, parallel=False)
241- # def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w):
242- # out_w[:] = weight
243- # for i in numba.prange(spike.shape[0]):
244- # if spike[i]:
245- # out_w[:, i] = np.clip(out_w[:, i] + trace, w_min, w_max)
246-
247- @ti .kernel
248- def _dense_on_post (
249- old_w : ti .types .ndarray (ndim = 2 ),
250- post_spike : ti .types .ndarray (ndim = 1 ),
251- pre_trace : ti .types .ndarray (ndim = 1 ),
252- w_min : ti .types .ndarray (ndim = 1 ),
253- w_max : ti .types .ndarray (ndim = 1 ),
254- out_w : ti .types .ndarray (ndim = 2 )
255- ):
256- w_min0 = w_min [0 ]
257- w_max0 = w_max [0 ]
258- num_pre , num_post = out_w .shape
259-
260- for i , j in ti .ndrange (num_pre , num_post ):
261- if post_spike [j ]:
262- new_value = out_w [i , j ] + pre_trace [i ]
263- if new_value < w_min0 :
264- out_w [i , j ] = w_min0
265- elif new_value > w_max0 :
266- out_w [i , j ] = w_max0
267- else :
268- out_w [i , j ] = new_value
269- else :
270- out_w [i , j ] = old_w [i , j ]
271-
272-
273- dense_on_post_prim = bti .XLACustomOp (cpu_kernel = _dense_on_post , gpu_kernel = _dense_on_post )
274-
275-
276- # @numba.njit(nogil=True, fastmath=True, parallel=False)
277- # def _cpu_dense_on_pre(weight, spike, trace, w_min, w_max, out_w):
278- # out_w[:] = weight
279- # for i in numba.prange(spike.shape[0]):
280- # if spike[i]:
281- # out_w[i] = np.clip(out_w[i] + trace, w_min, w_max)
282-
283- @ti .kernel
284- def _dense_on_pre (
285- old_w : ti .types .ndarray (ndim = 2 ),
286- pre_spike : ti .types .ndarray (ndim = 1 ),
287- post_trace : ti .types .ndarray (ndim = 1 ),
288- w_min : ti .types .ndarray (ndim = 1 ),
289- w_max : ti .types .ndarray (ndim = 1 ),
290- out_w : ti .types .ndarray (ndim = 2 )
291- ):
292- w_min0 = w_min [0 ]
293- w_max0 = w_max [0 ]
294- num_pre , num_post = out_w .shape
295-
296- for i , j in ti .ndrange (num_pre , num_post ):
297- if pre_spike [i ]:
298- new_value = out_w [i , j ] + post_trace [j ]
299- if new_value < w_min0 :
300- out_w [i , j ] = w_min0
301- elif new_value > w_max0 :
302- out_w [i , j ] = w_max0
303- else :
304- out_w [i , j ] = new_value
305- else :
306- out_w [i , j ] = old_w [i , j ]
307-
308-
309- dense_on_pre_prim = bti .XLACustomOp (cpu_kernel = _dense_on_pre , gpu_kernel = _dense_on_pre )
310-
311- else :
312- dense_on_pre_prim = None
313- dense_on_post_prim = None
314-
315-
316- def dense_on_pre (weight , spike , trace , w_min , w_max ):
317- if dense_on_pre_prim is None :
318- raise_braintaichi_not_found ()
319-
320- if w_min is None :
321- w_min = - np .inf
322- if w_max is None :
323- w_max = np .inf
324- w_min = jnp .atleast_1d (w_min )
325- w_max = jnp .atleast_1d (w_max )
326-
327- weight = bm .as_jax (weight )
328- spike = bm .as_jax (spike )
329- trace = bm .as_jax (trace )
330- w_min = bm .as_jax (w_min )
331- w_max = bm .as_jax (w_max )
332- return dense_on_pre_prim (weight , spike , trace , w_min , w_max ,
333- outs = [jax .ShapeDtypeStruct (weight .shape , weight .dtype )])[0 ]
334-
335-
336- def dense_on_post (weight , spike , trace , w_min , w_max ):
337- if dense_on_post_prim is None :
338- raise_braintaichi_not_found ()
339-
340- if w_min is None :
341- w_min = - np .inf
342- if w_max is None :
343- w_max = np .inf
344- w_min = jnp .atleast_1d (w_min )
345- w_max = jnp .atleast_1d (w_max )
346-
347- weight = bm .as_jax (weight )
348- spike = bm .as_jax (spike )
349- trace = bm .as_jax (trace )
350- w_min = bm .as_jax (w_min )
351- w_max = bm .as_jax (w_max )
352- return dense_on_post_prim (weight , spike , trace , w_min , w_max ,
353- outs = [jax .ShapeDtypeStruct (weight .shape , weight .dtype )])[0 ]
354-
355-
356238class AllToAll (Layer , SupportSTDP ):
357239 """Synaptic matrix multiplication with All2All connections.
358240
@@ -608,8 +490,10 @@ def stdp_update(
608490 if on_pre is not None : # update on presynaptic spike
609491 spike = on_pre ['spike' ]
610492 trace = on_pre ['trace' ]
611- self .weight .value = csr_on_pre_update (self .weight .value , self .indices , self .indptr , spike , trace , w_min ,
612- w_max )
493+ self .weight .value = csr_on_pre (
494+ self .weight .value , self .indices , self .indptr , spike , trace , w_min , w_max ,
495+ shape = (spike .shape [0 ], trace .shape [0 ]),
496+ )
613497 if on_post is not None : # update on postsynaptic spike
614498 if not hasattr (self , '_pre_ids' ):
615499 with jax .ensure_compile_time_eval ():
@@ -618,8 +502,11 @@ def stdp_update(
618502 )
619503 spike = on_post ['spike' ]
620504 trace = on_post ['trace' ]
621- self .weight .value = csc_on_post_update (self .weight .value , self ._pre_ids , self ._post_indptr ,
622- self .w_indices , spike , trace , w_min , w_max )
505+ self .weight .value = csr2csc_on_post (
506+ self .weight .value , self ._pre_ids , self ._post_indptr ,
507+ self .w_indices , spike , trace , w_min , w_max ,
508+ shape = (trace .shape [0 ], spike .shape [0 ]),
509+ )
623510
624511
625512class CSRLinear (_CSRLayer ):
@@ -722,196 +609,6 @@ def _batch_csrmv(self, x):
722609 transpose = self .transpose )
723610
724611
725- if False :
726- @ti .kernel
727- def _csr_on_pre_update (
728- old_w : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_syn)
729- indices : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_syn)
730- indptr : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_pre + 1)
731- spike : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_pre,)
732- trace : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_post,)
733- w_min : ti .types .ndarray (ndim = 1 ), # scalar
734- w_max : ti .types .ndarray (ndim = 1 ), # scalar
735- out_w : ti .types .ndarray (ndim = 1 ) # vector with shape of (num_syn)
736- ):
737- w_min0 = w_min [0 ]
738- w_max0 = w_max [0 ]
739- num_pre = spike .shape [0 ]
740- for i_pre in range (num_pre ):
741- if spike [i_pre ]:
742- for i_syn in range (indptr [i_pre ], indptr [i_pre + 1 ]):
743- out_w [i_syn ] = min (max (old_w [i_syn ] + trace [indices [i_syn ]], w_min0 ), w_max0 )
744- else :
745- for i_syn in range (indptr [i_pre ], indptr [i_pre + 1 ]):
746- out_w [i_syn ] = old_w [i_syn ]
747-
748-
749- csr_on_pre_update_prim = bti .XLACustomOp (cpu_kernel = _csr_on_pre_update , gpu_kernel = _csr_on_pre_update )
750-
751-
752- @ti .kernel
753- def _coo_on_pre_update (
754- old_w : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_syn)
755- pre_ids : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_syn)
756- post_ids : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_syn)
757- pre_spike : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_pre,)
758- post_trace : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_post,)
759- w_min : ti .types .ndarray (ndim = 1 ), # scalar
760- w_max : ti .types .ndarray (ndim = 1 ), # scalar
761- out_w : ti .types .ndarray (ndim = 1 ) # vector with shape of (num_syn)
762- ):
763- w_min0 = w_min [0 ]
764- w_max0 = w_max [0 ]
765- num_syn = old_w .shape [0 ]
766- for i_syn in range (num_syn ):
767- if pre_spike [pre_ids [i_syn ]]: # pre spike
768- out_w [i_syn ] = min (max (old_w [i_syn ] + post_trace [post_ids [i_syn ]], w_min0 ), w_max0 )
769- else :
770- out_w [i_syn ] = old_w [i_syn ]
771-
772-
773- coo_on_pre_update_prim = bti .XLACustomOp (cpu_kernel = _coo_on_pre_update , gpu_kernel = _coo_on_pre_update )
774-
775-
776- @ti .kernel
777- def _coo_on_post_update (
778- old_w : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_syn)
779- pre_ids : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_syn)
780- post_ids : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_syn)
781- post_spike : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_pre,)
782- pre_trace : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_post,)
783- w_min : ti .types .ndarray (ndim = 1 ), # scalar
784- w_max : ti .types .ndarray (ndim = 1 ), # scalar
785- out_w : ti .types .ndarray (ndim = 1 ) # vector with shape of (num_syn)
786- ):
787- w_min0 = w_min [0 ]
788- w_max0 = w_max [0 ]
789- num_syn = old_w .shape [0 ]
790- for i_syn in range (num_syn ):
791- if post_spike [post_ids [i_syn ]]: # pre spike
792- out_w [i_syn ] = min (max (old_w [i_syn ] + pre_trace [pre_ids [i_syn ]], w_min0 ), w_max0 )
793- else :
794- out_w [i_syn ] = old_w [i_syn ]
795-
796-
797- coo_on_post_update_prim = bti .XLACustomOp (cpu_kernel = _coo_on_post_update , gpu_kernel = _coo_on_post_update )
798-
799-
800- # @numba.njit(nogil=True, fastmath=True, parallel=False)
801- # def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w):
802- # out_w[:] = w
803- # w_min = w_min[()]
804- # w_max = w_max[()]
805- # for i in numba.prange(spike.shape[0]): # post id
806- # if spike[i]:
807- # for k in range(indptr[i], indptr[i + 1]):
808- # j = post_ids[k] # pre id
809- # l = w_ids[k] # syn id
810- # out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max)
811-
812- @ti .kernel
813- def _csc_on_post_update (
814- old_w : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_syn)
815- indices : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_syn)
816- indptr : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_post + 1)
817- w_ids : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_syn)
818- post_spike : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_post,)
819- pre_trace : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_pre,)
820- w_min : ti .types .ndarray (ndim = 1 ), # scalar
821- w_max : ti .types .ndarray (ndim = 1 ), # scalar
822- out_w : ti .types .ndarray (ndim = 1 ), # vector with shape of (num_syn)
823- ):
824- w_min0 = w_min [0 ]
825- w_max0 = w_max [0 ]
826- num_post = post_spike .shape [0 ]
827- for i_post in range (num_post ):
828- if post_spike [i_post ]:
829- for k in range (indptr [i_post ], indptr [i_post + 1 ]):
830- i_syn = w_ids [k ] # syn id
831- out_w [i_syn ] = min (max (old_w [i_syn ] + pre_trace [indices [k ]], w_min0 ), w_max0 )
832- else :
833- for k in range (indptr [i_post ], indptr [i_post + 1 ]):
834- i_syn = w_ids [k ] # syn id
835- out_w [i_syn ] = old_w [i_syn ]
836-
837-
838- csc_on_post_update_prim = bti .XLACustomOp (cpu_kernel = _csc_on_post_update , gpu_kernel = _csc_on_post_update )
839-
840-
841- else :
842- csr_on_pre_update_prim = None
843- coo_on_pre_update_prim = None
844- csc_on_post_update_prim = None
845-
846-
847- def csr_on_pre_update (w , indices , indptr , spike , trace , w_min = None , w_max = None ):
848- if csr_on_pre_update_prim is None :
849- raise_braintaichi_not_found ()
850-
851- if w_min is None :
852- w_min = - np .inf
853- if w_max is None :
854- w_max = np .inf
855- w_min = jnp .atleast_1d (w_min )
856- w_max = jnp .atleast_1d (w_max )
857-
858- w = bm .as_jax (w )
859- indices = bm .as_jax (indices )
860- indptr = bm .as_jax (indptr )
861- spike = bm .as_jax (spike )
862- trace = bm .as_jax (trace )
863- w_min = bm .as_jax (w_min )
864- w_max = bm .as_jax (w_max )
865- return csr_on_pre_update_prim (w , indices , indptr , spike , trace , w_min , w_max ,
866- outs = [jax .ShapeDtypeStruct (w .shape , w .dtype )])[0 ]
867-
868-
869- def coo_on_pre_update (w , pre_ids , post_ids , spike , trace , w_min = None , w_max = None ):
870- if coo_on_pre_update_prim is None :
871- raise_braintaichi_not_found ()
872-
873- if w_min is None :
874- w_min = - np .inf
875- if w_max is None :
876- w_max = np .inf
877- w_min = jnp .atleast_1d (w_min )
878- w_max = jnp .atleast_1d (w_max )
879-
880- w = bm .as_jax (w )
881- pre_ids = bm .as_jax (pre_ids )
882- post_ids = bm .as_jax (post_ids )
883- spike = bm .as_jax (spike )
884- trace = bm .as_jax (trace )
885- w_min = bm .as_jax (w_min )
886- w_max = bm .as_jax (w_max )
887-
888- return coo_on_pre_update_prim (w , pre_ids , post_ids , spike , trace , w_min , w_max ,
889- outs = [jax .ShapeDtypeStruct (w .shape , w .dtype )])[0 ]
890-
891-
892- def csc_on_post_update (w , post_ids , indptr , w_ids , post_spike , pre_trace , w_min = None , w_max = None ):
893- if csc_on_post_update_prim is None :
894- raise_braintaichi_not_found ()
895-
896- if w_min is None :
897- w_min = - np .inf
898- if w_max is None :
899- w_max = np .inf
900- w_min = jnp .atleast_1d (w_min )
901- w_max = jnp .atleast_1d (w_max )
902-
903- w = bm .as_jax (w )
904- post_ids = bm .as_jax (post_ids )
905- indptr = bm .as_jax (indptr )
906- w_ids = bm .as_jax (w_ids )
907- post_spike = bm .as_jax (post_spike )
908- pre_trace = bm .as_jax (pre_trace )
909- w_min = bm .as_jax (w_min )
910- w_max = bm .as_jax (w_max )
911- return csc_on_post_update_prim (w , post_ids , indptr , w_ids , post_spike , pre_trace , w_min , w_max ,
912- outs = [jax .ShapeDtypeStruct (w .shape , w .dtype )])[0 ]
913-
914-
915612class CSCLinear (Layer ):
916613 r"""Synaptic matrix multiplication with CSC sparse computation.
917614
0 commit comments