@@ -22,6 +22,14 @@ def _create_dog_filter(patch_shape, sigma, k=1.6, lmbda=1):
2222 dog = g1 - lmbda * g2
2323 return dog #- jnp.mean(dog)
2424
25+
26+ def _create_ratio_of_gauss_filter (patch_shape , sigma , k = 1.6 ):
27+ g1 = _create_gaussian_filter (patch_shape , sigma = sigma )
28+ g2 = _create_gaussian_filter (patch_shape , sigma = sigma * k )
29+ rog = g1 / (g2 + 1e-8 )
30+ return rog
31+
32+
2533def _create_patches (obs , patch_shape , step_shape ):
2634 """
2735 Extract 2D patches from a batch of images using a sliding window.
@@ -61,6 +69,29 @@ def _create_patches(obs, patch_shape, step_shape):
6169
6270 return patches
6371
72+ def _reconstruct (patches , nx_ny , area_shape , patch_shape , step_shape ):
73+ # patches: (N, nx * ny, px, py)
74+
75+ B = len (patches )
76+ nx , ny = nx_ny
77+ ix , iy = area_shape
78+ px , py = patch_shape
79+ sx , sy = step_shape
80+ x = jnp .zeros ((B , ix , iy ))
81+ counts = jnp .zeros ((ix , iy ))
82+
83+ idx = 0
84+ for i in range (ny ):
85+ for j in range (nx ):
86+ di = i * sx
87+ dj = j * sy
88+ x = x .at [:, di :di + px , dj :dj + py ].add (patches [:, idx ])
89+ counts = counts .at [di :di + px , dj :dj + py ].add (1.0 )
90+ idx += 1
91+
92+ return x / counts [None , :, :]
93+
94+
6495
6596class RetinalGanglionCell (JaxComponent ):
6697 """
@@ -121,10 +152,18 @@ def __init__(
121152 self .step_shape = step_shape
122153
123154 _filter = jnp .ones (self .patch_shape )
124- if filter_type == 'gaussian' :
155+
156+ if self .filter_type == 'gaussian' :
157+ print ("filter type is " , self .filter_type )
125158 _filter = _create_gaussian_filter (patch_shape = self .patch_shape , sigma = self .sigma )
126- elif filter_type == 'difference_of_gaussian' :
159+
160+ elif self .filter_type in ["difference_of_gaussian" , "DoG" ]:
161+ print ("filter type is difference of gaussian: f(x) = p1 - p2" )
127162 _filter = _create_dog_filter (patch_shape = self .patch_shape , sigma = sigma )
163+
164+ elif self .filter_type in ["ratio_of_gaussian" , "RoG" ]:
165+ print ("filter type is ratio of gaussian: f(x) = p1 / p2" )
166+ _filter = _create_ratio_of_gauss_filter (patch_shape = self .patch_shape , sigma = sigma )
128167
129168 # ═════════════════ compartments initial values ════════════════════
130169 in_restVals = jnp .zeros ((batch_size , * self .area_shape )) ## input: (B | ix | iy)
@@ -158,27 +197,16 @@ def advance_state(self, t):
158197
159198 self .outputs .set (outputs )
160199
200+
201+
161202 @compilable
162203 def reset (self ): ## reset core components/statistics
163- # self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member
164204 in_restVals = jnp .zeros ((self .batch_size , * self .area_shape )) ## input: (B | ix | iy)
165205 out_restVals = jnp .zeros ((self .batch_size , ## output.shape: (B | n_cells * px * py)
166206 self .n_cells * self .patch_shape [0 ] * self .patch_shape [1 ]))
167207 self .inputs .set (in_restVals )
168208 self .outputs .set (out_restVals )
169209
170- # Viet: NOTE: we should not need this function since the reset function
171- # one could set the batch size then do reset
172- # @compilable
173- # def batched_reset(self, batch_size):
174- # in_restVals = jnp.zeros((batch_size, *self.area_shape)) ## input: (B | ix | iy)
175-
176- # out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py)
177- # self.n_cells * self.patch_shape[0] * self.patch_shape[1]))
178-
179- # self.inputs.set(in_restVals)
180- # self.outputs.set(out_restVals)
181-
182210 @classmethod
183211 def help (cls ): ## component help function
184212 properties = {
0 commit comments