Skip to content

Commit 8fbe8e2

Browse files
author
Alexander Ororbia
committed
updates to gerstner syn, data-loader
1 parent 1a0d96f commit 8fbe8e2

2 files changed

Lines changed: 42 additions & 37 deletions

File tree

ngclearn/components/synapses/hebbian/gerstnerHebbianSynapse.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ class GerstnerHebbianSynapse(DenseSynapse):
2323
| c2_corr < 0 and c0 = c1_pre = c1_post = 0 => anti-Hebbian update
2424
| c2_corr = 1 and c1_pre = -x_theta < 0
2525
26+
| References:
27+
| Gerstner, W. and Kistler, W.M., 2002. Mathematical formulations of Hebbian
28+
| learning. Biological cybernetics, 87(5), pp.404-415.
29+
2630
"""
2731
def __init__(
2832
self,
@@ -37,7 +41,7 @@ def __init__(
3741
batch_size=1,
3842
**kwargs
3943
):
40-
bias_init = None ## no biases are included in Gerster's formulation
44+
bias_init = None ## NOTE: no biases are included in Gerster's formulation
4145
super().__init__(
4246
name,
4347
shape=shape,
@@ -48,11 +52,11 @@ def __init__(
4852
batch_size=batch_size,
4953
**kwargs
5054
)
51-
## General Hebbian meta-parameters
55+
## general Hebbian meta-parameters
5256
self.eta = eta
5357
self.sign_value = sign_value
5458

55-
## Expansion coefficients (c0, c1_pre, c1_post, c2_corr)
59+
## Gerstner and Kisler's expansion coefficients (c0, c1_pre, c1_post, c2_corr)
5660
if coeffs is None: ## Default to standard bilinear Hebb
5761
self.coeffs = {
5862
'c0': 0., 'c1_pre': 0., 'c1_post': 0., 'c2_corr': 1.0
@@ -64,50 +68,43 @@ def __init__(
6468
self.c1_post = self.coeffs['c1_post']
6569
self.c2_corr = self.coeffs['c2_corr']
6670

67-
# Initialize Weights (using JAX PRNG)
68-
#init_key, _ = random.split(self.key)
69-
#w_init = random.normal(init_key, shape) * 0.05
70-
71-
# Compartments (ngc-learn state management)
72-
#self.weights = Compartment(w_init)
71+
## set up relevant compartments
7372
self.pre = Compartment(jnp.zeros((1, shape[1])))
7473
self.post = Compartment(jnp.zeros((1, shape[0])))
74+
self.dWeights = Compartment(jnp.zeros(shape))
7575

7676
@compilable
77-
def evolve(self, **kwargs):
78-
"""
79-
Updates weights using the Gerstner general expansion.
80-
Assumes pre_act and post_act compartments have been populated.
81-
"""
82-
# Retrieve current states
77+
def evolve(self, **kwargs): ## perform update via Gerstner's general expansion
78+
## retrieve current compartment state values
8379
W = self.weights.get()
84-
x = self.pre.get() # pre-synaptic activity (batch, pre_dim)
85-
y = self.post.get() # post-synaptic activity (batch, post_dim)
80+
x = self.pre.get() ## pre-synaptic activity (batch, pre_dim)
81+
y = self.post.get() ## post-synaptic activity (batch, post_dim)
8682
batch_size = self.batch_size
8783

88-
## Bilinear Term (c2): correlation matrix
89-
### (post_dim, batch) @ (batch, pre_dim) -> (post_dim, pre_dim)
84+
## calculate bilinear Term (c2), i.e., correlation matrix
85+
### (pre_dim, batch) @ (batch, post_dim) -> (pre_dim, post_dim)
9086
dW_corr = jnp.matmul(x.T, y) * (1./batch_size)
91-
## Linear pre-synaptic term (c1_pre)
92-
### Average over batch then broadcast to match weight matrix
87+
## linear pre-synaptic term (c1_pre)
88+
### get mean over batch then broadcast to match weight matrix
9389
dW_pre = jnp.sum(x, axis=0, keepdims=True).T * (1./batch_size)
94-
## Linear post-synaptic term (c1_post)
95-
dW_post = jnp.sum(y, axis=0, keepdims=True) * (1./batch_size)
90+
## linear post-synaptic term (c1_post), mean over post-syn values
91+
dW_post = jnp.sum(y, axis=0, keepdims=True) * (1./batch_size)
9692

97-
## Apply Equation 3 Taylor expansion
93+
## apply Taylor expansion from Equation 3 (Gerstner and Kistler)
9894
dW = (self.c0 * W + ## synaptic decay
9995
self.c1_pre * dW_pre + ## bilinear term
10096
self.c1_post * dW_post + ## pre-synaptic gating term
10197
self.c2_corr * dW_corr ## post-synpatic gating term
10298
)
99+
self.dWeights.set(dW)
100+
103101
## perform a step of Hebbian ascent
104-
W = W + self.eta * dW
105-
## Update weights
102+
W = W + self.eta * dW ## update synaptic efficacies
106103
self.weights.set(W)
107104

108105
@compilable
109-
def reset(self, **kwargs):
110-
"""Clears activity compartments"""
106+
def reset(self, **kwargs): ## clear compartment values
111107
self.pre.set( jnp.zeros((self.batch_size, self.shape[1])) )
112108
self.post.set( jnp.zeros((self.batch_size, self.shape[0])) )
109+
self.dWeights.set(self.dWeights.get() * 0)
113110

ngclearn/utils/data_loader.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,14 @@ class DataLoader(object):
2525
key: PRNG key to control determinism of any underlying random values
2626
associated with this synaptic cable
2727
"""
28-
def __init__(self, design_matrices, batch_size, disable_shuffle=False,
29-
ensure_equal_batches=True, key=None):
28+
def __init__(
29+
self,
30+
design_matrices,
31+
batch_size,
32+
disable_shuffle=False,
33+
ensure_equal_batches=True,
34+
key=None
35+
):
3036
self.key = key
3137
if self.key is None:
3238
self.key = random.PRNGKey(time.time_ns())
@@ -47,23 +53,25 @@ def __init__(self, design_matrices, batch_size, disable_shuffle=False,
4753

4854
def __iter__(self):
4955
"""
50-
Yields a mini-batch of the form: [("name", batch),("name",batch),...]
56+
Yields a mini-batch of the form:
57+
58+
| batch = [("name", batchx), ("name", batchy),...("name", batchz)]
5159
"""
52-
if self.disable_shuffle == False:
60+
if not self.disable_shuffle: #self.disable_shuffle == False:
5361
self.key, *subkeys = random.split(self.key, 2)
5462
self.ptrs = random.permutation(subkeys[0], self.data_len)
5563
idx = 0
56-
while idx < len(self.ptrs): # go through each sample via the sampling pointer
64+
while idx < len(self.ptrs): ## go through each sample via the sampling pointer
5765
e_idx = idx + self.batch_size
58-
if e_idx > len(self.ptrs): # prevents reaching beyond length of dataset
66+
if e_idx > len(self.ptrs): ## prevents reaching beyond length of dataset
5967
e_idx = len(self.ptrs)
60-
# extract sampling integer pointers
68+
## extract sampling integer pointers
6169
indices = self.ptrs[idx:e_idx]
62-
if self.ensure_equal_batches == True:
70+
if self.ensure_equal_batches: # == True:
6371
if indices.shape[0] < self.batch_size:
6472
diff = self.batch_size - indices.shape[0]
6573
indices = jnp.concatenate((indices, self.ptrs[0:diff]))
66-
# create the actual pattern vector batch block matrices
74+
## create the actual pattern vector batch block matrices
6775
data_batch = []
6876
for dname, dmatrix in self.design_matrices:
6977
x_batch = dmatrix[indices]

0 commit comments

Comments
 (0)