@@ -70,7 +70,10 @@ def _create_patches(obs, patch_shape, step_shape):
7070
7171class RetinalGanglionCell (JaxComponent ):
7272 """
73- A group of retinal ganglion cell that senses the input stimuli and sends out the filtered signal to the brain.
73+ A group of retinal ganglion cell that sense input stimuli and send out filtered
74+ signals (as output). Note that these simulated cells employ internal generalized
75+ filters based on either Gaussian or difference-of-Gaussian kernels) to recover
76+ historical receptive field processing effects.
7477
7578 | --- Cell Input Compartments: ---
7679 | inputs - input (takes in external signals)
@@ -85,32 +88,34 @@ class RetinalGanglionCell(JaxComponent):
8588 filter_type: string name of filter function (Default: identity)
8689 :Note: supported filters include "gaussian", "difference_of_gaussian"
8790
88- sigma: standard deviation of gaussian kernel
91+ sigma: standard deviation of ( gaussian) kernel
8992
90- area_shape: receptive field area of ganglion cells in this module all together
93+ area_shape: shape of receptive field area of ganglion cells in this module ( all together)
9194
9295 n_cells: number of ganglion cells in this module
9396
94- patch_shape: each ganglion cell receptive field area
97+ patch_shape: shape of each ganglion cell's receptive field area
9598
96- step_shape: the non-overlapping area between each two ganglion cells
99+ step_shape: the non-overlapping area between each pair ( two) of ganglion cells
97100
98- batch_size: batch size dimension of this cell (Default: 1)
101+ batch_size: batch size dimension of this cell/module (Default: 1)
99102 """
100103
101- def __init__ (self , name : str ,
102- filter_type : str ,
103- area_shape : Tuple [int , int ],
104- n_cells : int ,
105- patch_shape : Tuple [int , int ],
106- step_shape : Tuple [int , int ],
107- batch_size : int = 1 ,
108- sigma : float = 1.0 ,
109- key : Union [jax .Array , None ] = None ,
110- ** kwargs ):
104+ def __init__ (
105+ self ,
106+ name : str ,
107+ filter_type : str ,
108+ area_shape : Tuple [int , int ],
109+ n_cells : int ,
110+ patch_shape : Tuple [int , int ],
111+ step_shape : Tuple [int , int ],
112+ batch_size : int = 1 ,
113+ sigma : float = 1.0 ,
114+ key : Union [jax .Array , None ] = None ,
115+ ** kwargs
116+ ):
111117 super ().__init__ (name = name , key = key )
112118
113-
114119 ## Layer Size Setup
115120 self .filter_type = filter_type
116121 self .n_cells = n_cells
@@ -143,14 +148,14 @@ def __init__(self, name: str,
143148 @compilable
144149 def advance_state (self , t ):
145150 inputs = self .inputs .get ()
146- filter = self .filter .get ()
151+ _filter = self .filter .get ()
147152 px , py = self .patch_shape
148153
149154 # ═══════════════════ extract pathches for filters ══════════════════
150155 input_patches = _create_patches (inputs , patch_shape = self .patch_shape , step_shape = self .step_shape )
151156
152157 # ═══════════════════ apply filter to all pathches ══════════════════
153- filtered_input = input_patches * filter ## shape: (B | n_cells | px | py)
158+ filtered_input = input_patches * _filter ## shape: (B | n_cells | px | py)
154159
155160 # ════════════ reshape all cells responses to a single input to brain ════════════
156161 filtered_input = filtered_input .reshape (- 1 , self .n_cells * (px * py )) ## shape: (B | n_cells * px * py)
@@ -184,7 +189,7 @@ def reset(self): ## reset core components/statistics
184189 @classmethod
185190 def help (cls ): ## component help function
186191 properties = {
187- "cell_type" : "RetinalGanglionCell - filters the input stimuli, "
192+ "cell_type" : "RetinalGanglionCell - filters the input stimuli according retinal ganglion dynamics "
188193 }
189194 compartment_props = {
190195 "inputs" :
@@ -196,11 +201,11 @@ def help(cls): ## component help function
196201 }
197202 hyperparams = {
198203 "filter_type" : "Type of the filter for preprocessing the input" ,
199- "sigma" : "Standard deviation of gaussian kernel" ,
204+ "sigma" : "Standard deviation of gaussian kernel/filter " ,
200205 "area_shape" : "Effective receptive field area shape of ganglion cells in this module" ,
201- "n_cells" : "Number of Retinal Ganglion (center-surround) cells to model in this layer" ,
202- "patch_shape" : "Classical Receptive field area shape of individual ganglion cells in this module" ,
203- "step_shape" : "Extra-Classical Receptive field area shape each ganglion cell in this module" ,
206+ "n_cells" : "Number of retinal ganglion (center-surround) cells to model in this layer" ,
207+ "patch_shape" : "Classical receptive field area shape of individual ganglion cells in this module" ,
208+ "step_shape" : "Extra-classical receptive field area shape each ganglion cell in this module" ,
204209 "batch_size" : "Batch size dimension of this component"
205210 }
206211 info = {cls .__name__ : properties ,
@@ -212,13 +217,15 @@ def help(cls): ## component help function
212217if __name__ == '__main__' :
213218 from ngcsimlib .context import Context
214219 with Context ("Bar" ) as bar :
215- X = RetinalGanglionCell ("RGC" , filter_type = "gaussian" ,
216- sigma = 2.3 ,
217- area_shape = (16 , 26 ),
218- n_cells = 3 ,
219- patch_shape = (16 , 16 ),
220- step_shape = (0 , 5 )
221- )
220+ X = RetinalGanglionCell (
221+ "RGC" ,
222+ filter_type = "gaussian" ,
223+ sigma = 2.3 ,
224+ area_shape = (16 , 26 ),
225+ n_cells = 3 ,
226+ patch_shape = (16 , 16 ),
227+ step_shape = (0 , 5 )
228+ )
222229 print (X )
223230
224231
0 commit comments