Skip to content

Commit d527ab6

Browse files
authored
Refactor ganglion cell (#151)
* Add patch reconstruction function to ganglionCell.py Added a new function to reconstruct patches from input data, improving data handling in the RetinalGanglionCell class. * Add ratio of Gaussian filter function
1 parent abe7dfa commit d527ab6

1 file changed

Lines changed: 43 additions & 15 deletions

File tree

ngclearn/components/input_encoders/ganglionCell.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ def _create_dog_filter(patch_shape, sigma, k=1.6, lmbda=1):
2222
dog = g1 - lmbda * g2
2323
return dog #- jnp.mean(dog)
2424

25+
26+
def _create_ratio_of_gauss_filter(patch_shape, sigma, k=1.6):
27+
g1 = _create_gaussian_filter(patch_shape, sigma=sigma)
28+
g2 = _create_gaussian_filter(patch_shape, sigma=sigma * k)
29+
rog = g1 / (g2 + 1e-8)
30+
return rog
31+
32+
2533
def _create_patches(obs, patch_shape, step_shape):
2634
"""
2735
Extract 2D patches from a batch of images using a sliding window.
@@ -61,6 +69,29 @@ def _create_patches(obs, patch_shape, step_shape):
6169

6270
return patches
6371

72+
def _reconstruct(patches, nx_ny, area_shape, patch_shape, step_shape):
73+
# patches: (N, nx * ny, px, py)
74+
75+
B = len(patches)
76+
nx, ny = nx_ny
77+
ix, iy = area_shape
78+
px, py = patch_shape
79+
sx, sy = step_shape
80+
x = jnp.zeros((B, ix, iy))
81+
counts = jnp.zeros((ix, iy))
82+
83+
idx = 0
84+
for i in range(ny):
85+
for j in range(nx):
86+
di = i * sx
87+
dj = j * sy
88+
x = x.at[:, di:di + px, dj:dj + py].add(patches[:, idx])
89+
counts = counts.at[di:di + px, dj:dj + py].add(1.0)
90+
idx += 1
91+
92+
return x / counts[None, :, :]
93+
94+
6495

6596
class RetinalGanglionCell(JaxComponent):
6697
"""
@@ -121,10 +152,18 @@ def __init__(
121152
self.step_shape = step_shape
122153

123154
_filter = jnp.ones(self.patch_shape)
124-
if filter_type == 'gaussian':
155+
156+
if self.filter_type == 'gaussian':
157+
print("filter type is ", self.filter_type)
125158
_filter = _create_gaussian_filter(patch_shape=self.patch_shape, sigma=self.sigma)
126-
elif filter_type == 'difference_of_gaussian':
159+
160+
elif self.filter_type in ["difference_of_gaussian", "DoG"]:
161+
print("filter type is difference of gaussian: f(x) = p1 - p2")
127162
_filter = _create_dog_filter(patch_shape=self.patch_shape, sigma=sigma)
163+
164+
elif self.filter_type in ["ratio_of_gaussian", "RoG"]:
165+
print("filter type is ratio of gaussian: f(x) = p1 / p2")
166+
_filter = _create_ratio_of_gauss_filter(patch_shape=self.patch_shape, sigma=sigma)
128167

129168
# ═════════════════ compartments initial values ════════════════════
130169
in_restVals = jnp.zeros((batch_size, *self.area_shape)) ## input: (B | ix | iy)
@@ -158,27 +197,16 @@ def advance_state(self, t):
158197

159198
self.outputs.set(outputs)
160199

200+
201+
161202
@compilable
162203
def reset(self): ## reset core components/statistics
163-
# self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member
164204
in_restVals = jnp.zeros((self.batch_size, *self.area_shape)) ## input: (B | ix | iy)
165205
out_restVals = jnp.zeros((self.batch_size, ## output.shape: (B | n_cells * px * py)
166206
self.n_cells * self.patch_shape[0] * self.patch_shape[1]))
167207
self.inputs.set(in_restVals)
168208
self.outputs.set(out_restVals)
169209

170-
# Viet: NOTE: we should not need this function since the reset function
171-
# one could set the batch size then do reset
172-
# @compilable
173-
# def batched_reset(self, batch_size):
174-
# in_restVals = jnp.zeros((batch_size, *self.area_shape)) ## input: (B | ix | iy)
175-
176-
# out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py)
177-
# self.n_cells * self.patch_shape[0] * self.patch_shape[1]))
178-
179-
# self.inputs.set(in_restVals)
180-
# self.outputs.set(out_restVals)
181-
182210
@classmethod
183211
def help(cls): ## component help function
184212
properties = {

0 commit comments

Comments
 (0)