@@ -23,6 +23,10 @@ class GerstnerHebbianSynapse(DenseSynapse):
2323 | c2_corr < 0 and c0 = c1_pre = c1_post = 0 => anti-Hebbian update
2424 | c2_corr = 1 and c1_pre = -x_theta < 0
2525
26+ | References:
27+ | Gerstner, W. and Kistler, W.M., 2002. Mathematical formulations of Hebbian
28+ | learning. Biological cybernetics, 87(5), pp.404-415.
29+
2630 """
2731 def __init__ (
2832 self ,
@@ -37,7 +41,7 @@ def __init__(
3741 batch_size = 1 ,
3842 ** kwargs
3943 ):
40- bias_init = None ## no biases are included in Gerster's formulation
44+ bias_init = None ## NOTE: no biases are included in Gerster's formulation
4145 super ().__init__ (
4246 name ,
4347 shape = shape ,
@@ -48,11 +52,11 @@ def __init__(
4852 batch_size = batch_size ,
4953 ** kwargs
5054 )
51- ## General Hebbian meta-parameters
55+ ## general Hebbian meta-parameters
5256 self .eta = eta
5357 self .sign_value = sign_value
5458
55- ## Expansion coefficients (c0, c1_pre, c1_post, c2_corr)
59+ ## Gerstner and Kisler's expansion coefficients (c0, c1_pre, c1_post, c2_corr)
5660 if coeffs is None : ## Default to standard bilinear Hebb
5761 self .coeffs = {
5862 'c0' : 0. , 'c1_pre' : 0. , 'c1_post' : 0. , 'c2_corr' : 1.0
@@ -64,50 +68,43 @@ def __init__(
6468 self .c1_post = self .coeffs ['c1_post' ]
6569 self .c2_corr = self .coeffs ['c2_corr' ]
6670
67- # Initialize Weights (using JAX PRNG)
68- #init_key, _ = random.split(self.key)
69- #w_init = random.normal(init_key, shape) * 0.05
70-
71- # Compartments (ngc-learn state management)
72- #self.weights = Compartment(w_init)
71+ ## set up relevant compartments
7372 self .pre = Compartment (jnp .zeros ((1 , shape [1 ])))
7473 self .post = Compartment (jnp .zeros ((1 , shape [0 ])))
74+ self .dWeights = Compartment (jnp .zeros (shape ))
7575
7676 @compilable
77- def evolve (self , ** kwargs ):
78- """
79- Updates weights using the Gerstner general expansion.
80- Assumes pre_act and post_act compartments have been populated.
81- """
82- # Retrieve current states
77+ def evolve (self , ** kwargs ): ## perform update via Gerstner's general expansion
78+ ## retrieve current compartment state values
8379 W = self .weights .get ()
84- x = self .pre .get () # pre-synaptic activity (batch, pre_dim)
85- y = self .post .get () # post-synaptic activity (batch, post_dim)
80+ x = self .pre .get () ## pre-synaptic activity (batch, pre_dim)
81+ y = self .post .get () ## post-synaptic activity (batch, post_dim)
8682 batch_size = self .batch_size
8783
88- ## Bilinear Term (c2): correlation matrix
89- ### (post_dim , batch) @ (batch, pre_dim ) -> (post_dim, pre_dim )
84+ ## calculate bilinear Term (c2), i.e., correlation matrix
85+ ### (pre_dim , batch) @ (batch, post_dim ) -> (pre_dim, post_dim )
9086 dW_corr = jnp .matmul (x .T , y ) * (1. / batch_size )
91- ## Linear pre-synaptic term (c1_pre)
92- ### Average over batch then broadcast to match weight matrix
87+ ## linear pre-synaptic term (c1_pre)
88+ ### get mean over batch then broadcast to match weight matrix
9389 dW_pre = jnp .sum (x , axis = 0 , keepdims = True ).T * (1. / batch_size )
94- ## Linear post-synaptic term (c1_post)
95- dW_post = jnp .sum (y , axis = 0 , keepdims = True ) * (1. / batch_size )
90+ ## linear post-synaptic term (c1_post), mean over post-syn values
91+ dW_post = jnp .sum (y , axis = 0 , keepdims = True ) * (1. / batch_size )
9692
97- ## Apply Equation 3 Taylor expansion
93+ ## apply Taylor expansion from Equation 3 (Gerstner and Kistler)
9894 dW = (self .c0 * W + ## synaptic decay
9995 self .c1_pre * dW_pre + ## bilinear term
10096 self .c1_post * dW_post + ## pre-synaptic gating term
10197 self .c2_corr * dW_corr ## post-synpatic gating term
10298 )
99+ self .dWeights .set (dW )
100+
103101 ## perform a step of Hebbian ascent
104- W = W + self .eta * dW
105- ## Update weights
102+ W = W + self .eta * dW ## update synaptic efficacies
106103 self .weights .set (W )
107104
108105 @compilable
109- def reset (self , ** kwargs ):
110- """Clears activity compartments"""
106+ def reset (self , ** kwargs ): ## clear compartment values
111107 self .pre .set ( jnp .zeros ((self .batch_size , self .shape [1 ])) )
112108 self .post .set ( jnp .zeros ((self .batch_size , self .shape [0 ])) )
109+ self .dWeights .set (self .dWeights .get () * 0 )
113110
0 commit comments