Skip to content

Commit 49ce5b9

Browse files
author
Alexander Ororbia
committed
wrote/integrated an ART2A synapse model, batch-generalized
1 parent 1362b11 commit 49ce5b9

File tree

1 file changed

+214
-0
lines changed

1 file changed

+214
-0
lines changed
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
from jax import random, numpy as jnp, jit
2+
from functools import partial
3+
from ngclearn import compilable
4+
from ngclearn import Compartment
5+
from ngclearn.utils.model_utils import softmax, bkwta
6+
7+
from ngclearn.components.synapses.denseSynapse import DenseSynapse
8+
9+
@partial(jit, static_argnums=[1])
10+
def _normalize(x_in, norm_fx=0):
11+
if norm_fx == 1:
12+
xmin = jnp.min(x, axis=1, keepdims=True)
13+
xmax = jnp.max(x, axis=1, keepdims=True)
14+
x = (x_in - xmin)/(xmax - xmin)
15+
else:
16+
x = x_in / jnp.linalg.norm(x_in, ord=2, axis=1, keepdims=True)
17+
return x
18+
19+
class ART2ASynapse(DenseSynapse): # Adaptive resonance theory (ART) 2A synaptic cable
20+
"""
21+
A synaptic cable that emulates a simplified form of adaptive resonance theory (ART)
22+
adapted for continuous input signals (specifically, the ART2A-C model that handles
23+
real-valued input values).
24+
25+
| --- Synapse Compartments: ---
26+
| inputs - input (takes in external signals)
27+
| outputs - output signals (transformation induced by synapses)
28+
| weights - current value matrix of synaptic efficacies
29+
| i_tick - current internal tick / marker (gets incremented by 1 for each call to `evolve`)
30+
| eta - current learning rate value
31+
| key - JAX PRNG key
32+
| --- Synaptic Plasticity Compartments: ---
33+
| inputs - pre-synaptic signal/value to drive 1st term of ART2A update (x)
34+
| outputs - post-synaptic signal/value to drive 2nd term of ART2A update (y)
35+
| dWeights - current delta matrix containing changes to be applied to synapses
36+
37+
| References:
38+
| Carpenter, Gail A., and Stephen Grossberg. "ART 2: Self-organization of stable category
39+
| recognition codes for analog input patterns." Applied optics 26.23 (1987): 4919-4930.
40+
|
41+
| Ororbia, Alexander G. "Continual competitive memory: A neural system for online task-free
42+
| lifelong learning." arXiv preprint arXiv:2106.13300 (2021).
43+
44+
Args:
45+
name: the string name of this cell
46+
47+
shape: tuple specifying shape of this synaptic cable (usually a 2-tuple with number of
48+
inputs by number of outputs)
49+
50+
eta: (initial) learning rate / step-size for this ART2A model (initial condition value for `eta`)
51+
52+
eta_decrement: constant value to decrease `eta` by each call to this synapse's `evolve()`, i.e.,
53+
this triggers a linear schedule for decreasing `eta` by (Default: 0)
54+
55+
vigilance: vigilance parameter to decide if a memory vector is updated (rho)
56+
57+
weight_init: a kernel to drive initialization of this synaptic cable's values;
58+
typically a tuple with 1st element as a string calling the name of
59+
initialization to use
60+
61+
resist_scale: a fixed scaling factor to apply to synaptic transform (Default: 1.)
62+
63+
p_conn: probability of a connection existing (default: 1.); setting
64+
this to < 1. will result in a sparser synaptic structure
65+
"""
66+
67+
def __init__(
68+
self,
69+
name,
70+
shape, ## determines memory matrix size
71+
eta=0.05, ## learning rate
72+
eta_decrement=0.,
73+
vigilance=0.3, ## vigilance parameter (rho)
74+
weight_init=None,
75+
resist_scale=1.,
76+
p_conn=1.,
77+
batch_size=1,
78+
**kwargs
79+
):
80+
super().__init__(
81+
name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs
82+
)
83+
84+
### Synapse and ART-2A hyper-parameters
85+
self.K = 1 ## number of winners for bmu calculation
86+
self.norm_fx = 0 ## 0 -> normalize via norm, 1 -> complement coding (min-max rescale)
87+
88+
self.shape = shape ## shape of synaptic efficacy matrix
89+
self.initial_eta = eta
90+
self.eta_decr = eta_decrement ## linear decrease to iteratively update eta by (each "tick")
91+
self.vigilance = vigilance ## (rho)
92+
93+
## ART-2A Compartment setup
94+
self.xprobe = Compartment(jnp.zeros((batch_size, shape[0])))
95+
self.eta = Compartment(jnp.zeros((1, 1)) + self.initial_eta, display_name="Dynamic step size")
96+
self.i_tick = Compartment(jnp.zeros((1, 1)))
97+
#self.bmu = Compartment(jnp.zeros((1, 1)), display_name="Best matching unit mask")
98+
self.dWeights = Compartment(self.weights.get() * 0)
99+
self.misses = Compartment(jnp.zeros((batch_size, 1)))
100+
101+
#@compilable
102+
def consolidate(self, wipe_mem=False): ## memory storage/consolidation routine
103+
## note that this co-routine needs to be non-compilable/non-jit-i-fied, as
104+
## its main purpose is to structurally alter the memory matrix W (dynamically)
105+
x_in = self.inputs.get()
106+
#xprobe = self.xprobe.get()
107+
xprobe = _normalize(x_in, norm_fx=self.norm_fx)
108+
if wipe_mem is False:
109+
W = self.weights.get() ## get current memory matrix
110+
miss_mask = self.misses.get()
111+
if jnp.sum(miss_mask) > 0: ## for non-resonant patterns
112+
r, c = jnp.nonzero(miss_mask)
113+
mem = xprobe[r, :]
114+
W = jnp.concat([W, mem.T], axis=1)
115+
self.weights.set(W)
116+
else:
117+
self.weights.set(xprobe.T)
118+
119+
@compilable
120+
def advance_state(self): ## forward-inference step of ART2A
121+
x_in = self.inputs.get()
122+
W = self.weights.get() ## get (transposed) memory matrix
123+
124+
x = _normalize(x_in, norm_fx=self.norm_fx)
125+
self.xprobe.set(x)
126+
sims = jnp.matmul(x, W) ## compute similarities (parallel dot products)
127+
z_winners = sims * bkwta(sims, nWTA=self.K) ## get winner mask (hidden layer)
128+
self.outputs.set(z_winners)
129+
130+
@compilable
131+
def evolve(self, t, dt): ## competitive Hebbian update step of ART2A
132+
W = self.weights.get() ## D x Z
133+
x = self.xprobe.get() ## B x D
134+
z_winners = self.outputs.get() ## B x Z
135+
eta = self.eta.get()
136+
## Note: we refactor ART update into a leaky integrator equation:
137+
## W = W * (1 - b) + dW * b = W + b * (-W + dW); b = eta
138+
139+
## for resonant patterns, we perform a Hebbian storage update
140+
hits = (z_winners >= self.vigilance) * 1. ## B x Z
141+
m = (jnp.sum(hits, axis=1, keepdims=True) > 0.) * 1. ## B x 1
142+
wnew = (-jnp.matmul(z_winners, W.T) + x) * m ## B x D
143+
dW = jnp.matmul(wnew.T, hits) ## D x Z ## adjustment matrix
144+
W = W + dW * eta ## D x Z ## do a step of Hebbian ascent
145+
146+
## NOTE: is this post-weight-update normalization needed?
147+
nW = jnp.linalg.norm(W, ord=2, axis=0, keepdims=True)
148+
mz = (jnp.sum(hits, axis=0, keepdims=True) > 0.) * 1.
149+
W = W / (nW * mz + (1. - mz))
150+
151+
self.weights.set(W) ## memory matrix advances forward to new state
152+
self.misses.set(1. - m) ## store unused/non-resonant pattern mask
153+
154+
#tmp_key, *subkeys = random.split(self.key.get(), 3)
155+
#self.key.set(tmp_key)
156+
## synaptic update noise
157+
#eps = random.normal(subkeys[0], W.shape) ## TODO: is this same size as tensor? or scalar?
158+
159+
## update learning rate eta
160+
eta_tp1 = jnp.maximum(1e-5, eta - self.eta_decr)
161+
self.eta.set(eta_tp1)
162+
163+
self.i_tick.set(self.i_tick.get() + 1)
164+
165+
@compilable
166+
def reset(self):
167+
preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0]))
168+
postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1]))
169+
170+
if not self.inputs.targeted:
171+
self.inputs.set(preVals)
172+
self.outputs.set(postVals)
173+
self.xprobe.set(preVals)
174+
self.misses.set(jnp.zeros((self.batch_size.get(), 1)))
175+
176+
self.dWeights.set(jnp.zeros(self.shape.get()))
177+
178+
@classmethod
179+
def help(cls): ## component help function
180+
properties = {
181+
"synapse_type": "ART2ASynapse - performs an adaptable synaptic transformation of inputs to produce output "
182+
"signals; synapses are adjusted via competitive Hebbian learning in accordance with "
183+
"adaptive resonance theory (2A)"
184+
}
185+
compartment_props = {
186+
"input_compartments":
187+
{"inputs": "Takes in external input signal values",
188+
"key": "JAX PRNG key"},
189+
"parameter_compartments":
190+
{"weights": "Synapse efficacy/strength parameter values"},
191+
"output_compartments":
192+
{"outputs": "Output of synaptic transformation"},
193+
}
194+
hyperparams = {
195+
"shape": "Shape of synaptic weight value matrix; number inputs x number outputs",
196+
"batch_size": "Batch size dimension of this component",
197+
"weight_init": "Initialization conditions for synaptic weight (W) values",
198+
"resist_scale": "Resistance level scaling factor (applied to output of transformation)",
199+
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)",
200+
"eta": "Global learning rate",
201+
"eta_decrement": "Constant amount to decrease global learning by each call to `evolve`"
202+
}
203+
info = {cls.__name__: properties,
204+
"compartments": compartment_props,
205+
"dynamics": "outputs = [bmu_mask] ;"
206+
"dW = ART2A competitive Hebbian update",
207+
"hyperparameters": hyperparams}
208+
return info
209+
210+
# if __name__ == '__main__':
211+
# from ngcsimlib.context import Context
212+
# with Context("Bar") as bar:
213+
# Wab = ART2ASynapse("Wab", (2, 3), 4, 4, 1.)
214+
# print(Wab)

0 commit comments

Comments
 (0)