Skip to content

Commit 81dcb6e

Browse files
author
Alexander Ororbia
committed
cleanup of ganglion cell
1 parent 496cf1f commit 81dcb6e

1 file changed

Lines changed: 38 additions & 31 deletions

File tree

ngclearn/components/input_encoders/ganglionCell.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ def _create_patches(obs, patch_shape, step_shape):
7070

7171
class RetinalGanglionCell(JaxComponent):
7272
"""
73-
A group of retinal ganglion cell that senses the input stimuli and sends out the filtered signal to the brain.
73+
A group of retinal ganglion cell that sense input stimuli and send out filtered
74+
signals (as output). Note that these simulated cells employ internal generalized
75+
filters based on either Gaussian or difference-of-Gaussian kernels) to recover
76+
historical receptive field processing effects.
7477
7578
| --- Cell Input Compartments: ---
7679
| inputs - input (takes in external signals)
@@ -85,32 +88,34 @@ class RetinalGanglionCell(JaxComponent):
8588
filter_type: string name of filter function (Default: identity)
8689
:Note: supported filters include "gaussian", "difference_of_gaussian"
8790
88-
sigma: standard deviation of gaussian kernel
91+
sigma: standard deviation of (gaussian) kernel
8992
90-
area_shape: receptive field area of ganglion cells in this module all together
93+
area_shape: shape of receptive field area of ganglion cells in this module (all together)
9194
9295
n_cells: number of ganglion cells in this module
9396
94-
patch_shape: each ganglion cell receptive field area
97+
patch_shape: shape of each ganglion cell's receptive field area
9598
96-
step_shape: the non-overlapping area between each two ganglion cells
99+
step_shape: the non-overlapping area between each pair (two) of ganglion cells
97100
98-
batch_size: batch size dimension of this cell (Default: 1)
101+
batch_size: batch size dimension of this cell/module (Default: 1)
99102
"""
100103

101-
def __init__(self, name: str,
102-
filter_type: str,
103-
area_shape: Tuple[int, int],
104-
n_cells: int,
105-
patch_shape: Tuple[int, int],
106-
step_shape: Tuple[int, int],
107-
batch_size: int = 1,
108-
sigma: float = 1.0,
109-
key: Union[jax.Array, None] = None,
110-
**kwargs):
104+
def __init__(
105+
self,
106+
name: str,
107+
filter_type: str,
108+
area_shape: Tuple[int, int],
109+
n_cells: int,
110+
patch_shape: Tuple[int, int],
111+
step_shape: Tuple[int, int],
112+
batch_size: int = 1,
113+
sigma: float = 1.0,
114+
key: Union[jax.Array, None] = None,
115+
**kwargs
116+
):
111117
super().__init__(name=name, key=key)
112118

113-
114119
## Layer Size Setup
115120
self.filter_type = filter_type
116121
self.n_cells = n_cells
@@ -143,14 +148,14 @@ def __init__(self, name: str,
143148
@compilable
144149
def advance_state(self, t):
145150
inputs = self.inputs.get()
146-
filter = self.filter.get()
151+
_filter = self.filter.get()
147152
px, py = self.patch_shape
148153

149154
# ═══════════════════ extract pathches for filters ══════════════════
150155
input_patches = _create_patches(inputs, patch_shape=self.patch_shape, step_shape=self.step_shape)
151156

152157
# ═══════════════════ apply filter to all pathches ══════════════════
153-
filtered_input = input_patches * filter ## shape: (B | n_cells | px | py)
158+
filtered_input = input_patches * _filter ## shape: (B | n_cells | px | py)
154159

155160
# ════════════ reshape all cells responses to a single input to brain ════════════
156161
filtered_input = filtered_input.reshape(-1, self.n_cells * (px * py)) ## shape: (B | n_cells * px * py)
@@ -184,7 +189,7 @@ def reset(self): ## reset core components/statistics
184189
@classmethod
185190
def help(cls): ## component help function
186191
properties = {
187-
"cell_type": "RetinalGanglionCell - filters the input stimuli, "
192+
"cell_type": "RetinalGanglionCell - filters the input stimuli according retinal ganglion dynamics"
188193
}
189194
compartment_props = {
190195
"inputs":
@@ -196,11 +201,11 @@ def help(cls): ## component help function
196201
}
197202
hyperparams = {
198203
"filter_type": "Type of the filter for preprocessing the input",
199-
"sigma": "Standard deviation of gaussian kernel",
204+
"sigma": "Standard deviation of gaussian kernel/filter",
200205
"area_shape": "Effective receptive field area shape of ganglion cells in this module",
201-
"n_cells": "Number of Retinal Ganglion (center-surround) cells to model in this layer",
202-
"patch_shape": "Classical Receptive field area shape of individual ganglion cells in this module",
203-
"step_shape": "Extra-Classical Receptive field area shape each ganglion cell in this module",
206+
"n_cells": "Number of retinal ganglion (center-surround) cells to model in this layer",
207+
"patch_shape": "Classical receptive field area shape of individual ganglion cells in this module",
208+
"step_shape": "Extra-classical receptive field area shape each ganglion cell in this module",
204209
"batch_size": "Batch size dimension of this component"
205210
}
206211
info = {cls.__name__: properties,
@@ -212,13 +217,15 @@ def help(cls): ## component help function
212217
if __name__ == '__main__':
213218
from ngcsimlib.context import Context
214219
with Context("Bar") as bar:
215-
X = RetinalGanglionCell("RGC", filter_type="gaussian",
216-
sigma=2.3,
217-
area_shape=(16, 26),
218-
n_cells = 3,
219-
patch_shape=(16, 16),
220-
step_shape=(0, 5)
221-
)
220+
X = RetinalGanglionCell(
221+
"RGC",
222+
filter_type="gaussian",
223+
sigma=2.3,
224+
area_shape=(16, 26),
225+
n_cells = 3,
226+
patch_shape=(16, 16),
227+
step_shape=(0, 5)
228+
)
222229
print(X)
223230

224231

0 commit comments

Comments
 (0)