|
| 1 | +import jax.numpy as jnp |
| 2 | +from jax import random, jit |
| 3 | + |
| 4 | +from ngclearn import compilable |
| 5 | +from ngclearn import Compartment |
| 6 | +from ngclearn.components.synapses import DenseSynapse |
| 7 | +from ngclearn.utils import tensorstats |
| 8 | +from ngcsimlib import deprecate_args |
| 9 | +#from ngclearn.utils.io_utils import save_pkl, load_pkl |
| 10 | + |
| 11 | +class GerstnerHebbianSynapse(DenseSynapse): |
| 12 | + """ |
| 13 | + A synapse component that implements Gerstner's general Hebbian |
| 14 | + learning (Taylor) expansion (Equation 3 from Gerstner & Kistler, 2002). |
| 15 | +
|
| 16 | + Note that this synpatic update model can recover several classical forms |
| 17 | + of Hebbian-like update rules, including the covariance rule. |
| 18 | +
|
| 19 | + There are other higher-order terms possible, i.e., \Theta(xy), such as |
| 20 | + x * y2 and y x^2, etc. |
| 21 | +
|
| 22 | + | c2_corr > 0 and c0 = c1_pre = c1_post = 0 => Hebbian update |
| 23 | + | c2_corr < 0 and c0 = c1_pre = c1_post = 0 => anti-Hebbian update |
| 24 | + | c2_corr = 1 and c1_pre = -x_theta < 0 |
| 25 | +
|
| 26 | + """ |
| 27 | + def __init__( |
| 28 | + self, |
| 29 | + name, |
| 30 | + shape, ## (post_dim, pre_dim) |
| 31 | + eta=0.01, ## global step-size |
| 32 | + coeffs=None, ## these configure which kind of Hebb learning is done |
| 33 | + weight_init=None, |
| 34 | + p_conn=1., |
| 35 | + resist_scale=1., |
| 36 | + sign_value=1., |
| 37 | + batch_size=1, |
| 38 | + **kwargs |
| 39 | + ): |
| 40 | + bias_init = None ## no biases are included in Gerster's formulation |
| 41 | + super().__init__( |
| 42 | + name, |
| 43 | + shape=shape, |
| 44 | + weight_init=weight_init, |
| 45 | + bias_init=bias_init, |
| 46 | + resist_scale=resist_scale, |
| 47 | + p_conn=p_conn, |
| 48 | + batch_size=batch_size, |
| 49 | + **kwargs |
| 50 | + ) |
| 51 | + ## General Hebbian meta-parameters |
| 52 | + self.eta = eta |
| 53 | + self.sign_value = sign_value |
| 54 | + |
| 55 | + ## Expansion coefficients (c0, c1_pre, c1_post, c2_corr) |
| 56 | + if coeffs is None: ## Default to standard bilinear Hebb |
| 57 | + self.coeffs = { |
| 58 | + 'c0': 0., 'c1_pre': 0., 'c1_post': 0., 'c2_corr': 1.0 |
| 59 | + } |
| 60 | + else: |
| 61 | + self.coeffs = coeffs |
| 62 | + self.c0 = self.coeffs['c0'] |
| 63 | + self.c1_pre = self.coeffs['c1_pre'] |
| 64 | + self.c1_post = self.coeffs['c1_post'] |
| 65 | + self.c2_corr = self.coeffs['c2_corr'] |
| 66 | + |
| 67 | + # Initialize Weights (using JAX PRNG) |
| 68 | + #init_key, _ = random.split(self.key) |
| 69 | + #w_init = random.normal(init_key, shape) * 0.05 |
| 70 | + |
| 71 | + # Compartments (ngc-learn state management) |
| 72 | + #self.weights = Compartment(w_init) |
| 73 | + self.pre = Compartment(jnp.zeros((1, shape[1]))) |
| 74 | + self.post = Compartment(jnp.zeros((1, shape[0]))) |
| 75 | + |
| 76 | + @compilable |
| 77 | + def evolve(self, **kwargs): |
| 78 | + """ |
| 79 | + Updates weights using the Gerstner general expansion. |
| 80 | + Assumes pre_act and post_act compartments have been populated. |
| 81 | + """ |
| 82 | + # Retrieve current states |
| 83 | + W = self.weights.get() |
| 84 | + x = self.pre.get() # pre-synaptic activity (batch, pre_dim) |
| 85 | + y = self.post.get() # post-synaptic activity (batch, post_dim) |
| 86 | + batch_size = self.batch_size |
| 87 | + |
| 88 | + ## Bilinear Term (c2): correlation matrix |
| 89 | + ### (post_dim, batch) @ (batch, pre_dim) -> (post_dim, pre_dim) |
| 90 | + dW_corr = jnp.matmul(x.T, y) * (1./batch_size) |
| 91 | + ## Linear pre-synaptic term (c1_pre) |
| 92 | + ### Average over batch then broadcast to match weight matrix |
| 93 | + dW_pre = jnp.sum(x, axis=0, keepdims=True).T * (1./batch_size) |
| 94 | + ## Linear post-synaptic term (c1_post) |
| 95 | + dW_post = jnp.sum(y, axis=0, keepdims=True) * (1./batch_size) |
| 96 | + |
| 97 | + ## Apply Equation 3 Taylor expansion |
| 98 | + dW = (self.c0 * W + ## synaptic decay |
| 99 | + self.c1_pre * dW_pre + ## bilinear term |
| 100 | + self.c1_post * dW_post + ## pre-synaptic gating term |
| 101 | + self.c2_corr * dW_corr ## post-synpatic gating term |
| 102 | + ) |
| 103 | + ## perform a step of Hebbian ascent |
| 104 | + W = W + self.eta * dW |
| 105 | + ## Update weights |
| 106 | + self.weights.set(W) |
| 107 | + |
| 108 | + @compilable |
| 109 | + def reset(self, **kwargs): |
| 110 | + """Clears activity compartments""" |
| 111 | + self.pre.set( jnp.zeros((self.batch_size, self.shape[1])) ) |
| 112 | + self.post.set( jnp.zeros((self.batch_size, self.shape[0])) ) |
| 113 | + |
0 commit comments