Skip to content

Commit 64eb27d

Browse files
author
Alexander Ororbia
committed
claned up ganglion-cell, added batched_reset
1 parent 969bb1b commit 64eb27d

1 file changed

Lines changed: 15 additions & 25 deletions

File tree

ngclearn/components/input_encoders/ganglionCell.py

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55
import jax
66
from typing import Union, Tuple
77

8-
9-
10-
11-
def create_gaussian_filter(patch_shape, sigma):
8+
def _create_gaussian_filter(patch_shape, sigma):
129
"""
1310
Create a 2D Gaussian kernel centered on patch_shape with given sigma.
1411
"""
@@ -25,23 +22,15 @@ def create_gaussian_filter(patch_shape, sigma):
2522
filter = jnp.exp(-((x - xc) ** 2 + (y - yc) ** 2) / (2 * (sigma ** 2)))
2623
return filter / jnp.sum(filter)
2724

28-
29-
30-
31-
32-
def create_dog_filter(patch_shape, sigma, k=1.6, lmbda=1):
33-
g1 = create_gaussian_filter(patch_shape, sigma=sigma)
34-
g2 = create_gaussian_filter(patch_shape, sigma=sigma * k)
25+
def _create_dog_filter(patch_shape, sigma, k=1.6, lmbda=1):
26+
g1 = _create_gaussian_filter(patch_shape, sigma=sigma)
27+
g2 = _create_gaussian_filter(patch_shape, sigma=sigma * k)
3528

3629
dog = g1 - lmbda * g2
3730

3831
return dog #- jnp.mean(dog)
3932

40-
41-
42-
43-
44-
def create_patches(obs, patch_shape, step_shape):
33+
def _create_patches(obs, patch_shape, step_shape):
4534
"""
4635
Extract 2D patches from a batch of images using a sliding window.
4736
@@ -78,9 +67,6 @@ def create_patches(obs, patch_shape, step_shape):
7867
return patches
7968

8069

81-
82-
83-
8470
class RetinalGanglionCell(JaxComponent):
8571
"""
8672
A groupd of retinal ganglion cell that senses the input
@@ -131,16 +117,17 @@ def __init__(self, name: str,
131117
self.n_cells = n_cells
132118
self.sigma = sigma
133119

120+
self.batch_size = batch_size
134121
self.area_shape = area_shape
135122
self.patch_shape = patch_shape
136123
self.step_shape = step_shape
137124

138125
filter = jnp.ones(self.patch_shape)
139126

140127
if filter_type == 'gaussian':
141-
filter = create_gaussian_filter(patch_shape=self.patch_shape, sigma=self.sigma)
128+
filter = _create_gaussian_filter(patch_shape=self.patch_shape, sigma=self.sigma)
142129
elif filter_type == 'difference_of_gaussian':
143-
filter = create_dog_filter(patch_shape=self.patch_shape, sigma=sigma)
130+
filter = _create_dog_filter(patch_shape=self.patch_shape, sigma=sigma)
144131

145132
# ═════════════════ compartments initial values ════════════════════
146133
in_restVals = jnp.zeros((batch_size,
@@ -161,7 +148,7 @@ def advance_state(self, t):
161148
px, py = self.patch_shape
162149

163150
# ═══════════════════ extract pathches for filters ══════════════════
164-
input_patches = create_patches(inputs, patch_shape=self.patch_shape, step_shape=self.step_shape)
151+
input_patches = _create_patches(inputs, patch_shape=self.patch_shape, step_shape=self.step_shape)
165152

166153
# ═══════════════════ apply filter to all pathches ══════════════════
167154
filtered_input = input_patches * filter ## shape: (B | n_cells | px | py)
@@ -175,9 +162,12 @@ def advance_state(self, t):
175162
self.outputs.set(outputs)
176163

177164
@compilable
178-
def reset(self, batch_size):
179-
in_restVals = jnp.zeros((batch_size,
180-
*self.area_shape)) ## input: (B | ix | iy)
165+
def reset(self): ## reset core components/statistics
166+
self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member
167+
168+
@compilable
169+
def batched_reset(self, batch_size):
170+
in_restVals = jnp.zeros((batch_size, *self.area_shape)) ## input: (B | ix | iy)
181171

182172
out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py)
183173
self.n_cells * self.patch_shape[0] * self.patch_shape[1]))

0 commit comments

Comments
 (0)