|
| 1 | +from jax import random, numpy as jnp, jit |
| 2 | +from functools import partial |
| 3 | +from ngclearn import compilable |
| 4 | +from ngclearn import Compartment |
| 5 | +from ngclearn.utils.model_utils import softmax, bkwta |
| 6 | + |
| 7 | +from ngclearn.components.synapses.denseSynapse import DenseSynapse |
| 8 | + |
| 9 | +@partial(jit, static_argnums=[1]) |
| 10 | +def _normalize(x_in, norm_fx=0): |
| 11 | + if norm_fx == 1: |
| 12 | + xmin = jnp.min(x, axis=1, keepdims=True) |
| 13 | + xmax = jnp.max(x, axis=1, keepdims=True) |
| 14 | + x = (x_in - xmin)/(xmax - xmin) |
| 15 | + else: |
| 16 | + x = x_in / jnp.linalg.norm(x_in, ord=2, axis=1, keepdims=True) |
| 17 | + return x |
| 18 | + |
| 19 | +class ART2ASynapse(DenseSynapse): # Adaptive resonance theory (ART) 2A synaptic cable |
| 20 | + """ |
| 21 | + A synaptic cable that emulates a simplified form of adaptive resonance theory (ART) |
| 22 | + adapted for continuous input signals (specifically, the ART2A-C model that handles |
| 23 | + real-valued input values). |
| 24 | +
|
| 25 | + | --- Synapse Compartments: --- |
| 26 | + | inputs - input (takes in external signals) |
| 27 | + | outputs - output signals (transformation induced by synapses) |
| 28 | + | weights - current value matrix of synaptic efficacies |
| 29 | + | i_tick - current internal tick / marker (gets incremented by 1 for each call to `evolve`) |
| 30 | + | eta - current learning rate value |
| 31 | + | key - JAX PRNG key |
| 32 | + | --- Synaptic Plasticity Compartments: --- |
| 33 | + | inputs - pre-synaptic signal/value to drive 1st term of ART2A update (x) |
| 34 | + | outputs - post-synaptic signal/value to drive 2nd term of ART2A update (y) |
| 35 | + | dWeights - current delta matrix containing changes to be applied to synapses |
| 36 | +
|
| 37 | + | References: |
| 38 | + | Carpenter, Gail A., and Stephen Grossberg. "ART 2: Self-organization of stable category |
| 39 | + | recognition codes for analog input patterns." Applied optics 26.23 (1987): 4919-4930. |
| 40 | + | |
| 41 | + | Ororbia, Alexander G. "Continual competitive memory: A neural system for online task-free |
| 42 | + | lifelong learning." arXiv preprint arXiv:2106.13300 (2021). |
| 43 | +
|
| 44 | + Args: |
| 45 | + name: the string name of this cell |
| 46 | +
|
| 47 | + shape: tuple specifying shape of this synaptic cable (usually a 2-tuple with number of |
| 48 | + inputs by number of outputs) |
| 49 | +
|
| 50 | + eta: (initial) learning rate / step-size for this ART2A model (initial condition value for `eta`) |
| 51 | +
|
| 52 | + eta_decrement: constant value to decrease `eta` by each call to this synapse's `evolve()`, i.e., |
| 53 | + this triggers a linear schedule for decreasing `eta` by (Default: 0) |
| 54 | +
|
| 55 | + vigilance: vigilance parameter to decide if a memory vector is updated (rho) |
| 56 | +
|
| 57 | + weight_init: a kernel to drive initialization of this synaptic cable's values; |
| 58 | + typically a tuple with 1st element as a string calling the name of |
| 59 | + initialization to use |
| 60 | +
|
| 61 | + resist_scale: a fixed scaling factor to apply to synaptic transform (Default: 1.) |
| 62 | +
|
| 63 | + p_conn: probability of a connection existing (default: 1.); setting |
| 64 | + this to < 1. will result in a sparser synaptic structure |
| 65 | + """ |
| 66 | + |
| 67 | + def __init__( |
| 68 | + self, |
| 69 | + name, |
| 70 | + shape, ## determines memory matrix size |
| 71 | + eta=0.05, ## learning rate |
| 72 | + eta_decrement=0., |
| 73 | + vigilance=0.3, ## vigilance parameter (rho) |
| 74 | + weight_init=None, |
| 75 | + resist_scale=1., |
| 76 | + p_conn=1., |
| 77 | + batch_size=1, |
| 78 | + **kwargs |
| 79 | + ): |
| 80 | + super().__init__( |
| 81 | + name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs |
| 82 | + ) |
| 83 | + |
| 84 | + ### Synapse and ART-2A hyper-parameters |
| 85 | + self.K = 1 ## number of winners for bmu calculation |
| 86 | + self.norm_fx = 0 ## 0 -> normalize via norm, 1 -> complement coding (min-max rescale) |
| 87 | + |
| 88 | + self.shape = shape ## shape of synaptic efficacy matrix |
| 89 | + self.initial_eta = eta |
| 90 | + self.eta_decr = eta_decrement ## linear decrease to iteratively update eta by (each "tick") |
| 91 | + self.vigilance = vigilance ## (rho) |
| 92 | + |
| 93 | + ## ART-2A Compartment setup |
| 94 | + self.xprobe = Compartment(jnp.zeros((batch_size, shape[0]))) |
| 95 | + self.eta = Compartment(jnp.zeros((1, 1)) + self.initial_eta, display_name="Dynamic step size") |
| 96 | + self.i_tick = Compartment(jnp.zeros((1, 1))) |
| 97 | + #self.bmu = Compartment(jnp.zeros((1, 1)), display_name="Best matching unit mask") |
| 98 | + self.dWeights = Compartment(self.weights.get() * 0) |
| 99 | + self.misses = Compartment(jnp.zeros((batch_size, 1))) |
| 100 | + |
| 101 | + #@compilable |
| 102 | + def consolidate(self, wipe_mem=False): ## memory storage/consolidation routine |
| 103 | + ## note that this co-routine needs to be non-compilable/non-jit-i-fied, as |
| 104 | + ## its main purpose is to structurally alter the memory matrix W (dynamically) |
| 105 | + x_in = self.inputs.get() |
| 106 | + #xprobe = self.xprobe.get() |
| 107 | + xprobe = _normalize(x_in, norm_fx=self.norm_fx) |
| 108 | + if wipe_mem is False: |
| 109 | + W = self.weights.get() ## get current memory matrix |
| 110 | + miss_mask = self.misses.get() |
| 111 | + if jnp.sum(miss_mask) > 0: ## for non-resonant patterns |
| 112 | + r, c = jnp.nonzero(miss_mask) |
| 113 | + mem = xprobe[r, :] |
| 114 | + W = jnp.concat([W, mem.T], axis=1) |
| 115 | + self.weights.set(W) |
| 116 | + else: |
| 117 | + self.weights.set(xprobe.T) |
| 118 | + |
| 119 | + @compilable |
| 120 | + def advance_state(self): ## forward-inference step of ART2A |
| 121 | + x_in = self.inputs.get() |
| 122 | + W = self.weights.get() ## get (transposed) memory matrix |
| 123 | + |
| 124 | + x = _normalize(x_in, norm_fx=self.norm_fx) |
| 125 | + self.xprobe.set(x) |
| 126 | + sims = jnp.matmul(x, W) ## compute similarities (parallel dot products) |
| 127 | + z_winners = sims * bkwta(sims, nWTA=self.K) ## get winner mask (hidden layer) |
| 128 | + self.outputs.set(z_winners) |
| 129 | + |
| 130 | + @compilable |
| 131 | + def evolve(self, t, dt): ## competitive Hebbian update step of ART2A |
| 132 | + W = self.weights.get() ## D x Z |
| 133 | + x = self.xprobe.get() ## B x D |
| 134 | + z_winners = self.outputs.get() ## B x Z |
| 135 | + eta = self.eta.get() |
| 136 | + ## Note: we refactor ART update into a leaky integrator equation: |
| 137 | + ## W = W * (1 - b) + dW * b = W + b * (-W + dW); b = eta |
| 138 | + |
| 139 | + ## for resonant patterns, we perform a Hebbian storage update |
| 140 | + hits = (z_winners >= self.vigilance) * 1. ## B x Z |
| 141 | + m = (jnp.sum(hits, axis=1, keepdims=True) > 0.) * 1. ## B x 1 |
| 142 | + wnew = (-jnp.matmul(z_winners, W.T) + x) * m ## B x D |
| 143 | + dW = jnp.matmul(wnew.T, hits) ## D x Z ## adjustment matrix |
| 144 | + W = W + dW * eta ## D x Z ## do a step of Hebbian ascent |
| 145 | + |
| 146 | + ## NOTE: is this post-weight-update normalization needed? |
| 147 | + nW = jnp.linalg.norm(W, ord=2, axis=0, keepdims=True) |
| 148 | + mz = (jnp.sum(hits, axis=0, keepdims=True) > 0.) * 1. |
| 149 | + W = W / (nW * mz + (1. - mz)) |
| 150 | + |
| 151 | + self.weights.set(W) ## memory matrix advances forward to new state |
| 152 | + self.misses.set(1. - m) ## store unused/non-resonant pattern mask |
| 153 | + |
| 154 | + #tmp_key, *subkeys = random.split(self.key.get(), 3) |
| 155 | + #self.key.set(tmp_key) |
| 156 | + ## synaptic update noise |
| 157 | + #eps = random.normal(subkeys[0], W.shape) ## TODO: is this same size as tensor? or scalar? |
| 158 | + |
| 159 | + ## update learning rate eta |
| 160 | + eta_tp1 = jnp.maximum(1e-5, eta - self.eta_decr) |
| 161 | + self.eta.set(eta_tp1) |
| 162 | + |
| 163 | + self.i_tick.set(self.i_tick.get() + 1) |
| 164 | + |
| 165 | + @compilable |
| 166 | + def reset(self): |
| 167 | + preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0])) |
| 168 | + postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1])) |
| 169 | + |
| 170 | + if not self.inputs.targeted: |
| 171 | + self.inputs.set(preVals) |
| 172 | + self.outputs.set(postVals) |
| 173 | + self.xprobe.set(preVals) |
| 174 | + self.misses.set(jnp.zeros((self.batch_size.get(), 1))) |
| 175 | + |
| 176 | + self.dWeights.set(jnp.zeros(self.shape.get())) |
| 177 | + |
| 178 | + @classmethod |
| 179 | + def help(cls): ## component help function |
| 180 | + properties = { |
| 181 | + "synapse_type": "ART2ASynapse - performs an adaptable synaptic transformation of inputs to produce output " |
| 182 | + "signals; synapses are adjusted via competitive Hebbian learning in accordance with " |
| 183 | + "adaptive resonance theory (2A)" |
| 184 | + } |
| 185 | + compartment_props = { |
| 186 | + "input_compartments": |
| 187 | + {"inputs": "Takes in external input signal values", |
| 188 | + "key": "JAX PRNG key"}, |
| 189 | + "parameter_compartments": |
| 190 | + {"weights": "Synapse efficacy/strength parameter values"}, |
| 191 | + "output_compartments": |
| 192 | + {"outputs": "Output of synaptic transformation"}, |
| 193 | + } |
| 194 | + hyperparams = { |
| 195 | + "shape": "Shape of synaptic weight value matrix; number inputs x number outputs", |
| 196 | + "batch_size": "Batch size dimension of this component", |
| 197 | + "weight_init": "Initialization conditions for synaptic weight (W) values", |
| 198 | + "resist_scale": "Resistance level scaling factor (applied to output of transformation)", |
| 199 | + "p_conn": "Probability of a connection existing (otherwise, it is masked to zero)", |
| 200 | + "eta": "Global learning rate", |
| 201 | + "eta_decrement": "Constant amount to decrease global learning by each call to `evolve`" |
| 202 | + } |
| 203 | + info = {cls.__name__: properties, |
| 204 | + "compartments": compartment_props, |
| 205 | + "dynamics": "outputs = [bmu_mask] ;" |
| 206 | + "dW = ART2A competitive Hebbian update", |
| 207 | + "hyperparameters": hyperparams} |
| 208 | + return info |
| 209 | + |
| 210 | +# if __name__ == '__main__': |
| 211 | +# from ngcsimlib.context import Context |
| 212 | +# with Context("Bar") as bar: |
| 213 | +# Wab = ART2ASynapse("Wab", (2, 3), 4, 4, 1.) |
| 214 | +# print(Wab) |
0 commit comments