Skip to content

Commit abe7dfa

Browse files
author
Alexander Ororbia
committed
minor edits/updates
1 parent 22ba7aa commit abe7dfa

2 files changed

Lines changed: 132 additions & 8 deletions

File tree

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import jax.numpy as jnp
2+
from jax import random, jit
3+
4+
from ngclearn import compilable
5+
from ngclearn import Compartment
6+
from ngclearn.components.synapses import DenseSynapse
7+
from ngclearn.utils import tensorstats
8+
from ngcsimlib import deprecate_args
9+
#from ngclearn.utils.io_utils import save_pkl, load_pkl
10+
11+
class GerstnerHebbianSynapse(DenseSynapse):
12+
"""
13+
A synapse component that implements Gerstner's general Hebbian
14+
learning (Taylor) expansion (Equation 3 from Gerstner & Kistler, 2002).
15+
16+
Note that this synpatic update model can recover several classical forms
17+
of Hebbian-like update rules, including the covariance rule.
18+
19+
There are other higher-order terms possible, i.e., \Theta(xy), such as
20+
x * y2 and y x^2, etc.
21+
22+
| c2_corr > 0 and c0 = c1_pre = c1_post = 0 => Hebbian update
23+
| c2_corr < 0 and c0 = c1_pre = c1_post = 0 => anti-Hebbian update
24+
| c2_corr = 1 and c1_pre = -x_theta < 0
25+
26+
"""
27+
def __init__(
28+
self,
29+
name,
30+
shape, ## (post_dim, pre_dim)
31+
eta=0.01, ## global step-size
32+
coeffs=None, ## these configure which kind of Hebb learning is done
33+
weight_init=None,
34+
p_conn=1.,
35+
resist_scale=1.,
36+
sign_value=1.,
37+
batch_size=1,
38+
**kwargs
39+
):
40+
bias_init = None ## no biases are included in Gerster's formulation
41+
super().__init__(
42+
name,
43+
shape=shape,
44+
weight_init=weight_init,
45+
bias_init=bias_init,
46+
resist_scale=resist_scale,
47+
p_conn=p_conn,
48+
batch_size=batch_size,
49+
**kwargs
50+
)
51+
## General Hebbian meta-parameters
52+
self.eta = eta
53+
self.sign_value = sign_value
54+
55+
## Expansion coefficients (c0, c1_pre, c1_post, c2_corr)
56+
if coeffs is None: ## Default to standard bilinear Hebb
57+
self.coeffs = {
58+
'c0': 0., 'c1_pre': 0., 'c1_post': 0., 'c2_corr': 1.0
59+
}
60+
else:
61+
self.coeffs = coeffs
62+
self.c0 = self.coeffs['c0']
63+
self.c1_pre = self.coeffs['c1_pre']
64+
self.c1_post = self.coeffs['c1_post']
65+
self.c2_corr = self.coeffs['c2_corr']
66+
67+
# Initialize Weights (using JAX PRNG)
68+
#init_key, _ = random.split(self.key)
69+
#w_init = random.normal(init_key, shape) * 0.05
70+
71+
# Compartments (ngc-learn state management)
72+
#self.weights = Compartment(w_init)
73+
self.pre = Compartment(jnp.zeros((1, shape[1])))
74+
self.post = Compartment(jnp.zeros((1, shape[0])))
75+
76+
@compilable
77+
def evolve(self, **kwargs):
78+
"""
79+
Updates weights using the Gerstner general expansion.
80+
Assumes pre_act and post_act compartments have been populated.
81+
"""
82+
# Retrieve current states
83+
W = self.weights.get()
84+
x = self.pre.get() # pre-synaptic activity (batch, pre_dim)
85+
y = self.post.get() # post-synaptic activity (batch, post_dim)
86+
batch_size = self.batch_size
87+
88+
## Bilinear Term (c2): correlation matrix
89+
### (post_dim, batch) @ (batch, pre_dim) -> (post_dim, pre_dim)
90+
dW_corr = jnp.matmul(x.T, y) * (1./batch_size)
91+
## Linear pre-synaptic term (c1_pre)
92+
### Average over batch then broadcast to match weight matrix
93+
dW_pre = jnp.sum(x, axis=0, keepdims=True).T * (1./batch_size)
94+
## Linear post-synaptic term (c1_post)
95+
dW_post = jnp.sum(y, axis=0, keepdims=True) * (1./batch_size)
96+
97+
## Apply Equation 3 Taylor expansion
98+
dW = (self.c0 * W + ## synaptic decay
99+
self.c1_pre * dW_pre + ## bilinear term
100+
self.c1_post * dW_post + ## pre-synaptic gating term
101+
self.c2_corr * dW_corr ## post-synpatic gating term
102+
)
103+
## perform a step of Hebbian ascent
104+
W = W + self.eta * dW
105+
## Update weights
106+
self.weights.set(W)
107+
108+
@compilable
109+
def reset(self, **kwargs):
110+
"""Clears activity compartments"""
111+
self.pre.set( jnp.zeros((self.batch_size, self.shape[1])) )
112+
self.post.set( jnp.zeros((self.batch_size, self.shape[0])) )
113+

ngclearn/utils/filters/gauss_filter.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from jax import lax, jit
33
from functools import partial
44

5-
65
def _calc_gaussian_kernel_2D( ## internal co-routine
76
sigma: float,
87
radius: int
@@ -21,7 +20,8 @@ def gaussian_filter(
2120
sigma_center: float, ## sigma1
2221
sigma_surround: float, ## sigma2
2322
kernel_size : int, ## radius
24-
use_ratio=False ## if True, this becomes a ratio-of-Gaussians
23+
use_ratio=False, ## if True, this becomes a ratio-of-Gaussians
24+
edge_pad_mode="edge" ## "reflect"
2525
) -> jnp.ndarray:
2626
"""
2727
Applies a difference-of-Gaussians filter to a batch of 2D images (of CxHxW tensor shape).
@@ -40,27 +40,38 @@ def gaussian_filter(
4040
Returns:
4141
An output tensor of shape (B, C, H, W)
4242
"""
43-
x = images
43+
## Pad spatial dimensions (H, W) using edge-clamping to remove artifacts
44+
# Format for 4D (B, C, H, W): ((Before_B, After_B), (Before_C, After_C), (Before_H, After_H), (Before_W, After_W))
45+
padding_config = ((0, 0), (0, 0), (kernel_size, kernel_size), (kernel_size, kernel_size))
46+
padded_x = jnp.pad(images, padding_config, mode=edge_pad_mode)
47+
4448
## Construct two 2D Gaussian kernels
4549
k1 = _calc_gaussian_kernel_2D(sigma_center, kernel_size) ## center kernel
4650
k2 = _calc_gaussian_kernel_2D(sigma_surround, kernel_size) ## surround kernel
51+
4752
## Define dimension ordering for lax.conv ('NCHW' standard layout)
4853
dn = lax.ConvDimensionNumbers(
4954
lhs_spec=(0, 1, 2, 3), ## (batch, channel, height, width)
5055
rhs_spec=(0, 1, 2, 3), ## (out_channel, in_channel, height, width)
5156
out_spec=(0, 1, 2, 3) ## (batch, channel, height, width)
5257
)
53-
## Perform spatial convolutions w/ edge padding to emulate 'SAME' behavior
58+
59+
## Extract channel count dynamically for independent channel-wise filtering
60+
num_channels = images.shape[1]
61+
62+
## Perform spatial convolutions w/ 'VALID' padding on the edge-padded input
5463
blur_center = lax.conv_general_dilated(
55-
x, k1, window_strides=(1, 1), padding=[(kernel_size, kernel_size), (kernel_size, kernel_size)], dimension_numbers=dn
64+
padded_x, k1, window_strides=(1, 1), padding='VALID', dimension_numbers=dn, feature_group_count=num_channels
5665
)
5766
blur_surround = lax.conv_general_dilated(
58-
x, k2, window_strides=(1, 1), padding=[(kernel_size, kernel_size), (kernel_size, kernel_size)], dimension_numbers=dn
67+
padded_x, k2, window_strides=(1, 1), padding='VALID', dimension_numbers=dn, feature_group_count=num_channels
5968
)
69+
6070
## Perform final filter calculation
6171
if use_ratio:
6272
eps = 1e-5
63-
output = blur_center / (blur_surround + eps) ## Compute kernel difference
73+
output = blur_center / (blur_surround + eps) ## Compute kernel ratio
6474
else:
65-
output = blur_center - blur_surround ## Compute kernel ratio
75+
output = blur_center - blur_surround ## Compute kernel difference
6676
return output ## shape: (B, C, H, W)
77+

0 commit comments

Comments
 (0)