Skip to content

Commit 73e9cde

Browse files
author
Alexander Ororbia
committed
cleaned up stdp syn
1 parent 0e254a4 commit 73e9cde

1 file changed

Lines changed: 6 additions & 5 deletions

File tree

ngclearn/components/synapses/hebbian/STDPSynapse.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from ngclearn.components.synapses import DenseSynapse
44
from ngclearn.utils import tensorstats
55

6-
class STDPSynapse(DenseSynapse): # classical STDP
6+
class STDPSynapse(DenseSynapse): # power-law / trace-based STDP
77
"""
88
A synaptic cable that adjusts its efficacies via raw
99
spike-timing-dependent plasticity (STDP).
@@ -62,7 +62,7 @@ class STDPSynapse(DenseSynapse): # classical STDP
6262
"""
6363

6464
# Define Functions
65-
def __init__(self, name, shape, A_plus, A_minus, tau_plus=10., tau_minus=10.,
65+
def __init__(self, name, shape, A_plus, A_minus, tau_plus=10., tau_minus=10., w_decay=0.,
6666
eta=1., tau_w=0., weight_init=None, resist_scale=1., p_conn=1., w_bound=1.,
6767
batch_size=1, **kwargs):
6868
super().__init__(name, shape, weight_init, None, resist_scale,
@@ -77,6 +77,7 @@ def __init__(self, name, shape, A_plus, A_minus, tau_plus=10., tau_minus=10.,
7777
self.Rscale = resist_scale ## post-transformation scale factor
7878
self.w_bound = w_bound #1. ## soft weight constraint
7979
self.tau_w = tau_w ## synaptic update time constant
80+
self.w_decay = w_decay
8081

8182
## Compartment setup
8283
preVals = jnp.zeros((self.batch_size, shape[0]))
@@ -112,7 +113,7 @@ def _compute_update(Aplus, Aminus, tau_plus, tau_minus, preSpike, postSpike,
112113
return dW
113114

114115
@staticmethod
115-
def _evolve(dt, w_bound, tau_w, Aplus, Aminus, tau_plus, tau_minus, preSpike,
116+
def _evolve(dt, w_bound, w_decay, tau_w, Aplus, Aminus, tau_plus, tau_minus, preSpike,
116117
postSpike, pre_tols, post_tols, weights, eta):
117118
dWeights = STDPSynapse._compute_update(
118119
Aplus, Aminus, tau_plus, tau_minus, preSpike, postSpike, pre_tols,
@@ -122,9 +123,9 @@ def _evolve(dt, w_bound, tau_w, Aplus, Aminus, tau_plus, tau_minus, preSpike,
122123
if tau_w > 0.: ## triggers Euler-style synaptic update
123124
weights = weights + (-weights * dt/tau_w + dWeights * eta)
124125
else: ## raw simple ascent-style update
125-
weights = weights + dWeights * eta
126+
weights = weights + dWeights * eta - weights * w_decay
126127
## enforce non-negativity
127-
eps = 0.01
128+
eps = 0.001 # 0.01
128129
weights = jnp.clip(weights, eps, w_bound - eps) # jnp.abs(w_bound))
129130
return weights, dWeights
130131

0 commit comments

Comments
 (0)