55import jax
66from typing import Union , Tuple
77
8-
9-
10-
11- def create_gaussian_filter (patch_shape , sigma ):
8+ def _create_gaussian_filter (patch_shape , sigma ):
129 """
1310 Create a 2D Gaussian kernel centered on patch_shape with given sigma.
1411 """
@@ -25,23 +22,15 @@ def create_gaussian_filter(patch_shape, sigma):
2522 filter = jnp .exp (- ((x - xc ) ** 2 + (y - yc ) ** 2 ) / (2 * (sigma ** 2 )))
2623 return filter / jnp .sum (filter )
2724
28-
29-
30-
31-
32- def create_dog_filter (patch_shape , sigma , k = 1.6 , lmbda = 1 ):
33- g1 = create_gaussian_filter (patch_shape , sigma = sigma )
34- g2 = create_gaussian_filter (patch_shape , sigma = sigma * k )
25+ def _create_dog_filter (patch_shape , sigma , k = 1.6 , lmbda = 1 ):
26+ g1 = _create_gaussian_filter (patch_shape , sigma = sigma )
27+ g2 = _create_gaussian_filter (patch_shape , sigma = sigma * k )
3528
3629 dog = g1 - lmbda * g2
3730
3831 return dog #- jnp.mean(dog)
3932
40-
41-
42-
43-
44- def create_patches (obs , patch_shape , step_shape ):
33+ def _create_patches (obs , patch_shape , step_shape ):
4534 """
4635 Extract 2D patches from a batch of images using a sliding window.
4736
@@ -78,9 +67,6 @@ def create_patches(obs, patch_shape, step_shape):
7867 return patches
7968
8069
81-
82-
83-
8470class RetinalGanglionCell (JaxComponent ):
8571 """
8672 A groupd of retinal ganglion cell that senses the input
@@ -131,16 +117,17 @@ def __init__(self, name: str,
131117 self .n_cells = n_cells
132118 self .sigma = sigma
133119
120+ self .batch_size = batch_size
134121 self .area_shape = area_shape
135122 self .patch_shape = patch_shape
136123 self .step_shape = step_shape
137124
138125 filter = jnp .ones (self .patch_shape )
139126
140127 if filter_type == 'gaussian' :
141- filter = create_gaussian_filter (patch_shape = self .patch_shape , sigma = self .sigma )
128+ filter = _create_gaussian_filter (patch_shape = self .patch_shape , sigma = self .sigma )
142129 elif filter_type == 'difference_of_gaussian' :
143- filter = create_dog_filter (patch_shape = self .patch_shape , sigma = sigma )
130+ filter = _create_dog_filter (patch_shape = self .patch_shape , sigma = sigma )
144131
145132 # ═════════════════ compartments initial values ════════════════════
146133 in_restVals = jnp .zeros ((batch_size ,
@@ -161,7 +148,7 @@ def advance_state(self, t):
161148 px , py = self .patch_shape
162149
163150 # ═══════════════════ extract pathches for filters ══════════════════
164- input_patches = create_patches (inputs , patch_shape = self .patch_shape , step_shape = self .step_shape )
151+ input_patches = _create_patches (inputs , patch_shape = self .patch_shape , step_shape = self .step_shape )
165152
166153 # ═══════════════════ apply filter to all pathches ══════════════════
167154 filtered_input = input_patches * filter ## shape: (B | n_cells | px | py)
@@ -175,9 +162,12 @@ def advance_state(self, t):
175162 self .outputs .set (outputs )
176163
177164 @compilable
178- def reset (self , batch_size ):
179- in_restVals = jnp .zeros ((batch_size ,
180- * self .area_shape )) ## input: (B | ix | iy)
165+ def reset (self ): ## reset core components/statistics
166+ self .batched_reset (batch_size = self .batch_size ) ## arg = batch_size data-member
167+
168+ @compilable
169+ def batched_reset (self , batch_size ):
170+ in_restVals = jnp .zeros ((batch_size , * self .area_shape )) ## input: (B | ix | iy)
181171
182172 out_restVals = jnp .zeros ((batch_size , ## output.shape: (B | n_cells * px * py)
183173 self .n_cells * self .patch_shape [0 ] * self .patch_shape [1 ]))
0 commit comments