33from ngclearn .components .synapses import DenseSynapse
44from ngclearn .utils import tensorstats
55
6- class STDPSynapse (DenseSynapse ): # classical STDP
6+ class STDPSynapse (DenseSynapse ): # power-law / trace-based STDP
77 """
88 A synaptic cable that adjusts its efficacies via raw
99 spike-timing-dependent plasticity (STDP).
@@ -62,7 +62,7 @@ class STDPSynapse(DenseSynapse): # classical STDP
6262 """
6363
6464 # Define Functions
65- def __init__ (self , name , shape , A_plus , A_minus , tau_plus = 10. , tau_minus = 10. ,
65+ def __init__ (self , name , shape , A_plus , A_minus , tau_plus = 10. , tau_minus = 10. , w_decay = 0. ,
6666 eta = 1. , tau_w = 0. , weight_init = None , resist_scale = 1. , p_conn = 1. , w_bound = 1. ,
6767 batch_size = 1 , ** kwargs ):
6868 super ().__init__ (name , shape , weight_init , None , resist_scale ,
@@ -77,6 +77,7 @@ def __init__(self, name, shape, A_plus, A_minus, tau_plus=10., tau_minus=10.,
7777 self .Rscale = resist_scale ## post-transformation scale factor
7878 self .w_bound = w_bound #1. ## soft weight constraint
7979 self .tau_w = tau_w ## synaptic update time constant
80+ self .w_decay = w_decay
8081
8182 ## Compartment setup
8283 preVals = jnp .zeros ((self .batch_size , shape [0 ]))
@@ -112,7 +113,7 @@ def _compute_update(Aplus, Aminus, tau_plus, tau_minus, preSpike, postSpike,
112113 return dW
113114
114115 @staticmethod
115- def _evolve (dt , w_bound , tau_w , Aplus , Aminus , tau_plus , tau_minus , preSpike ,
116+ def _evolve (dt , w_bound , w_decay , tau_w , Aplus , Aminus , tau_plus , tau_minus , preSpike ,
116117 postSpike , pre_tols , post_tols , weights , eta ):
117118 dWeights = STDPSynapse ._compute_update (
118119 Aplus , Aminus , tau_plus , tau_minus , preSpike , postSpike , pre_tols ,
@@ -122,9 +123,9 @@ def _evolve(dt, w_bound, tau_w, Aplus, Aminus, tau_plus, tau_minus, preSpike,
122123 if tau_w > 0. : ## triggers Euler-style synaptic update
123124 weights = weights + (- weights * dt / tau_w + dWeights * eta )
124125 else : ## raw simple ascent-style update
125- weights = weights + dWeights * eta
126+ weights = weights + dWeights * eta - weights * w_decay
126127 ## enforce non-negativity
127- eps = 0.01
128+ eps = 0.001 # 0. 01
128129 weights = jnp .clip (weights , eps , w_bound - eps ) # jnp.abs(w_bound))
129130 return weights , dWeights
130131
0 commit comments