Skip to content

Commit be47ab0

Browse files
author
Alexander Ororbia
committed
updates to art2a, cleanup of probes
1 parent 49ce5b9 commit be47ab0

File tree

6 files changed

+115
-34
lines changed

6 files changed

+115
-34
lines changed

ngclearn/components/synapses/competitive/ART2ASynapse.py

Lines changed: 61 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from jax import random, numpy as jnp, jit
1+
from jax import random, numpy as jnp, jit, nn
22
from functools import partial
33
from ngclearn import compilable
44
from ngclearn import Compartment
@@ -69,7 +69,7 @@ def __init__(
6969
name,
7070
shape, ## determines memory matrix size
7171
eta=0.05, ## learning rate
72-
eta_decrement=0.,
72+
eta_decrement=0., ## linear scheduled decrement over eta
7373
vigilance=0.3, ## vigilance parameter (rho)
7474
weight_init=None,
7575
resist_scale=1.,
@@ -96,34 +96,64 @@ def __init__(
9696
self.i_tick = Compartment(jnp.zeros((1, 1)))
9797
#self.bmu = Compartment(jnp.zeros((1, 1)), display_name="Best matching unit mask")
9898
self.dWeights = Compartment(self.weights.get() * 0)
99-
self.misses = Compartment(jnp.zeros((batch_size, 1)))
99+
self.misses = Compartment(jnp.zeros((batch_size, 1))) ## marker for non-resonant patterns in a batch
100+
101+
self.weights.set(self.weights.get() * 0)
102+
self.used = Compartment(jnp.zeros((1, shape[1]))) ## marks if memory slot used
103+
104+
def insert(self, x, idx): ## manual memory insertion co-routine
105+
W = self.weights.get()
106+
z_m = jnp.expand_dims(nn.one_hot(idx, W.shape[1]), axis=0)
107+
dW = (W * 0 + x.T) * z_m
108+
W = W + dW
109+
self.weights.set(W)
110+
self.used.set(((self.used.get() + z_m) > 0.) * 1.)
111+
112+
def grow(self, n_memories): ## grow out memory matrix by fixed amount
113+
W = self.weights.get()
114+
used = self.used.get()
115+
## expand memory matrix by a fixed set of empty memory slots
116+
W = jnp.concat([W, jnp.zeros((W.shape[0], n_memories))], axis=1)
117+
n_unused = jnp.zeros((1, n_memories))
118+
used = jnp.concat([used, n_unused], axis=1)
119+
#print("used: ", used.shape)
120+
self.used.set(used)
121+
self.weights.set(W)
122+
self.dWeights.set(W * 0)
123+
self.shape = self.weights.get().shape
100124

101-
#@compilable
102-
def consolidate(self, wipe_mem=False): ## memory storage/consolidation routine
103-
## note that this co-routine needs to be non-compilable/non-jit-i-fied, as
104-
## its main purpose is to structurally alter the memory matrix W (dynamically)
105-
x_in = self.inputs.get()
106-
#xprobe = self.xprobe.get()
107-
xprobe = _normalize(x_in, norm_fx=self.norm_fx)
108-
if wipe_mem is False:
109-
W = self.weights.get() ## get current memory matrix
110-
miss_mask = self.misses.get()
111-
if jnp.sum(miss_mask) > 0: ## for non-resonant patterns
112-
r, c = jnp.nonzero(miss_mask)
113-
mem = xprobe[r, :]
114-
W = jnp.concat([W, mem.T], axis=1)
115-
self.weights.set(W)
116-
else:
117-
self.weights.set(xprobe.T)
125+
@compilable
126+
def consolidate(self): ## memory consolition co-routine (for non-resonant signals)
127+
n_used = int(jnp.sum(self.used.get())) ## number unused slots left
128+
x = self.xprobe.get()
129+
W = self.weights.get()
130+
nonresonants = self.misses.get()
131+
## we project non-resonant memories to empty slots in memory W
132+
S = jnp.eye(x.shape[0], self.shape[1], k=n_used)
133+
dWstore = jnp.matmul((x * nonresonants).T, S)
134+
W = W + dWstore ## Hebbian update to memory
135+
## re-compute number of used slots post-consolidation
136+
nW = jnp.linalg.norm(W, ord=2, axis=0, keepdims=True)
137+
used = (nW > 0.) * 1
138+
139+
self.weights.set(W)
140+
self.used.set(used)
141+
## else, currently discard un-absorbed/non-resonant patterns
142+
## can add a function that "grows" out block matrix by a chunk (to control growth)
143+
## TODO: add pruning mechanism for low-usage slots
118144

119145
@compilable
120146
def advance_state(self): ## forward-inference step of ART2A
121147
x_in = self.inputs.get()
122148
W = self.weights.get() ## get (transposed) memory matrix
149+
used = self.used.get()
123150

124151
x = _normalize(x_in, norm_fx=self.norm_fx)
125152
self.xprobe.set(x)
126153
sims = jnp.matmul(x, W) ## compute similarities (parallel dot products)
154+
## we correct activities by masking out unused slots
155+
sims_min = jnp.amin(sims, axis=1, keepdims=True)
156+
sims = sims * used + (1. - used) * (sims_min - 1.)
127157
z_winners = sims * bkwta(sims, nWTA=self.K) ## get winner mask (hidden layer)
128158
self.outputs.set(z_winners)
129159

@@ -142,15 +172,17 @@ def evolve(self, t, dt): ## competitive Hebbian update step of ART2A
142172
wnew = (-jnp.matmul(z_winners, W.T) + x) * m ## B x D
143173
dW = jnp.matmul(wnew.T, hits) ## D x Z ## adjustment matrix
144174
W = W + dW * eta ## D x Z ## do a step of Hebbian ascent
175+
nonresonants = 1. - m ## mark non-resonant patterns in batch
145176

146177
## NOTE: is this post-weight-update normalization needed?
147-
nW = jnp.linalg.norm(W, ord=2, axis=0, keepdims=True)
148-
mz = (jnp.sum(hits, axis=0, keepdims=True) > 0.) * 1.
149-
W = W / (nW * mz + (1. - mz))
150-
151-
self.weights.set(W) ## memory matrix advances forward to new state
152-
self.misses.set(1. - m) ## store unused/non-resonant pattern mask
153-
178+
#nW = jnp.linalg.norm(W, ord=2, axis=0, keepdims=True)
179+
#used = (nW > 0.) * 1
180+
#mz = (jnp.sum(hits, axis=0, keepdims=True) > 0.) * 1.
181+
#W = W / (nW * mz + (1. - mz))
182+
183+
self.weights.set(W)
184+
self.misses.set(nonresonants) ## store unused/non-resonant pattern mas
185+
154186
#tmp_key, *subkeys = random.split(self.key.get(), 3)
155187
#self.key.set(tmp_key)
156188
## synaptic update noise
@@ -171,8 +203,8 @@ def reset(self):
171203
self.inputs.set(preVals)
172204
self.outputs.set(postVals)
173205
self.xprobe.set(preVals)
174-
self.misses.set(jnp.zeros((self.batch_size.get(), 1)))
175-
206+
#self.misses.set(jnp.zeros((self.batch_size.get(), 1)))
207+
self.misses.set(self.misses.get() * 0)
176208
self.dWeights.set(jnp.zeros(self.shape.get()))
177209

178210
@classmethod
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
11
from .SOMSynapse import SOMSynapse
22
from .hopfieldSynapse import HopfieldSynapse
3+
from .vectorQuantizeSynapse import VectorQuantizeSynapse ## LVQ
4+
from .ART2ASynapse import ART2ASynapse ## ART for contus inputs
5+
## NOTE: add in ART1Synapse for processing binary/pulse values
6+
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
## point to supported analysis probes
22
from .linear_probe import LinearProbe
33
from .attentive_probe import AttentiveProbe
4+
from .knn_probe import KNNProbe
5+

ngclearn/utils/analysis/attentive_probe.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,25 @@ class AttentiveProbe(Probe):
218218
219219
"""
220220
def __init__(
221-
self, dkey, source_seq_length, input_dim, out_dim, num_heads=8, attn_dim=64,
222-
target_seq_length=1, learnable_query_dim=32, batch_size=1, hid_dim=32,
223-
use_LN=True, use_LN_input=False, use_softmax=True, dropout=0.5, eta=0.0002,
224-
eta_decay=0.0, min_eta=1e-5, **kwargs
221+
self,
222+
dkey,
223+
source_seq_length,
224+
input_dim,
225+
out_dim,
226+
num_heads=8,
227+
attn_dim=64,
228+
target_seq_length=1,
229+
learnable_query_dim=32,
230+
batch_size=1,
231+
hid_dim=32,
232+
use_LN=True,
233+
use_LN_input=False,
234+
use_softmax=True,
235+
dropout=0.5,
236+
eta=0.0002,
237+
eta_decay=0.0,
238+
min_eta=1e-5,
239+
**kwargs
225240
):
226241
super().__init__(dkey, batch_size, **kwargs)
227242
assert attn_dim % num_heads == 0, f"`attn_dim` must be divisible by `num_heads`. Got {attn_dim} and {num_heads}."

ngclearn/utils/analysis/linear_probe.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,15 @@ class LinearProbe(Probe):
7373
7474
"""
7575
def __init__(
76-
self, dkey, source_seq_length, input_dim, out_dim, batch_size=1, use_LN=False, use_softmax=False, **kwargs
76+
self,
77+
dkey,
78+
source_seq_length,
79+
input_dim,
80+
out_dim,
81+
batch_size=1,
82+
use_LN=False,
83+
use_softmax=False,
84+
**kwargs
7785
):
7886
super().__init__(dkey, batch_size, **kwargs)
7987
self.dkey, *subkeys = random.split(self.dkey, 3)

ngclearn/utils/model_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,26 @@ def one_hot(P):
227227
p_t = jnp.argmax(P, axis=1)
228228
return nn.one_hot(p_t, num_classes=nC, dtype=jnp.float32)
229229

230+
@partial(jit, static_argnums=[1, 2])
231+
def chebyshev_norm(d, axis=-1, keepdims=False):
232+
"""
233+
Calculate the Chebyshev distance between two tensor-arrays.
234+
235+
Args:
236+
d: tensor d to measure against the origin
237+
238+
axis: axis to measure distance between the two tensors
239+
240+
keepdims: preserve dimensions of d
241+
242+
Returns:
243+
the Chebyshev distance (values) within d
244+
"""
245+
abs_diff = jnp.abs(d) ## d could be (a - b) externally
246+
dist_vals = jnp.max(abs_diff, axis=axis, keepdims=keepdims)
247+
return dist_vals
248+
249+
@jit
230250
def binarize(data, threshold=0.5):
231251
"""
232252
Converts the vector *data* to its binary equivalent

0 commit comments

Comments
 (0)