88def _create_gaussian_filter (patch_shape , sigma ):
99 ## Create a 2D Gaussian kernel centered on patch_shape with given sigma.
1010 px , py = patch_shape
11-
1211 x_ = jnp .linspace (0 , px - 1 , px )
1312 y_ = jnp .linspace (0 , py - 1 , py )
14-
1513 x , y = jnp .meshgrid (x_ , y_ )
16-
1714 xc = px // 2
1815 yc = py // 2
19-
20- filter = jnp .exp (- ((x - xc ) ** 2 + (y - yc ) ** 2 ) / (2 * (sigma ** 2 )))
21- return filter / jnp .sum (filter )
16+ _filter = jnp .exp (- ((x - xc ) ** 2 + (y - yc ) ** 2 ) / (2 * (sigma ** 2 )))
17+ return _filter / jnp .sum (_filter )
2218
2319def _create_dog_filter (patch_shape , sigma , k = 1.6 , lmbda = 1 ):
2420 g1 = _create_gaussian_filter (patch_shape , sigma = sigma )
2521 g2 = _create_gaussian_filter (patch_shape , sigma = sigma * k )
26-
2722 dog = g1 - lmbda * g2
28-
2923 return dog #- jnp.mean(dog)
3024
25+
26+ def _create_ratio_of_gauss_filter (patch_shape , sigma , k = 1.6 ):
27+ g1 = _create_gaussian_filter (patch_shape , sigma = sigma )
28+ g2 = _create_gaussian_filter (patch_shape , sigma = sigma * k )
29+ rog = g1 / (g2 + 1e-8 )
30+ return rog
31+
32+
3133def _create_patches (obs , patch_shape , step_shape ):
3234 """
3335 Extract 2D patches from a batch of images using a sliding window.
@@ -67,6 +69,29 @@ def _create_patches(obs, patch_shape, step_shape):
6769
6870 return patches
6971
72+ def _reconstruct (patches , nx_ny , area_shape , patch_shape , step_shape ):
73+ # patches: (N, nx * ny, px, py)
74+
75+ B = len (patches )
76+ nx , ny = nx_ny
77+ ix , iy = area_shape
78+ px , py = patch_shape
79+ sx , sy = step_shape
80+ x = jnp .zeros ((B , ix , iy ))
81+ counts = jnp .zeros ((ix , iy ))
82+
83+ idx = 0
84+ for i in range (ny ):
85+ for j in range (nx ):
86+ di = i * sx
87+ dj = j * sy
88+ x = x .at [:, di :di + px , dj :dj + py ].add (patches [:, idx ])
89+ counts = counts .at [di :di + px , dj :dj + py ].add (1.0 )
90+ idx += 1
91+
92+ return x / counts [None , :, :]
93+
94+
7095
7196class RetinalGanglionCell (JaxComponent ):
7297 """
@@ -126,23 +151,30 @@ def __init__(
126151 self .patch_shape = patch_shape
127152 self .step_shape = step_shape
128153
129- filter = jnp .ones (self .patch_shape )
154+ _filter = jnp .ones (self .patch_shape )
130155
131- if filter_type == 'gaussian' :
132- filter = _create_gaussian_filter (patch_shape = self .patch_shape , sigma = self .sigma )
133- elif filter_type == 'difference_of_gaussian' :
134- filter = _create_dog_filter (patch_shape = self .patch_shape , sigma = sigma )
156+ if self .filter_type == 'gaussian' :
157+ print ("filter type is " , self .filter_type )
158+ _filter = _create_gaussian_filter (patch_shape = self .patch_shape , sigma = self .sigma )
159+
160+ elif self .filter_type in ["difference_of_gaussian" , "DoG" ]:
161+ print ("filter type is difference of gaussian: f(x) = p1 - p2" )
162+ _filter = _create_dog_filter (patch_shape = self .patch_shape , sigma = sigma )
163+
164+ elif self .filter_type in ["ratio_of_gaussian" , "RoG" ]:
165+ print ("filter type is ratio of gaussian: f(x) = p1 / p2" )
166+ _filter = _create_ratio_of_gauss_filter (patch_shape = self .patch_shape , sigma = sigma )
135167
136168 # ═════════════════ compartments initial values ════════════════════
137- in_restVals = jnp .zeros ((batch_size ,
138- * self .area_shape )) ## input: (B | ix | iy)
169+ in_restVals = jnp .zeros ((batch_size , * self .area_shape )) ## input: (B | ix | iy)
139170
140- out_restVals = jnp .zeros ((batch_size , ## output.shape: (B | n_cells * px * py)
141- self .n_cells * self .patch_shape [0 ] * self .patch_shape [1 ]))
171+ out_restVals = jnp .zeros (
172+ (batch_size , self .n_cells * self .patch_shape [0 ] * self .patch_shape [1 ])
173+ ) ## output.shape: (B | n_cells * px * py)
142174
143175 # ═══════════════════ set compartments ══════════════════════
144176 self .inputs = Compartment (in_restVals , display_name = "Input Stimulus" ) # input compartment
145- self .filter = Compartment (filter , display_name = "Filter" ) # Filter compartment
177+ self .filter = Compartment (_filter , display_name = "Filter" ) # Filter compartment
146178 self .outputs = Compartment (out_restVals , display_name = "Output Signal" ) # output compartment
147179
148180 @compilable
@@ -165,27 +197,16 @@ def advance_state(self, t):
165197
166198 self .outputs .set (outputs )
167199
200+
201+
168202 @compilable
169203 def reset (self ): ## reset core components/statistics
170- # self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member
171204 in_restVals = jnp .zeros ((self .batch_size , * self .area_shape )) ## input: (B | ix | iy)
172205 out_restVals = jnp .zeros ((self .batch_size , ## output.shape: (B | n_cells * px * py)
173206 self .n_cells * self .patch_shape [0 ] * self .patch_shape [1 ]))
174207 self .inputs .set (in_restVals )
175208 self .outputs .set (out_restVals )
176209
177- # Viet: NOTE: we should not need this function since the reset function
178- # one could set the batch size then do reset
179- # @compilable
180- # def batched_reset(self, batch_size):
181- # in_restVals = jnp.zeros((batch_size, *self.area_shape)) ## input: (B | ix | iy)
182-
183- # out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py)
184- # self.n_cells * self.patch_shape[0] * self.patch_shape[1]))
185-
186- # self.inputs.set(in_restVals)
187- # self.outputs.set(out_restVals)
188-
189210 @classmethod
190211 def help (cls ): ## component help function
191212 properties = {
0 commit comments