1919import jax
2020import jax .numpy as jnp
2121import numpy as np
22- from brainevent import csr_on_pre , csr2csc_on_post
23- from brainevent import dense_on_pre , dense_on_post
22+ from brainevent import (
23+ update_csr_on_binary_pre ,
24+ update_csr_on_binary_post ,
25+ update_dense_on_binary_pre ,
26+ update_dense_on_binary_post ,
27+ )
2428
2529from brainpy import connect , initialize as init
2630from brainpy import math as bm
@@ -226,11 +230,11 @@ def stdp_update(
226230 if on_pre is not None :
227231 spike = on_pre ['spike' ]
228232 trace = on_pre ['trace' ]
229- self .W .value = dense_on_pre (self .W .value , spike , trace , w_min , w_max )
233+ self .W .value = update_dense_on_binary_pre (self .W .value , spike , trace , w_min , w_max )
230234 if on_post is not None :
231235 spike = on_post ['spike' ]
232236 trace = on_post ['trace' ]
233- self .W .value = dense_on_post (self .W .value , trace , spike , w_min , w_max )
237+ self .W .value = update_dense_on_binary_post (self .W .value , trace , spike , w_min , w_max )
234238
235239
236240Linear = Dense
@@ -321,11 +325,11 @@ def stdp_update(
321325 if on_pre is not None :
322326 spike = on_pre ['spike' ]
323327 trace = on_pre ['trace' ]
324- self .weight .value = dense_on_pre (self .weight .value , spike , trace , w_min , w_max )
328+ self .weight .value = update_dense_on_binary_pre (self .weight .value , spike , trace , w_min , w_max )
325329 if on_post is not None :
326330 spike = on_post ['spike' ]
327331 trace = on_post ['trace' ]
328- self .weight .value = dense_on_post (self .weight .value , trace , spike , w_min , w_max )
332+ self .weight .value = update_dense_on_binary_post (self .weight .value , trace , spike , w_min , w_max )
329333
330334
331335class OneToOne (Layer , SupportSTDP ):
@@ -449,11 +453,11 @@ def stdp_update(
449453 if on_pre is not None :
450454 spike = on_pre ['spike' ]
451455 trace = on_pre ['trace' ]
452- self .weight .value = dense_on_pre (self .weight .value , spike , trace , w_min , w_max )
456+ self .weight .value = update_dense_on_binary_pre (self .weight .value , spike , trace , w_min , w_max )
453457 if on_post is not None :
454458 spike = on_post ['spike' ]
455459 trace = on_post ['trace' ]
456- self .weight .value = dense_on_post (self .weight .value , trace , spike , w_min , w_max )
460+ self .weight .value = update_dense_on_binary_post (self .weight .value , trace , spike , w_min , w_max )
457461
458462
459463class _CSRLayer (Layer , SupportSTDP ):
@@ -500,7 +504,7 @@ def stdp_update(
500504 if on_pre is not None : # update on presynaptic spike
501505 spike = on_pre ['spike' ]
502506 trace = on_pre ['trace' ]
503- self .weight .value = csr_on_pre (
507+ self .weight .value = update_csr_on_binary_pre (
504508 self .weight .value , self .indices , self .indptr , spike , trace , w_min , w_max ,
505509 shape = (spike .shape [0 ], trace .shape [0 ]),
506510 )
@@ -512,7 +516,7 @@ def stdp_update(
512516 )
513517 spike = on_post ['spike' ]
514518 trace = on_post ['trace' ]
515- self .weight .value = csr2csc_on_post (
519+ self .weight .value = update_csr_on_binary_post (
516520 self .weight .value , self ._pre_ids , self ._post_indptr ,
517521 self .w_indices , trace , spike , w_min , w_max ,
518522 shape = (trace .shape [0 ], spike .shape [0 ]),
0 commit comments