Skip to content

Commit 36c56a6

Browse files
author
Alexander Ororbia
committed
minor patches to components, including hebb-syn/conv/deconv and reward-cell
1 parent c01a619 commit 36c56a6

7 files changed

Lines changed: 15 additions & 59 deletions

File tree

ngclearn/components/neurons/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .graded.bernoulliErrorCell import BernoulliErrorCell
66
from .graded.rewardErrorCell import RewardErrorCell
77
## point to standard spiking cell component types
8-
from .spiking.sLIFCell import SLIFCell
8+
#from .spiking.sLIFCell import SLIFCell
99
from .spiking.IFCell import IFCell
1010
from .spiking.LIFCell import LIFCell
1111
from .spiking.WTASCell import WTASCell

ngclearn/components/neurons/graded/rateCell.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -317,20 +317,6 @@ def help(cls): ## component help function
317317
"hyperparameters": hyperparams}
318318
return info
319319

320-
def __repr__(self):
321-
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
322-
maxlen = max(len(c) for c in comps) + 5
323-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
324-
for c in comps:
325-
stats = tensorstats(getattr(self, c).get())
326-
if stats is not None:
327-
line = [f"{k}: {v}" for k, v in stats.items()]
328-
line = ", ".join(line)
329-
else:
330-
line = "None"
331-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
332-
return lines
333-
334320
if __name__ == '__main__':
335321
from ngcsimlib.context import Context
336322
with Context("Bar") as bar:

ngclearn/components/neurons/graded/rewardErrorCell.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -134,20 +134,6 @@ def help(cls): ## component help function
134134
"hyperparameters": hyperparams}
135135
return info
136136

137-
def __repr__(self):
138-
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
139-
maxlen = max(len(c) for c in comps) + 5
140-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
141-
for c in comps:
142-
stats = tensorstats(getattr(self, c).get())
143-
if stats is not None:
144-
line = [f"{k}: {v}" for k, v in stats.items()]
145-
line = ", ".join(line)
146-
else:
147-
line = "None"
148-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
149-
return lines
150-
151137
if __name__ == '__main__':
152138
from ngcsimlib.context import Context
153139
with Context("Bar") as bar:

ngclearn/components/synapses/convolution/hebbianConvSynapse.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,8 @@ def __init__(
127127
########################################################################
128128

129129
## set up outer optimization compartments
130-
self.opt_params = Compartment(get_opt_init_fn(optim_type)(
131-
[self.weights.get(), self.biases.get()]
132-
if bias_init else [self.weights.get()])
130+
self.opt_params = Compartment(
131+
get_opt_init_fn(optim_type)([self.weights.get(), self.biases.get()] if bias_init else [self.weights.get()])
133132
)
134133

135134
def _init(self, batch_size, x_size, shape, stride, padding, pad_args, weights):

ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,8 @@ def __init__(
115115
########################################################################
116116

117117
## set up outer optimization compartments
118-
self.opt_params = Compartment(get_opt_init_fn(optim_type)(
119-
[self.weights.get(), self.biases.get()]
120-
if bias_init else [self.weights.get()])
118+
self.opt_params = Compartment(
119+
get_opt_init_fn(optim_type)([self.weights.get(), self.biases.get()] if bias_init else [self.weights.get()])
121120
)
122121

123122
def _init(self, batch_size, x_size, shape, stride, padding, pad_args, weights):

ngclearn/components/synapses/hebbian/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#from .hebbianSynapse import HebbianSynapse
1+
from .hebbianSynapse import HebbianSynapse
22
from .traceSTDPSynapse import TraceSTDPSynapse
33
from .expSTDPSynapse import ExpSTDPSynapse
44
from .eventSTDPSynapse import EventSTDPSynapse

ngclearn/components/synapses/hebbian/hebbianSynapse.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,7 @@ def __init__(
172172
prior=("constant", 0.), w_decay=0., sign_value=1., optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1.,
173173
resist_scale=1., batch_size=1, **kwargs
174174
):
175-
super().__init__(name, shape, weight_init, bias_init, resist_scale,
176-
p_conn, batch_size=batch_size, **kwargs)
175+
super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, batch_size=batch_size, **kwargs)
177176

178177
if w_decay > 0.:
179178
prior = ('l2', w_decay)
@@ -209,13 +208,14 @@ def __init__(
209208
self.dBiases = Compartment(jnp.zeros(shape[1]))
210209

211210
#key, subkey = random.split(self.key.value)
212-
self.opt_params = Compartment(get_opt_init_fn(optim_type)(
213-
[self.weights.get(), self.biases.get()]
214-
if bias_init else [self.weights.get()]))
211+
self.opt_params = Compartment(
212+
get_opt_init_fn(optim_type)([self.weights.get(), self.biases.get()] if bias_init else [self.weights.get()])
213+
)
215214

216215
@staticmethod
217-
def _compute_update(w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
218-
post_wght, pre, post, weights):
216+
def _compute_update(
217+
w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, post_wght, pre, post, weights
218+
):
219219
## calculate synaptic update values
220220
dW, db = _calc_update(
221221
pre, post, weights, w_bound, is_nonnegative=is_nonnegative,
@@ -257,8 +257,8 @@ def evolve(self):
257257
def reset(self): #, batch_size, shape):
258258
preVals = jnp.zeros((self.batch_size, self.shape[0]))
259259
postVals = jnp.zeros((self.batch_size, self.shape[1]))
260-
#not self.inputs.targeted and self.inputs.set(preVals) # inputs
261-
self.inputs.set(preVals)
260+
if not self.inputs.targeted:
261+
self.inputs.set(preVals)
262262
self.outputs.set(postVals) # outputs
263263
self.pre.set(preVals) # pre
264264
self.post.set(postVals) # post
@@ -310,20 +310,6 @@ def help(cls): ## component help function
310310
"hyperparameters": hyperparams}
311311
return info
312312

313-
def __repr__(self):
314-
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
315-
maxlen = max(len(c) for c in comps) + 5
316-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
317-
for c in comps:
318-
stats = tensorstats(getattr(self, c).get())
319-
if stats is not None:
320-
line = [f"{k}: {v}" for k, v in stats.items()]
321-
line = ", ".join(line)
322-
else:
323-
line = "None"
324-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
325-
return lines
326-
327313
if __name__ == '__main__':
328314
from ngcsimlib.context import Context
329315
with Context("Bar") as bar:

0 commit comments

Comments
 (0)