33from ngclearn .components .synapses import DenseSynapse
44from ngclearn .utils import tensorstats
55
6- class STDPSynapse (DenseSynapse ): # power-law / trace-based STDP
6+ class STDPSynapse (DenseSynapse ): # classical STDP
77 """
88 A synaptic cable that adjusts its efficacies via raw
99 spike-timing-dependent plasticity (STDP).
@@ -92,17 +92,20 @@ def __init__(self, name, shape, A_plus, A_minus, tau_plus=10., tau_minus=10.,
9292 def _compute_update (Aplus , Aminus , tau_plus , tau_minus , preSpike , postSpike ,
9393 pre_tols , post_tols , weights ):
9494 ## calculate time deltas matrix block --> (t_post - t_pre)
95- post_m = (post_tols > 0. ) #* 1.
96- pre_m = (pre_tols > 0. ).T # * 1.
95+ post_m = (post_tols > 0. ) ## zero post-tols mask
96+ pre_m = (pre_tols > 0. ).T ## zero pre-tols mask
9797 t_delta = ((weights * 0 + 1. ) * post_tols ) - pre_tols .T ## t_delta.shape = weights.shape
98- t_delta = t_delta * post_m * pre_m ## mask out zero tols
98+ t_delta = t_delta * post_m * pre_m ## mask out zero tols and same-time spikes
99+ pos_t_delta_m = (t_delta > 0. ) ## positive t-delta mask
100+ neg_t_delta_m = (t_delta < 0. ) ## negative t-delta mask
101+ #t_delta = t_delta * pos_t_delta_m + t_delta * neg_t_delta_m ## mask out same time spikes
99102 ## calculate post-synaptic term
100- postTerm = jnp .exp (- t_delta / tau_plus ) * ( t_delta > 0. ) #* post_m * pre_m
103+ postTerm = jnp .exp (- t_delta / tau_plus ) * pos_t_delta_m
101104 dWpost = postTerm * (postSpike * Aplus )
102105 dWpre = 0.
103106 if Aminus > 0. :
104107 ## calculate pre-synaptic term
105- preTerm = jnp .exp (- t_delta / tau_minus ) * ( t_delta < 0. ) #* post_m * pre_m
108+ preTerm = jnp .exp (- t_delta / tau_minus ) * neg_t_delta_m
106109 dWpre = - preTerm * (preSpike .T * Aminus )
107110 ## calc final weighted adjustment
108111 dW = (dWpost + dWpre )
@@ -121,7 +124,7 @@ def _evolve(dt, w_bound, tau_w, Aplus, Aminus, tau_plus, tau_minus, preSpike,
121124 else : ## raw simple ascent-style update
122125 weights = weights + dWeights * eta
123126 ## enforce non-negativity
124- eps = 0.001 # 0. 01
127+ eps = 0.01
125128 weights = jnp .clip (weights , eps , w_bound - eps ) # jnp.abs(w_bound))
126129 return weights , dWeights
127130
0 commit comments