Skip to content

Commit a9cf886

Browse files
authored
Flexible batch size (#142)
* Modify reset method to accept batch_size parameter for flexible test set size * Modify reset method to accept batch_size parameter * Refactor RateCell class reset function for flexible batch size * Refactor GaussianErrorCell class functions for flexible batch size * flexible batch_size
1 parent e9e6c75 commit a9cf886

5 files changed

Lines changed: 17 additions & 19 deletions

File tree

ngclearn/components/input_encoders/ganglionCell.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ def __init__(self, name: str,
131131
self.n_cells = n_cells
132132
self.sigma = sigma
133133

134-
self.batch_size = batch_size
135134
self.area_shape = area_shape
136135
self.patch_shape = patch_shape
137136
self.step_shape = step_shape
@@ -144,10 +143,10 @@ def __init__(self, name: str,
144143
filter = create_dog_filter(patch_shape=self.patch_shape, sigma=sigma)
145144

146145
# ═════════════════ compartments initial values ════════════════════
147-
in_restVals = jnp.zeros((self.batch_size,
146+
in_restVals = jnp.zeros((batch_size,
148147
*self.area_shape)) ## input: (B | ix | iy)
149148

150-
out_restVals = jnp.zeros((self.batch_size, ## output.shape: (B | n_cells * px * py)
149+
out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py)
151150
self.n_cells * self.patch_shape[0] * self.patch_shape[1]))
152151

153152
# ═══════════════════ set compartments ══════════════════════
@@ -176,11 +175,11 @@ def advance_state(self, t):
176175
self.outputs.set(outputs)
177176

178177
@compilable
179-
def reset(self):
180-
in_restVals = jnp.zeros((self.batch_size,
178+
def reset(self, batch_size):
179+
in_restVals = jnp.zeros((batch_size,
181180
*self.area_shape)) ## input: (B | ix | iy)
182181

183-
out_restVals = jnp.zeros((self.batch_size, ## output.shape: (B | n_cells * px * py)
182+
out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py)
184183
self.n_cells * self.patch_shape[0] * self.patch_shape[1]))
185184

186185
self.inputs.set(in_restVals)

ngclearn/components/neurons/graded/gaussianErrorCell.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs):
4848
self.sigma_shape = sigma_shape
4949
self.shape = shape
5050
self.n_units = n_units
51-
self.batch_size = batch_size
5251

5352
## Convolution shape setup
5453
self.width = self.height = n_units
@@ -108,10 +107,10 @@ def advance_state(self, dt): ## compute Gaussian error cell output
108107
# @transition(output_compartments=["dmu", "dtarget", "dSigma", "target", "mu", "modulator", "L", "mask"])
109108
# @staticmethod
110109
@compilable
111-
def reset(self): ## reset core components/statistics
112-
_shape = (self.batch_size, self.shape[0])
110+
def reset(self, batch_size): ## reset core components/statistics
111+
_shape = (batch_size, self.shape[0])
113112
if len(self.shape) > 1:
114-
_shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2])
113+
_shape = (batch_size, self.shape[0], self.shape[1], self.shape[2])
115114
restVals = jnp.zeros(_shape)
116115
dmu = restVals
117116
dtarget = restVals

ngclearn/components/neurons/graded/rateCell.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,10 @@ def advance_state(self, dt):
252252
self.zF.set(zF)
253253

254254
@compilable
255-
def reset(self): #, batch_size, shape): #n_units
256-
_shape = (self.batch_size, self.shape[0])
255+
def reset(self, batch_size):
256+
_shape = (batch_size, self.shape[0])
257257
if len(self.shape) > 1:
258-
_shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2])
258+
_shape = (batch_size, self.shape[0], self.shape[1], self.shape[2])
259259
restVals = jnp.zeros(_shape)
260260
self.j.set(restVals)
261261
self.j_td.set(restVals)

ngclearn/components/synapses/patched/hebbianPatchedSynapse.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,9 @@ def evolve(self):
281281
self.dBiases.set(dBiases)
282282

283283
@compilable
284-
def reset(self):
285-
preVals = jnp.zeros((self.batch_size, self.shape[0]))
286-
postVals = jnp.zeros((self.batch_size, self.shape[1]))
284+
def reset(self, batch_size):
285+
preVals = jnp.zeros((batch_size, self.shape[0]))
286+
postVals = jnp.zeros((batch_size, self.shape[1]))
287287
# BUG: the self.inputs here does not have the targeted field
288288
# NOTE: Quick workaround is to check if targeted is in the input or not
289289
hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(preVals) # inputs

ngclearn/components/synapses/patched/patchedSynapse.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,9 @@ def advance_state(self):
167167
self.pre_out.set(pre_out)
168168

169169
@compilable
170-
def reset(self):
171-
preVals = jnp.zeros((self.batch_size, self.shape[0]))
172-
postVals = jnp.zeros((self.batch_size, self.shape[1]))
170+
def reset(self, batch_size):
171+
preVals = jnp.zeros((batch_size, self.shape[0]))
172+
postVals = jnp.zeros((batch_size, self.shape[1]))
173173

174174
# BUG: the self.inputs here does not have the targeted field
175175
# NOTE: Quick workaround is to check if targeted is in the input or not

0 commit comments

Comments
 (0)