Skip to content

Commit 0e254a4

Browse files
author
Alexander Ororbia
committed
cleaned up stdp syn
1 parent db15423 commit 0e254a4

1 file changed

Lines changed: 10 additions & 7 deletions

File tree

ngclearn/components/synapses/hebbian/STDPSynapse.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from ngclearn.components.synapses import DenseSynapse
44
from ngclearn.utils import tensorstats
55

6-
class STDPSynapse(DenseSynapse): # power-law / trace-based STDP
6+
class STDPSynapse(DenseSynapse): # classical STDP
77
"""
88
A synaptic cable that adjusts its efficacies via raw
99
spike-timing-dependent plasticity (STDP).
@@ -92,17 +92,20 @@ def __init__(self, name, shape, A_plus, A_minus, tau_plus=10., tau_minus=10.,
9292
def _compute_update(Aplus, Aminus, tau_plus, tau_minus, preSpike, postSpike,
9393
pre_tols, post_tols, weights):
9494
## calculate time deltas matrix block --> (t_post - t_pre)
95-
post_m = (post_tols > 0.) #* 1.
96-
pre_m = (pre_tols > 0.).T # * 1.
95+
post_m = (post_tols > 0.) ## zero post-tols mask
96+
pre_m = (pre_tols > 0.).T ## zero pre-tols mask
9797
t_delta = ((weights * 0 + 1.) * post_tols) - pre_tols.T ## t_delta.shape = weights.shape
98-
t_delta = t_delta * post_m * pre_m ## mask out zero tols
98+
t_delta = t_delta * post_m * pre_m ## mask out zero tols and same-time spikes
99+
pos_t_delta_m = (t_delta > 0.) ## positive t-delta mask
100+
neg_t_delta_m = (t_delta < 0.) ## negative t-delta mask
101+
#t_delta = t_delta * pos_t_delta_m + t_delta * neg_t_delta_m ## mask out same time spikes
99102
## calculate post-synaptic term
100-
postTerm = jnp.exp(-t_delta/tau_plus) * (t_delta > 0.) #* post_m * pre_m
103+
postTerm = jnp.exp(-t_delta/tau_plus) * pos_t_delta_m
101104
dWpost = postTerm * (postSpike * Aplus)
102105
dWpre = 0.
103106
if Aminus > 0.:
104107
## calculate pre-synaptic term
105-
preTerm = jnp.exp(-t_delta / tau_minus) * (t_delta < 0.) #* post_m * pre_m
108+
preTerm = jnp.exp(-t_delta / tau_minus) * neg_t_delta_m
106109
dWpre = -preTerm * (preSpike.T * Aminus)
107110
## calc final weighted adjustment
108111
dW = (dWpost + dWpre)
@@ -121,7 +124,7 @@ def _evolve(dt, w_bound, tau_w, Aplus, Aminus, tau_plus, tau_minus, preSpike,
121124
else: ## raw simple ascent-style update
122125
weights = weights + dWeights * eta
123126
## enforce non-negativity
124-
eps = 0.001 # 0.01
127+
eps = 0.01
125128
weights = jnp.clip(weights, eps, w_bound - eps) # jnp.abs(w_bound))
126129
return weights, dWeights
127130

0 commit comments

Comments
 (0)