Skip to content

Commit 86f09e0

Browse files
authored
Merge pull request #155 from NACLab/main
Nudge ngclearn to v3.1.1 minor update
2 parents a4ef0c4 + ad0cc85 commit 86f09e0

32 files changed

Lines changed: 1795 additions & 652 deletions

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ $ python install -e .
119119
</pre>
120120

121121
**Version:**<br>
122-
3.1.0 <!--1.2.3-Beta--> <!-- -Alpha -->
122+
3.1.1 <!--1.2.3-Beta--> <!-- -Alpha -->
123123

124124
Author:
125125
Alexander G. Ororbia II<br>

ngclearn/components/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
from .input_encoders.ganglionCell import RetinalGanglionCell
3131
from .input_encoders.latencyCell import LatencyCell
3232
from .input_encoders.phasorCell import PhasorCell
33+
#from .input_encoders.populationCoderCell import PopulationCoderCell
34+
#from .input_encoders.gridCell import GridCell
35+
#from .input_encoders.placeCell import PlaceCell
3336

3437
## point to synapse component types
3538
from .synapses.denseSynapse import DenseSynapse
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from .bernoulliCell import BernoulliCell
22
from .poissonCell import PoissonCell
33
from .latencyCell import LatencyCell
4-
from .ganglionCell import RetinalGanglionCell
54
from .phasorCell import PhasorCell
6-
5+
from .ganglionCell import RetinalGanglionCell
6+
#from .populationCoderCell import PopulationCoderCell
7+
#from .gridCell import GridCell
8+
#from .placeCell import PlaceCell
79

ngclearn/components/input_encoders/bernoulliCell.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@ class BernoulliCell(JaxComponent):
2727
"""
2828

2929
def __init__(
30-
self, name: str, n_units: int, batch_size: int = 1, key: Union[jax.Array, None] = None, **kwargs
30+
self,
31+
name: str,
32+
n_units: int,
33+
batch_size: int = 1,
34+
key: Union[jax.Array, None] = None,
35+
**kwargs
3136
):
3237
super().__init__(name=name, key=key)
3338

ngclearn/components/input_encoders/ganglionCell.py

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,28 @@
88
def _create_gaussian_filter(patch_shape, sigma):
99
## Create a 2D Gaussian kernel centered on patch_shape with given sigma.
1010
px, py = patch_shape
11-
1211
x_ = jnp.linspace(0, px - 1, px)
1312
y_ = jnp.linspace(0, py - 1, py)
14-
1513
x, y = jnp.meshgrid(x_, y_)
16-
1714
xc = px // 2
1815
yc = py // 2
19-
20-
filter = jnp.exp(-((x - xc) ** 2 + (y - yc) ** 2) / (2 * (sigma ** 2)))
21-
return filter / jnp.sum(filter)
16+
_filter = jnp.exp(-((x - xc) ** 2 + (y - yc) ** 2) / (2 * (sigma ** 2)))
17+
return _filter / jnp.sum(_filter)
2218

2319
def _create_dog_filter(patch_shape, sigma, k=1.6, lmbda=1):
2420
g1 = _create_gaussian_filter(patch_shape, sigma=sigma)
2521
g2 = _create_gaussian_filter(patch_shape, sigma=sigma * k)
26-
2722
dog = g1 - lmbda * g2
28-
2923
return dog #- jnp.mean(dog)
3024

25+
26+
def _create_ratio_of_gauss_filter(patch_shape, sigma, k=1.6):
27+
g1 = _create_gaussian_filter(patch_shape, sigma=sigma)
28+
g2 = _create_gaussian_filter(patch_shape, sigma=sigma * k)
29+
rog = g1 / (g2 + 1e-8)
30+
return rog
31+
32+
3133
def _create_patches(obs, patch_shape, step_shape):
3234
"""
3335
Extract 2D patches from a batch of images using a sliding window.
@@ -67,6 +69,29 @@ def _create_patches(obs, patch_shape, step_shape):
6769

6870
return patches
6971

72+
def _reconstruct(patches, nx_ny, area_shape, patch_shape, step_shape):
73+
# patches: (N, nx * ny, px, py)
74+
75+
B = len(patches)
76+
nx, ny = nx_ny
77+
ix, iy = area_shape
78+
px, py = patch_shape
79+
sx, sy = step_shape
80+
x = jnp.zeros((B, ix, iy))
81+
counts = jnp.zeros((ix, iy))
82+
83+
idx = 0
84+
for i in range(ny):
85+
for j in range(nx):
86+
di = i * sx
87+
dj = j * sy
88+
x = x.at[:, di:di + px, dj:dj + py].add(patches[:, idx])
89+
counts = counts.at[di:di + px, dj:dj + py].add(1.0)
90+
idx += 1
91+
92+
return x / counts[None, :, :]
93+
94+
7095

7196
class RetinalGanglionCell(JaxComponent):
7297
"""
@@ -126,23 +151,30 @@ def __init__(
126151
self.patch_shape = patch_shape
127152
self.step_shape = step_shape
128153

129-
filter = jnp.ones(self.patch_shape)
154+
_filter = jnp.ones(self.patch_shape)
130155

131-
if filter_type == 'gaussian':
132-
filter = _create_gaussian_filter(patch_shape=self.patch_shape, sigma=self.sigma)
133-
elif filter_type == 'difference_of_gaussian':
134-
filter = _create_dog_filter(patch_shape=self.patch_shape, sigma=sigma)
156+
if self.filter_type == 'gaussian':
157+
print("filter type is ", self.filter_type)
158+
_filter = _create_gaussian_filter(patch_shape=self.patch_shape, sigma=self.sigma)
159+
160+
elif self.filter_type in ["difference_of_gaussian", "DoG"]:
161+
print("filter type is difference of gaussian: f(x) = p1 - p2")
162+
_filter = _create_dog_filter(patch_shape=self.patch_shape, sigma=sigma)
163+
164+
elif self.filter_type in ["ratio_of_gaussian", "RoG"]:
165+
print("filter type is ratio of gaussian: f(x) = p1 / p2")
166+
_filter = _create_ratio_of_gauss_filter(patch_shape=self.patch_shape, sigma=sigma)
135167

136168
# ═════════════════ compartments initial values ════════════════════
137-
in_restVals = jnp.zeros((batch_size,
138-
*self.area_shape)) ## input: (B | ix | iy)
169+
in_restVals = jnp.zeros((batch_size, *self.area_shape)) ## input: (B | ix | iy)
139170

140-
out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py)
141-
self.n_cells * self.patch_shape[0] * self.patch_shape[1]))
171+
out_restVals = jnp.zeros(
172+
(batch_size, self.n_cells * self.patch_shape[0] * self.patch_shape[1])
173+
) ## output.shape: (B | n_cells * px * py)
142174

143175
# ═══════════════════ set compartments ══════════════════════
144176
self.inputs = Compartment(in_restVals, display_name="Input Stimulus") # input compartment
145-
self.filter = Compartment(filter, display_name="Filter") # Filter compartment
177+
self.filter = Compartment(_filter, display_name="Filter") # Filter compartment
146178
self.outputs = Compartment(out_restVals, display_name="Output Signal") # output compartment
147179

148180
@compilable
@@ -165,27 +197,16 @@ def advance_state(self, t):
165197

166198
self.outputs.set(outputs)
167199

200+
201+
168202
@compilable
169203
def reset(self): ## reset core components/statistics
170-
# self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member
171204
in_restVals = jnp.zeros((self.batch_size, *self.area_shape)) ## input: (B | ix | iy)
172205
out_restVals = jnp.zeros((self.batch_size, ## output.shape: (B | n_cells * px * py)
173206
self.n_cells * self.patch_shape[0] * self.patch_shape[1]))
174207
self.inputs.set(in_restVals)
175208
self.outputs.set(out_restVals)
176209

177-
# Viet: NOTE: we should not need this function since the reset function
178-
# one could set the batch size then do reset
179-
# @compilable
180-
# def batched_reset(self, batch_size):
181-
# in_restVals = jnp.zeros((batch_size, *self.area_shape)) ## input: (B | ix | iy)
182-
183-
# out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py)
184-
# self.n_cells * self.patch_shape[0] * self.patch_shape[1]))
185-
186-
# self.inputs.set(in_restVals)
187-
# self.outputs.set(out_restVals)
188-
189210
@classmethod
190211
def help(cls): ## component help function
191212
properties = {

ngclearn/components/input_encoders/poissonCell.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,13 @@ class PoissonCell(JaxComponent):
3232

3333
@deprecate_args(max_freq="target_freq")
3434
def __init__(
35-
self, name: str, n_units: int, target_freq: float = 63.75, batch_size: int = 1,
36-
key: Union[jax.Array, None] = None, **kwargs
35+
self,
36+
name: str,
37+
n_units: int,
38+
target_freq: float = 63.75,
39+
batch_size: int = 1,
40+
key: Union[jax.Array, None] = None,
41+
**kwargs
3742
):
3843
super().__init__(name=name, key=key)
3944

0 commit comments

Comments
 (0)