From f8395e8f0dbbf315b128f1493f29ea9d01ed356a Mon Sep 17 00:00:00 2001 From: Will Gebhardt Date: Thu, 24 Jul 2025 11:35:41 -0400 Subject: [PATCH 001/121] Working v3 --- ngclearn/__init__.py | 47 ++-- ngclearn/commands/__init__.py | 1 - ngclearn/components/__init__.py | 130 +++++----- .../input_encoders/bernoulliCell.py | 92 +++---- .../components/input_encoders/latencyCell.py | 149 +++++------ .../components/input_encoders/phasorCell.py | 128 +++------- .../components/input_encoders/poissonCell.py | 103 ++------ ngclearn/components/jaxComponent.py | 49 +++- ngclearn/components/neurons/__init__.py | 34 +-- .../components/neurons/graded/rateCell.py | 1 - .../components/neurons/spiking/LIFCell.py | 234 ++++++------------ .../components/neurons/spiking/__init__.py | 22 +- ngclearn/components/other/__init__.py | 6 +- ngclearn/components/other/varTrace.py | 91 +++---- ngclearn/components/synapses/__init__.py | 76 +++--- ngclearn/components/synapses/denseSynapse.py | 72 ++---- .../components/synapses/hebbian/__init__.py | 12 +- .../synapses/hebbian/traceSTDPSynapse.py | 165 +++++------- ngclearn/utils/jaxProcess.py | 4 +- ngclearn/utils/model_utils.py | 73 +++--- 20 files changed, 609 insertions(+), 880 deletions(-) delete mode 100644 ngclearn/commands/__init__.py diff --git a/ngclearn/__init__.py b/ngclearn/__init__.py index e457cb52..d404651e 100644 --- a/ngclearn/__init__.py +++ b/ngclearn/__init__.py @@ -1,9 +1,6 @@ import sys -import subprocess import pkg_resources from pkg_resources import get_distribution -#from pathlib import Path -#from sys import argv __version__ = get_distribution('ngclearn').version @@ -31,33 +28,27 @@ import ngcsimlib -from ngcsimlib.context import Context -from ngcsimlib.component import Component +from ngcsimlib import Component, MethodProcess, JointProcess +from ngcsimlib.context import Context, ContextObjectTypes from ngcsimlib.compartment import Compartment -from ngcsimlib.resolver import resolver -from ngcsimlib import utils as sim_utils from ngclearn.utils.jaxProcess import JaxProcess -from ngcsimlib.compilers.process import transition, Process - - -from ngcsimlib import configure, preload_modules from ngcsimlib import logger -if not Path(argv[0]).name == "sphinx-build" or Path(argv[0]).name == "build.py": - if "readthedocs" not in argv[0]: ## prevent readthedocs execution of preload - configure() - logger.init_logging() - from ngcsimlib.configManager import get_config - pkg_config = get_config("packages") - if pkg_config is not None: - use_base_numpy = pkg_config.get("use_base_numpy", False) - if use_base_numpy: - import numpy as numpy - else: - from jax import numpy - else: - from jax import numpy - - - preload_modules() +# if not Path(argv[0]).name == "sphinx-build" or Path(argv[0]).name == "build.py": +# if "readthedocs" not in argv[0]: ## prevent readthedocs execution of preload +# configure() +# logger.init_logging() +# from ngcsimlib.configManager import get_config +# pkg_config = get_config("packages") +# if pkg_config is not None: +# use_base_numpy = pkg_config.get("use_base_numpy", False) +# if use_base_numpy: +# import numpy as numpy +# else: +# from jax import numpy +# else: +# from jax import numpy +# +# +# preload_modules() diff --git a/ngclearn/commands/__init__.py b/ngclearn/commands/__init__.py deleted file mode 100644 index 74eb06b3..00000000 --- a/ngclearn/commands/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from ngcsimlib.commands import * diff --git a/ngclearn/components/__init__.py b/ngclearn/components/__init__.py index af856c1a..f69f30a5 100644 --- a/ngclearn/components/__init__.py +++ b/ngclearn/components/__init__.py @@ -1,65 +1,65 @@ -from .jaxComponent import JaxComponent - -## point to rate-coded cell component types -from .neurons.graded.rateCell import RateCell -from .neurons.graded.gaussianErrorCell import GaussianErrorCell -from .neurons.graded.laplacianErrorCell import LaplacianErrorCell -from .neurons.graded.bernoulliErrorCell import BernoulliErrorCell -from .neurons.graded.rewardErrorCell import RewardErrorCell - -## point to standard spiking cell component types -from .neurons.spiking.sLIFCell import SLIFCell -from .neurons.spiking.IFCell import IFCell -from .neurons.spiking.LIFCell import LIFCell -from .neurons.spiking.WTASCell import WTASCell -from .neurons.spiking.quadLIFCell import QuadLIFCell -from .neurons.spiking.adExCell import AdExCell -from .neurons.spiking.fitzhughNagumoCell import FitzhughNagumoCell -from .neurons.spiking.izhikevichCell import IzhikevichCell -from .neurons.spiking.hodgkinHuxleyCell import HodgkinHuxleyCell -from .neurons.spiking.RAFCell import RAFCell - -## point to transformer/operator component types -from .other.varTrace import VarTrace -from .other.expKernel import ExpKernel - -## point to input encoder component types -from .input_encoders.bernoulliCell import BernoulliCell -from .input_encoders.poissonCell import PoissonCell -from .input_encoders.latencyCell import LatencyCell -from .input_encoders.phasorCell import PhasorCell - -## point to synapse component types -from .synapses.denseSynapse import DenseSynapse -from .synapses.staticSynapse import StaticSynapse -from .synapses.hebbian.hebbianSynapse import HebbianSynapse -from .synapses.hebbian.traceSTDPSynapse import TraceSTDPSynapse -from .synapses.hebbian.expSTDPSynapse import ExpSTDPSynapse -from .synapses.hebbian.eventSTDPSynapse import EventSTDPSynapse -from .synapses.hebbian.BCMSynapse import BCMSynapse -from .synapses.STPDenseSynapse import STPDenseSynapse -from .synapses.exponentialSynapse import ExponentialSynapse -from .synapses.doubleExpSynapse import DoupleExpSynapse -from .synapses.alphaSynapse import AlphaSynapse - -## point to convolutional component types -from .synapses.convolution.convSynapse import ConvSynapse -from .synapses.convolution.staticConvSynapse import StaticConvSynapse -from .synapses.convolution.hebbianConvSynapse import HebbianConvSynapse -from .synapses.convolution.traceSTDPConvSynapse import TraceSTDPConvSynapse -from .synapses.convolution.deconvSynapse import DeconvSynapse -from .synapses.convolution.staticDeconvSynapse import StaticDeconvSynapse -from .synapses.convolution.hebbianDeconvSynapse import HebbianDeconvSynapse -from .synapses.convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse -## point to modulated component types -from .synapses.modulated.MSTDPETSynapse import MSTDPETSynapse -from .synapses.modulated.REINFORCESynapse import REINFORCESynapse - -## point to monitors -from .monitor import Monitor - -## point to patched component types -from .synapses.patched.patchedSynapse import PatchedSynapse -from .synapses.patched.staticPatchedSynapse import StaticPatchedSynapse -from .synapses.patched.hebbianPatchedSynapse import HebbianPatchedSynapse - +# from .jaxComponent import JaxComponent +# +# ## point to rate-coded cell component types +# from .neurons.graded.rateCell import RateCell +# from .neurons.graded.gaussianErrorCell import GaussianErrorCell +# from .neurons.graded.laplacianErrorCell import LaplacianErrorCell +# from .neurons.graded.bernoulliErrorCell import BernoulliErrorCell +# from .neurons.graded.rewardErrorCell import RewardErrorCell +# +# ## point to standard spiking cell component types +# from .neurons.spiking.sLIFCell import SLIFCell +# from .neurons.spiking.IFCell import IFCell +# from .neurons.spiking.LIFCell import LIFCell +# from .neurons.spiking.WTASCell import WTASCell +# from .neurons.spiking.quadLIFCell import QuadLIFCell +# from .neurons.spiking.adExCell import AdExCell +# from .neurons.spiking.fitzhughNagumoCell import FitzhughNagumoCell +# from .neurons.spiking.izhikevichCell import IzhikevichCell +# from .neurons.spiking.hodgkinHuxleyCell import HodgkinHuxleyCell +# from .neurons.spiking.RAFCell import RAFCell +# +# ## point to transformer/operator component types +# from .other.varTrace import VarTrace +# from .other.expKernel import ExpKernel +# +# ## point to input encoder component types +# from .input_encoders.bernoulliCell import BernoulliCell +# from .input_encoders.poissonCell import PoissonCell +# from .input_encoders.latencyCell import LatencyCell +# from .input_encoders.phasorCell import PhasorCell +# +# ## point to synapse component types +# from .synapses.denseSynapse import DenseSynapse +# from .synapses.staticSynapse import StaticSynapse +# from .synapses.hebbian.hebbianSynapse import HebbianSynapse +# from .synapses.hebbian.traceSTDPSynapse import TraceSTDPSynapse +# from .synapses.hebbian.expSTDPSynapse import ExpSTDPSynapse +# from .synapses.hebbian.eventSTDPSynapse import EventSTDPSynapse +# from .synapses.hebbian.BCMSynapse import BCMSynapse +# from .synapses.STPDenseSynapse import STPDenseSynapse +# from .synapses.exponentialSynapse import ExponentialSynapse +# from .synapses.doubleExpSynapse import DoupleExpSynapse +# from .synapses.alphaSynapse import AlphaSynapse +# +# ## point to convolutional component types +# from .synapses.convolution.convSynapse import ConvSynapse +# from .synapses.convolution.staticConvSynapse import StaticConvSynapse +# from .synapses.convolution.hebbianConvSynapse import HebbianConvSynapse +# from .synapses.convolution.traceSTDPConvSynapse import TraceSTDPConvSynapse +# from .synapses.convolution.deconvSynapse import DeconvSynapse +# from .synapses.convolution.staticDeconvSynapse import StaticDeconvSynapse +# from .synapses.convolution.hebbianDeconvSynapse import HebbianDeconvSynapse +# from .synapses.convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse +# ## point to modulated component types +# from .synapses.modulated.MSTDPETSynapse import MSTDPETSynapse +# from .synapses.modulated.REINFORCESynapse import REINFORCESynapse +# +# ## point to monitors +# from .monitor import Monitor +# +# ## point to patched component types +# from .synapses.patched.patchedSynapse import PatchedSynapse +# from .synapses.patched.staticPatchedSynapse import StaticPatchedSynapse +# from .synapses.patched.hebbianPatchedSynapse import HebbianPatchedSynapse +# diff --git a/ngclearn/components/input_encoders/bernoulliCell.py b/ngclearn/components/input_encoders/bernoulliCell.py index d240de64..6d1ca713 100755 --- a/ngclearn/components/input_encoders/bernoulliCell.py +++ b/ngclearn/components/input_encoders/bernoulliCell.py @@ -1,12 +1,9 @@ from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, random -from ngclearn.utils import tensorstats -from ngcsimlib.deprecators import deprecate_args -from ngcsimlib.logger import info, warn - -from ngcsimlib.compilers.process import transition -#from ngcsimlib.component import Component from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable +import jax +from typing import Union class BernoulliCell(JaxComponent): """ @@ -29,51 +26,31 @@ class BernoulliCell(JaxComponent): batch_size: batch size dimension of this cell (Default: 1) """ - def __init__(self, name, n_units, batch_size=1, **kwargs): - super().__init__(name, **kwargs) - #super(BernoulliCell, self).__init__(name, **kwargs) + def __init__(self, name: str, n_units: int, batch_size: int = 1, key: Union[jax.Array, None] = None): + super().__init__(name=name, key=key) ## Layer Size Setup - self.batch_size = batch_size - self.n_units = n_units + self.batch_size = Compartment(batch_size, fixed=True) + self.n_units = Compartment(n_units, fixed=True) - # Compartments (state of the cell, parameters, will be updated through stateless calls) - restVals = jnp.zeros((self.batch_size, self.n_units)) + restVals = jnp.zeros((batch_size, n_units)) self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment self.outputs = Compartment(restVals, display_name="Spikes") # output compartment self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike - @transition(output_compartments=["outputs", "tols", "key"]) - @staticmethod - def advance_state(t, key, inputs, tols): - ## NOTE: should `inputs` be checked if bounded to [0,1]? - # print(key) - # print(t) - # print(inputs.shape) - # print(tols.shape) - # print("-----") - key, *subkeys = random.split(key, 3) - outputs = random.bernoulli(subkeys[0], p=inputs).astype(jnp.float32) - # Updates time-of-last-spike (tols) variable: - # output = s = binary spike vector - # tols = current time-of-last-spike variable - tols = (1. - outputs) * tols + (outputs * t) - return outputs, tols, key - - @transition(output_compartments=["inputs", "outputs", "tols"]) - @staticmethod - def reset(batch_size, n_units): - restVals = jnp.zeros((batch_size, n_units)) - return restVals, restVals, restVals - - def save(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - jnp.savez(file_name, key=self.key.value) + @compilable + def advance_state(self, t): + key, subkey = random.split(self.key.get(), 2) + self.outputs.set(random.bernoulli(subkey, p=self.inputs.get()).astype(jnp.float32)) + self.tols.set((1. - self.outputs.get()) * self.tols.get() + (self.outputs.get() * t)) + self.key.set(key) - def load(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - self.key.set(data['key']) + @compilable + def reset(self): + restVals = jnp.zeros((self.batch_size.get(), self.n_units.get())) + self.inputs.set(restVals) + self.outputs.set(restVals) + self.tols.set(restVals) @classmethod def help(cls): ## component help function @@ -101,22 +78,23 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines + # def __repr__(self): + # comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + # maxlen = max(len(c) for c in comps) + 5 + # lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" + # for c in comps: + # stats = tensorstats(getattr(self, c).value) + # if stats is not None: + # line = [f"{k}: {v}" for k, v in stats.items()] + # line = ", ".join(line) + # else: + # line = "None" + # lines += f" {f'({c})'.ljust(maxlen)}{line}\n" + # return lines if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: X = BernoulliCell("X", 9) - print(X) + + X.batch_size.set(10) diff --git a/ngclearn/components/input_encoders/latencyCell.py b/ngclearn/components/input_encoders/latencyCell.py index c7343cfa..374bea78 100755 --- a/ngclearn/components/input_encoders/latencyCell.py +++ b/ngclearn/components/input_encoders/latencyCell.py @@ -1,16 +1,13 @@ from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, random, jit from functools import partial -from ngclearn.utils import tensorstats -from ngcsimlib.deprecators import deprecate_args -from ngcsimlib.logger import info, warn - -from ngcsimlib.compilers.process import transition -#from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment +import jax +from typing import Union from ngclearn.utils.model_utils import clamp_min, clamp_max +from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable @partial(jit, static_argnums=[5]) def _calc_spike_times_linear(data, tau, thr, first_spk_t, num_steps=1., @@ -148,88 +145,78 @@ class LatencyCell(JaxComponent): # Define Functions def __init__( - self, name, n_units, tau=1., threshold=0.01, first_spike_time=0., linearize=False, normalize=False, - clip_spikes=False, num_steps=1., batch_size=1, **kwargs + self, name: str, n_units: int, tau: float = 1., threshold: float = 0.01, + first_spike_time: float = 0., linearize: bool = False, + normalize: bool = False, clip_spikes: bool = False, + num_steps: float = 1., batch_size: int = 1, + key: Union[jax.Array, None] = None ): - super().__init__(name, **kwargs) + super().__init__(name=name, key=key) ## latency meta-parameters - self.first_spike_time = first_spike_time - self.tau = tau - self.threshold = threshold - self.linearize = linearize - self.clip_spikes = clip_spikes + self.first_spike_time = Compartment(first_spike_time, fixed=True) + self.tau = Compartment(tau, fixed=True) + self.threshold = Compartment(threshold, fixed=True) + self.linearize = Compartment(linearize, fixed=True) + self.clip_spikes = Compartment(clip_spikes, fixed=True) ## normalize latency code s.t. final spike(s) occur w/in num_steps - self.normalize = normalize - self.num_steps = num_steps + self.normalize = Compartment(normalize, fixed=True) + self.num_steps = Compartment(num_steps, fixed=True) ## Layer Size Setup - self.batch_size = batch_size - self.n_units = n_units + self.batch_size = Compartment(batch_size, fixed=True) + self.n_units = Compartment(n_units, fixed=True) ## Compartment setup - restVals = jnp.zeros((self.batch_size, self.n_units)) + restVals = jnp.zeros((batch_size, n_units)) self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment self.outputs = Compartment(restVals, display_name="Spikes") # output compartment self.mask = Compartment(restVals, display_name="Spike Time Mask") self.clip_mask = Compartment(restVals, display_name="Clip Mask") self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike self.targ_sp_times = Compartment(restVals, display_name="Target Spike Time", units="ms") - #self.reset() - @transition(output_compartments=["targ_sp_times", "clip_mask"]) - @staticmethod - def calc_spike_times( - linearize, tau, threshold, first_spike_time, num_steps, normalize, clip_spikes, inputs - ): - ## would call this function before processing a spike train (at start) - data = inputs - if clip_spikes: - clip_mask = (data < threshold) * 1. ## find values under threshold + @compilable + def calc_spike_times(self): + if self.clip_spikes.get(): + self.clip_mask.set((self.inputs.get() < self.threshold) * 1.) else: - clip_mask = data * 0. ## all values allowed to fire spikes - if linearize: ## linearize spike time calculation - stimes = _calc_spike_times_linear(data, tau, threshold, - first_spike_time, - num_steps, normalize) - targ_sp_times = stimes #* calcEvent + targ_sp_times * (1. - calcEvent) - else: ## standard nonlinear spike time calculation - stimes = _calc_spike_times_nonlinear(data, tau, threshold, - first_spike_time, - num_steps=num_steps, - normalize=normalize) - targ_sp_times = stimes #* calcEvent + targ_sp_times * (1. - calcEvent) - return targ_sp_times, clip_mask - - @transition(output_compartments=["outputs", "tols", "mask", "targ_sp_times", "key"]) - @staticmethod - def advance_state(t, dt, key, inputs, mask, clip_mask, targ_sp_times, tols): - key, *subkeys = random.split(key, 2) - data = inputs ## get sensory pattern data / features - spikes, spk_mask = _extract_spike(targ_sp_times, t, mask) ## get spikes at t - - # Updates time-of-last-spike (tols) variable: - # output = s = binary spike vector - # tols = current time-of-last-spike variable - tols = (1. - spikes) * tols + (spikes * t) - - spikes = spikes * (1. - clip_mask) - return spikes, tols, spk_mask, targ_sp_times, key - - @transition(output_compartments=["inputs", "outputs", "tols", "mask", "clip_mask", "targ_sp_times"]) - @staticmethod - def reset(batch_size, n_units): - restVals = jnp.zeros((batch_size, n_units)) - return (restVals, restVals, restVals, restVals, restVals, restVals) - - def save(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - jnp.savez(file_name, key=self.key.value) - - def load(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - self.key.set(data['key']) + self.clip_mask.set(self.inputs.get() * 0.) + + if self.linearize.get(): + self.targ_sp_times.set( + _calc_spike_times_linear(self.inputs.get(), + self.tau.get(), + self.threshold.get(), + self.first_spike_time.get(), + self.num_steps.get(), + self.normalize.get())) + else: + self.targ_sp_times.set( + _calc_spike_times_nonlinear(self.inputs.get(), + self.tau.get(), + self.threshold.get(), + self.first_spike_time.get(), + self.num_steps.get(), + self.normalize.get())) + + + @compilable + def advance_state(self, t): + spikes, spike_mask = _extract_spike(self.targ_sp_times.get(), t, self.mask.get()) + self.tols.set((1. - spikes) * self.tols.get() + (spikes * t)) + self.outputs.set(spikes * (1. - self.clip_mask.get())) + self.mask.set(spike_mask) + + @compilable + def reset(self): + restVals = jnp.zeros((self.batch_size.get(), self.n_units.get())) + self.inputs.set(restVals) + self.outputs.set(restVals) + self.tols.set(restVals) + self.mask.set(restVals) + self.clip_mask.set(restVals) + self.targ_sp_times.set(restVals) @classmethod def help(cls): ## component help function @@ -266,22 +253,10 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: X = LatencyCell("X", 9) print(X) + print(X.calc_spike_times.compiled.code) + print(X.advance_state.compiled.code) \ No newline at end of file diff --git a/ngclearn/components/input_encoders/phasorCell.py b/ngclearn/components/input_encoders/phasorCell.py index 9eaa16a7..a9ca1425 100755 --- a/ngclearn/components/input_encoders/phasorCell.py +++ b/ngclearn/components/input_encoders/phasorCell.py @@ -1,12 +1,10 @@ from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, random -from ngclearn.utils import tensorstats -from ngcsimlib.deprecators import deprecate_args -from ngcsimlib.logger import info, warn +import jax +from typing import Union -from ngcsimlib.compilers.process import transition -#from ngcsimlib.component import Component from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable class PhasorCell(JaxComponent): @@ -35,19 +33,25 @@ class PhasorCell(JaxComponent): # Define Functions def __init__( - self, name, n_units, target_freq=63.75, batch_size=1, disable_phasor=False, **kwargs): - super().__init__(name, **kwargs) + self, name: str, n_units: int, target_freq: float = 63.75, + batch_size: int = 1, key: Union[jax.Array, None] = None): + super().__init__(name=name, key=key) + + _key, subkey = random.split(self.key.get(), 2) + self.key.set(_key) ## Phasor meta-parameters - self.target_freq = target_freq ## maximum frequency (in Hertz/Hz) + self.target_freq = Compartment(target_freq, fixed=True) ## maximum frequency (in Hertz/Hz) + self.base_scale = Compartment(random.poisson(subkey[0], lam=target_freq, shape=(batch_size, n_units)) / target_freq, fixed=True) ## Layer Size Setup - self.batch_size = batch_size - self.n_units = n_units - _key, *subkey = random.split(self.key.value, 3) - self.key.set(_key) + self.batch_size = Compartment(batch_size, fixed=True) + self.n_units = Compartment(n_units, fixed=True) + + + ## Compartment setup - restVals = jnp.zeros((self.batch_size, self.n_units)) + restVals = jnp.zeros((batch_size, n_units)) self.inputs = Compartment(restVals, display_name="Input Stimulus") # input # compartment @@ -56,80 +60,41 @@ def __init__( self.tols = Compartment(initial_value=restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike self.angles = Compartment(restVals, display_name="Angles", units="deg") - # self.base_scale = random.uniform(subkey, self.angles.value.shape, - # minval=0.75, maxval=1.25) - # self.base_scale = ((random.normal(subkey, self.angles.value.shape) * 0.15) + 1) - # alpha = ((random.normal(subkey, self.angles.value.shape) * (jnp.sqrt(target_freq) / target_freq)) + 1) - # beta = random.poisson(subkey, lam=target_freq, shape=self.angles.value.shape) / target_freq - - self.base_scale = random.poisson(subkey[0], lam=target_freq, shape=self.angles.value.shape) / target_freq - self.disable_phasor = disable_phasor - - def validate(self, dt=None, **validation_kwargs): - valid = super().validate(**validation_kwargs) - if dt is None: - warn(f"{self.name} requires a validation kwarg of `dt`") - return False - ## check for unstable combinations of dt and target-frequency - # meta-params - events_per_timestep = (dt / 1000.) * self.target_freq ## - # compute scaled probability - if events_per_timestep > 1.: - valid = False - warn( - f"{self.name} will be unable to make as many temporal events " - f"as " - f"requested! ({events_per_timestep} events/timestep) Unstable " - f"combination of dt = {dt} and target_freq = " - f"{self.target_freq} " - f"being used!" - ) - return valid - - @transition(output_compartments=["outputs", "tols", "key", "angles"]) - @staticmethod - def advance_state(t, dt, target_freq, key, inputs, angles, tols, base_scale, disable_phasor): + + @compilable + def advance_state(self, t, dt): ms_per_second = 1000 # ms/s - events_per_ms = target_freq / ms_per_second # e/s s/ms -> e/ms + events_per_ms = self.target_freq.get() / ms_per_second # e/s s/ms -> e/ms ms_per_event = 1 / events_per_ms # ms/e time_step_per_event = ms_per_event / dt # ms/e * ts/ms -> ts / e angle_per_event = 2 * jnp.pi # rad / e angle_per_timestep = angle_per_event / time_step_per_event # rad / e # * e/ts -> rad / ts - key, *subkey = random.split(key, 3) - # scatter = random.uniform(subkey, angles.shape, minval=0.5, - # maxval=1.5) * base_scale + key, *subkey = random.split(self.key.get(), 3) - scatter = ((random.normal(subkey[0], angles.shape) * 0.2) + 1) * base_scale + scatter = ((random.normal(subkey[0], self.angles.get().shape) * 0.2) + 1) * self.base_scale.get() scattered_update = angle_per_timestep * scatter - scaled_scattered_update = scattered_update * inputs + scaled_scattered_update = scattered_update * self.inputs.get() - updated_angles = angles + scaled_scattered_update - outputs = jnp.where(updated_angles > angle_per_event, 1., 0.) - updated_angles = jnp.where(updated_angles > angle_per_event, - updated_angles - angle_per_event, - updated_angles) - if disable_phasor: - outputs = inputs + 0 - tols = tols * (1. - outputs) + t * outputs + updated_angles = self.angles.get() + scaled_scattered_update + self.outputs.set(jnp.where(updated_angles > angle_per_event, 1., 0.)) - return outputs, tols, key, updated_angles + self.angles.set(jnp.where(updated_angles > angle_per_event, + updated_angles - angle_per_event, + updated_angles)) - @transition(output_compartments=["inputs", "outputs", "tols", "angles", "key"]) - @staticmethod - def reset(batch_size, n_units, key, target_freq): - restVals = jnp.zeros((batch_size, n_units)) - key, *subkey = random.split(key, 3) - return restVals, restVals, restVals, restVals, key + self.tols.set(self.tols.get() * (1. - self.outputs.get()) + t * self.outputs.get()) - def save(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - jnp.savez(file_name, key=self.key.value) + @compilable + def reset(self): + restVals = jnp.zeros((self.batch_size.get(), self.n_units.get())) + self.inputs.set(restVals) + self.outputs.set(restVals) + self.tols.set(restVals) + self.angles.set(restVals) + key, _ = random.split(self.key.get(), 2) + self.key.set(key) - def load(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - self.key.set(data['key']) @classmethod def help(cls): ## component help function @@ -157,19 +122,4 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if - Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - diff --git a/ngclearn/components/input_encoders/poissonCell.py b/ngclearn/components/input_encoders/poissonCell.py index 5f385951..65022156 100644 --- a/ngclearn/components/input_encoders/poissonCell.py +++ b/ngclearn/components/input_encoders/poissonCell.py @@ -1,11 +1,9 @@ from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, random -from ngclearn.utils import tensorstats -from ngcsimlib.deprecators import deprecate_args -from ngcsimlib.logger import info, warn +import jax +from typing import Union -from ngcsimlib.compilers.process import transition -#from ngcsimlib.component import Component +from ngcsimlib.parser import compilable from ngcsimlib.compartment import Compartment class PoissonCell(JaxComponent): @@ -31,72 +29,41 @@ class PoissonCell(JaxComponent): batch_size: batch size dimension of this cell (Default: 1) """ - @deprecate_args(max_freq="target_freq") - def __init__(self, name, n_units, target_freq=63.75, batch_size=1, **kwargs): - super().__init__(name, **kwargs) + def __init__(self, name: str, n_units: int, target_freq: float = 63.75, batch_size: int = 1, + key: Union[jax.Array, None] = None): + super().__init__(name=name, key=key) ## Constrained Bernoulli meta-parameters - self.target_freq = target_freq ## maximum frequency (in Hertz/Hz) + self.target_freq = Compartment(target_freq, fixed=True) ## maximum frequency (in Hertz/Hz) ## Layer Size Setup - self.batch_size = batch_size - self.n_units = n_units + self.batch_size = Compartment(batch_size, fixed=True) + self.n_units = Compartment(n_units, fixed=True) # Compartments (state of the cell, parameters, will be updated through stateless calls) - restVals = jnp.zeros((self.batch_size, self.n_units)) + restVals = jnp.zeros((self.batch_size.get(), self.n_units.get())) self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment self.outputs = Compartment(restVals, display_name="Spikes") # output compartment self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike - def validate(self, dt=None, **validation_kwargs): - valid = super().validate(**validation_kwargs) - if dt is None: - warn(f"{self.name} requires a validation kwarg of `dt`") - return False - ## check for unstable combinations of dt and target-frequency meta-params - events_per_timestep = (dt/1000.) * self.target_freq ## compute scaled probability - if events_per_timestep > 1.: - valid = False - warn( - f"{self.name} will be unable to make as many temporal events as " - f"requested! ({events_per_timestep} events/timestep) Unstable " - f"combination of dt = {dt} and target_freq = {self.target_freq} " - f"being used!" - ) - return valid - - @transition(output_compartments=["outputs", "tols", "key"]) - @staticmethod - def advance_state(t, dt, target_freq, key, inputs, tols): - key, *subkeys = random.split(key, 2) - pspike = inputs * (dt / 1000.) * target_freq - eps = random.uniform(subkeys[0], inputs.shape, minval=0., maxval=1., + @compilable + def advance_state(self, t, dt): + key, subkey = random.split(self.key.get(), 2) + pspike = self.inputs.get() * (dt / 1000.) * self.target_freq.get() + eps = random.uniform(subkey, self.inputs.get().shape, minval=0., maxval=1., dtype=jnp.float32) - outputs = (eps < pspike).astype(jnp.float32) - - # Updates time-of-last-spike (tols) variable: - # output = s = binary spike vector - # tols = current time-of-last-spike variable - tols = (1. - outputs) * tols + (outputs * t) - return outputs, tols, key - - @transition(output_compartments=["inputs", "outputs", "tols"]) - @staticmethod - def reset(batch_size, n_units): - restVals = jnp.zeros((batch_size, n_units)) - return restVals, restVals, restVals - - def save(self, directory, **kwargs): - target_freq = (self.target_freq if isinstance(self.target_freq, float) - else jnp.ones([[self.target_freq]])) - file_name = directory + "/" + self.name + ".npz" - jnp.savez(file_name, key=self.key.value, target_freq=target_freq) - - def load(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - self.key.set(data['key']) - self.target_freq = data['target_freq'] + + self.outputs.set((eps < pspike).astype(jnp.float32)) + self.tols.set((1. - self.outputs.get()) * self.tols.get() + (self.outputs.get() * t)) + self.key.set(key) + + @compilable + def reset(self): + restVals = jnp.zeros((self.batch_size.get(), self.n_units.get())) + if not self.inputs.targeted: + self.inputs.set(restVals) + self.outputs.set(restVals) + self.tols.set(restVals) @classmethod def help(cls): ## component help function @@ -126,22 +93,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if - Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - - if __name__ == '__main__': from ngcsimlib.context import Context diff --git a/ngclearn/components/jaxComponent.py b/ngclearn/components/jaxComponent.py index 0488c47c..6d8f08ae 100755 --- a/ngclearn/components/jaxComponent.py +++ b/ngclearn/components/jaxComponent.py @@ -1,8 +1,12 @@ import time + +from typing import Union +import jax +from jax import numpy as jnp from jax import random -#from ngclearn import resolver, Component, Compartment -from ngcsimlib.component import Component from ngcsimlib.compartment import Compartment +from ngcsimlib import Component + class JaxComponent(Component): """ @@ -14,12 +18,43 @@ class JaxComponent(Component): key: PRNG key to control determinism of any underlying random values associated with this cell - directory: string indicating directory on disk to save component parameter - values to """ - def __init__(self, name, key=None, directory=None, **kwargs): - super().__init__(name, **kwargs) - self.directory = directory + def __init__(self, name: str, key: Union[jax.Array, None] = None): + super().__init__(name) self.key = Compartment( random.PRNGKey(time.time_ns()) if key is None else key) + + + def save(self, directory: str): + """ + The default save method for JaxComponents, it stores the values of all + non-targeted (non-wired) compartments into a .npz file. + + Args: + directory: The directory to save the .npz file. + """ + file_name = directory + "/" + self.name + ".npz" + data = {} + for comp_name, comp in self.compartments: + if not comp.targeted: + data[comp_name] = comp.get() + jnp.savez(file_name, **data) + + + def load(self, directory: str): + """ + The default load method for JaxComponents, it is expected to work with + the default save. If the save method is modified this one will need to + be modified too. + + Args: + directory: The directory to load the .npz file. + """ + file_name = directory + "/" + self.name + ".npz" + data = jnp.load(file_name) + for comp_name, comp in self.compartments: + d = data.get(comp_name, None) + if d is not None: + comp.set(d) + diff --git a/ngclearn/components/neurons/__init__.py b/ngclearn/components/neurons/__init__.py index e7165d7e..2398f011 100644 --- a/ngclearn/components/neurons/__init__.py +++ b/ngclearn/components/neurons/__init__.py @@ -1,17 +1,17 @@ -## point to rate-coded cell componet types -from .graded.rateCell import RateCell -from .graded.gaussianErrorCell import GaussianErrorCell -from .graded.laplacianErrorCell import LaplacianErrorCell -from .graded.bernoulliErrorCell import BernoulliErrorCell -from .graded.rewardErrorCell import RewardErrorCell -## point to standard spiking cell component types -from .spiking.sLIFCell import SLIFCell -from .spiking.IFCell import IFCell -from .spiking.LIFCell import LIFCell -from .spiking.WTASCell import WTASCell -from .spiking.quadLIFCell import QuadLIFCell -from .spiking.adExCell import AdExCell -from .spiking.fitzhughNagumoCell import FitzhughNagumoCell -from .spiking.izhikevichCell import IzhikevichCell -from .spiking.hodgkinHuxleyCell import HodgkinHuxleyCell -from .spiking.RAFCell import RAFCell +# ## point to rate-coded cell componet types +# from .graded.rateCell import RateCell +# from .graded.gaussianErrorCell import GaussianErrorCell +# from .graded.laplacianErrorCell import LaplacianErrorCell +# from .graded.bernoulliErrorCell import BernoulliErrorCell +# from .graded.rewardErrorCell import RewardErrorCell +# ## point to standard spiking cell component types +# from .spiking.sLIFCell import SLIFCell +# from .spiking.IFCell import IFCell +# from .spiking.LIFCell import LIFCell +# from .spiking.WTASCell import WTASCell +# from .spiking.quadLIFCell import QuadLIFCell +# from .spiking.adExCell import AdExCell +# from .spiking.fitzhughNagumoCell import FitzhughNagumoCell +# from .spiking.izhikevichCell import IzhikevichCell +# from .spiking.hodgkinHuxleyCell import HodgkinHuxleyCell +# from .spiking.RAFCell import RAFCell diff --git a/ngclearn/components/neurons/graded/rateCell.py b/ngclearn/components/neurons/graded/rateCell.py index e55ce8ff..63a9fe3b 100755 --- a/ngclearn/components/neurons/graded/rateCell.py +++ b/ngclearn/components/neurons/graded/rateCell.py @@ -5,7 +5,6 @@ from ngclearn.utils import tensorstats # from ngclearn import resolver, Component, Compartment from ngcsimlib.compartment import Compartment -from ngcsimlib.compilers.process import transition from ngclearn.components.jaxComponent import JaxComponent from ngclearn.utils.model_utils import create_function, threshold_soft, \ threshold_cauchy diff --git a/ngclearn/components/neurons/spiking/LIFCell.py b/ngclearn/components/neurons/spiking/LIFCell.py index 371e8058..96217e46 100644 --- a/ngclearn/components/neurons/spiking/LIFCell.py +++ b/ngclearn/components/neurons/spiking/LIFCell.py @@ -1,17 +1,12 @@ from ngclearn.components.jaxComponent import JaxComponent -from jax import numpy as jnp, random, jit, nn -from functools import partial -from ngclearn.utils import tensorstats -from ngcsimlib.deprecators import deprecate_args -from ngcsimlib.logger import info, warn +from jax import numpy as jnp, random, nn from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ step_euler, step_rk2 from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator, triangular_estimator, straight_through_estimator) -from ngcsimlib.compilers.process import transition -#from ngcsimlib.component import Component +from ngcsimlib.parser import compilable from ngcsimlib.compartment import Compartment def _dfv(t, v, params): ## voltage dynamics wrapper @@ -112,53 +107,53 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell v_min: minimum voltage to clamp dynamics to (Default: None) """ ## batch_size arg? - @deprecate_args(thr_jitter=None, v_decay="conduct_leak") def __init__( self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., conduct_leak=1., tau_theta=1e7, theta_plus=0.05, refract_time=5., one_spike=False, integration_type="euler", surrogate_type="straight_through", - v_min=None, max_one_spike=False, **kwargs + v_min=None, max_one_spike=False, key=None ): - super().__init__(name, **kwargs) + super().__init__(name, key) ## Integration properties self.integrationType = integration_type self.intgFlag = get_integrator_code(self.integrationType) - - ## membrane parameter setup (affects ODE integration) - self.tau_m = tau_m ## membrane time constant - self.resist_m = resist_m ## resistance value self.one_spike = one_spike ## True => constrains system to simulate 1 spike per time step self.max_one_spike = max_one_spike - self.v_min = v_min ## ensures voltage is never < v_min - self.v_rest = v_rest #-65. # mV - self.v_reset = v_reset # -60. # -65. # mV (milli-volts) - self.g_L = conduct_leak ## controls strength of voltage leak (1 -> LIF, 0 => IF) + ## membrane parameter setup (affects ODE integration) + self.tau_m = Compartment(tau_m, fixed=True) ## membrane time constant + self.resist_m = Compartment(resist_m, fixed=True) ## resistance value + + self.v_min = Compartment(v_min, fixed=True) ## ensures voltage is never < v_min + + self.v_rest = Compartment(v_rest, fixed=True) #-65. # mV + self.v_reset = Compartment(v_reset, fixed=True) # -60. # -65. # mV (milli-volts) + self.g_L = Compartment(conduct_leak, fixed=True) ## controls strength of voltage leak (1 -> LIF, 0 => IF) ## basic asserts to prevent neuronal dynamics breaking... #assert (self.conduct_leak * self.dt / self.tau_m) <= 1. ## <-- to integrate in verify... - assert self.resist_m > 0. - self.tau_theta = tau_theta ## threshold time constant # ms (0 turns off) - self.theta_plus = theta_plus #0.05 ## threshold increment - self.refract_T = refract_time #5. # 2. ## refractory period # ms - self.thr = thr ## (fixed) base value for threshold #-52 # -72. # mV + assert self.resist_m.get() > 0. + self.tau_theta = Compartment(tau_theta, fixed=True) ## threshold time constant # ms (0 turns off) + self.theta_plus = Compartment(theta_plus, fixed=True) #0.05 ## threshold increment + self.refract_T = Compartment(refract_time, fixed=True) #5. # 2. ## refractory period # ms + self.thr = Compartment(thr, fixed=True) ## (fixed) base value for threshold #-52 # -72. # mV ## Layer Size Setup - self.batch_size = 1 - self.n_units = n_units + self.batch_size = Compartment(1, fixed=True) + self.n_units = Compartment(n_units, fixed=True) - ## set up surrogate function for spike emission - if surrogate_type == "secant_lif": - self.spike_fx, self.d_spike_fx = secant_lif_estimator() - elif surrogate_type == "arctan": - self.spike_fx, self.d_spike_fx = arctan_estimator() - elif surrogate_type == "triangular": - self.spike_fx, self.d_spike_fx = triangular_estimator() - else: ## default: straight_through - self.spike_fx, self.d_spike_fx = straight_through_estimator() + # ## set up surrogate function for spike emission + # if surrogate_type == "secant_lif": + # spike_fx, d_spike_fx = secant_lif_estimator() + # elif surrogate_type == "arctan": + # spike_fx, d_spike_fx = arctan_estimator() + # elif surrogate_type == "triangular": + # spike_fx, d_spike_fx = triangular_estimator() + # else: ## default: straight_through + # spike_fx, d_spike_fx = straight_through_estimator() ## Compartment setup - restVals = jnp.zeros((self.batch_size, self.n_units)) + restVals = jnp.zeros((self.batch_size.get(), self.n_units.get())) self.j = Compartment(restVals, display_name="Current", units="mA") self.v = Compartment(restVals + self.v_rest, display_name="Voltage", units="mV") @@ -170,122 +165,70 @@ def __init__( units="mV") self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") ## time-of-last-spike - self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value") + # self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value") - @transition(output_compartments=["v", "s", "s_raw", "rfr", "thr_theta", "tols", "key", "surrogate"]) - @staticmethod - def advance_state( - t, dt, tau_m, resist_m, v_rest, v_reset, g_L, refract_T, thr, tau_theta, theta_plus, one_spike, max_one_spike, - v_min, intgFlag, d_spike_fx, key, j, v, rfr, thr_theta, tols - ): - skey = None ## this is an empty dkey if single_spike mode turned off - if one_spike and not max_one_spike: - key, skey = random.split(key, 2) - ## run one integration step for neuronal dynamics - j = j * resist_m - ############################################################################ - ### Runs leaky integrator (leaky integrate-and-fire; LIF) neuronal dynamics. - _v_thr = thr_theta + thr ## calc present voltage threshold - #mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask - ## update voltage / membrane potential - v_params = (j, rfr, tau_m, refract_T, v_rest, g_L) - if intgFlag == 1: - _, _v = step_rk2(0., v, _dfv, dt, v_params) + @compilable + def advance_state(self, dt, t): + j = self.j.get() * self.resist_m.get() + + _v_thr = self.thr_theta.get() + self.thr.get() ## calc present voltage threshold + + v_params = (j, self.rfr.get(), self.tau_m.get(), self.refract_T.get(), self.v_rest.get(), self.g_L.get()) + + if self.intgFlag == 1: + _, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params) else: - _, _v = step_euler(0., v, _dfv, dt, v_params) - ## obtain action potentials/spikes/pulses + _, _v = step_euler(0., self.v.get(), _dfv, dt, v_params) + s = (_v > _v_thr) * 1. - v_prespike = v - ## update refractory variables - _rfr = (rfr + dt) * (1. - s) - ## perform hyper-polarization of neuronal cells - _v = _v * (1. - s) + s * v_reset - - raw_s = s + 0 ## preserve un-altered spikes - ############################################################################ - ## this is a spike post-processing step - if skey is not None: + _rfr = (self.rfr.get() + dt) * (1. - s) + _v = _v * (1. - s) + s * self.v_reset.get() + + raw_s = s + + if self.one_spike and not self.max_one_spike: + key, skey = random.split(self.key.get(), 2) + m_switch = (jnp.sum(s) > 0.).astype(jnp.float32) ## TODO: not batch-able rS = s * random.uniform(skey, s.shape) rS = nn.one_hot(jnp.argmax(rS, axis=1), num_classes=s.shape[1], dtype=jnp.float32) s = s * (1. - m_switch) + rS * m_switch - if max_one_spike: - rS = nn.one_hot(jnp.argmax(v_prespike, axis=1), num_classes=s.shape[1], dtype=jnp.float32) ## get max-volt spike + self.key.set(key) + + if self.max_one_spike: + rS = nn.one_hot(jnp.argmax(self.v.get(), axis=1), num_classes=s.shape[1], dtype=jnp.float32) ## get max-volt spike s = s * rS ## mask out non-max volt spikes - ############################################################################ - raw_spikes = raw_s - v = _v - rfr = _rfr - surrogate = d_spike_fx(v, _v_thr) #d_spike_fx(v, thr + thr_theta) - if tau_theta > 0.: + if self.tau_theta.get() > 0.: ## run one integration step for threshold dynamics - thr_theta = _update_theta(dt, thr_theta, raw_spikes, tau_theta, theta_plus) + thr_theta = _update_theta(dt, self.thr_theta.get(), raw_s, self.tau_theta.get(), self.theta_plus.get()) + self.thr_theta.set(thr_theta) + ## update tols - tols = (1. - s) * tols + (s * t) - if v_min is not None: ## ensures voltage never < v_rest - v = jnp.maximum(v, v_min) - return v, s, raw_spikes, rfr, thr_theta, tols, key, surrogate - - @transition(output_compartments=["j", "v", "s", "s_raw", "rfr", "tols", "surrogate"]) - @staticmethod - def reset(batch_size, n_units, v_rest, refract_T): - restVals = jnp.zeros((batch_size, n_units)) - j = restVals #+ 0 - v = restVals + v_rest - s = restVals #+ 0 - s_raw = restVals - rfr = restVals + refract_T - #thr_theta = restVals ## do not reset thr_theta - tols = restVals #+ 0 - surrogate = restVals + 1. - return j, v, s, s_raw, rfr, tols, surrogate - - def save(self, directory, **kwargs): - ## do a protected save of constants, depending on whether they are floats or arrays - tau_m = (self.tau_m if isinstance(self.tau_m, float) - else jnp.asarray([[self.tau_m * 1.]])) - thr = (self.thr if isinstance(self.thr, float) - else jnp.asarray([[self.thr * 1.]])) - v_rest = (self.v_rest if isinstance(self.v_rest, float) - else jnp.asarray([[self.v_rest * 1.]])) - v_reset = (self.v_reset if isinstance(self.v_reset, float) - else jnp.asarray([[self.v_reset * 1.]])) - g_L = (self.g_L if isinstance(self.g_L, float) - else jnp.asarray([[self.g_L * 1.]])) - resist_m = (self.resist_m if isinstance(self.resist_m, float) - else jnp.asarray([[self.resist_m * 1.]])) - tau_theta = (self.tau_theta if isinstance(self.tau_theta, float) - else jnp.asarray([[self.tau_theta * 1.]])) - theta_plus = (self.theta_plus if isinstance(self.theta_plus, float) - else jnp.asarray([[self.theta_plus * 1.]])) - - file_name = directory + "/" + self.name + ".npz" - jnp.savez(file_name, - threshold_theta=self.thr_theta.value, - tau_m=tau_m, thr=thr, v_rest=v_rest, - v_reset=v_reset, g_L=g_L, - resist_m=resist_m, tau_theta=tau_theta, - theta_plus=theta_plus, - key=self.key.value) - - def load(self, directory, seeded=False, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - self.thr_theta.set(data['threshold_theta']) - ## constants loaded in - self.tau_m = data['tau_m'] - self.thr = data['thr'] - self.v_rest = data['v_rest'] - self.v_reset = data['v_reset'] - self.g_L = data['g_L'] - self.resist_m = data['resist_m'] - self.tau_theta = data['tau_theta'] - self.theta_plus = data['theta_plus'] - - if seeded: - self.key.set(data['key']) + self.tols.set((1. - s) * self.tols.get() + (s * t)) + + if self.v_min.get() is not None: ## ensures voltage never < v_rest + _v = jnp.maximum(_v, self.v_min.get()) + + + self.v.set(_v) + self.s.set(s) + self.s_raw.set(raw_s) + self.rfr.set(_rfr) + + + @compilable + def reset(self): + restVals = jnp.zeros((self.batch_size.get(), self.n_units.get())) + if not self.j.targeted: + self.j.set(restVals) + self.v.set(restVals + self.v_rest.get()) + self.s.set(restVals) + self.s_raw.set(restVals) + self.rfr.set(restVals + self.refract_T.get()) + self.tols.set(restVals) + @classmethod def help(cls): ## component help function @@ -333,19 +276,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines if __name__ == '__main__': from ngcsimlib.context import Context diff --git a/ngclearn/components/neurons/spiking/__init__.py b/ngclearn/components/neurons/spiking/__init__.py index 690087b7..6687c56a 100644 --- a/ngclearn/components/neurons/spiking/__init__.py +++ b/ngclearn/components/neurons/spiking/__init__.py @@ -1,11 +1,11 @@ -## point to standard spiking cell component types -from .sLIFCell import SLIFCell -from .LIFCell import LIFCell -from .IFCell import IFCell -from .WTASCell import WTASCell -from .quadLIFCell import QuadLIFCell -from .adExCell import AdExCell -from .fitzhughNagumoCell import FitzhughNagumoCell -from .izhikevichCell import IzhikevichCell -from .RAFCell import RAFCell -from .hodgkinHuxleyCell import HodgkinHuxleyCell +# ## point to standard spiking cell component types +# from .sLIFCell import SLIFCell +# from .LIFCell import LIFCell +# from .IFCell import IFCell +# from .WTASCell import WTASCell +# from .quadLIFCell import QuadLIFCell +# from .adExCell import AdExCell +# from .fitzhughNagumoCell import FitzhughNagumoCell +# from .izhikevichCell import IzhikevichCell +# from .RAFCell import RAFCell +# from .hodgkinHuxleyCell import HodgkinHuxleyCell diff --git a/ngclearn/components/other/__init__.py b/ngclearn/components/other/__init__.py index cff092d9..14d46a49 100644 --- a/ngclearn/components/other/__init__.py +++ b/ngclearn/components/other/__init__.py @@ -1,3 +1,3 @@ -from .varTrace import VarTrace -from .expKernel import ExpKernel - +# from .varTrace import VarTrace +# from .expKernel import ExpKernel +# diff --git a/ngclearn/components/other/varTrace.py b/ngclearn/components/other/varTrace.py index 94510e75..936b9f10 100644 --- a/ngclearn/components/other/varTrace.py +++ b/ngclearn/components/other/varTrace.py @@ -2,11 +2,9 @@ from jax import numpy as jnp, random, jit from functools import partial from ngclearn.utils import tensorstats -from ngcsimlib.deprecators import deprecate_args +from ngcsimlib.parser import compilable from ngcsimlib.logger import info, warn -from ngcsimlib.compilers.process import transition -#from ngcsimlib.component import Component from ngcsimlib.compartment import Compartment @partial(jit, static_argnums=[4]) @@ -79,52 +77,56 @@ class VarTrace(JaxComponent): ## low-pass filter # Define Functions def __init__(self, name, n_units, tau_tr, a_delta, P_scale=1., gamma_tr=1, decay_type="exp", - n_nearest_spikes=0, batch_size=1, **kwargs): - super().__init__(name, **kwargs) + n_nearest_spikes=0, batch_size=1, key=None): + super().__init__(name, key) ## Trace control coefficients - self.tau_tr = tau_tr ## trace time constant - self.a_delta = a_delta ## trace increment (if spike occurred) - self.P_scale = P_scale ## trace scale if non-additive trace to be used - self.gamma_tr = gamma_tr self.decay_type = decay_type ## lin --> linear decay; exp --> exponential decay - self.n_nearest_spikes = n_nearest_spikes + + self.tau_tr = Compartment(tau_tr, fixed=True) ## trace time constant + self.a_delta = Compartment(a_delta, fixed=True) ## trace increment (if spike occurred) + self.P_scale = Compartment(P_scale, fixed=True) ## trace scale if non-additive trace to be used + self.gamma_tr = Compartment(gamma_tr, fixed=True) + self.n_nearest_spikes = Compartment(n_nearest_spikes, fixed=True) ## Layer Size Setup - self.batch_size = batch_size - self.n_units = n_units + self.batch_size = Compartment(batch_size, fixed=True) + self.n_units = Compartment(n_units, fixed=True) - restVals = jnp.zeros((self.batch_size, self.n_units)) + restVals = jnp.zeros((self.batch_size.get(), self.n_units.get())) self.inputs = Compartment(restVals) # input compartment self.outputs = Compartment(restVals) # output compartment self.trace = Compartment(restVals) - @transition(output_compartments=["outputs", "trace"]) - @staticmethod - def advance_state( - dt, decay_type, tau_tr, a_delta, P_scale, gamma_tr, inputs, trace, n_nearest_spikes - ): - decayFactor = 0. - if "exp" in decay_type: - decayFactor = jnp.exp(-dt/tau_tr) - elif "lin" in decay_type: - decayFactor = (1. - dt/tau_tr) - _x_tr = gamma_tr * trace * decayFactor - if n_nearest_spikes > 0: ## run k-nearest neighbor trace - _x_tr = _x_tr + inputs * (a_delta - (trace/n_nearest_spikes)) + @compilable + def advance_state(self, dt): + if "exp" in self.decay_type: + decayFactor = jnp.exp(-dt/self.tau_tr.get()) + elif "lin" in self.decay_type: + decayFactor = (1. - dt/self.tau_tr.get()) else: - if a_delta > 0.: ## run full convolution trace - _x_tr = _x_tr + inputs * a_delta - else: ## run simple max-clamped trace - _x_tr = _x_tr * (1. - inputs) + inputs * P_scale - trace = _x_tr - return trace, trace - - @transition(output_compartments=["inputs", "outputs", "trace"]) - @staticmethod - def reset(batch_size, n_units): - restVals = jnp.zeros((batch_size, n_units)) - return restVals, restVals, restVals + decayFactor = 0. + + + _x_tr = self.gamma_tr.get() * self.trace.get() * decayFactor + if self.n_nearest_spikes.get() > 0: + _x_tr = _x_tr + self.inputs.get() * (self.a_delta.get() - (self.trace.get() / self.n_nearest_spikes.get())) + else: + if self.a_delta.get() > 0.: + _x_tr = _x_tr + self.inputs.get() * self.a_delta.get() + else: + _x_tr = _x_tr * (1. - self.inputs.get()) + self.inputs.get() * self.P_scale.get() + + self.trace.set(_x_tr) + self.outputs.set(_x_tr) + + + @compilable + def reset(self): + restVals = jnp.zeros((self.batch_size.get(), self.n_units.get())) + self.inputs.set(restVals) + self.outputs.set(restVals) + self.trace.set(restVals) @classmethod def help(cls): ## component help function @@ -159,19 +161,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines if __name__ == '__main__': from ngcsimlib.context import Context diff --git a/ngclearn/components/synapses/__init__.py b/ngclearn/components/synapses/__init__.py index 2c21c231..fd701c25 100644 --- a/ngclearn/components/synapses/__init__.py +++ b/ngclearn/components/synapses/__init__.py @@ -1,38 +1,38 @@ -from .denseSynapse import DenseSynapse -from .staticSynapse import StaticSynapse - - -## short-term plasticity components -from .STPDenseSynapse import STPDenseSynapse -from .exponentialSynapse import ExponentialSynapse -from .doubleExpSynapse import DoupleExpSynapse -from .alphaSynapse import AlphaSynapse - -## dense synaptic components -from .hebbian.hebbianSynapse import HebbianSynapse -from .hebbian.traceSTDPSynapse import TraceSTDPSynapse -from .hebbian.expSTDPSynapse import ExpSTDPSynapse -from .hebbian.eventSTDPSynapse import EventSTDPSynapse -from .hebbian.BCMSynapse import BCMSynapse - - -## conv/deconv synaptic components -from .convolution.convSynapse import ConvSynapse -from .convolution.staticConvSynapse import StaticConvSynapse -from .convolution.hebbianConvSynapse import HebbianConvSynapse -from .convolution.traceSTDPConvSynapse import TraceSTDPConvSynapse -from .convolution.deconvSynapse import DeconvSynapse -from .convolution.staticDeconvSynapse import StaticDeconvSynapse -from .convolution.hebbianDeconvSynapse import HebbianDeconvSynapse -from .convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse - - -## modulated synaptic components -from .modulated.MSTDPETSynapse import MSTDPETSynapse -from .modulated.REINFORCESynapse import REINFORCESynapse - -## patched synaptic components -from .patched.patchedSynapse import PatchedSynapse -from .patched.staticPatchedSynapse import StaticPatchedSynapse -from .patched.hebbianPatchedSynapse import HebbianPatchedSynapse - +# from .denseSynapse import DenseSynapse +# from .staticSynapse import StaticSynapse +# +# +# ## short-term plasticity components +# from .STPDenseSynapse import STPDenseSynapse +# from .exponentialSynapse import ExponentialSynapse +# from .doubleExpSynapse import DoupleExpSynapse +# from .alphaSynapse import AlphaSynapse +# +# ## dense synaptic components +# from .hebbian.hebbianSynapse import HebbianSynapse +# from .hebbian.traceSTDPSynapse import TraceSTDPSynapse +# from .hebbian.expSTDPSynapse import ExpSTDPSynapse +# from .hebbian.eventSTDPSynapse import EventSTDPSynapse +# from .hebbian.BCMSynapse import BCMSynapse +# +# +# ## conv/deconv synaptic components +# from .convolution.convSynapse import ConvSynapse +# from .convolution.staticConvSynapse import StaticConvSynapse +# from .convolution.hebbianConvSynapse import HebbianConvSynapse +# from .convolution.traceSTDPConvSynapse import TraceSTDPConvSynapse +# from .convolution.deconvSynapse import DeconvSynapse +# from .convolution.staticDeconvSynapse import StaticDeconvSynapse +# from .convolution.hebbianDeconvSynapse import HebbianDeconvSynapse +# from .convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse +# +# +# ## modulated synaptic components +# from .modulated.MSTDPETSynapse import MSTDPETSynapse +# from .modulated.REINFORCESynapse import REINFORCESynapse +# +# ## patched synaptic components +# from .patched.patchedSynapse import PatchedSynapse +# from .patched.staticPatchedSynapse import StaticPatchedSynapse +# from .patched.hebbianPatchedSynapse import HebbianPatchedSynapse +# diff --git a/ngclearn/components/synapses/denseSynapse.py b/ngclearn/components/synapses/denseSynapse.py index fc4e7ea0..075471b3 100755 --- a/ngclearn/components/synapses/denseSynapse.py +++ b/ngclearn/components/synapses/denseSynapse.py @@ -4,9 +4,8 @@ from ngclearn.utils.weight_distribution import initialize_params from ngcsimlib.logger import info -from ngcsimlib.compilers.process import transition -from ngcsimlib.component import Component from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable class DenseSynapse(JaxComponent): ## base dense synaptic cable """ @@ -47,28 +46,30 @@ def __init__( ): super().__init__(name, **kwargs) - self.batch_size = batch_size + self.batch_size = Compartment(batch_size, fixed=True) self.weight_init = weight_init self.bias_init = bias_init ## Synapse meta-parameters - self.shape = shape - self.Rscale = resist_scale + self.shape = Compartment(shape, fixed=True) + self.resist_scale = Compartment(resist_scale, fixed=True) ## Set up synaptic weight values - tmp_key, *subkeys = random.split(self.key.value, 4) + tmp_key, *subkeys = random.split(self.key.get(), 4) + if self.weight_init is None: info(self.name, "is using default weight initializer!") self.weight_init = {"dist": "uniform", "amin": 0.025, "amax": 0.8} weights = initialize_params(subkeys[0], self.weight_init, shape) + if 0. < p_conn < 1.: ## only non-zero and <1 probs allowed p_mask = random.bernoulli(subkeys[1], p=p_conn, shape=shape) weights = weights * p_mask ## sparsify matrix - self.batch_size = batch_size #1 ## Compartment setup - preVals = jnp.zeros((self.batch_size, shape[0])) - postVals = jnp.zeros((self.batch_size, shape[1])) + preVals = jnp.zeros((self.batch_size.get(), shape[0])) + postVals = jnp.zeros((self.batch_size.get(), shape[1])) + self.inputs = Compartment(preVals) self.outputs = Compartment(postVals) self.weights = Compartment(weights) @@ -80,35 +81,16 @@ def __init__( (1, shape[1])) if bias_init else 0.0) - @transition(output_compartments=["outputs"]) - @staticmethod - def advance_state(Rscale, inputs, weights, biases): - outputs = (jnp.matmul(inputs, weights) * Rscale) + biases - return outputs - - @transition(output_compartments=["inputs", "outputs"]) - @staticmethod - def reset(batch_size, shape): - preVals = jnp.zeros((batch_size, shape[0])) - postVals = jnp.zeros((batch_size, shape[1])) - inputs = preVals - outputs = postVals - return inputs, outputs - - def save(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - if self.bias_init != None: - jnp.savez(file_name, weights=self.weights.value, - biases=self.biases.value) - else: - jnp.savez(file_name, weights=self.weights.value) - - def load(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - self.weights.set(data['weights']) - if "biases" in data.keys(): - self.biases.set(data['biases']) + @compilable + def advance_state(self): + self.outputs.set((jnp.matmul(self.inputs.get(), self.weights.get()) * self.resist_scale.get()) + self.biases.get()) + + @compilable + def reset(self): + if not self.inputs.targeted: + self.inputs.set(jnp.zeros((self.batch_size.get(), self.shape.get()[0]))) + + self.outputs.set(jnp.zeros((self.batch_size.get(), self.shape.get()[1]))) @classmethod def help(cls): ## component help function @@ -141,20 +123,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: diff --git a/ngclearn/components/synapses/hebbian/__init__.py b/ngclearn/components/synapses/hebbian/__init__.py index f39d556f..61b33f17 100644 --- a/ngclearn/components/synapses/hebbian/__init__.py +++ b/ngclearn/components/synapses/hebbian/__init__.py @@ -1,6 +1,6 @@ -from .hebbianSynapse import HebbianSynapse -from .traceSTDPSynapse import TraceSTDPSynapse -from .expSTDPSynapse import ExpSTDPSynapse -from .eventSTDPSynapse import EventSTDPSynapse -from .BCMSynapse import BCMSynapse - +# from .hebbianSynapse import HebbianSynapse +# from .traceSTDPSynapse import TraceSTDPSynapse +# from .expSTDPSynapse import ExpSTDPSynapse +# from .eventSTDPSynapse import EventSTDPSynapse +# from .BCMSynapse import BCMSynapse +# diff --git a/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py b/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py index 777c26cc..cd6dd86f 100755 --- a/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py +++ b/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py @@ -1,9 +1,8 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compilers.process import transition -from ngcsimlib.component import Component from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable -from ngclearn.components.synapses import DenseSynapse +from ngclearn.components.synapses.denseSynapse import DenseSynapse from ngclearn.utils import tensorstats @@ -73,102 +72,81 @@ def __init__( super().__init__(name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs) - ## Synaptic hyper-parameters - self.shape = shape ## shape of synaptic efficacy matrix - self.tau_w = tau_w - self.mu = mu ## controls power-scaling of STDP rule - self.preTrace_target = pretrace_target ## target (pre-synaptic) trace activity value # 0.7 - self.Aplus = A_plus ## LTP strength - self.Aminus = A_minus ## LTD strength - self.Rscale = resist_scale ## post-transformation scale factor - self.w_bound = w_bound #1. ## soft weight constraint - self.w_eps = 0. ## w_eps = 0.01 - self.weight_mask = weight_mask - if self.weight_mask is None: - self.weight_mask = jnp.ones((1, 1)) - self.weights.set(self.weights.value * self.weight_mask) + self.tau_w = Compartment(tau_w, fixed=True) + self.mu = Compartment(mu, fixed=True) ## controls power-scaling of STDP rule + self.preTrace_target = Compartment(pretrace_target, fixed=True) ## target (pre-synaptic) trace activity value # 0.7 + self.Aplus = Compartment(A_plus, fixed=True) ## LTP strength + self.Aminus = Compartment(A_minus, fixed=True) ## LTD strength + self.w_bound = Compartment(w_bound, fixed=True) #1. ## soft weight constraint + self.w_eps = Compartment(0., fixed=True) ## w_eps = 0.01 + + if weight_mask is None: + self.weight_mask = Compartment(jnp.ones((1, 1)), fixed=True) + else: + self.weight_mask = Compartment(self.weight_mask, fixed=True) + + self.weights.set(self.weights.get() * self.weight_mask.get()) ## Compartment setup - preVals = jnp.zeros((self.batch_size, shape[0])) - postVals = jnp.zeros((self.batch_size, shape[1])) + preVals = jnp.zeros((self.batch_size.get(), shape[0])) + postVals = jnp.zeros((self.batch_size.get(), shape[1])) self.preSpike = Compartment(preVals) self.postSpike = Compartment(postVals) self.preTrace = Compartment(preVals) self.postTrace = Compartment(postVals) - self.dWeights = Compartment(self.weights.value * 0) + self.dWeights = Compartment(self.weights.get() * 0) self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate - #@transition(output_compartments=["outputs"]) - #@staticmethod - #def advance_state(Rscale, inputs, weights, biases, weight_mask): - # outputs = (jnp.matmul(inputs, weights * weight_mask) * Rscale) + biases - # return outputs + def _compute_update(self): + if self.mu.get() > 0.: + post_shift = jnp.power(self.w_bound.get() - self.weights.get(), self.mu.get()) + pre_shift = jnp.power(self.weights.get(), self.mu.get()) + dWpost = (post_shift * jnp.matmul((self.preSpike.get() - self.preTrace_target.get()).T, self.postSpike.get())) * self.Aplus.get() + + if self.Aminus.get() > 0.: + dWpre = -(pre_shift * jnp.matmul(self.preSpike.get().T, self.postTrace.get())) * self.Aminus.get() + else: + dWpre = 0. - @staticmethod - def _compute_update( - dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights - ): - pre = preSpike - x_pre = preTrace - post = postSpike - x_post = postTrace - W = weights - x_tar = preTrace_target - if mu > 0.: - ## equations 3, 5, & 6 from Diehl and Cook - full power-law STDP - post_shift = jnp.power(w_bound - W, mu) - pre_shift = jnp.power(W, mu) - dWpost = (post_shift * jnp.matmul((x_pre - x_tar).T, post)) * Aplus - dWpre = 0. - if Aminus > 0.: - dWpre = -(pre_shift * jnp.matmul(pre.T, x_post)) * Aminus else: - ## calculate post-synaptic term - dWpost = jnp.matmul((x_pre - x_tar).T, post * Aplus) - - dWpre = 0. - if Aminus > 0.: - ## calculate pre-synaptic term - dWpre = -jnp.matmul(pre.T, x_post * Aminus) - ## calc final weighted adjustment - dW = (dWpost + dWpre) + dWpost = jnp.matmul((self.preSpike.get() - self.preTrace_target.get()).T, self.postSpike.get() * self.Aplus.get()) + if self.Aminus.get() > 0.: + dWpre = -jnp.matmul(self.preSpike.get().T, self.postTrace.get() * self.Aminus.get()) + else: + dWpre = 0. + + dW = (dWpost - dWpre) return dW - @transition(output_compartments=["weights", "dWeights"]) - @staticmethod - def evolve( - dt, w_bound, w_eps, preTrace_target, mu, Aplus, Aminus, tau_w, preSpike, postSpike, preTrace, - postTrace, weights, eta, weight_mask - ): - #_wm = weight_mask # - _wm = (weight_mask != 0.) - dWeights = TraceSTDPSynapse._compute_update( - dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights - ) - ## do a gradient ascent update/shift - decayTerm = 0. - if tau_w > 0.: - decayTerm = weights / tau_w - weights = weights + (dWeights * eta) - decayTerm #weight_mask * eta) - ## enforce non-negativity - #w_eps = 0. # 0.01 # 0.001 - weights = jnp.clip(weights, w_eps, w_bound - w_eps) # jnp.abs(w_bound)) - weights = weights * _wm # weight_mask - return weights, dWeights - - @transition(output_compartments=["inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace", "dWeights"]) - @staticmethod - def reset(batch_size, shape): - preVals = jnp.zeros((batch_size, shape[0])) - postVals = jnp.zeros((batch_size, shape[1])) - inputs = preVals - outputs = postVals - preSpike = preVals - postSpike = postVals - preTrace = preVals - postTrace = postVals - dWeights = jnp.zeros(shape) - return inputs, outputs, preSpike, postSpike, preTrace, postTrace, dWeights + @compilable + def evolve(self): + dWeights = self._compute_update() + if self.tau_w.get() > 0.: + decayTerm = self.weights.get() / self.tau_w.get() + else: + decayTerm = 0. + + # print(jnp.nonzero(dWeights)) + w = self.weights.get() + (dWeights * self.eta.get()) - decayTerm + w = jnp.clip(w, self.w_eps.get(), self.w_bound.get() - self.w_eps.get()) + w = jnp.where(self.weight_mask.get() != 0., w, 0.) + self.weights.set(w) + self.dWeights.set(dWeights) + + @compilable + def reset(self): + preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0])) + postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1])) + + if not self.inputs.targeted: + self.inputs.set(preVals) + self.outputs.set(postVals) + self.preSpike.set(preVals) + self.postSpike.set(postVals) + self.preTrace.set(preVals) + self.postTrace.set(postVals) + self.dWeights.set(jnp.zeros(self.shape.get())) + @classmethod def help(cls): ## component help function @@ -214,19 +192,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines if __name__ == '__main__': from ngcsimlib.context import Context diff --git a/ngclearn/utils/jaxProcess.py b/ngclearn/utils/jaxProcess.py index dd1dabc3..8c3de576 100644 --- a/ngclearn/utils/jaxProcess.py +++ b/ngclearn/utils/jaxProcess.py @@ -1,11 +1,11 @@ from ngcsimlib.compartment import Compartment -from ngcsimlib.compilers.process import Process +from ngcsimlib import MethodProcess from jax.lax import scan as _scan from ngcsimlib.logger import warn from jax import numpy as jnp -class JaxProcess(Process): +class JaxProcess(MethodProcess): """ The JaxProcess is a subclass of the ngcsimlib Process class. The functionality added by this subclass is the use of the jax scanner to run a diff --git a/ngclearn/utils/model_utils.py b/ngclearn/utils/model_utils.py index a7b9f141..64e28f21 100755 --- a/ngclearn/utils/model_utils.py +++ b/ngclearn/utils/model_utils.py @@ -6,7 +6,6 @@ import jax from jax import numpy as jnp, grad, jit, vmap, random, lax, nn from jax.lax import scan as _scan -from ngcsimlib.utils import Get_Compartment_Batch, Set_Compartment_Batch, get_current_context import os, sys from functools import partial import numpy as np @@ -716,39 +715,39 @@ def d_clip(x, min_val, max_val): return jnp.where((x < min_val) | (x > max_val), 0.0, 1.0) -def scanner(fn): - """ - A wrapper for Jax's scanner that handles the "getting" of the current - state and "setting" of the final state to and from the model. - - | @scanner - | def process(current_state, args): - | t = args[0] - | dt = args[1] - | current_state = model.advance_state(current_state, t, dt) - | current_state = model.evolve(current_state, t, dt) - | return current_state, (current_state[COMPONENT.COMPARTMENT.path], ...) - | - | outputs = models.process(jnp.array([[ARG0, ARG1] for i in range(NUM_LOOPS)])) - - | Notes on the scanner function call: - | 1) `current_state` is a hash-map mapped to all compartment values by path - | 2) `args` is the external arguments defined in the passed Jax array - | 3) `outputs` is a tuple containing time-concatenated Jax arrays of the - | compartment statistics you want tracked - - Args: - fn: function that is executed at every time step of a Jax-unrolled loop, - it must take in the current state and external arguments - - Returns: - wrapped (fast) function that is Jax-scanned/jit-i-fied - """ - def _scanned(_xs): - vals, stacked = _scan(fn, init=Get_Compartment_Batch(), xs=_xs) - Set_Compartment_Batch(vals) - return stacked - - if get_current_context() is not None: - get_current_context().__setattr__(fn.__name__, _scanned) - return _scanned +# def scanner(fn): +# """ +# A wrapper for Jax's scanner that handles the "getting" of the current +# state and "setting" of the final state to and from the model. +# +# | @scanner +# | def process(current_state, args): +# | t = args[0] +# | dt = args[1] +# | current_state = model.advance_state(current_state, t, dt) +# | current_state = model.evolve(current_state, t, dt) +# | return current_state, (current_state[COMPONENT.COMPARTMENT.path], ...) +# | +# | outputs = models.process(jnp.array([[ARG0, ARG1] for i in range(NUM_LOOPS)])) +# +# | Notes on the scanner function call: +# | 1) `current_state` is a hash-map mapped to all compartment values by path +# | 2) `args` is the external arguments defined in the passed Jax array +# | 3) `outputs` is a tuple containing time-concatenated Jax arrays of the +# | compartment statistics you want tracked +# +# Args: +# fn: function that is executed at every time step of a Jax-unrolled loop, +# it must take in the current state and external arguments +# +# Returns: +# wrapped (fast) function that is Jax-scanned/jit-i-fied +# """ +# def _scanned(_xs): +# vals, stacked = _scan(fn, init=Get_Compartment_Batch(), xs=_xs) +# Set_Compartment_Batch(vals) +# return stacked +# +# if get_current_context() is not None: +# get_current_context().__setattr__(fn.__name__, _scanned) +# return _scanned From ea724e96a7ace65b622c68d811ee7aa93fb423b9 Mon Sep 17 00:00:00 2001 From: Will Gebhardt Date: Thu, 24 Jul 2025 14:04:10 -0400 Subject: [PATCH 002/121] Undid fixed compartemts Undid the fixed compartments to work with new global constant tracking --- .../components/input_encoders/poissonCell.py | 12 ++--- .../components/neurons/spiking/LIFCell.py | 50 ++++++++--------- ngclearn/components/other/varTrace.py | 34 ++++++------ ngclearn/components/synapses/denseSynapse.py | 16 +++--- .../synapses/hebbian/traceSTDPSynapse.py | 54 +++++++++---------- 5 files changed, 83 insertions(+), 83 deletions(-) diff --git a/ngclearn/components/input_encoders/poissonCell.py b/ngclearn/components/input_encoders/poissonCell.py index 65022156..5eeb057b 100644 --- a/ngclearn/components/input_encoders/poissonCell.py +++ b/ngclearn/components/input_encoders/poissonCell.py @@ -34,14 +34,14 @@ def __init__(self, name: str, n_units: int, target_freq: float = 63.75, batch_si super().__init__(name=name, key=key) ## Constrained Bernoulli meta-parameters - self.target_freq = Compartment(target_freq, fixed=True) ## maximum frequency (in Hertz/Hz) + self.target_freq = target_freq ## maximum frequency (in Hertz/Hz) ## Layer Size Setup - self.batch_size = Compartment(batch_size, fixed=True) - self.n_units = Compartment(n_units, fixed=True) + self.batch_size = batch_size + self.n_units = n_units # Compartments (state of the cell, parameters, will be updated through stateless calls) - restVals = jnp.zeros((self.batch_size.get(), self.n_units.get())) + restVals = jnp.zeros((self.batch_size, self.n_units)) self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment self.outputs = Compartment(restVals, display_name="Spikes") # output compartment self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike @@ -49,7 +49,7 @@ def __init__(self, name: str, n_units: int, target_freq: float = 63.75, batch_si @compilable def advance_state(self, t, dt): key, subkey = random.split(self.key.get(), 2) - pspike = self.inputs.get() * (dt / 1000.) * self.target_freq.get() + pspike = self.inputs.get() * (dt / 1000.) * self.target_freq eps = random.uniform(subkey, self.inputs.get().shape, minval=0., maxval=1., dtype=jnp.float32) @@ -59,7 +59,7 @@ def advance_state(self, t, dt): @compilable def reset(self): - restVals = jnp.zeros((self.batch_size.get(), self.n_units.get())) + restVals = jnp.zeros((self.batch_size, self.n_units)) if not self.inputs.targeted: self.inputs.set(restVals) self.outputs.set(restVals) diff --git a/ngclearn/components/neurons/spiking/LIFCell.py b/ngclearn/components/neurons/spiking/LIFCell.py index 96217e46..aab7ba33 100644 --- a/ngclearn/components/neurons/spiking/LIFCell.py +++ b/ngclearn/components/neurons/spiking/LIFCell.py @@ -121,25 +121,25 @@ def __init__( self.max_one_spike = max_one_spike ## membrane parameter setup (affects ODE integration) - self.tau_m = Compartment(tau_m, fixed=True) ## membrane time constant - self.resist_m = Compartment(resist_m, fixed=True) ## resistance value + self.tau_m = tau_m ## membrane time constant + self.resist_m = resist_m ## resistance value - self.v_min = Compartment(v_min, fixed=True) ## ensures voltage is never < v_min + self.v_min = v_min ## ensures voltage is never < v_min - self.v_rest = Compartment(v_rest, fixed=True) #-65. # mV - self.v_reset = Compartment(v_reset, fixed=True) # -60. # -65. # mV (milli-volts) - self.g_L = Compartment(conduct_leak, fixed=True) ## controls strength of voltage leak (1 -> LIF, 0 => IF) + self.v_rest = v_rest #-65. # mV + self.v_reset = v_reset # -60. # -65. # mV (milli-volts) + self.g_L = conduct_leak ## controls strength of voltage leak (1 -> LIF, 0 => IF) ## basic asserts to prevent neuronal dynamics breaking... #assert (self.conduct_leak * self.dt / self.tau_m) <= 1. ## <-- to integrate in verify... - assert self.resist_m.get() > 0. - self.tau_theta = Compartment(tau_theta, fixed=True) ## threshold time constant # ms (0 turns off) - self.theta_plus = Compartment(theta_plus, fixed=True) #0.05 ## threshold increment - self.refract_T = Compartment(refract_time, fixed=True) #5. # 2. ## refractory period # ms - self.thr = Compartment(thr, fixed=True) ## (fixed) base value for threshold #-52 # -72. # mV + assert self.resist_m > 0. + self.tau_theta = tau_theta ## threshold time constant # ms (0 turns off) + self.theta_plus = theta_plus #0.05 ## threshold increment + self.refract_T = refract_time #5. # 2. ## refractory period # ms + self.thr = thr ## (fixed) base value for threshold #-52 # -72. # mV ## Layer Size Setup - self.batch_size = Compartment(1, fixed=True) - self.n_units = Compartment(n_units, fixed=True) + self.batch_size = 1 + self.n_units = n_units # ## set up surrogate function for spike emission # if surrogate_type == "secant_lif": @@ -153,7 +153,7 @@ def __init__( ## Compartment setup - restVals = jnp.zeros((self.batch_size.get(), self.n_units.get())) + restVals = jnp.zeros((self.batch_size, self.n_units)) self.j = Compartment(restVals, display_name="Current", units="mA") self.v = Compartment(restVals + self.v_rest, display_name="Voltage", units="mV") @@ -169,11 +169,11 @@ def __init__( @compilable def advance_state(self, dt, t): - j = self.j.get() * self.resist_m.get() + j = self.j.get() * self.resist_m - _v_thr = self.thr_theta.get() + self.thr.get() ## calc present voltage threshold + _v_thr = self.thr_theta.get() + self.thr ## calc present voltage threshold - v_params = (j, self.rfr.get(), self.tau_m.get(), self.refract_T.get(), self.v_rest.get(), self.g_L.get()) + v_params = (j, self.rfr.get(), self.tau_m.get(), self.refract_T, self.v_rest, self.g_L) if self.intgFlag == 1: _, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params) @@ -182,7 +182,7 @@ def advance_state(self, dt, t): s = (_v > _v_thr) * 1. _rfr = (self.rfr.get() + dt) * (1. - s) - _v = _v * (1. - s) + s * self.v_reset.get() + _v = _v * (1. - s) + s * self.v_reset raw_s = s @@ -200,16 +200,16 @@ def advance_state(self, dt, t): rS = nn.one_hot(jnp.argmax(self.v.get(), axis=1), num_classes=s.shape[1], dtype=jnp.float32) ## get max-volt spike s = s * rS ## mask out non-max volt spikes - if self.tau_theta.get() > 0.: + if self.tau_theta > 0.: ## run one integration step for threshold dynamics - thr_theta = _update_theta(dt, self.thr_theta.get(), raw_s, self.tau_theta.get(), self.theta_plus.get()) + thr_theta = _update_theta(dt, self.thr_theta.get(), raw_s, self.tau_theta, self.theta_plus.get()) self.thr_theta.set(thr_theta) ## update tols self.tols.set((1. - s) * self.tols.get() + (s * t)) - if self.v_min.get() is not None: ## ensures voltage never < v_rest - _v = jnp.maximum(_v, self.v_min.get()) + if self.v_min is not None: ## ensures voltage never < v_rest + _v = jnp.maximum(_v, self.v_min) self.v.set(_v) @@ -220,13 +220,13 @@ def advance_state(self, dt, t): @compilable def reset(self): - restVals = jnp.zeros((self.batch_size.get(), self.n_units.get())) + restVals = jnp.zeros((self.batch_size, self.n_units)) if not self.j.targeted: self.j.set(restVals) - self.v.set(restVals + self.v_rest.get()) + self.v.set(restVals + self.v_rest) self.s.set(restVals) self.s_raw.set(restVals) - self.rfr.set(restVals + self.refract_T.get()) + self.rfr.set(restVals + self.refract_T) self.tols.set(restVals) diff --git a/ngclearn/components/other/varTrace.py b/ngclearn/components/other/varTrace.py index 936b9f10..8e83bc2d 100644 --- a/ngclearn/components/other/varTrace.py +++ b/ngclearn/components/other/varTrace.py @@ -83,17 +83,17 @@ def __init__(self, name, n_units, tau_tr, a_delta, P_scale=1., gamma_tr=1, decay ## Trace control coefficients self.decay_type = decay_type ## lin --> linear decay; exp --> exponential decay - self.tau_tr = Compartment(tau_tr, fixed=True) ## trace time constant - self.a_delta = Compartment(a_delta, fixed=True) ## trace increment (if spike occurred) - self.P_scale = Compartment(P_scale, fixed=True) ## trace scale if non-additive trace to be used - self.gamma_tr = Compartment(gamma_tr, fixed=True) - self.n_nearest_spikes = Compartment(n_nearest_spikes, fixed=True) + self.tau_tr = tau_tr ## trace time constant + self.a_delta = a_delta ## trace increment (if spike occurred) + self.P_scale = P_scale ## trace scale if non-additive trace to be used + self.gamma_tr = gamma_tr + self.n_nearest_spikes = n_nearest_spikes ## Layer Size Setup - self.batch_size = Compartment(batch_size, fixed=True) - self.n_units = Compartment(n_units, fixed=True) + self.batch_size = batch_size + self.n_units = n_units - restVals = jnp.zeros((self.batch_size.get(), self.n_units.get())) + restVals = jnp.zeros((self.batch_size, self.n_units)) self.inputs = Compartment(restVals) # input compartment self.outputs = Compartment(restVals) # output compartment self.trace = Compartment(restVals) @@ -101,21 +101,21 @@ def __init__(self, name, n_units, tau_tr, a_delta, P_scale=1., gamma_tr=1, decay @compilable def advance_state(self, dt): if "exp" in self.decay_type: - decayFactor = jnp.exp(-dt/self.tau_tr.get()) + decayFactor = jnp.exp(-dt/self.tau_tr) elif "lin" in self.decay_type: - decayFactor = (1. - dt/self.tau_tr.get()) + decayFactor = (1. - dt/self.tau_tr) else: decayFactor = 0. - _x_tr = self.gamma_tr.get() * self.trace.get() * decayFactor - if self.n_nearest_spikes.get() > 0: - _x_tr = _x_tr + self.inputs.get() * (self.a_delta.get() - (self.trace.get() / self.n_nearest_spikes.get())) + _x_tr = self.gamma_tr * self.trace.get() * decayFactor + if self.n_nearest_spikes > 0: + _x_tr = _x_tr + self.inputs.get() * (self.a_delta - (self.trace.get() / self.n_nearest_spikes)) else: - if self.a_delta.get() > 0.: - _x_tr = _x_tr + self.inputs.get() * self.a_delta.get() + if self.a_delta > 0.: + _x_tr = _x_tr + self.inputs.get() * self.a_delta else: - _x_tr = _x_tr * (1. - self.inputs.get()) + self.inputs.get() * self.P_scale.get() + _x_tr = _x_tr * (1. - self.inputs.get()) + self.inputs.get() * self.P_scale self.trace.set(_x_tr) self.outputs.set(_x_tr) @@ -123,7 +123,7 @@ def advance_state(self, dt): @compilable def reset(self): - restVals = jnp.zeros((self.batch_size.get(), self.n_units.get())) + restVals = jnp.zeros((self.batch_size, self.n_units)) self.inputs.set(restVals) self.outputs.set(restVals) self.trace.set(restVals) diff --git a/ngclearn/components/synapses/denseSynapse.py b/ngclearn/components/synapses/denseSynapse.py index 075471b3..91a4fda3 100755 --- a/ngclearn/components/synapses/denseSynapse.py +++ b/ngclearn/components/synapses/denseSynapse.py @@ -46,13 +46,13 @@ def __init__( ): super().__init__(name, **kwargs) - self.batch_size = Compartment(batch_size, fixed=True) + self.batch_size = batch_size self.weight_init = weight_init self.bias_init = bias_init ## Synapse meta-parameters - self.shape = Compartment(shape, fixed=True) - self.resist_scale = Compartment(resist_scale, fixed=True) + self.shape = shape + self.resist_scale = resist_scale ## Set up synaptic weight values tmp_key, *subkeys = random.split(self.key.get(), 4) @@ -67,8 +67,8 @@ def __init__( weights = weights * p_mask ## sparsify matrix ## Compartment setup - preVals = jnp.zeros((self.batch_size.get(), shape[0])) - postVals = jnp.zeros((self.batch_size.get(), shape[1])) + preVals = jnp.zeros((self.batch_size, shape[0])) + postVals = jnp.zeros((self.batch_size, shape[1])) self.inputs = Compartment(preVals) self.outputs = Compartment(postVals) @@ -83,14 +83,14 @@ def __init__( @compilable def advance_state(self): - self.outputs.set((jnp.matmul(self.inputs.get(), self.weights.get()) * self.resist_scale.get()) + self.biases.get()) + self.outputs.set((jnp.matmul(self.inputs.get(), self.weights.get()) * self.resist_scale) + self.biases.get()) @compilable def reset(self): if not self.inputs.targeted: - self.inputs.set(jnp.zeros((self.batch_size.get(), self.shape.get()[0]))) + self.inputs.set(jnp.zeros((self.batch_size, self.shape[0]))) - self.outputs.set(jnp.zeros((self.batch_size.get(), self.shape.get()[1]))) + self.outputs.set(jnp.zeros((self.batch_size, self.shape[1]))) @classmethod def help(cls): ## component help function diff --git a/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py b/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py index cd6dd86f..b941c83d 100755 --- a/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py +++ b/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py @@ -72,46 +72,46 @@ def __init__( super().__init__(name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs) - self.tau_w = Compartment(tau_w, fixed=True) - self.mu = Compartment(mu, fixed=True) ## controls power-scaling of STDP rule - self.preTrace_target = Compartment(pretrace_target, fixed=True) ## target (pre-synaptic) trace activity value # 0.7 - self.Aplus = Compartment(A_plus, fixed=True) ## LTP strength - self.Aminus = Compartment(A_minus, fixed=True) ## LTD strength - self.w_bound = Compartment(w_bound, fixed=True) #1. ## soft weight constraint - self.w_eps = Compartment(0., fixed=True) ## w_eps = 0.01 + self.tau_w = tau_w + self.mu = mu ## controls power-scaling of STDP rule + self.preTrace_target = pretrace_target ## target (pre-synaptic) trace activity value # 0.7 + self.Aplus = A_plus ## LTP strength + self.Aminus = A_minus ## LTD strength + self.w_bound = w_bound #1. ## soft weight constraint + self.w_eps = 0. ## w_eps = 0.01 if weight_mask is None: - self.weight_mask = Compartment(jnp.ones((1, 1)), fixed=True) + self.weight_mask = jnp.ones((1, 1)) else: - self.weight_mask = Compartment(self.weight_mask, fixed=True) + self.weight_mask = self.weight_mask - self.weights.set(self.weights.get() * self.weight_mask.get()) + self.weights.set(self.weights.get() * self.weight_mask) ## Compartment setup - preVals = jnp.zeros((self.batch_size.get(), shape[0])) - postVals = jnp.zeros((self.batch_size.get(), shape[1])) + preVals = jnp.zeros((self.batch_size, shape[0])) + postVals = jnp.zeros((self.batch_size, shape[1])) self.preSpike = Compartment(preVals) self.postSpike = Compartment(postVals) self.preTrace = Compartment(preVals) self.postTrace = Compartment(postVals) self.dWeights = Compartment(self.weights.get() * 0) - self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate + self.eta = jnp.ones((1, 1)) * eta ## global learning rate def _compute_update(self): - if self.mu.get() > 0.: - post_shift = jnp.power(self.w_bound.get() - self.weights.get(), self.mu.get()) - pre_shift = jnp.power(self.weights.get(), self.mu.get()) - dWpost = (post_shift * jnp.matmul((self.preSpike.get() - self.preTrace_target.get()).T, self.postSpike.get())) * self.Aplus.get() + if self.mu > 0.: + post_shift = jnp.power(self.w_bound - self.weights.get(), self.mu) + pre_shift = jnp.power(self.weights.get(), self.mu) + dWpost = (post_shift * jnp.matmul((self.preSpike.get() - self.preTrace_target).T, self.postSpike.get())) * self.Aplus - if self.Aminus.get() > 0.: - dWpre = -(pre_shift * jnp.matmul(self.preSpike.get().T, self.postTrace.get())) * self.Aminus.get() + if self.Aminus > 0.: + dWpre = -(pre_shift * jnp.matmul(self.preSpike.get().T, self.postTrace.get())) * self.Aminus else: dWpre = 0. else: - dWpost = jnp.matmul((self.preSpike.get() - self.preTrace_target.get()).T, self.postSpike.get() * self.Aplus.get()) - if self.Aminus.get() > 0.: - dWpre = -jnp.matmul(self.preSpike.get().T, self.postTrace.get() * self.Aminus.get()) + dWpost = jnp.matmul((self.preSpike.get() - self.preTrace_target).T, self.postSpike.get() * self.Aplus) + if self.Aminus > 0.: + dWpre = -jnp.matmul(self.preSpike.get().T, self.postTrace.get() * self.Aminus) else: dWpre = 0. @@ -121,15 +121,15 @@ def _compute_update(self): @compilable def evolve(self): dWeights = self._compute_update() - if self.tau_w.get() > 0.: - decayTerm = self.weights.get() / self.tau_w.get() + if self.tau_w > 0.: + decayTerm = self.weights.get() / self.tau_w else: decayTerm = 0. # print(jnp.nonzero(dWeights)) - w = self.weights.get() + (dWeights * self.eta.get()) - decayTerm - w = jnp.clip(w, self.w_eps.get(), self.w_bound.get() - self.w_eps.get()) - w = jnp.where(self.weight_mask.get() != 0., w, 0.) + w = self.weights.get() + (dWeights * self.eta) - decayTerm + w = jnp.clip(w, self.w_eps, self.w_bound - self.w_eps) + w = jnp.where(self.weight_mask != 0., w, 0.) self.weights.set(w) self.dWeights.set(dWeights) From d2d4331e4e67f2666236ef9c932a4da175154cb7 Mon Sep 17 00:00:00 2001 From: Will Gebhardt Date: Fri, 8 Aug 2025 11:29:02 -0400 Subject: [PATCH 003/121] Fixed an execution bug --- ngclearn/components/jaxComponent.py | 3 +-- .../components/synapses/hebbian/traceSTDPSynapse.py | 10 +++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/ngclearn/components/jaxComponent.py b/ngclearn/components/jaxComponent.py index 6d8f08ae..858a09c3 100755 --- a/ngclearn/components/jaxComponent.py +++ b/ngclearn/components/jaxComponent.py @@ -1,6 +1,6 @@ import time -from typing import Union +from typing import Union, Dict, Any import jax from jax import numpy as jnp from jax import random @@ -25,7 +25,6 @@ def __init__(self, name: str, key: Union[jax.Array, None] = None): self.key = Compartment( random.PRNGKey(time.time_ns()) if key is None else key) - def save(self, directory: str): """ The default save method for JaxComponents, it stores the values of all diff --git a/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py b/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py index b941c83d..66d3137c 100755 --- a/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py +++ b/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py @@ -83,7 +83,7 @@ def __init__( if weight_mask is None: self.weight_mask = jnp.ones((1, 1)) else: - self.weight_mask = self.weight_mask + self.weight_mask = weight_mask self.weights.set(self.weights.get() * self.weight_mask) @@ -95,13 +95,13 @@ def __init__( self.preTrace = Compartment(preVals) self.postTrace = Compartment(postVals) self.dWeights = Compartment(self.weights.get() * 0) - self.eta = jnp.ones((1, 1)) * eta ## global learning rate + self.eta = eta ## global learning rate def _compute_update(self): if self.mu > 0.: post_shift = jnp.power(self.w_bound - self.weights.get(), self.mu) pre_shift = jnp.power(self.weights.get(), self.mu) - dWpost = (post_shift * jnp.matmul((self.preSpike.get() - self.preTrace_target).T, self.postSpike.get())) * self.Aplus + dWpost = (post_shift * jnp.matmul((self.preTrace.get() - self.preTrace_target).T, self.postSpike.get())) * self.Aplus if self.Aminus > 0.: dWpre = -(pre_shift * jnp.matmul(self.preSpike.get().T, self.postTrace.get())) * self.Aminus @@ -109,13 +109,13 @@ def _compute_update(self): dWpre = 0. else: - dWpost = jnp.matmul((self.preSpike.get() - self.preTrace_target).T, self.postSpike.get() * self.Aplus) + dWpost = jnp.matmul((self.preTrace.get() - self.preTrace_target).T, self.postSpike.get() * self.Aplus) if self.Aminus > 0.: dWpre = -jnp.matmul(self.preSpike.get().T, self.postTrace.get() * self.Aminus) else: dWpre = 0. - dW = (dWpost - dWpre) + dW = (dWpost + dWpre) return dW @compilable From 2fc230042c1427eef5c372298a8a0c3b9aad71f5 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Mon, 22 Sep 2025 18:01:37 -0400 Subject: [PATCH 004/121] ported over quad-lif to v3 - needs testing --- .../components/neurons/spiking/LIFCell.py | 2 +- .../components/neurons/spiking/quadLIFCell.py | 127 +++++++++--------- 2 files changed, 63 insertions(+), 66 deletions(-) diff --git a/ngclearn/components/neurons/spiking/LIFCell.py b/ngclearn/components/neurons/spiking/LIFCell.py index aab7ba33..24d8cb3f 100644 --- a/ngclearn/components/neurons/spiking/LIFCell.py +++ b/ngclearn/components/neurons/spiking/LIFCell.py @@ -202,7 +202,7 @@ def advance_state(self, dt, t): if self.tau_theta > 0.: ## run one integration step for threshold dynamics - thr_theta = _update_theta(dt, self.thr_theta.get(), raw_s, self.tau_theta, self.theta_plus.get()) + thr_theta = _update_theta(dt, self.thr_theta.get(), raw_s, self.tau_theta, self.theta_plus) #.get()) self.thr_theta.set(thr_theta) ## update tols diff --git a/ngclearn/components/neurons/spiking/quadLIFCell.py b/ngclearn/components/neurons/spiking/quadLIFCell.py index ec7bbd32..c240f1a3 100755 --- a/ngclearn/components/neurons/spiking/quadLIFCell.py +++ b/ngclearn/components/neurons/spiking/quadLIFCell.py @@ -1,17 +1,16 @@ from ngclearn.components.jaxComponent import JaxComponent -from jax import numpy as jnp, random, jit, nn +from jax import numpy as jnp, random, jit, nn, Array from functools import partial from ngclearn.utils import tensorstats -from ngcsimlib.deprecators import deprecate_args +from ngcsimlib import deprecate_args from ngcsimlib.logger import info, warn from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ step_euler, step_rk2 -from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator, - triangular_estimator, - straight_through_estimator) +# from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator, +# triangular_estimator, +# straight_through_estimator) -from ngcsimlib.compilers.process import transition -#from ngcsimlib.component import Component +from ngcsimlib.parser import compilable from ngcsimlib.compartment import Compartment from ngclearn.components.neurons.spiking.LIFCell import LIFCell @@ -30,7 +29,7 @@ def _dfv(t, v, params): ## voltage dynamics wrapper return dv_dt #@partial(jit, static_argnums=[3, 4]) -def _update_theta(dt, v_theta, s, tau_theta, theta_plus=0.05): +def _update_theta(dt, v_theta, s, tau_theta, theta_plus: Array | float=0.05): ### Runs homeostatic threshold update dynamics one step (via Euler integration). #theta_decay = 0.9999999 #0.999999762 #jnp.exp(-dt/1e7) #theta_plus = 0.05 @@ -133,71 +132,69 @@ def __init__( self.v_c = v_scale self.a0 = critical_v - @transition(output_compartments=["v", "s", "s_raw", "rfr", "thr_theta", "tols", "key", "surrogate"]) - @staticmethod - def advance_state( - t, dt, tau_m, resist_m, v_rest, v_reset, v_c, a0, refract_T, thr, tau_theta, theta_plus, - one_spike, lower_clamp_voltage, intgFlag, d_spike_fx, key, j, v, rfr, thr_theta, tols - ): - skey = None ## this is an empty dkey if single_spike mode turned off - if one_spike: - key, skey = random.split(key, 2) - ## run one integration step for neuronal dynamics - j = j * resist_m - ############################################################################ - ### Runs leaky integrator (leaky integrate-and-fire; LIF) neuronal dynamics. - _v_thr = thr_theta + thr #v_theta + v_thr ## calc present voltage threshold - #mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask - ## update voltage / membrane potential - v_params = (j, rfr, tau_m, refract_T, v_rest, v_c, a0) - if intgFlag == 1: - _, _v = step_rk2(0., v, _dfv, dt, v_params) - else: #_v = v + (v_rest - v) * (dt/tau_m) + (j * mask) - _, _v = step_euler(0., v, _dfv, dt, v_params) - ## obtain action potentials/spikes + @compilable + def advance_state(self, dt, t): + j = self.j.get() * self.resist_m + + _v_thr = self.thr_theta.get() + self.thr ## calc present voltage threshold + + v_params = (j, self.rfr.get(), self.tau_m.get(), self.refract_T, self.v_rest, self.v_c, self.a0) + + if self.intgFlag == 1: + _, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params) + else: + _, _v = step_euler(0., self.v.get(), _dfv, dt, v_params) + s = (_v > _v_thr) * 1. - ## update refractory variables - _rfr = (rfr + dt) * (1. - s) - ## perform hyper-polarization of neuronal cells - _v = _v * (1. - s) + s * v_reset - - raw_s = s + 0 ## preserve un-altered spikes - ############################################################################ - ## this is a spike post-processing step - if skey is not None: - m_switch = (jnp.sum(s) > 0.).astype(jnp.float32) ## TODO: not batch-able + _rfr = (self.rfr.get() + dt) * (1. - s) + _v = _v * (1. - s) + s * self.v_reset + + raw_s = s + + #surrogate = d_spike_fx(v, _v_thr) # d_spike_fx(v, thr + thr_theta) + + if self.one_spike and not self.max_one_spike: + key, skey = random.split(self.key.get(), 2) + + m_switch = (jnp.sum(s) > 0.).astype(jnp.float32) ## TODO: not batch-able rS = s * random.uniform(skey, s.shape) rS = nn.one_hot(jnp.argmax(rS, axis=1), num_classes=s.shape[1], dtype=jnp.float32) s = s * (1. - m_switch) + rS * m_switch - ############################################################################ - raw_spikes = raw_s - v = _v - rfr = _rfr + self.key.set(key) - surrogate = d_spike_fx(v, _v_thr) #d_spike_fx(v, thr + thr_theta) - if tau_theta > 0.: + if self.max_one_spike: + rS = nn.one_hot(jnp.argmax(self.v.get(), axis=1), num_classes=s.shape[1], + dtype=jnp.float32) ## get max-volt spike + s = s * rS ## mask out non-max volt spikes + + if self.tau_theta > 0.: ## run one integration step for threshold dynamics - thr_theta = _update_theta(dt, thr_theta, raw_spikes, tau_theta, theta_plus) + thr_theta = _update_theta(dt, self.thr_theta.get(), raw_s, self.tau_theta, self.theta_plus) # .get()) + self.thr_theta.set(thr_theta) + ## update tols - tols = (1. - s) * tols + (s * t) - if lower_clamp_voltage: ## ensure voltage never < v_rest - v = jnp.maximum(v, v_rest) - return v, s, raw_spikes, rfr, thr_theta, tols, key, surrogate - - @transition(output_compartments=["j", "v", "s", "s_raw", "rfr", "tols", "surrogate"]) - @staticmethod - def reset(batch_size, n_units, v_rest, refract_T): - restVals = jnp.zeros((batch_size, n_units)) - j = restVals #+ 0 - v = restVals + v_rest - s = restVals #+ 0 - s_raw = restVals - rfr = restVals + refract_T - #thr_theta = restVals ## do not reset thr_theta - tols = restVals #+ 0 - surrogate = restVals + 1. - return j, v, s, s_raw, rfr, tols, surrogate + self.tols.set((1. - s) * self.tols.get() + (s * t)) + + if self.v_min is not None: ## ensures voltage never < v_rest + _v = jnp.maximum(_v, self.v_min) + + self.v.set(_v) + self.s.set(s) + self.s_raw.set(raw_s) + self.rfr.set(_rfr) + + @compilable + def reset(self): + restVals = jnp.zeros((self.batch_size, self.n_units)) + if not self.j.targeted: + self.j.set(restVals) + self.v.set(restVals + self.v_rest) + self.s.set(restVals) + self.s_raw.set(restVals) + self.rfr.set(restVals + self.refract_T) + self.tols.set(restVals) + #self.surrogate.set(restVals) def save(self, directory, **kwargs): ## do a protected save of constants, depending on whether they are floats or arrays From 92b6940da9c6341fbd531aa5d9169e264d8e9f4f Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Fri, 26 Sep 2025 11:31:27 -0400 Subject: [PATCH 005/121] ported over IF/quadLIF cells, minor revision to LIF cell --- ngclearn/components/neurons/spiking/IFCell.py | 92 +++++++++---------- .../components/neurons/spiking/LIFCell.py | 2 +- .../components/neurons/spiking/quadLIFCell.py | 4 +- 3 files changed, 49 insertions(+), 49 deletions(-) diff --git a/ngclearn/components/neurons/spiking/IFCell.py b/ngclearn/components/neurons/spiking/IFCell.py index 08416f6d..42814a3d 100755 --- a/ngclearn/components/neurons/spiking/IFCell.py +++ b/ngclearn/components/neurons/spiking/IFCell.py @@ -2,16 +2,15 @@ from jax import numpy as jnp, random, jit, nn from functools import partial from ngclearn.utils import tensorstats -from ngcsimlib.deprecators import deprecate_args +from ngcsimlib import deprecate_args from ngcsimlib.logger import info, warn from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ step_euler, step_rk2 -from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator, - triangular_estimator, - straight_through_estimator) +# from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator, +# triangular_estimator, +# straight_through_estimator) -from ngcsimlib.compilers.process import transition -#from ngcsimlib.component import Component +from ngcsimlib.parser import compilable from ngcsimlib.compartment import Compartment @@ -35,7 +34,7 @@ class IFCell(JaxComponent): ## integrate-and-fire cell The specific differential equation that characterizes this cell is (for adjusting v, given current j, over time) is: - | tau_m * dv/dt = (v_rest - v) + j * R + | tau_m * dv/dt = j * R | where R is the membrane resistance and v_rest is the resting potential | also, if a spike occurs, v is set to v_reset @@ -91,10 +90,10 @@ class IFCell(JaxComponent): ## integrate-and-fire cell """ @deprecate_args(thr_jitter=None) - def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., - v_reset=-60., refract_time=0., integration_type="euler", - surrogate_type="straight_through", lower_clamp_voltage=True, - **kwargs): + def __init__( + self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., refract_time=0., + integration_type="euler", surrogate_type="straight_through", lower_clamp_voltage=True, **kwargs + ): super().__init__(name, **kwargs) ## Integration properties @@ -118,12 +117,12 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., self.n_units = n_units ## set up surrogate function for spike emission - if surrogate_type == "arctan": - self.spike_fx, self.d_spike_fx = arctan_estimator() - elif surrogate_type == "triangular": - self.spike_fx, self.d_spike_fx = triangular_estimator() - else: ## default: straight_through - self.spike_fx, self.d_spike_fx = straight_through_estimator() + # if surrogate_type == "arctan": + # self.spike_fx, self.d_spike_fx = arctan_estimator() + # elif surrogate_type == "triangular": + # self.spike_fx, self.d_spike_fx = triangular_estimator() + # else: ## default: straight_through + # self.spike_fx, self.d_spike_fx = straight_through_estimator() ## Compartment setup @@ -138,47 +137,48 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., units="ms") ## time-of-last-spike self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value") - @transition(output_compartments=["v", "s", "rfr", "tols", "key", "surrogate"]) - @staticmethod + @compilable def advance_state( - t, dt, tau_m, resist_m, v_rest, v_reset, refract_T, thr, lower_clamp_voltage, intgFlag, d_spike_fx, key, - j, v, rfr, tols + self, dt, t ): ## run one integration step for neuronal dynamics - j = j * resist_m + j = self.j.get() * self.resist_m ### Runs integrator (or integrate-and-fire; IF) neuronal dynamics ## update voltage / membrane potential - v_params = (j, rfr, tau_m, refract_T) - if intgFlag == 1: - _, _v = step_rk2(0., v, _dfv, dt, v_params) + v_params = (j, self.rfr.get(), self.tau_m, self.refract_T) + if self.intgFlag == 1: + _, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params) else: - _, _v = step_euler(0., v, _dfv, dt, v_params) + _, _v = step_euler(0., self.v.get(), _dfv, dt, v_params) ## obtain action potentials/spikes - s = (_v > thr) * 1. + s = (_v > self.thr) * 1. ## update refractory variables - rfr = (rfr + dt) * (1. - s) + rfr = (self.rfr.get() + dt) * (1. - s) ## perform hyper-polarization of neuronal cells - v = _v * (1. - s) + s * v_reset + v = _v * (1. - s) + s * self.v_reset + + #surrogate = d_spike_fx(v, self.thr) - surrogate = d_spike_fx(v, thr) ## update tols - tols = (1. - s) * tols + (s * t) - if lower_clamp_voltage: ## ensure voltage never < v_rest - v = jnp.maximum(v, v_rest) - return v, s, rfr, tols, key, surrogate - - @transition(output_compartments=["j", "v", "s", "rfr", "tols", "surrogate"]) - @staticmethod - def reset(batch_size, n_units, v_rest, refract_T): - restVals = jnp.zeros((batch_size, n_units)) - j = restVals #+ 0 - v = restVals + v_rest - s = restVals #+ 0 - rfr = restVals + refract_T - tols = restVals #+ 0 - surrogate = restVals + 1. - return j, v, s, rfr, tols, surrogate + self.tols.set((1. - s) * self.tols.get() + (s * t)) + if self.lower_clamp_voltage: ## ensure voltage never < v_rest + _v = jnp.maximum(v, self.v_rest) + + self.v.set(_v) + self.s.set(s) + self.rfr.set(rfr) + + @compilable + def reset(self): + restVals = jnp.zeros((self.batch_size, self.n_units)) + if not self.j.targeted: + self.j.set(restVals) + self.v.set(restVals + self.v_rest) + self.s.set(restVals) + self.rfr.set(restVals + self.refract_T) + self.tols.set(restVals) + #surrogate = restVals + 1. def save(self, directory, **kwargs): ## do a protected save of constants, depending on whether they are floats or arrays diff --git a/ngclearn/components/neurons/spiking/LIFCell.py b/ngclearn/components/neurons/spiking/LIFCell.py index 24d8cb3f..d00fa171 100644 --- a/ngclearn/components/neurons/spiking/LIFCell.py +++ b/ngclearn/components/neurons/spiking/LIFCell.py @@ -19,7 +19,7 @@ def _dfv(t, v, params): ## voltage dynamics wrapper #@partial(jit, static_argnums=[3, 4]) -def _update_theta(dt, v_theta, s, tau_theta, theta_plus=0.05): +def _update_theta(dt, v_theta, s, tau_theta, theta_plus: Array=0.05): ### Runs homeostatic threshold update dynamics one step (via Euler integration). #theta_decay = 0.9999999 #0.999999762 #jnp.exp(-dt/1e7) #theta_plus = 0.05 diff --git a/ngclearn/components/neurons/spiking/quadLIFCell.py b/ngclearn/components/neurons/spiking/quadLIFCell.py index c240f1a3..da1084c4 100755 --- a/ngclearn/components/neurons/spiking/quadLIFCell.py +++ b/ngclearn/components/neurons/spiking/quadLIFCell.py @@ -29,7 +29,7 @@ def _dfv(t, v, params): ## voltage dynamics wrapper return dv_dt #@partial(jit, static_argnums=[3, 4]) -def _update_theta(dt, v_theta, s, tau_theta, theta_plus: Array | float=0.05): +def _update_theta(dt, v_theta, s, tau_theta, theta_plus: Array=0.05): ### Runs homeostatic threshold update dynamics one step (via Euler integration). #theta_decay = 0.9999999 #0.999999762 #jnp.exp(-dt/1e7) #theta_plus = 0.05 @@ -138,7 +138,7 @@ def advance_state(self, dt, t): _v_thr = self.thr_theta.get() + self.thr ## calc present voltage threshold - v_params = (j, self.rfr.get(), self.tau_m.get(), self.refract_T, self.v_rest, self.v_c, self.a0) + v_params = (j, self.rfr.get(), self.tau_m, self.refract_T, self.v_rest, self.v_c, self.a0) if self.intgFlag == 1: _, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params) From aa3b52e6d87a1e992aea7b73ff624a87441a42fb Mon Sep 17 00:00:00 2001 From: Will Gebhardt Date: Mon, 29 Sep 2025 15:56:46 -0400 Subject: [PATCH 006/121] Start util cleanup --- ngclearn/__init__.py | 4 +- ngclearn/utils/__init__.py | 13 +- ngclearn/utils/distribution_generator.py | 413 +++++++++++++++++++++++ ngclearn/utils/optim/sgd.py | 11 +- ngclearn/utils/patch.py | 101 ++++++ ngclearn/utils/viz/compartment_plot.py | 38 +++ ngclearn/utils/viz/compartment_raster.py | 49 +++ 7 files changed, 607 insertions(+), 22 deletions(-) mode change 100755 => 100644 ngclearn/utils/__init__.py create mode 100644 ngclearn/utils/distribution_generator.py create mode 100644 ngclearn/utils/patch.py create mode 100644 ngclearn/utils/viz/compartment_plot.py create mode 100755 ngclearn/utils/viz/compartment_raster.py diff --git a/ngclearn/__init__.py b/ngclearn/__init__.py index d404651e..8d0c7e10 100644 --- a/ngclearn/__init__.py +++ b/ngclearn/__init__.py @@ -28,11 +28,11 @@ import ngcsimlib -from ngcsimlib import Component, MethodProcess, JointProcess +from ngclearn.utils import JointProcess, MethodProcess from ngcsimlib.context import Context, ContextObjectTypes from ngcsimlib.compartment import Compartment -from ngclearn.utils.jaxProcess import JaxProcess +# from ngclearn.utils.jaxProcess import JaxProcess from ngcsimlib import logger # if not Path(argv[0]).name == "sphinx-build" or Path(argv[0]).name == "build.py": diff --git a/ngclearn/utils/__init__.py b/ngclearn/utils/__init__.py old mode 100755 new mode 100644 index 9c9f984c..7bba010f --- a/ngclearn/utils/__init__.py +++ b/ngclearn/utils/__init__.py @@ -1,10 +1,3 @@ -from .model_utils import tensorstats -from .jaxProcess import JaxProcess -## forward imports from core ngc-learn utility sub-packages -from . import viz -from . import io_utils -from . import metric_utils -from . import model_utils -from . import patch_utils -from . import weight_distribution -from . import surrogate_fx +from .distribution_generator import DistributionGenerator +from .JaxProcessesMixin import JaxJointProcess as JointProcess, JaxMethodProcess as MethodProcess + diff --git a/ngclearn/utils/distribution_generator.py b/ngclearn/utils/distribution_generator.py new file mode 100644 index 00000000..f62c32b1 --- /dev/null +++ b/ngclearn/utils/distribution_generator.py @@ -0,0 +1,413 @@ +import time +from typing import TypedDict, List, Protocol, Sequence +from typing_extensions import Unpack +import jax +import numpy + +from ngcsimlib.logger import error + + +class DistributionParams(TypedDict, total=False): + """ + Extra parameters to be used when generating distributions. + + Attributes: + amin: sets the lower bound of the distribution + amax: sets the upper bound of the distribution + lower_triangle: keeps the lower triangle, sets the rest to zero + upper_triangle: keeps the upper triangle, sets the rest to zero + hollow: produces a hollow distribution (zeros along the diagonal) + eye: produces an eye distribution (zeros the off-diagonal) + col_mask: + single value, keeps n random columns + list values, keeps the provided column indices + row_mask: + single value, keeps n random rows + list values, keeps the provided row indices + use_numpy: use default numpy + """ + amin: float + amax: float + lower_triangle: bool + upper_triangle: bool + hollow: bool + eye: bool + col_mask: int | List[int] + row_mask: int | List[int] + use_numpy: bool + dtype: numpy.dtype + + +class DistributionInitializer(Protocol): + def __call__(self, shape: Sequence[int], + dkey: jax.dtypes.prng_key | int | None = None) -> jax.Array: ... + + + +class DistributionGenerator(object): + @staticmethod + def constant(value: float, **params: Unpack[ + DistributionParams]) -> DistributionInitializer: + """ + Produces a distribution initializer for a constant distribution. + Args: + value: the constant value to fill the array with + **params: the extra distribution parameters + + Returns: a distribution initializer + """ + using_np = params.get("use_numpy", False) + if using_np: + def constant_generator(shape: Sequence[int], + seed: int | None = None) -> numpy.ndarray: + matrix = numpy.ones(shape, + params.get("dtype", numpy.float32)) * value + matrix = DistributionGenerator._process_params_numpy(matrix, + params, + seed) + return matrix + else: + def constant_generator(shape: Sequence[int], + dKey: jax.dtypes.prng_key | None = None) -> jax.Array: + matrix = jax.numpy.ones(shape, params.get("dtype", + jax.numpy.float32)) * value + matrix = DistributionGenerator._process_params_jax(matrix, + params, dKey) + return matrix + return constant_generator + + @staticmethod + def uniform(low: float = 0.0, high: float = 1.0, **params: Unpack[ + DistributionParams]) -> DistributionInitializer: + """ + Produces a distribution initializer for a uniform distribution. + + Args: + low: lower bound of the uniform distribution (inclusive) + high: upper bound of the uniform distribution (exclusive) + **params: the extra distribution parameters + + Returns: a distribution initializer + """ + using_np = params.get("use_numpy", False) + + if using_np: + def uniform_generator(shape: Sequence[int], + seed: int | None = None) -> numpy.ndarray: + rng = numpy.random.default_rng(seed) + matrix = rng.uniform(low=low, high=high, size=shape).astype( + params.get("dtype", numpy.float32)) + matrix = DistributionGenerator._process_params_numpy(matrix, + params, + seed) + return matrix + else: + def uniform_generator(shape: Sequence[int], + dKey: jax.Array | None = None) -> jax.Array: + if dKey is None: + dKey = jax.random.PRNGKey(time.time_ns()) + dKey, subKey = jax.random.split(dKey, 2) + + matrix = jax.random.uniform( + dKey, + shape=shape, + minval=low, + maxval=high, + dtype=params.get("dtype", jax.numpy.float32) + ) + matrix = DistributionGenerator._process_params_jax(matrix, + params, + subKey) + return matrix + + return uniform_generator + + @staticmethod + def gaussian(mean: float = 0.0, std: float = 1.0, **params: Unpack[ + DistributionParams]) -> DistributionInitializer: + """ + Produces a distribution initializer for a Gaussian (normal) distribution. + + Args: + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + **params: the extra distribution parameters + + Returns: a distribution initializer + """ + using_numpy = params.get("use_numpy", False) + + if using_numpy: + def gaussian_generator(shape: Sequence[int], + seed: int | None = None) -> numpy.ndarray: + rng = numpy.random.default_rng(seed) + matrix = rng.normal(loc=mean, scale=std, size=shape).astype( + params.get("dtype", numpy.float32)) + matrix = DistributionGenerator._process_params_numpy(matrix, + params, + seed) + return matrix + else: + def gaussian_generator(shape: Sequence[int], + dKey: jax.Array | None = None) -> jax.Array: + if dKey is None: + dKey = jax.random.PRNGKey(time.time_ns()) + dKey, subKey = jax.random.split(dKey, 2) + matrix = jax.random.normal( + dKey, + shape=shape, + dtype=params.get("dtype", jax.numpy.float32) + ) + matrix = mean + std * matrix + matrix = DistributionGenerator._process_params_jax(matrix, + params, + subKey) + return matrix + + return gaussian_generator + + @staticmethod + def fan_in_uniform( + **params: Unpack[DistributionParams]) -> DistributionInitializer: + """ + Produces a distribution initializer using a fan-in uniform strategy. + The values are sampled from a uniform distribution in the range [-limit, limit], + where limit = sqrt(1 / fan_in), and fan_in is inferred from the shape. + + Args: + **params: extra distribution parameters + + Returns: a distribution initializer + """ + using_numpy = params.get("use_numpy", False) + + def compute_limit(fan_in: int) -> float: + return float(numpy.sqrt(1.0 / fan_in)) + + if using_numpy: + def fan_in_uniform_generator(shape: Sequence[int], + seed: int | None = None) -> numpy.ndarray: + if len(shape) < 2: + error("fan_in_uniform requires shape with at least 2 dimensions") + fan_in = shape[1] + limit = compute_limit(fan_in) + + rng = numpy.random.default_rng(seed) + matrix = rng.uniform(low=-limit, high=limit, size=shape).astype( + params.get("dtype", numpy.float32)) + matrix = DistributionGenerator._process_params_numpy(matrix, + params, + seed) + return matrix + else: + def fan_in_uniform_generator(shape: Sequence[int], + dKey: jax.Array | None = None) -> jax.Array: + if len(shape) < 2: + error("fan_in_uniform requires shape with at least 2 dimensions") + fan_in = shape[1] + limit = compute_limit(fan_in) + + if dKey is None: + dKey = jax.random.PRNGKey(time.time_ns()) + dKey, subKey = jax.random.split(dKey, 2) + + matrix = jax.random.uniform( + dKey, + shape=shape, + minval=-limit, + maxval=limit, + dtype=params.get("dtype", jax.numpy.float32) + ) + matrix = DistributionGenerator._process_params_jax(matrix, + params, + subKey) + return matrix + + return fan_in_uniform_generator + + @staticmethod + def fan_in_gaussian( + **params: Unpack[DistributionParams]) -> DistributionInitializer: + """ + Produces a distribution initializer using a fan-in Gaussian (normal) strategy. + The values are sampled from a normal distribution with mean 0 and stddev = sqrt(1 / fan_in), + where fan_in is inferred from the shape. + + Args: + **params: extra distribution parameters + + Returns: a distribution initializer + """ + using_numpy = params.get("use_numpy", False) + + def compute_std(fan_in: int) -> float: + return float(numpy.sqrt(1.0 / fan_in)) + + if using_numpy: + def fan_in_gaussian_generator(shape: Sequence[int], + seed: int | None) -> numpy.ndarray: + if len(shape) < 2: + error("fan_in_gaussian requires shape with at least 2 dimensions") + fan_in = shape[0] + std = compute_std(fan_in) + + rng = numpy.random.default_rng(seed) + matrix = rng.normal(loc=0.0, scale=std, size=shape).astype( + params.get("dtype", numpy.float32)) + matrix = DistributionGenerator._process_params_numpy(matrix, + params, + seed) + return matrix + else: + def fan_in_gaussian_generator(shape: Sequence[int], + dKey: jax.Array | None) -> jax.Array: + if len(shape) < 2: + error("fan_in_gaussian requires shape with at least 2 dimensions") + fan_in = shape[0] + std = compute_std(fan_in) + + if dKey is None: + dKey = jax.random.PRNGKey(time.time_ns()) + dKey, subKey = jax.random.split(dKey, 2) + + matrix = jax.random.normal( + dKey, + shape=shape, + dtype=params.get("dtype", jax.numpy.float32) + ) + matrix = matrix * std + matrix = DistributionGenerator._process_params_jax(matrix, + params, subKey) + return matrix + + return fan_in_gaussian_generator + + @staticmethod + def _process_params_jax(ary: jax.Array, params: DistributionParams, + dKey: jax.dtypes.prng_key | None) -> jax.Array: + if dKey is None: + dKey = jax.random.PRNGKey(time.time_ns()) + + amin = params.get("amin", None) + if amin is not None: + ary = jax.numpy.maximum(ary, amin) + + amax = params.get("amax", None) + if amax is not None: + ary = jax.numpy.minimum(ary, amax) + + lower_triangle = params.get("lower_triangle", False) + upper_triangle = params.get("upper_triangle", False) + if lower_triangle and upper_triangle: + error( + "lower_triangle and upper_triangle are mutually exclusive when initializing a distribution") + + if lower_triangle: + ary = jax.numpy.tril(ary) + if upper_triangle: + ary = jax.numpy.triu(ary) + + if params.get("hollow", False): + ary = (1.0 - jax.numpy.eye(*ary.shape)) * ary + + if params.get("eye", False): + ary = jax.numpy.eye(*ary.shape) * ary + + col_mask = params.get("col_mask", None) + if col_mask is not None: + if isinstance(col_mask, int): + dKey, subKey = jax.random.split(dKey, 2) + keep_indices = jax.random.choice(subKey, ary.shape[1], + shape=(col_mask,), + replace=False) + mask = jax.numpy.zeros(ary.shape[1], dtype=bool).at[ + keep_indices].set(True) + mask = jax.numpy.broadcast_to(mask, ary.shape) + ary = jax.numpy.where(mask, ary, 0) + elif isinstance(col_mask, Sequence): + mask = jax.numpy.zeros(ary.shape[1], dtype=bool).at[ + col_mask].set(True) + mask = jax.numpy.broadcast_to(mask, ary.shape) + ary = jax.numpy.where(mask, ary, 0) + + row_mask = params.get("row_mask", None) + if row_mask is not None: + if isinstance(row_mask, int): + dKey, subKey = jax.random.split(dKey, 2) + keep_indices = jax.random.choice(subKey, ary.shape[0], + shape=(row_mask,), + replace=False) + mask = jax.numpy.zeros(ary.shape[0], dtype=bool).at[ + keep_indices].set(True) + mask = jax.numpy.broadcast_to(mask, ary.shape) + ary = jax.numpy.where(mask, ary, 0) + elif isinstance(row_mask, Sequence): + mask = jax.numpy.zeros(ary.shape[0], dtype=bool).at[ + row_mask].set(True) + mask = jax.numpy.broadcast_to(mask, ary.shape) + ary = jax.numpy.where(mask, ary, 0) + + return ary.astype(params.get("dtype", jax.numpy.float32)) + + @staticmethod + def _process_params_numpy(ary: numpy.ndarray, params: DistributionParams, + seed: int | None) -> numpy.ndarray: + amin = params.get("amin", None) + if amin is not None: + ary = numpy.maximum(ary, amin) + + amax = params.get("amax", None) + if amax is not None: + ary = numpy.minimum(ary, amax) + + lower_triangle = params.get("lower_triangle", False) + upper_triangle = params.get("upper_triangle", False) + if lower_triangle and upper_triangle: + error( + "lower_triangle and upper_triangle are mutually exclusive when initializing a distribution") + + if lower_triangle: + ary = numpy.tril(ary) + if upper_triangle: + ary = numpy.triu(ary) + + if params.get("hollow", False): + ary = (1.0 - numpy.eye(*ary.shape)) * ary + + if params.get("eye", False): + ary = numpy.eye(*ary.shape) * ary + + col_mask = params.get("col_mask", None) + if col_mask is not None: + if isinstance(col_mask, int): + rng = numpy.random.default_rng(seed) + keep_indices = rng.choice(ary.shape[1], size=col_mask, + replace=False) + mask = numpy.zeros(ary.shape[1], dtype=bool) + mask[keep_indices] = True + mask = numpy.broadcast_to(mask, ary.shape) + ary = numpy.where(mask, ary, 0) + elif isinstance(col_mask, Sequence): + mask = numpy.zeros(ary.shape[1], dtype=bool) + mask[list(col_mask)] = True + mask = numpy.broadcast_to(mask, ary.shape) + ary = numpy.where(mask, ary, 0) + + row_mask = params.get("row_mask", None) + if row_mask is not None: + if isinstance(row_mask, int): + rng = numpy.random.default_rng(seed) + keep_indices = rng.choice(ary.shape[0], size=row_mask, + replace=False) + mask = numpy.zeros(ary.shape[0], dtype=bool) + mask[keep_indices] = True + mask = numpy.broadcast_to(mask, ary.shape) + ary = numpy.where(mask, ary, 0) + elif isinstance(row_mask, Sequence): + mask = numpy.zeros(ary.shape[0], dtype=bool) + mask[list(row_mask)] = True + mask = numpy.broadcast_to(mask, ary.shape) + ary = numpy.where(mask, ary, 0) + + return ary + diff --git a/ngclearn/utils/optim/sgd.py b/ngclearn/utils/optim/sgd.py index 68594d4a..e0c38e64 100755 --- a/ngclearn/utils/optim/sgd.py +++ b/ngclearn/utils/optim/sgd.py @@ -1,13 +1,4 @@ -# %% - -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.resolver import resolver - -import numpy as np -from jax import jit, numpy as jnp, random, nn, lax -from functools import partial -import time +from jax import jit, numpy as jnp def step_update(param, update, lr): """ diff --git a/ngclearn/utils/patch.py b/ngclearn/utils/patch.py new file mode 100644 index 00000000..c5dd14f2 --- /dev/null +++ b/ngclearn/utils/patch.py @@ -0,0 +1,101 @@ +from typing import Literal + +from jax import numpy as jnp + +from ngcsimlib.logger import error, warn + + +class PatchGenerator(object): + def __init__(self, + patch_height: int, + patch_width: int, + horizontal_alignment: Literal['left', 'right', 'center', 'fit']=None, + vertical_alignment: Literal['top', 'bottom', 'center', 'fit']=None, + horizontal_stride: int | None = None, + vertical_stride: int | None = None): + self.horizontal_alignment = horizontal_alignment or 'left' + self.horizontal_stride = horizontal_stride or 0 + self.patch_height = patch_height + + self.vertical_alignment = vertical_alignment or 'top' + self.vertical_stride = vertical_stride or 0 + self.patch_width = patch_width + + self.idx_cache = {} + + self._current_height = None + self._current_width = None + + self._max_patch = None + self._current_idx = -1 + self._current_img = None + + def __iter__(self): + if self._current_img is None: + error("Attempting to generate patches but no image has been provided") + + self._current_idx = 0 + return self + + def target(self, img: jnp.ndarray): + height, width = img.shape[:2] + if height == self._current_height and width == self._current_width: + self._current_img = img + return + + if self.patch_height > height or self.patch_width > width: + warn("Image to small for patches to be extracted, aborting") + return + + horizontal_idxs = [] + vertical_idxs = [] + + actual_patch_width = self.patch_width - self.horizontal_stride + if self.horizontal_alignment == 'left': + horizontal_idxs += range(0, width-self.patch_width, actual_patch_width) + elif self.horizontal_alignment == 'right': + horizontal_idxs += [i - self.patch_width for i in range(width, self.patch_width, -actual_patch_width)] + elif self.horizontal_alignment == 'center': + centerx = width // 2 + horizontal_idxs += range(centerx, width-self.patch_width, actual_patch_width) + horizontal_idxs += [i - self.patch_width for i in range(centerx, self.patch_width, -actual_patch_width)] + elif self.horizontal_alignment == 'fit': + extra = ((width - self.patch_width) % actual_patch_width) // 2 + horizontal_idxs += range(extra, width - self.patch_width + 1, + actual_patch_width) + else: + pass + + actual_patch_height = self.patch_height - self.vertical_stride + if self.vertical_alignment == 'left': + horizontal_idxs += range(0, height-self.patch_height, actual_patch_height) + elif self.vertical_alignment == 'right': + horizontal_idxs += [i - self.patch_height for i in range(height, self.patch_width, -actual_patch_height)] + elif self.vertical_alignment == 'center': + centery = height // 2 + horizontal_idxs += range(centery, height-self.patch_height, actual_patch_height) + horizontal_idxs += [i - self.patch_width for i in range(centery, self.patch_height, -actual_patch_height)] + elif self.vertical_alignment == 'fit': + extra = ((height - self.patch_height) % actual_patch_height) // 2 + horizontal_idxs += range(extra, height - self.patch_height + 1, + actual_patch_height) + + print(horizontal_idxs) + + img = jnp.zeros((len(horizontal_idxs), width)) + for row, idx in enumerate(horizontal_idxs): + img = img.at[row, idx:idx + self.patch_width].set( + img[row, idx:idx + self.patch_width] + 50) + + import matplotlib.pyplot as plt + + plt.imshow(img) + plt.show() + + + +gen = PatchGenerator(patch_width=5, patch_height=5, horizontal_alignment='center', horizontal_stride=1) + +test_img = jnp.zeros((32, 32)) + +gen.target(test_img) diff --git a/ngclearn/utils/viz/compartment_plot.py b/ngclearn/utils/viz/compartment_plot.py new file mode 100644 index 00000000..639d813a --- /dev/null +++ b/ngclearn/utils/viz/compartment_plot.py @@ -0,0 +1,38 @@ +""" +Raster visualization functions/utilities. +""" +import matplotlib.pyplot as plt +import jax +from typing import Sequence + +def create_plot(history: jax.Array, ax: plt.Axes | None = None, + indices: Sequence[int] | None = None): + """ + Generates a raster plot of a given (binary) spike train (row dimension + corresponds to the discrete time dimension). + + Args: + history: a numpy binary array of shape (T x number_of_neurons) + + ax: a hook/pointer to a currently external plot that this raster plot + should be made a sub-figure of + + indices: optional indices of neurons (row integer indices) to focus on + plotting + + s: size of the spike scatter points (Default = 0.5) + + c: color of the spike scatter points (Default = black) + + """ + n_count = history.shape[0] + if ax is None: + nc = n_count if indices is None else len(indices) + fig_size = 5 if nc < 25 else int(nc / 5) + plt.figure(figsize=(fig_size, fig_size)) + + _ax = ax if ax is not None else plt + + for k in range(history.shape[1]): + if indices is None or k in indices: + _ax.plot(history[:, k]) \ No newline at end of file diff --git a/ngclearn/utils/viz/compartment_raster.py b/ngclearn/utils/viz/compartment_raster.py new file mode 100755 index 00000000..d66a73eb --- /dev/null +++ b/ngclearn/utils/viz/compartment_raster.py @@ -0,0 +1,49 @@ +""" +Raster visualization functions/utilities. +""" +import matplotlib.pyplot as plt +import jax +from typing import Sequence + +def create_raster_plot(spike_train: jax.Array, ax: plt.Axes | None = None, + indices: Sequence[int] | None = None, s=0.5, c="black"): + """ + Generates a raster plot of a given (binary) spike train (row dimension + corresponds to the discrete time dimension). + + Args: + spike_train: a numpy binary array of shape (T x number_of_neurons) + + ax: a hook/pointer to a currently external plot that this raster plot + should be made a sub-figure of + + indices: optional indices of neurons (row integer indices) to focus on + plotting + + s: size of the spike scatter points (Default = 0.5) + + c: color of the spike scatter points (Default = black) + + """ + step_count = spike_train.shape[0] + n_count = spike_train.shape[1] + if ax is None: + nc = n_count if indices is None else len(indices) + fig_size = 5 if nc < 25 else int(nc / 5) + plt.figure(figsize=(fig_size, fig_size)) + + _ax = ax if ax is not None else plt + + events = [] + for t in range(n_count): + if indices is None or t in indices: + e = spike_train[:, t].nonzero() + events.append(e[0]) + _ax.eventplot(events, linelengths=s, colors=c) + if ax is None: + _ax.yticks(ticks=[i for i in (range(n_count if indices is None else len(indices)))], + labels=["N" + str(i) for i in (range(n_count) if indices is None else indices)]) + _ax.xticks(ticks=[i for i in range(0, step_count+1, max(int(step_count / 5), 1))]) + else: + _ax.set_yticks(ticks=[i for i in (range(n_count if indices is None else len(indices)))], + labels=["N" + str(i) for i in (range(n_count) if indices is None else indices)]) From 03305049d2166205d087660e0d6205d27ea3f0f6 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Tue, 30 Sep 2025 12:52:00 -0400 Subject: [PATCH 007/121] refactored/ported RAFCell to v3 --- .../components/neurons/spiking/RAFCell.py | 95 ++++++++++--------- 1 file changed, 50 insertions(+), 45 deletions(-) diff --git a/ngclearn/components/neurons/spiking/RAFCell.py b/ngclearn/components/neurons/spiking/RAFCell.py index df95de1d..8f41569f 100755 --- a/ngclearn/components/neurons/spiking/RAFCell.py +++ b/ngclearn/components/neurons/spiking/RAFCell.py @@ -2,15 +2,16 @@ from jax import numpy as jnp, random, jit, nn from functools import partial from ngclearn.utils import tensorstats -from ngcsimlib.deprecators import deprecate_args +from ngcsimlib import deprecate_args from ngcsimlib.logger import info, warn from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ step_euler, step_rk2 -from ngcsimlib.compilers.process import transition -#from ngcsimlib.component import Component +from ngcsimlib.parser import compilable from ngcsimlib.compartment import Compartment +######################################################################################################################## +## RAF dynamics (multi-dimensional ODEs) @jit def _dfv_internal(j, v, w, tau_m, omega, b): ## "voltage" dynamics # dy/dt = omega x + b y @@ -34,6 +35,7 @@ def _dfw(t, w, params): ## angular driver dynamics wrapper j, v, tau_w, omega, b = params dv_dt = _dfw_internal(j, v, w, tau_w, omega, b) return dv_dt +######################################################################################################################## class RAFCell(JaxComponent): """ @@ -60,8 +62,7 @@ class RAFCell(JaxComponent): | tols - time-of-last-spike | References: - | Izhikevich, Eugene M. "Resonate-and-fire neurons." Neural networks - | 14.6-7 (2001): 883-894. + | Izhikevich, Eugene M. "Resonate-and-fire neurons." Neural networks 14.6-7 (2001): 883-894. Args: name: the string name of this cell @@ -77,7 +78,7 @@ class RAFCell(JaxComponent): omega: angular frequency (Default: 10) - b: oscillation dampening factor (Default: -1) + dampen_factor: oscillation dampening factor (Default: -1) ("b" in Izhikevich 2001) v_reset: reset condition for membrane potential (Default: 1 mV) @@ -98,10 +99,10 @@ class RAFCell(JaxComponent): at an increase in computational cost (and simulation time) """ - @deprecate_args(resist_m="resist_v", tau_m="tau_v") + @deprecate_args(resist_m="resist_v", tau_m="tau_v", b="dampen_factor") def __init__( - self, name, n_units, tau_v=1., tau_w=1., thr=1., omega=10., b=-1., v_reset=0., w_reset=0., v0=0., w0=0., - resist_v=1., integration_type="euler", batch_size=1, **kwargs + self, name, n_units, tau_v=1., tau_w=1., thr=1., omega=10., dampen_factor=-1., v_reset=0., w_reset=0., + v0=0., w0=0., resist_v=1., integration_type="euler", batch_size=1, **kwargs ): #v_rest=-72., v_reset=-75., w_reset=0., thr=5., v0=-70., w0=0., tau_w=400., thr=5., omega=10., b=-1. super().__init__(name, **kwargs) @@ -115,8 +116,8 @@ def __init__( self.resist_v = resist_v self.tau_w = tau_w self.omega = omega ## angular frequency - self.b = b ## dampening factor - ## note: the smaller b is, the faster the oscillation dampens to resting state values + self.dampen_factor = dampen_factor ## dampening factor (b) + ## Note: the smaller that dampen_factor "b" is, the faster the oscillation dampens to resting state values self.v_reset = v_reset self.w_reset = w_reset self.v0 = v0 @@ -137,42 +138,46 @@ def __init__( restVals, display_name="Time-of-Last-Spike", units="ms" ) ## time-of-last-spike - @transition(output_compartments=["j", "v", "w", "s", "tols"]) - @staticmethod - def advance_state(t, dt, tau_v, resist_v, tau_w, thr, omega, b, - v_reset, w_reset, intgFlag, j, v, w, tols): + @compilable + def advance_state( + self, t, dt + ): ## continue with centered dynamics - j_ = j * resist_v - if intgFlag == 1: ## RK-2/midpoint + j_ = self.j.get() * self.resist_v + if self.intgFlag == 1: ## RK-2/midpoint ## Note: we integrate ODEs in order: first w, then v - w_params = (j_, v, tau_w, omega, b) - _, _w = step_rk2(0., w, _dfw, dt, w_params) - v_params = (j_, _w, tau_v, omega, b) - _, _v = step_rk2(0., v, _dfv, dt, v_params) + w_params = (j_, self.v.get(), self.tau_w, self.omega, self.dampen_factor) + _, _w = step_rk2(0., self.w.get(), _dfw, dt, w_params) + v_params = (j_, _w, self.tau_v, self.omega, self.dampen_factor) + _, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params) else: # integType == 0 (default -- Euler) ## Note: we integrate ODEs in order: first w, then v - w_params = (j_, v, tau_w, omega, b) - _, _w = step_euler(0., w, _dfw, dt, w_params) - v_params = (j_, _w, tau_v, omega, b) - _, _v = step_euler(0., v, _dfv, dt, v_params) - s = (_v > thr) * 1. ## emit spikes/pulses + w_params = (j_, self.v.get(), self.tau_w, self.omega, self.dampen_factor) + _, _w = step_euler(0., self.w.get(), _dfw, dt, w_params) + v_params = (j_, _w, self.tau_v, self.omega, self.dampen_factor) + _, _v = step_euler(0., self.v.get(), _dfv, dt, v_params) + + s = (_v > self.thr) * 1. ## emit spikes/pulses ## hyperpolarize/reset/snap variables - w = _w * (1. - s) + s * w_reset - v = _v * (1. - s) + s * v_reset - - tols = (1. - s) * tols + (s * t) ## update times-of-last-spike(s) - return j, v, w, s, tols - - @transition(output_compartments=["j", "v", "w", "s", "tols"]) - @staticmethod - def reset(batch_size, n_units, v0, w0): - restVals = jnp.zeros((batch_size, n_units)) - j = restVals # None - v = restVals + v0 - w = restVals + w0 - s = restVals #+ 0 - tols = restVals #+ 0 - return j, v, w, s, tols + w = _w * (1. - s) + s * self.w_reset + v = _v * (1. - s) + s * self.v_reset + + self.tols.set((1. - s) * self.tols.get() + (s * t)) ## update times-of-last-spike(s) + + #self.j.set(j_) + self.v.set(v) + self.w.set(w) + self.s.set(s) + + @compilable + def reset(self): + restVals = jnp.zeros((self.batch_size, self.n_units)) + if not self.j.targeted: + self.j.set(restVals) + self.v.set(restVals + self.v0) + self.w.set(restVals + self.w0) + self.s.set(restVals) + self.tols.set(restVals) @classmethod def help(cls): ## component help function @@ -198,7 +203,7 @@ def help(cls): ## component help function "tau_w": "Recovery variable time constant", "v_reset": "Reset membrane potential value", "w_reset": "Reset angular driver value", - "b": "Exponential dampening factor applied to oscillations", + "dampen_factor": "Exponential dampening factor applied to oscillations", "omega": "Angular frequency of neuronal progress per second (radians)", "v0": "Initial condition for membrane potential/voltage", "w0": "Initial condition for membrane angular driver variable", @@ -207,8 +212,8 @@ def help(cls): ## component help function } info = {cls.__name__: properties, "compartments": compartment_props, - "dynamics": "tau_v * dv/dt = omega * w + v * b; " - "tau_w * dw/dt = w * b - v * omega + j", + "dynamics": "tau_v * dv/dt = omega * w + v * dampen_factor; " + "tau_w * dw/dt = w * dampen_factor - v * omega + j", "hyperparameters": hyperparams} return info From 50f0db488b3decf35a374cd067e53659b013f474 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 1 Oct 2025 19:09:59 -0400 Subject: [PATCH 008/121] ported over/refactored WTASCell for v3 --- .../components/neurons/spiking/WTASCell.py | 75 ++++++++++--------- 1 file changed, 41 insertions(+), 34 deletions(-) diff --git a/ngclearn/components/neurons/spiking/WTASCell.py b/ngclearn/components/neurons/spiking/WTASCell.py index c6f9edb6..52fd893b 100755 --- a/ngclearn/components/neurons/spiking/WTASCell.py +++ b/ngclearn/components/neurons/spiking/WTASCell.py @@ -3,11 +3,10 @@ from jax import numpy as jnp, random, jit, nn from functools import partial from ngclearn.utils import tensorstats -from ngcsimlib.deprecators import deprecate_args +from ngcsimlib import deprecate_args from ngcsimlib.logger import info, warn -from ngcsimlib.compilers.process import transition -#from ngcsimlib.component import Component +from ngcsimlib.parser import compilable from ngcsimlib.compartment import Compartment from ngclearn.utils.model_utils import softmax @@ -88,42 +87,50 @@ def __init__( self.rfr = Compartment(restVals + self.refract_T) self.tols = Compartment(restVals) ## time-of-last-spike - @transition(output_compartments=["v", "s", "thr", "rfr", "tols"]) - @staticmethod - def advance_state(t, dt, tau_m, R_m, thr_gain, refract_T, j, v, thr, rfr, tols): - mask = (rfr >= refract_T) * 1. ## check refractory period - v = (j * R_m) * mask + # @transition(output_compartments=["v", "s", "thr", "rfr", "tols"]) + # @staticmethod + @compilable + def advance_state( + self, t, dt #, tau_m, R_m, thr_gain, refract_T, j, v, thr, rfr, tols + ): + mask = (self.rfr.get() >= self.refract_T) * 1. ## check refractory period + v = (self.j.get() * self.R_m) * mask vp = softmax(v) # convert to Categorical (spike) probabilities # s = nn.one_hot(jnp.argmax(vp, axis=1), j.shape[1]) ## hard-max spike - s = (vp > thr) * 1. ## calculate action potential + s = (vp > self.thr.get()) * 1. ## calculate action potential q = 1. ## Note: thr_gain ==> "rho_b" ## increment threshold upon spike(s) occurrence dthr = jnp.sum(s, axis=1, keepdims=True) - q - thr = jnp.maximum(thr + dthr * thr_gain, 0.025) ## calc new threshold - rfr = (rfr + dt) * (1. - s) + s * dt # set refract to dt - - tols = (1. - s) * tols + (s * t) ## update tols - return v, s, thr, rfr, tols - - @transition(output_compartments=["j", "v", "s", "rfr", "tols"]) - @staticmethod - def reset(batch_size, n_units, refract_T): - restVals = jnp.zeros((batch_size, n_units)) - j = restVals #+ 0 - v = restVals #+ 0 - s = restVals #+ 0 - rfr = restVals + refract_T - tols = restVals #+ 0 - return j, v, s, rfr, tols - - def save(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - jnp.savez(file_name, threshold=self.thr.value) - - def load(self, directory, seeded=False, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - self.thr.set( data['threshold'] ) + thr = jnp.maximum(self.thr.get() + dthr * self.thr_gain, 0.025) ## calc new threshold + rfr = (self.rfr.get() + dt) * (1. - s) + s * dt # set refract to dt + + self.tols.set((1. - s) * self.tols.get() + (s * t)) ## update times-of-last-spike(s) + + self.v.set(v) + self.s.set(s) + self.thr.set(thr) + self.rfr.set(rfr) + + # @transition(output_compartments=["j", "v", "s", "rfr", "tols"]) + # @staticmethod + @compilable + def reset(self): + restVals = jnp.zeros((self.batch_size, self.n_units)) + if not self.j.targeted: + self.j.set(restVals) + self.v.set(restVals) + self.s.set(restVals) + self.rfr.set(restVals + self.refract_T) + self.tols.set(restVals) + + # def save(self, directory, **kwargs): + # file_name = directory + "/" + self.name + ".npz" + # jnp.savez(file_name, threshold=self.thr.value) + # + # def load(self, directory, seeded=False, **kwargs): + # file_name = directory + "/" + self.name + ".npz" + # data = jnp.load(file_name) + # self.thr.set( data['threshold'] ) @classmethod def help(cls): ## component help function From c7908708e8c2db02accad3cea6b62e5aafe0626a Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Mon, 6 Oct 2025 12:33:44 -0400 Subject: [PATCH 009/121] wrote successful unit-test of WTASCell --- .../components/neurons/spiking/WTASCell.py | 6 +-- .../neurons/spiking/test_WTASCell.py | 48 ++++++++----------- 2 files changed, 23 insertions(+), 31 deletions(-) diff --git a/ngclearn/components/neurons/spiking/WTASCell.py b/ngclearn/components/neurons/spiking/WTASCell.py index 52fd893b..525d078b 100755 --- a/ngclearn/components/neurons/spiking/WTASCell.py +++ b/ngclearn/components/neurons/spiking/WTASCell.py @@ -73,7 +73,7 @@ def __init__( ## base threshold setup ## according to eqn 26 of the source paper, the initial condition for the ## threshold should technically be between: 1/n_units < threshold0 << 0.5, e.g., 0.15 - key, subkey = random.split(self.key.value) + key, subkey = random.split(self.key.get()) self.threshold0 = thr_base + random.uniform(subkey, (1, n_units), minval=-thr_jitter, maxval=thr_jitter, dtype=jnp.float32) @@ -125,7 +125,7 @@ def reset(self): # def save(self, directory, **kwargs): # file_name = directory + "/" + self.name + ".npz" - # jnp.savez(file_name, threshold=self.thr.value) + # jnp.savez(file_name, threshold=self.thr.get()) # # def load(self, directory, seeded=False, **kwargs): # file_name = directory + "/" + self.name + ".npz" @@ -170,7 +170,7 @@ def __repr__(self): maxlen = max(len(c) for c in comps) + 5 lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" for c in comps: - stats = tensorstats(getattr(self, c).value) + stats = tensorstats(getattr(self, c).get()) if stats is not None: line = [f"{k}: {v}" for k, v in stats.items()] line = ", ".join(line) diff --git a/tests/components/neurons/spiking/test_WTASCell.py b/tests/components/neurons/spiking/test_WTASCell.py index b56b87e5..cf936359 100644 --- a/tests/components/neurons/spiking/test_WTASCell.py +++ b/tests/components/neurons/spiking/test_WTASCell.py @@ -1,17 +1,12 @@ from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context import numpy as np np.random.seed(42) -from ngclearn.components import WTASCell -from ngcsimlib.compilers import compile_command, wrap_command -from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment +from ngclearn import Context, MethodProcess +from ngclearn.components.neurons.spiking.WTASCell import WTASCell +from numpy.testing import assert_array_equal from ngcsimlib.context import Context -from ngcsimlib.utils.compartment import Get_Compartment_Batch def test_WTASCell1(): @@ -27,27 +22,22 @@ def test_WTASCell1(): ) #""" - advance_process = (Process("advance_proc") + advance_process = (MethodProcess(name="advance_proc") >> a.advance_state) - # ctx.wrap_and_add_command(advance_process.pure, name="run") - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") + #ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess(name="reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") + #ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") #""" - """ - reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - """ + # ## set up non-compiled utility commands + # @Context.dynamicCommand + # def clamp(x): + # a.j.set(x) - ## set up non-compiled utility commands - @Context.dynamicCommand - def clamp(x): - a.j.set(x) + def clamp(x): + a.j.set(x) ## input spike train x_seq = jnp.asarray([[0., 1.], [0., 1.], [1., 0.], [1., 0.]], dtype=jnp.float32) @@ -55,14 +45,16 @@ def clamp(x): y_seq = x_seq outs = [] - ctx.reset() + reset_process.run() for ts in range(x_seq.shape[0]): x_t = x_seq[ts:ts+1, :] ## get data at time t - ctx.clamp(x_t) - ctx.run(t=ts * 1., dt=dt) - outs.append(a.s.value) + #ctx.clamp(x_t) + clamp(x_t) + advance_process.run(t=ts * 1., dt=dt) + outs.append(a.s.get()) outs = jnp.concatenate(outs, axis=0) - #print(outs) + # print(outs) + # print(y_seq) #exit() ## output should equal input assert_array_equal(outs, y_seq) From 3df6341c594381385e8a63ab30470278308b1c33 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 8 Oct 2025 14:34:24 -0400 Subject: [PATCH 010/121] put back in init-structure/pointers --- ngclearn/components/__init__.py | 130 +++++++++--------- ngclearn/components/neurons/__init__.py | 34 ++--- .../components/neurons/spiking/__init__.py | 22 +-- ngclearn/components/other/__init__.py | 6 +- ngclearn/components/synapses/__init__.py | 76 +++++----- .../components/synapses/hebbian/__init__.py | 12 +- 6 files changed, 140 insertions(+), 140 deletions(-) diff --git a/ngclearn/components/__init__.py b/ngclearn/components/__init__.py index f69f30a5..96f8a2cf 100644 --- a/ngclearn/components/__init__.py +++ b/ngclearn/components/__init__.py @@ -1,65 +1,65 @@ -# from .jaxComponent import JaxComponent -# -# ## point to rate-coded cell component types -# from .neurons.graded.rateCell import RateCell -# from .neurons.graded.gaussianErrorCell import GaussianErrorCell -# from .neurons.graded.laplacianErrorCell import LaplacianErrorCell -# from .neurons.graded.bernoulliErrorCell import BernoulliErrorCell -# from .neurons.graded.rewardErrorCell import RewardErrorCell -# -# ## point to standard spiking cell component types -# from .neurons.spiking.sLIFCell import SLIFCell -# from .neurons.spiking.IFCell import IFCell -# from .neurons.spiking.LIFCell import LIFCell -# from .neurons.spiking.WTASCell import WTASCell -# from .neurons.spiking.quadLIFCell import QuadLIFCell -# from .neurons.spiking.adExCell import AdExCell -# from .neurons.spiking.fitzhughNagumoCell import FitzhughNagumoCell -# from .neurons.spiking.izhikevichCell import IzhikevichCell -# from .neurons.spiking.hodgkinHuxleyCell import HodgkinHuxleyCell -# from .neurons.spiking.RAFCell import RAFCell -# -# ## point to transformer/operator component types -# from .other.varTrace import VarTrace -# from .other.expKernel import ExpKernel -# -# ## point to input encoder component types -# from .input_encoders.bernoulliCell import BernoulliCell -# from .input_encoders.poissonCell import PoissonCell -# from .input_encoders.latencyCell import LatencyCell -# from .input_encoders.phasorCell import PhasorCell -# -# ## point to synapse component types -# from .synapses.denseSynapse import DenseSynapse -# from .synapses.staticSynapse import StaticSynapse -# from .synapses.hebbian.hebbianSynapse import HebbianSynapse -# from .synapses.hebbian.traceSTDPSynapse import TraceSTDPSynapse -# from .synapses.hebbian.expSTDPSynapse import ExpSTDPSynapse -# from .synapses.hebbian.eventSTDPSynapse import EventSTDPSynapse -# from .synapses.hebbian.BCMSynapse import BCMSynapse -# from .synapses.STPDenseSynapse import STPDenseSynapse -# from .synapses.exponentialSynapse import ExponentialSynapse -# from .synapses.doubleExpSynapse import DoupleExpSynapse -# from .synapses.alphaSynapse import AlphaSynapse -# -# ## point to convolutional component types -# from .synapses.convolution.convSynapse import ConvSynapse -# from .synapses.convolution.staticConvSynapse import StaticConvSynapse -# from .synapses.convolution.hebbianConvSynapse import HebbianConvSynapse -# from .synapses.convolution.traceSTDPConvSynapse import TraceSTDPConvSynapse -# from .synapses.convolution.deconvSynapse import DeconvSynapse -# from .synapses.convolution.staticDeconvSynapse import StaticDeconvSynapse -# from .synapses.convolution.hebbianDeconvSynapse import HebbianDeconvSynapse -# from .synapses.convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse -# ## point to modulated component types -# from .synapses.modulated.MSTDPETSynapse import MSTDPETSynapse -# from .synapses.modulated.REINFORCESynapse import REINFORCESynapse -# -# ## point to monitors -# from .monitor import Monitor -# -# ## point to patched component types -# from .synapses.patched.patchedSynapse import PatchedSynapse -# from .synapses.patched.staticPatchedSynapse import StaticPatchedSynapse -# from .synapses.patched.hebbianPatchedSynapse import HebbianPatchedSynapse -# +from .jaxComponent import JaxComponent + +## point to rate-coded cell component types +from .neurons.graded.rateCell import RateCell +from .neurons.graded.gaussianErrorCell import GaussianErrorCell +from .neurons.graded.laplacianErrorCell import LaplacianErrorCell +from .neurons.graded.bernoulliErrorCell import BernoulliErrorCell +from .neurons.graded.rewardErrorCell import RewardErrorCell + +## point to standard spiking cell component types +from .neurons.spiking.sLIFCell import SLIFCell +from .neurons.spiking.IFCell import IFCell +from .neurons.spiking.LIFCell import LIFCell +from .neurons.spiking.WTASCell import WTASCell +from .neurons.spiking.quadLIFCell import QuadLIFCell +from .neurons.spiking.adExCell import AdExCell +from .neurons.spiking.fitzhughNagumoCell import FitzhughNagumoCell +from .neurons.spiking.izhikevichCell import IzhikevichCell +from .neurons.spiking.hodgkinHuxleyCell import HodgkinHuxleyCell +from .neurons.spiking.RAFCell import RAFCell + +## point to transformer/operator component types +from .other.varTrace import VarTrace +from .other.expKernel import ExpKernel + +## point to input encoder component types +from .input_encoders.bernoulliCell import BernoulliCell +from .input_encoders.poissonCell import PoissonCell +from .input_encoders.latencyCell import LatencyCell +from .input_encoders.phasorCell import PhasorCell + +## point to synapse component types +from .synapses.denseSynapse import DenseSynapse +from .synapses.staticSynapse import StaticSynapse +from .synapses.hebbian.hebbianSynapse import HebbianSynapse +from .synapses.hebbian.traceSTDPSynapse import TraceSTDPSynapse +from .synapses.hebbian.expSTDPSynapse import ExpSTDPSynapse +from .synapses.hebbian.eventSTDPSynapse import EventSTDPSynapse +from .synapses.hebbian.BCMSynapse import BCMSynapse +from .synapses.STPDenseSynapse import STPDenseSynapse +from .synapses.exponentialSynapse import ExponentialSynapse +from .synapses.doubleExpSynapse import DoupleExpSynapse +from .synapses.alphaSynapse import AlphaSynapse + +## point to convolutional component types +from .synapses.convolution.convSynapse import ConvSynapse +from .synapses.convolution.staticConvSynapse import StaticConvSynapse +from .synapses.convolution.hebbianConvSynapse import HebbianConvSynapse +from .synapses.convolution.traceSTDPConvSynapse import TraceSTDPConvSynapse +from .synapses.convolution.deconvSynapse import DeconvSynapse +from .synapses.convolution.staticDeconvSynapse import StaticDeconvSynapse +from .synapses.convolution.hebbianDeconvSynapse import HebbianDeconvSynapse +from .synapses.convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse +## point to modulated component types +from .synapses.modulated.MSTDPETSynapse import MSTDPETSynapse +from .synapses.modulated.REINFORCESynapse import REINFORCESynapse + +## point to monitors +from .monitor import Monitor + +## point to patched component types +from .synapses.patched.patchedSynapse import PatchedSynapse +from .synapses.patched.staticPatchedSynapse import StaticPatchedSynapse +from .synapses.patched.hebbianPatchedSynapse import HebbianPatchedSynapse + diff --git a/ngclearn/components/neurons/__init__.py b/ngclearn/components/neurons/__init__.py index 2398f011..e7165d7e 100644 --- a/ngclearn/components/neurons/__init__.py +++ b/ngclearn/components/neurons/__init__.py @@ -1,17 +1,17 @@ -# ## point to rate-coded cell componet types -# from .graded.rateCell import RateCell -# from .graded.gaussianErrorCell import GaussianErrorCell -# from .graded.laplacianErrorCell import LaplacianErrorCell -# from .graded.bernoulliErrorCell import BernoulliErrorCell -# from .graded.rewardErrorCell import RewardErrorCell -# ## point to standard spiking cell component types -# from .spiking.sLIFCell import SLIFCell -# from .spiking.IFCell import IFCell -# from .spiking.LIFCell import LIFCell -# from .spiking.WTASCell import WTASCell -# from .spiking.quadLIFCell import QuadLIFCell -# from .spiking.adExCell import AdExCell -# from .spiking.fitzhughNagumoCell import FitzhughNagumoCell -# from .spiking.izhikevichCell import IzhikevichCell -# from .spiking.hodgkinHuxleyCell import HodgkinHuxleyCell -# from .spiking.RAFCell import RAFCell +## point to rate-coded cell componet types +from .graded.rateCell import RateCell +from .graded.gaussianErrorCell import GaussianErrorCell +from .graded.laplacianErrorCell import LaplacianErrorCell +from .graded.bernoulliErrorCell import BernoulliErrorCell +from .graded.rewardErrorCell import RewardErrorCell +## point to standard spiking cell component types +from .spiking.sLIFCell import SLIFCell +from .spiking.IFCell import IFCell +from .spiking.LIFCell import LIFCell +from .spiking.WTASCell import WTASCell +from .spiking.quadLIFCell import QuadLIFCell +from .spiking.adExCell import AdExCell +from .spiking.fitzhughNagumoCell import FitzhughNagumoCell +from .spiking.izhikevichCell import IzhikevichCell +from .spiking.hodgkinHuxleyCell import HodgkinHuxleyCell +from .spiking.RAFCell import RAFCell diff --git a/ngclearn/components/neurons/spiking/__init__.py b/ngclearn/components/neurons/spiking/__init__.py index 6687c56a..690087b7 100644 --- a/ngclearn/components/neurons/spiking/__init__.py +++ b/ngclearn/components/neurons/spiking/__init__.py @@ -1,11 +1,11 @@ -# ## point to standard spiking cell component types -# from .sLIFCell import SLIFCell -# from .LIFCell import LIFCell -# from .IFCell import IFCell -# from .WTASCell import WTASCell -# from .quadLIFCell import QuadLIFCell -# from .adExCell import AdExCell -# from .fitzhughNagumoCell import FitzhughNagumoCell -# from .izhikevichCell import IzhikevichCell -# from .RAFCell import RAFCell -# from .hodgkinHuxleyCell import HodgkinHuxleyCell +## point to standard spiking cell component types +from .sLIFCell import SLIFCell +from .LIFCell import LIFCell +from .IFCell import IFCell +from .WTASCell import WTASCell +from .quadLIFCell import QuadLIFCell +from .adExCell import AdExCell +from .fitzhughNagumoCell import FitzhughNagumoCell +from .izhikevichCell import IzhikevichCell +from .RAFCell import RAFCell +from .hodgkinHuxleyCell import HodgkinHuxleyCell diff --git a/ngclearn/components/other/__init__.py b/ngclearn/components/other/__init__.py index 14d46a49..cff092d9 100644 --- a/ngclearn/components/other/__init__.py +++ b/ngclearn/components/other/__init__.py @@ -1,3 +1,3 @@ -# from .varTrace import VarTrace -# from .expKernel import ExpKernel -# +from .varTrace import VarTrace +from .expKernel import ExpKernel + diff --git a/ngclearn/components/synapses/__init__.py b/ngclearn/components/synapses/__init__.py index fd701c25..2c21c231 100644 --- a/ngclearn/components/synapses/__init__.py +++ b/ngclearn/components/synapses/__init__.py @@ -1,38 +1,38 @@ -# from .denseSynapse import DenseSynapse -# from .staticSynapse import StaticSynapse -# -# -# ## short-term plasticity components -# from .STPDenseSynapse import STPDenseSynapse -# from .exponentialSynapse import ExponentialSynapse -# from .doubleExpSynapse import DoupleExpSynapse -# from .alphaSynapse import AlphaSynapse -# -# ## dense synaptic components -# from .hebbian.hebbianSynapse import HebbianSynapse -# from .hebbian.traceSTDPSynapse import TraceSTDPSynapse -# from .hebbian.expSTDPSynapse import ExpSTDPSynapse -# from .hebbian.eventSTDPSynapse import EventSTDPSynapse -# from .hebbian.BCMSynapse import BCMSynapse -# -# -# ## conv/deconv synaptic components -# from .convolution.convSynapse import ConvSynapse -# from .convolution.staticConvSynapse import StaticConvSynapse -# from .convolution.hebbianConvSynapse import HebbianConvSynapse -# from .convolution.traceSTDPConvSynapse import TraceSTDPConvSynapse -# from .convolution.deconvSynapse import DeconvSynapse -# from .convolution.staticDeconvSynapse import StaticDeconvSynapse -# from .convolution.hebbianDeconvSynapse import HebbianDeconvSynapse -# from .convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse -# -# -# ## modulated synaptic components -# from .modulated.MSTDPETSynapse import MSTDPETSynapse -# from .modulated.REINFORCESynapse import REINFORCESynapse -# -# ## patched synaptic components -# from .patched.patchedSynapse import PatchedSynapse -# from .patched.staticPatchedSynapse import StaticPatchedSynapse -# from .patched.hebbianPatchedSynapse import HebbianPatchedSynapse -# +from .denseSynapse import DenseSynapse +from .staticSynapse import StaticSynapse + + +## short-term plasticity components +from .STPDenseSynapse import STPDenseSynapse +from .exponentialSynapse import ExponentialSynapse +from .doubleExpSynapse import DoupleExpSynapse +from .alphaSynapse import AlphaSynapse + +## dense synaptic components +from .hebbian.hebbianSynapse import HebbianSynapse +from .hebbian.traceSTDPSynapse import TraceSTDPSynapse +from .hebbian.expSTDPSynapse import ExpSTDPSynapse +from .hebbian.eventSTDPSynapse import EventSTDPSynapse +from .hebbian.BCMSynapse import BCMSynapse + + +## conv/deconv synaptic components +from .convolution.convSynapse import ConvSynapse +from .convolution.staticConvSynapse import StaticConvSynapse +from .convolution.hebbianConvSynapse import HebbianConvSynapse +from .convolution.traceSTDPConvSynapse import TraceSTDPConvSynapse +from .convolution.deconvSynapse import DeconvSynapse +from .convolution.staticDeconvSynapse import StaticDeconvSynapse +from .convolution.hebbianDeconvSynapse import HebbianDeconvSynapse +from .convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse + + +## modulated synaptic components +from .modulated.MSTDPETSynapse import MSTDPETSynapse +from .modulated.REINFORCESynapse import REINFORCESynapse + +## patched synaptic components +from .patched.patchedSynapse import PatchedSynapse +from .patched.staticPatchedSynapse import StaticPatchedSynapse +from .patched.hebbianPatchedSynapse import HebbianPatchedSynapse + diff --git a/ngclearn/components/synapses/hebbian/__init__.py b/ngclearn/components/synapses/hebbian/__init__.py index 61b33f17..f39d556f 100644 --- a/ngclearn/components/synapses/hebbian/__init__.py +++ b/ngclearn/components/synapses/hebbian/__init__.py @@ -1,6 +1,6 @@ -# from .hebbianSynapse import HebbianSynapse -# from .traceSTDPSynapse import TraceSTDPSynapse -# from .expSTDPSynapse import ExpSTDPSynapse -# from .eventSTDPSynapse import EventSTDPSynapse -# from .BCMSynapse import BCMSynapse -# +from .hebbianSynapse import HebbianSynapse +from .traceSTDPSynapse import TraceSTDPSynapse +from .expSTDPSynapse import ExpSTDPSynapse +from .eventSTDPSynapse import EventSTDPSynapse +from .BCMSynapse import BCMSynapse + From a38eb1b0caa97c8367d4d9e112352687f3fa0b18 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 8 Oct 2025 14:51:13 -0400 Subject: [PATCH 011/121] fixed minor error in LIFCell, got unit-test for LIFCell to run --- .../components/neurons/spiking/LIFCell.py | 2 +- .../neurons/spiking/test_LIFCell.py | 50 ++++++++----------- .../neurons/spiking/test_WTASCell.py | 4 +- 3 files changed, 22 insertions(+), 34 deletions(-) diff --git a/ngclearn/components/neurons/spiking/LIFCell.py b/ngclearn/components/neurons/spiking/LIFCell.py index d00fa171..b35a01cc 100644 --- a/ngclearn/components/neurons/spiking/LIFCell.py +++ b/ngclearn/components/neurons/spiking/LIFCell.py @@ -1,5 +1,5 @@ from ngclearn.components.jaxComponent import JaxComponent -from jax import numpy as jnp, random, nn +from jax import numpy as jnp, random, nn, Array from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ step_euler, step_rk2 from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator, diff --git a/tests/components/neurons/spiking/test_LIFCell.py b/tests/components/neurons/spiking/test_LIFCell.py index 6f5f7c1a..b918d9a1 100644 --- a/tests/components/neurons/spiking/test_LIFCell.py +++ b/tests/components/neurons/spiking/test_LIFCell.py @@ -2,15 +2,10 @@ from ngcsimlib.context import Context import numpy as np np.random.seed(42) -from ngclearn.components import LIFCell -from ngcsimlib.compilers import compile_command, wrap_command -from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context -from ngcsimlib.utils.compartment import Get_Compartment_Batch +from ngclearn import Context, MethodProcess +from ngclearn.components.neurons.spiking.LIFCell import LIFCell +from numpy.testing import assert_array_equal def test_LIFCell1(): name = "lif_ctx" @@ -26,27 +21,21 @@ def test_LIFCell1(): ) #""" - advance_process = (Process("advance_proc") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - #ctx.wrap_and_add_command(advance_process.pure, name="run") - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") + #ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") + #ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") #""" - - """ - reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - """ - ## set up non-compiled utility commands - @Context.dynamicCommand - def clamp(x): - a.j.set(x) + # @Context.dynamicCommand + # def clamp(x): + # a.j.set(x) + + def clamp(x): + a.j.set(x) ## input spike train x_seq = jnp.asarray([[1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 0.]], dtype=jnp.float32) @@ -54,15 +43,16 @@ def clamp(x): y_seq = jnp.asarray([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]], dtype=jnp.float32) outs = [] - ctx.reset() + reset_process.run() #ctx.reset() for ts in range(x_seq.shape[1]): x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t - ctx.clamp(x_t) - ctx.run(t=ts * 1., dt=dt) - outs.append(a.s.value) + clamp(x_t) #ctx.clamp(x_t) + advance_process.run(t=ts * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt) + outs.append(a.s.get()) outs = jnp.concatenate(outs, axis=1) - #print(outs) - + # print(outs) + # print(y_seq) + ## output should equal input assert_array_equal(outs, y_seq) diff --git a/tests/components/neurons/spiking/test_WTASCell.py b/tests/components/neurons/spiking/test_WTASCell.py index cf936359..82384701 100644 --- a/tests/components/neurons/spiking/test_WTASCell.py +++ b/tests/components/neurons/spiking/test_WTASCell.py @@ -6,7 +6,6 @@ from ngclearn import Context, MethodProcess from ngclearn.components.neurons.spiking.WTASCell import WTASCell from numpy.testing import assert_array_equal -from ngcsimlib.context import Context def test_WTASCell1(): @@ -48,8 +47,7 @@ def clamp(x): reset_process.run() for ts in range(x_seq.shape[0]): x_t = x_seq[ts:ts+1, :] ## get data at time t - #ctx.clamp(x_t) - clamp(x_t) + clamp(x_t) #ctx.clamp(x_t) advance_process.run(t=ts * 1., dt=dt) outs.append(a.s.get()) outs = jnp.concatenate(outs, axis=0) From 5ab856454333595e97024ad84779d55acfbf91ed Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Mon, 3 Nov 2025 16:01:48 -0500 Subject: [PATCH 012/121] quad-lif test sketched --- .../components/neurons/spiking/quadLIFCell.py | 1 + .../neurons/spiking/test_quadLIFCell.py | 52 +++++++------------ 2 files changed, 21 insertions(+), 32 deletions(-) diff --git a/ngclearn/components/neurons/spiking/quadLIFCell.py b/ngclearn/components/neurons/spiking/quadLIFCell.py index da1084c4..638c91d7 100755 --- a/ngclearn/components/neurons/spiking/quadLIFCell.py +++ b/ngclearn/components/neurons/spiking/quadLIFCell.py @@ -128,6 +128,7 @@ def __init__( name, n_units, tau_m, resist_m, thr, v_rest, v_reset, 1., tau_theta, theta_plus, refract_time, one_spike, integration_type, surrogate_type, lower_clamp_voltage, **kwargs ) + ## only two distinct additional constants distinguish the Quad-LIF cell self.v_c = v_scale self.a0 = critical_v diff --git a/tests/components/neurons/spiking/test_quadLIFCell.py b/tests/components/neurons/spiking/test_quadLIFCell.py index d79418ff..13cfe916 100644 --- a/tests/components/neurons/spiking/test_quadLIFCell.py +++ b/tests/components/neurons/spiking/test_quadLIFCell.py @@ -2,15 +2,10 @@ from ngcsimlib.context import Context import numpy as np np.random.seed(42) -from ngclearn.components import QuadLIFCell -from ngcsimlib.compilers import compile_command, wrap_command -from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context -from ngcsimlib.utils.compartment import Get_Compartment_Batch +from ngclearn import Context, MethodProcess +from ngclearn.components.neurons.spiking.quadLIFCell import QuadLIFCell +from numpy.testing import assert_array_equal def test_quadLIFCell1(): name = "quadlif_ctx" @@ -25,28 +20,22 @@ def test_quadLIFCell1(): name="a", n_units=1, tau_m=30., resist_m=1., key=subkeys[0] ) - #""" - advance_process = (Process("advance_proc") + # """ + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - #ctx.wrap_and_add_command(advance_process.pure, name="run") - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") + # ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - #""" - - """ - reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - """ - + # ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") + # """ ## set up non-compiled utility commands - @Context.dynamicCommand - def clamp(x): - a.j.set(x) + # @Context.dynamicCommand + # def clamp(x): + # a.j.set(x) + + def clamp(x): + a.j.set(x) ## input spike train x_seq = jnp.asarray([[1., 1., 1., 1., 1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 0., 0.]], dtype=jnp.float32) @@ -54,16 +43,15 @@ def clamp(x): y_seq = jnp.asarray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0.]], dtype=jnp.float32) outs = [] - ctx.reset() + reset_process.run() # ctx.reset() for ts in range(x_seq.shape[1]): x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t - ctx.clamp(x_t) - ctx.run(t=ts * 1., dt=dt) - outs.append(a.s.value) + clamp(x_t) # ctx.clamp(x_t) + advance_process.run(t=ts * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt) + outs.append(a.s.get()) outs = jnp.concatenate(outs, axis=1) - #print(outs) ## output should equal input assert_array_equal(outs, y_seq) -#test_quadLIFCell1() +test_quadLIFCell1() From 195e3dba6317d29e0bc52fbc374744ec7024c77e Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Mon, 3 Nov 2025 17:25:01 -0500 Subject: [PATCH 013/121] sketch of ifcell test --- ngclearn/components/neurons/spiking/IFCell.py | 12 ++-- .../components/neurons/spiking/test_IFCell.py | 57 ++++++++----------- 2 files changed, 28 insertions(+), 41 deletions(-) diff --git a/ngclearn/components/neurons/spiking/IFCell.py b/ngclearn/components/neurons/spiking/IFCell.py index 42814a3d..86b8f2a7 100755 --- a/ngclearn/components/neurons/spiking/IFCell.py +++ b/ngclearn/components/neurons/spiking/IFCell.py @@ -1,14 +1,12 @@ from ngclearn.components.jaxComponent import JaxComponent -from jax import numpy as jnp, random, jit, nn -from functools import partial +from jax import numpy as jnp, random, nn, Array, jit from ngclearn.utils import tensorstats from ngcsimlib import deprecate_args -from ngcsimlib.logger import info, warn from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ step_euler, step_rk2 -# from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator, -# triangular_estimator, -# straight_through_estimator) +from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator, + triangular_estimator, + straight_through_estimator) from ngcsimlib.parser import compilable from ngcsimlib.compartment import Compartment @@ -135,7 +133,7 @@ def __init__( display_name="Refractory Time Period", units="ms") self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") ## time-of-last-spike - self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value") + #self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value") @compilable def advance_state( diff --git a/tests/components/neurons/spiking/test_IFCell.py b/tests/components/neurons/spiking/test_IFCell.py index 28f3d8c0..623a0118 100644 --- a/tests/components/neurons/spiking/test_IFCell.py +++ b/tests/components/neurons/spiking/test_IFCell.py @@ -2,15 +2,10 @@ from ngcsimlib.context import Context import numpy as np np.random.seed(42) -from ngclearn.components import IFCell -from ngcsimlib.compilers import compile_command, wrap_command -from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context -from ngcsimlib.utils.compartment import Get_Compartment_Batch +from ngclearn import Context, MethodProcess +from ngclearn.components.neurons.spiking.IFCell import IFCell +from numpy.testing import assert_array_equal def test_IFCell1(): name = "if_ctx" @@ -18,35 +13,28 @@ def test_IFCell1(): dkey = random.PRNGKey(1234) dkey, *subkeys = random.split(dkey, 6) dt = 1. # ms - trace_increment = 0.1 # ---- build a simple Poisson cell system ---- with Context(name) as ctx: a = IFCell( name="a", n_units=1, tau_m=5., resist_m=10., key=subkeys[0] ) - #""" - advance_process = (Process("advance_proc") + # """ + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - #ctx.wrap_and_add_command(advance_process.pure, name="run") - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") + # ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - #""" - - """ - reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - """ - + # ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") + # """ ## set up non-compiled utility commands - @Context.dynamicCommand - def clamp(x): - a.j.set(x) + # @Context.dynamicCommand + # def clamp(x): + # a.j.set(x) + + def clamp(x): + a.j.set(x) ## input spike train x_seq = jnp.asarray([[1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.]], dtype=jnp.float32) @@ -54,16 +42,17 @@ def clamp(x): y_seq = jnp.asarray([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.]], dtype=jnp.float32) outs = [] - ctx.reset() + reset_process.run() # ctx.reset() for ts in range(x_seq.shape[1]): x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t - ctx.clamp(x_t) - ctx.run(t=ts * 1., dt=dt) - outs.append(a.s.value) + clamp(x_t) # ctx.clamp(x_t) + advance_process.run(t=ts * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt) + outs.append(a.s.get()) outs = jnp.concatenate(outs, axis=1) - print(outs) - + # print(outs) + # print(y_seq) + ## output should equal input assert_array_equal(outs, y_seq) -#test_IFCell1() +test_IFCell1() From d9aa8f833584e4f11dd37d7131dff4c2306cb00e Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Mon, 3 Nov 2025 18:08:16 -0500 Subject: [PATCH 014/121] fixed minor bugs and tests locally pass for if, quad-lif, and lif-cells now, with minor patches to help fun and doc-strings --- ngclearn/components/neurons/spiking/IFCell.py | 29 +-------- .../components/neurons/spiking/LIFCell.py | 25 +++---- .../components/neurons/spiking/quadLIFCell.py | 65 +++---------------- .../components/neurons/spiking/test_IFCell.py | 2 +- .../neurons/spiking/test_quadLIFCell.py | 3 +- 5 files changed, 20 insertions(+), 104 deletions(-) diff --git a/ngclearn/components/neurons/spiking/IFCell.py b/ngclearn/components/neurons/spiking/IFCell.py index 86b8f2a7..1e7f26ba 100755 --- a/ngclearn/components/neurons/spiking/IFCell.py +++ b/ngclearn/components/neurons/spiking/IFCell.py @@ -87,7 +87,7 @@ class IFCell(JaxComponent): ## integrate-and-fire cell the value of `v_rest` (default: True) """ - @deprecate_args(thr_jitter=None) + #@deprecate_args(thr_jitter=None) def __init__( self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., refract_time=0., integration_type="euler", surrogate_type="straight_through", lower_clamp_voltage=True, **kwargs @@ -178,33 +178,6 @@ def reset(self): self.tols.set(restVals) #surrogate = restVals + 1. - def save(self, directory, **kwargs): - ## do a protected save of constants, depending on whether they are floats or arrays - tau_m = (self.tau_m if isinstance(self.tau_m, float) - else jnp.asarray([[self.tau_m * 1.]])) - thr = (self.thr if isinstance(self.thr, float) - else jnp.asarray([[self.thr * 1.]])) - v_rest = (self.v_rest if isinstance(self.v_rest, float) - else jnp.asarray([[self.v_rest * 1.]])) - v_reset = (self.v_reset if isinstance(self.v_reset, float) - else jnp.asarray([[self.v_reset * 1.]])) - v_decay = (self.v_decay if isinstance(self.v_decay, float) - else jnp.asarray([[self.v_decay * 1.]])) - resist_m = (self.resist_m if isinstance(self.resist_m, float) - else jnp.asarray([[self.resist_m * 1.]])) - tau_theta = (self.tau_theta if isinstance(self.tau_theta, float) - else jnp.asarray([[self.tau_theta * 1.]])) - theta_plus = (self.theta_plus if isinstance(self.theta_plus, float) - else jnp.asarray([[self.theta_plus * 1.]])) - - file_name = directory + "/" + self.name + ".npz" - jnp.savez(file_name, - tau_m=tau_m, thr=thr, v_rest=v_rest, - v_reset=v_reset, v_decay=v_decay, - resist_m=resist_m, tau_theta=tau_theta, - theta_plus=theta_plus, - key=self.key.value) - def load(self, directory, seeded=False, **kwargs): file_name = directory + "/" + self.name + ".npz" data = jnp.load(file_name) diff --git a/ngclearn/components/neurons/spiking/LIFCell.py b/ngclearn/components/neurons/spiking/LIFCell.py index b35a01cc..44fd2d0f 100644 --- a/ngclearn/components/neurons/spiking/LIFCell.py +++ b/ngclearn/components/neurons/spiking/LIFCell.py @@ -151,20 +151,15 @@ def __init__( # else: ## default: straight_through # spike_fx, d_spike_fx = straight_through_estimator() - ## Compartment setup restVals = jnp.zeros((self.batch_size, self.n_units)) self.j = Compartment(restVals, display_name="Current", units="mA") - self.v = Compartment(restVals + self.v_rest, - display_name="Voltage", units="mV") + self.v = Compartment(restVals + self.v_rest, display_name="Voltage", units="mV") self.s = Compartment(restVals, display_name="Spikes") self.s_raw = Compartment(restVals, display_name="Raw Spike Pulses") - self.rfr = Compartment(restVals + self.refract_T, - display_name="Refractory Time Period", units="ms") - self.thr_theta = Compartment(restVals, display_name="Threshold Adaptive Shift", - units="mV") - self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", - units="ms") ## time-of-last-spike + self.rfr = Compartment(restVals + self.refract_T, display_name="Refractory Time Period", units="ms") + self.thr_theta = Compartment(restVals, display_name="Threshold Adaptive Shift", units="mV") + self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") ## time-of-last-spike # self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value") @compilable @@ -258,17 +253,13 @@ def help(cls): ## component help function "v_reset": "Reset membrane potential value", "conduct_leak": "Conductance leak / voltage decay factor", "tau_theta": "Threshold/homoestatic increment time constant", - "theta_plus": "Amount to increment threshold by upon occurrence " - "of spike", + "theta_plus": "Amount to increment threshold by upon occurrence of a spike", "refract_time": "Length of relative refractory period (ms)", - "one_spike": "Should only one spike be sampled/allowed to emit at " - "any given time step?", - "integration_type": "Type of numerical integration to use for the " - "cell dynamics", + "one_spike": "Should only one spike be sampled/allowed to emit at any given time step?", + "integration_type": "Type of numerical integration to use for the cell dynamics", "surrgoate_type": "Type of surrogate function to use approximate " "derivative of spike w.r.t. voltage/current", - "lower_bound_clamp": "Should voltage be lower bounded to be never " - "be below `v_rest`" + "v_min": "Minimum voltage allowed before voltage variables are min-clipped/clamped" } info = {cls.__name__: properties, "compartments": compartment_props, diff --git a/ngclearn/components/neurons/spiking/quadLIFCell.py b/ngclearn/components/neurons/spiking/quadLIFCell.py index 638c91d7..7d68ffa3 100755 --- a/ngclearn/components/neurons/spiking/quadLIFCell.py +++ b/ngclearn/components/neurons/spiking/quadLIFCell.py @@ -116,17 +116,19 @@ class QuadLIFCell(LIFCell): ## quadratic integrate-and-fire cell (straight-through estimator), "triangular" (triangular estimator), "arctan" (arc-tangent estimator), and "secant_lif" (the LIF-specialized secant estimator) + + v_min: minimum voltage to clamp dynamics to (Default: None) """ ## batch_size arg? - @deprecate_args(thr_jitter=None, critical_v="critical_V") + #@deprecate_args(thr_jitter=None, critical_v="critical_V") def __init__( self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., v_scale=-41.6, critical_v=1., tau_theta=1e7, theta_plus=0.05, refract_time=5., one_spike=False, integration_type="euler", - surrogate_type="straight_through", lower_clamp_voltage=True, **kwargs + surrogate_type="straight_through", v_min=None, **kwargs ): super().__init__( name, n_units, tau_m, resist_m, thr, v_rest, v_reset, 1., tau_theta, theta_plus, refract_time, - one_spike, integration_type, surrogate_type, lower_clamp_voltage, **kwargs + one_spike, integration_type, surrogate_type, v_min=v_min, **kwargs ) ## only two distinct additional constants distinguish the Quad-LIF cell @@ -197,51 +199,6 @@ def reset(self): self.tols.set(restVals) #self.surrogate.set(restVals) - def save(self, directory, **kwargs): - ## do a protected save of constants, depending on whether they are floats or arrays - tau_m = (self.tau_m if isinstance(self.tau_m, float) - else jnp.asarray([[self.tau_m * 1.]])) - thr = (self.thr if isinstance(self.thr, float) - else jnp.asarray([[self.thr * 1.]])) - v_rest = (self.v_rest if isinstance(self.v_rest, float) - else jnp.asarray([[self.v_rest * 1.]])) - v_reset = (self.v_reset if isinstance(self.v_reset, float) - else jnp.asarray([[self.v_reset * 1.]])) - v_decay = (self.v_decay if isinstance(self.v_decay, float) - else jnp.asarray([[self.v_decay * 1.]])) - resist_m = (self.resist_m if isinstance(self.resist_m, float) - else jnp.asarray([[self.resist_m * 1.]])) - tau_theta = (self.tau_theta if isinstance(self.tau_theta, float) - else jnp.asarray([[self.tau_theta * 1.]])) - theta_plus = (self.theta_plus if isinstance(self.theta_plus, float) - else jnp.asarray([[self.theta_plus * 1.]])) - - file_name = directory + "/" + self.name + ".npz" - jnp.savez(file_name, - threshold_theta=self.thr_theta.value, - tau_m=tau_m, thr=thr, v_rest=v_rest, - v_reset=v_reset, v_decay=v_decay, - resist_m=resist_m, tau_theta=tau_theta, - theta_plus=theta_plus, - key=self.key.value) - - def load(self, directory, seeded=False, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - self.thr_theta.set(data['threshold_theta']) - ## constants loaded in - self.tau_m = data['tau_m'] - self.thr = data['thr'] - self.v_rest = data['v_rest'] - self.v_reset = data['v_reset'] - self.v_decay = data['v_decay'] - self.resist_m = data['resist_m'] - self.tau_theta = data['tau_theta'] - self.theta_plus = data['theta_plus'] - - if seeded: - self.key.set(data['key']) - @classmethod def help(cls): ## component help function properties = { @@ -270,17 +227,13 @@ def help(cls): ## component help function "v_reset": "Reset membrane potential value", "v_decay": "Voltage leak/decay factor", "tau_theta": "Threshold/homoestatic increment time constant", - "theta_plus": "Amount to increment threshold by upon occurrence " - "of spike", + "theta_plus": "Amount to increment threshold by upon occurrence of a spike", "refract_time": "Length of relative refractory period (ms)", - "one_spike": "Should only one spike be sampled/allowed to emit at " - "any given time step?", - "integration_type": "Type of numerical integration to use for the " - "cell dynamics", + "one_spike": "Should only one spike be sampled/allowed to emit at any given time step?", + "integration_type": "Type of numerical integration to use for the cell dynamics", "surrgoate_type": "Type of surrogate function to use approximate " "derivative of spike w.r.t. voltage/current", - "lower_bound_clamp": "Should voltage be lower bounded to be never " - "be below `v_rest`" + "v_min": "Minimum voltage allowed before voltage variables are min-clipped/clamped" } info = {cls.__name__: properties, "compartments": compartment_props, diff --git a/tests/components/neurons/spiking/test_IFCell.py b/tests/components/neurons/spiking/test_IFCell.py index 623a0118..3db38d72 100644 --- a/tests/components/neurons/spiking/test_IFCell.py +++ b/tests/components/neurons/spiking/test_IFCell.py @@ -55,4 +55,4 @@ def clamp(x): ## output should equal input assert_array_equal(outs, y_seq) -test_IFCell1() +#test_IFCell1() diff --git a/tests/components/neurons/spiking/test_quadLIFCell.py b/tests/components/neurons/spiking/test_quadLIFCell.py index 13cfe916..81414b9c 100644 --- a/tests/components/neurons/spiking/test_quadLIFCell.py +++ b/tests/components/neurons/spiking/test_quadLIFCell.py @@ -13,7 +13,6 @@ def test_quadLIFCell1(): dkey = random.PRNGKey(1234) dkey, *subkeys = random.split(dkey, 6) dt = 1. # ms - trace_increment = 0.1 # ---- build a simple Poisson cell system ---- with Context(name) as ctx: a = QuadLIFCell( @@ -54,4 +53,4 @@ def clamp(x): ## output should equal input assert_array_equal(outs, y_seq) -test_quadLIFCell1() +#test_quadLIFCell1() From ba3fb6dfd4e2ebb2bb236195dc39e6fe0630293e Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Mon, 3 Nov 2025 18:24:14 -0500 Subject: [PATCH 015/121] refactored raf-cell and test passed --- .../components/neurons/spiking/RAFCell.py | 2 +- .../neurons/spiking/test_RAFCell.py | 51 +++++++------------ 2 files changed, 20 insertions(+), 33 deletions(-) diff --git a/ngclearn/components/neurons/spiking/RAFCell.py b/ngclearn/components/neurons/spiking/RAFCell.py index 8f41569f..10d8d3b9 100755 --- a/ngclearn/components/neurons/spiking/RAFCell.py +++ b/ngclearn/components/neurons/spiking/RAFCell.py @@ -99,7 +99,7 @@ class RAFCell(JaxComponent): at an increase in computational cost (and simulation time) """ - @deprecate_args(resist_m="resist_v", tau_m="tau_v", b="dampen_factor") + #@deprecate_args(resist_m="resist_v", tau_m="tau_v", b="dampen_factor") def __init__( self, name, n_units, tau_v=1., tau_w=1., thr=1., omega=10., dampen_factor=-1., v_reset=0., w_reset=0., v0=0., w0=0., resist_v=1., integration_type="euler", batch_size=1, **kwargs diff --git a/tests/components/neurons/spiking/test_RAFCell.py b/tests/components/neurons/spiking/test_RAFCell.py index a8a7fbfc..3a076ba6 100644 --- a/tests/components/neurons/spiking/test_RAFCell.py +++ b/tests/components/neurons/spiking/test_RAFCell.py @@ -1,17 +1,11 @@ from jax import numpy as jnp, random, jit from ngcsimlib.context import Context import numpy as np - np.random.seed(42) -from ngclearn.components import RAFCell -from ngcsimlib.compilers import compile_command, wrap_command -from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context -from ngcsimlib.utils.compartment import Get_Compartment_Batch +from ngclearn import Context, MethodProcess +from ngclearn.components.neurons.spiking.RAFCell import RAFCell +from numpy.testing import assert_array_equal def test_RAFCell1(): @@ -26,28 +20,22 @@ def test_RAFCell1(): name="a", n_units=1, tau_v=20., resist_v=1., key=subkeys[0] ) - #""" - advance_process = (Process("advance_proc") + # """ + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - # ctx.wrap_and_add_command(advance_process.pure, name="run") - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") + # ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - #""" - - """ - reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - """ - + # ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") + # """ ## set up non-compiled utility commands - @Context.dynamicCommand - def clamp(x): - a.j.set(x) + # @Context.dynamicCommand + # def clamp(x): + # a.j.set(x) + + def clamp(x): + a.j.set(x) ## input spike train x_seq = jnp.asarray([[0., 1., 0., 0., 0., 0., 1., 0., 0.]], dtype=jnp.float32) @@ -55,14 +43,13 @@ def clamp(x): y_seq = jnp.asarray([[0., 0., 0., 1., 0., 0., 0., 0., 1.]], dtype=jnp.float32) outs = [] - ctx.reset() + reset_process.run() # ctx.reset() for ts in range(x_seq.shape[1]): x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t - ctx.clamp(x_t) - ctx.run(t=ts * 1., dt=dt) - outs.append(a.s.value) + clamp(x_t) # ctx.clamp(x_t) + advance_process.run(t=ts * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt) + outs.append(a.s.get()) outs = jnp.concatenate(outs, axis=1) - #print(outs) ## output should equal input assert_array_equal(outs, y_seq) From 4013bc0f380d9591cd224517d5ad23ac76ededb8 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Tue, 4 Nov 2025 18:18:59 -0500 Subject: [PATCH 016/121] refactored adex/test passed; minor cleanup in lif, raf, and wtas cells --- .../components/neurons/spiking/LIFCell.py | 2 +- .../components/neurons/spiking/RAFCell.py | 4 +- .../components/neurons/spiking/WTASCell.py | 8 +-- .../components/neurons/spiking/adExCell.py | 71 ++++++++++--------- .../neurons/spiking/test_adExCell.py | 56 ++++++--------- 5 files changed, 62 insertions(+), 79 deletions(-) diff --git a/ngclearn/components/neurons/spiking/LIFCell.py b/ngclearn/components/neurons/spiking/LIFCell.py index 44fd2d0f..c205b7ad 100644 --- a/ngclearn/components/neurons/spiking/LIFCell.py +++ b/ngclearn/components/neurons/spiking/LIFCell.py @@ -200,7 +200,7 @@ def advance_state(self, dt, t): thr_theta = _update_theta(dt, self.thr_theta.get(), raw_s, self.tau_theta, self.theta_plus) #.get()) self.thr_theta.set(thr_theta) - ## update tols + ## update time-of-last spike variable(s) self.tols.set((1. - s) * self.tols.get() + (s * t)) if self.v_min is not None: ## ensures voltage never < v_rest diff --git a/ngclearn/components/neurons/spiking/RAFCell.py b/ngclearn/components/neurons/spiking/RAFCell.py index 10d8d3b9..5a807c2c 100755 --- a/ngclearn/components/neurons/spiking/RAFCell.py +++ b/ngclearn/components/neurons/spiking/RAFCell.py @@ -139,9 +139,7 @@ def __init__( ) ## time-of-last-spike @compilable - def advance_state( - self, t, dt - ): + def advance_state(self, t, dt): ## continue with centered dynamics j_ = self.j.get() * self.resist_v if self.intgFlag == 1: ## RK-2/midpoint diff --git a/ngclearn/components/neurons/spiking/WTASCell.py b/ngclearn/components/neurons/spiking/WTASCell.py index 525d078b..8b4f368c 100755 --- a/ngclearn/components/neurons/spiking/WTASCell.py +++ b/ngclearn/components/neurons/spiking/WTASCell.py @@ -87,12 +87,8 @@ def __init__( self.rfr = Compartment(restVals + self.refract_T) self.tols = Compartment(restVals) ## time-of-last-spike - # @transition(output_compartments=["v", "s", "thr", "rfr", "tols"]) - # @staticmethod @compilable - def advance_state( - self, t, dt #, tau_m, R_m, thr_gain, refract_T, j, v, thr, rfr, tols - ): + def advance_state(self, t, dt): mask = (self.rfr.get() >= self.refract_T) * 1. ## check refractory period v = (self.j.get() * self.R_m) * mask vp = softmax(v) # convert to Categorical (spike) probabilities @@ -111,8 +107,6 @@ def advance_state( self.thr.set(thr) self.rfr.set(rfr) - # @transition(output_compartments=["j", "v", "s", "rfr", "tols"]) - # @staticmethod @compilable def reset(self): restVals = jnp.zeros((self.batch_size, self.n_units)) diff --git a/ngclearn/components/neurons/spiking/adExCell.py b/ngclearn/components/neurons/spiking/adExCell.py index fdff5f4c..00c9fc8a 100755 --- a/ngclearn/components/neurons/spiking/adExCell.py +++ b/ngclearn/components/neurons/spiking/adExCell.py @@ -2,12 +2,12 @@ from jax import numpy as jnp, random, jit, nn from functools import partial from ngclearn.utils import tensorstats -from ngcsimlib.deprecators import deprecate_args +from ngcsimlib import deprecate_args from ngcsimlib.logger import info, warn from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ step_euler, step_rk2 -from ngcsimlib.compilers.process import transition -#from ngcsimlib.component import Component + +from ngcsimlib.parser import compilable from ngcsimlib.compartment import Compartment @jit @@ -97,7 +97,7 @@ class AdExCell(JaxComponent): at an increase in computational cost (and simulation time) """ - @deprecate_args(v_thr="thr") + #@deprecate_args(v_thr="thr") def __init__( self, name, n_units, tau_m=15., resist_m=1., tau_w=400., v_sharpness=2., intrinsic_mem_thr=-55., thr=5., v_rest=-72., v_reset=-75., a=0.1, b=0.75, v0=-70., w0=0., integration_type="euler", batch_size=1, **kwargs @@ -136,39 +136,40 @@ def __init__( self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") ## time-of-last-spike - @transition(output_compartments=["j", "v", "w", "s", "tols"]) - @staticmethod - def advance_state( - t, dt, tau_m, R_m, tau_w, thr, a, b, sharpV, vT, v_rest, v_reset, intgFlag, j, v, w, tols - ): - if intgFlag == 1: ## RK-2/midpoint - v_params = (j, w, tau_m, v_rest, sharpV, vT, R_m) - _, _v = step_rk2(0., v, _dfv, dt, v_params) - w_params = (j, v, a, tau_w, v_rest) - _, _w = step_rk2(0., w, _dfw, dt, w_params) + @compilable + def advance_state(self, t, dt): + if self.intgFlag == 1: ## RK-2/midpoint + v_params = (self.j.get(), self.w.get(), self.tau_m, self.v_rest, self.sharpV, self.vT, self.R_m) + _, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params) + w_params = (self.j.get(), self.v.get(), self.a, self.tau_w, self.v_rest) + _, _w = step_rk2(0., self.w.get(), _dfw, dt, w_params) else: # intgFlag == 0 (default -- Euler) - v_params = (j, w, tau_m, v_rest, sharpV, vT, R_m) - _, _v = step_euler(0., v, _dfv, dt, v_params) - w_params = (j, v, a, tau_w, v_rest) - _, _w = step_euler(0., w, _dfw, dt, w_params) - s = (_v > thr) * 1. ## emit spikes/pulses + v_params = (self.j.get(), self.w.get(), self.tau_m, self.v_rest, self.sharpV, self.vT, self.R_m) + _, _v = step_euler(0., self.v.get(), _dfv, dt, v_params) + w_params = (self.j.get(), self.v.get(), self.a, self.tau_w, self.v_rest) + _, _w = step_euler(0., self.w.get(), _dfw, dt, w_params) + s = (_v > self.thr) * 1. ## emit spikes/pulses ## hyperpolarize/reset/snap variables - v = _v * (1. - s) + s * v_reset - w = _w * (1. - s) + s * (_w + b) - - tols = (1. - s) * tols + (s * t) ## update time-of-last spike variable(s) - return j, v, w, s, tols - - @transition(output_compartments=["j", "v", "w", "s", "tols"]) - @staticmethod - def reset(batch_size, n_units, v0, w0): - restVals = jnp.zeros((batch_size, n_units)) - j = restVals # None - v = restVals + v0 - w = restVals + w0 - s = restVals #+ 0 - tols = restVals #+ 0 - return j, v, w, s, tols + v = _v * (1. - s) + s * self.v_reset + w = _w * (1. - s) + s * (_w + self.b) + + ## update time-of-last spike variable(s) + self.tols.set((1. - s) * self.tols.get() + (s * t)) + + #self.j.set(j) ## j is not getting modified in these dynamics + self.v.set(v) + self.w.set(w) + self.s.set(s) + + @compilable + def reset(self): + restVals = jnp.zeros((self.batch_size, self.n_units)) + if not self.j.targeted: + self.j.set(restVals) + self.v.set(restVals + self.v0) + self.w.set(restVals + self.w0) + self.s.set(restVals) + self.tols.set(restVals) @classmethod def help(cls): ## component help function diff --git a/tests/components/neurons/spiking/test_adExCell.py b/tests/components/neurons/spiking/test_adExCell.py index 2c0b9338..cb1dd528 100644 --- a/tests/components/neurons/spiking/test_adExCell.py +++ b/tests/components/neurons/spiking/test_adExCell.py @@ -1,17 +1,11 @@ from jax import numpy as jnp, random, jit from ngcsimlib.context import Context import numpy as np - np.random.seed(42) -from ngclearn.components import AdExCell -from ngcsimlib.compilers import compile_command, wrap_command -from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context -from ngcsimlib.utils.compartment import Get_Compartment_Batch +from ngclearn import Context, MethodProcess +from ngclearn.components.neurons.spiking.adExCell import AdExCell +from numpy.testing import assert_array_equal def test_adExCell1(): @@ -26,28 +20,22 @@ def test_adExCell1(): name="a", n_units=1, tau_m=50., resist_m=30., thr=-66., key=subkeys[0] ) - #""" - advance_process = (Process("advance_proc") + # """ + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - # ctx.wrap_and_add_command(advance_process.pure, name="run") - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") + # ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - #""" - - """ - reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - """ - + # ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") + # """ ## set up non-compiled utility commands - @Context.dynamicCommand - def clamp(x): - a.j.set(x) + # @Context.dynamicCommand + # def clamp(x): + # a.j.set(x) + + def clamp(x): + a.j.set(x) ## input spike train x_seq = jnp.ones((1, 10)) @@ -55,16 +43,18 @@ def clamp(x): y_seq = jnp.asarray([[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]], dtype=jnp.float32) outs = [] - ctx.reset() + reset_process.run() # ctx.reset() for ts in range(x_seq.shape[1]): x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t - ctx.clamp(x_t) - ctx.run(t=ts * 1., dt=dt) - outs.append(a.s.value) + clamp(x_t) # ctx.clamp(x_t) + advance_process.run(t=ts * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt) + outs.append(a.s.get()) + outs = jnp.concatenate(outs, axis=1) - #print(outs) + # print(outs) + # print(y_seq) ## output should equal input assert_array_equal(outs, y_seq) -#test_adExCell1() +test_adExCell1() From 3c0ed36a198343736d123a203d285445e53e9fc8 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Tue, 4 Nov 2025 18:31:34 -0500 Subject: [PATCH 017/121] refactored fn-cell and test passed --- .../components/neurons/spiking/adExCell.py | 2 +- .../neurons/spiking/fitzhughNagumoCell.py | 85 ++++++++++--------- .../spiking/test_fitzhughNagumoCell.py | 56 +++++------- 3 files changed, 67 insertions(+), 76 deletions(-) diff --git a/ngclearn/components/neurons/spiking/adExCell.py b/ngclearn/components/neurons/spiking/adExCell.py index 00c9fc8a..8c7575dd 100755 --- a/ngclearn/components/neurons/spiking/adExCell.py +++ b/ngclearn/components/neurons/spiking/adExCell.py @@ -32,7 +32,7 @@ def _dfw(t, w, params): ## recovery dynamics wrapper dv_dt = _dfw_internal(j, v, w, a, tau_m, v_rest) return dv_dt -class AdExCell(JaxComponent): +class AdExCell(JaxComponent): ## adaptive exponential integrate-and-fire cell """ The AdEx (adaptive exponential leaky integrate-and-fire) neuronal cell model; a two-variable model. This cell model iteratively evolves diff --git a/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py b/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py index 2cab7f56..e3549da7 100755 --- a/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py +++ b/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py @@ -2,13 +2,12 @@ from jax import numpy as jnp, random, jit, nn from functools import partial from ngclearn.utils import tensorstats -from ngcsimlib.deprecators import deprecate_args +from ngcsimlib import deprecate_args from ngcsimlib.logger import info, warn from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ step_euler, step_rk2 -from ngcsimlib.compilers.process import transition -#from ngcsimlib.component import Component +from ngcsimlib.parser import compilable from ngcsimlib.compartment import Compartment @@ -34,7 +33,7 @@ def _dfw(t, w, params): ## recovery dynamics wrapper dv_dt = _dfw_internal(j, v, w, a, b, g, tau_m) return dv_dt -class FitzhughNagumoCell(JaxComponent): +class FitzhughNagumoCell(JaxComponent): ## F-H cell """ The Fitzhugh-Nagumo neuronal cell model; a two-variable simplification of the Hodgkin-Huxley (squid axon) model. This cell model iteratively evolves @@ -103,10 +102,10 @@ class FitzhughNagumoCell(JaxComponent): at an increase in computational cost (and simulation time) """ - # Define Functions - def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7, - beta=0.8, gamma=3., v0=0., w0=0., v_thr=1.07, spike_reset=False, - integration_type="euler", **kwargs): + def __init__( + self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7, beta=0.8, gamma=3., v0=0., w0=0., + v_thr=1.07, spike_reset=False, integration_type="euler", **kwargs + ): super().__init__(name, **kwargs) ## Integration properties @@ -115,7 +114,7 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7, ## Cell properties self.tau_m = tau_m - self.R_m = resist_m + self.resist_m = resist_m ## resistance R_m self.tau_w = tau_w self.alpha = alpha self.beta = beta @@ -138,41 +137,44 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7, self.s = Compartment(restVals) self.tols = Compartment(restVals) ## time-of-last-spike - @transition(output_compartments=["j", "v", "w", "s", "tols"]) - @staticmethod - def advance_state(t, dt, tau_m, R_m, tau_w, v_thr, spike_reset, v0, w0, alpha, - beta, gamma, intgFlag, j, v, w, tols): - j_mod = j * R_m - if intgFlag == 1: - v_params = (j_mod, w, alpha, beta, gamma, tau_m) - _, _v = step_rk2(0., v, _dfv, dt, v_params) # _v = step_rk2(v, v_params, _dfv, dt) - w_params = (j_mod, v, alpha, beta, gamma, tau_w) - _, _w = step_rk2(0., w, _dfw, dt, w_params) # _w = step_rk2(w, w_params, _dfw, dt) + @compilable + def advance_state(self, t, dt): + j_mod = self.j.get() * self.resist_m + if self.intgFlag == 1: + v_params = (j_mod, self.w.get(), self.alpha, self.beta, self.gamma, self.tau_m) + _, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params) # _v = step_rk2(v, v_params, _dfv, dt) + w_params = (j_mod, self.v.get(), self.alpha, self.beta, self.gamma, self.tau_w) + _, _w = step_rk2(0., self.w.get(), _dfw, dt, w_params) # _w = step_rk2(w, w_params, _dfw, dt) else: # integType == 0 (default -- Euler) - v_params = (j_mod, w, alpha, beta, gamma, tau_m) - _, _v = step_euler(0., v, _dfv, dt, v_params) # _v = step_euler(v, v_params, _dfv, dt) - w_params = (j_mod, v, alpha, beta, gamma, tau_w) - _, _w = step_euler(0., w, _dfw, dt, w_params) # _w = step_euler(w, w_params, _dfw, dt) - s = (_v > v_thr) * 1. + v_params = (j_mod, self.w.get(), self.alpha, self.beta, self.gamma, self.tau_m) + _, _v = step_euler(0., self.v.get(), _dfv, dt, v_params) # _v = step_euler(v, v_params, _dfv, dt) + w_params = (j_mod, self.v.get(), self.alpha, self.beta, self.gamma, self.tau_w) + _, _w = step_euler(0., self.w.get(), _dfw, dt, w_params) # _w = step_euler(w, w_params, _dfw, dt) + s = (_v > self.v_thr) * 1. v = _v w = _w - if spike_reset: ## if spike-reset used, variables snapped back to initial conditions - v = v * (1. - s) + s * v0 - w = w * (1. - s) + s * w0 - tols = (1. - s) * tols + (s * t) ## update tols - return j, v, w, s, tols - - @transition(output_compartments=["j", "v", "w", "s", "tols"]) - @staticmethod - def reset(batch_size, n_units, v0, w0): - restVals = jnp.zeros((batch_size, n_units)) - j = restVals # None - v = restVals + v0 - w = restVals + w0 - s = restVals #+ 0 - tols = restVals #+ 0 - return j, v, w, s, tols + if self.spike_reset: ## if spike-reset used, variables snapped back to initial conditions + v = v * (1. - s) + s * self.v0 + w = w * (1. - s) + s * self.w0 + + ## update time-of-last spike variable(s) + self.tols.set((1. - s) * self.tols.get() + (s * t)) + + # self.j.set(j) ## j is not getting modified in these dynamics + self.v.set(v) + self.w.set(w) + self.s.set(s) + + @compilable + def reset(self): + restVals = jnp.zeros((self.batch_size, self.n_units)) + if not self.j.targeted: + self.j.set(restVals) + self.v.set(restVals + self.v0) + self.w.set(restVals + self.w0) + self.s.set(restVals) + self.tols.set(restVals) @classmethod def help(cls): ## component help function @@ -197,8 +199,7 @@ def help(cls): ## component help function "resist_m": "Membrane resistance value", "tau_w": "Recovery variable time constant", "v_thr": "Base voltage threshold value", - "spike_reset": "Should voltage/recover be snapped to initial " - "condition(s) if spike emitted?", + "spike_reset": "Should voltage/recover be snapped to initial condition(s) if spike emitted?", "alpha": "Dimensionless recovery variable shift factor `a", "beta": "Dimensionless recovery variable scale factor `b`", "gamma": "Power-term divisor constant", diff --git a/tests/components/neurons/spiking/test_fitzhughNagumoCell.py b/tests/components/neurons/spiking/test_fitzhughNagumoCell.py index eecc28e5..5ca0f489 100644 --- a/tests/components/neurons/spiking/test_fitzhughNagumoCell.py +++ b/tests/components/neurons/spiking/test_fitzhughNagumoCell.py @@ -1,17 +1,11 @@ from jax import numpy as jnp, random, jit from ngcsimlib.context import Context import numpy as np - np.random.seed(42) -from ngclearn.components import FitzhughNagumoCell -from ngcsimlib.compilers import compile_command, wrap_command -from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context -from ngcsimlib.utils.compartment import Get_Compartment_Batch +from ngclearn import Context, MethodProcess +from ngclearn.components.neurons.spiking.fitzhughNagumoCell import FitzhughNagumoCell +from numpy.testing import assert_array_equal def test_fitzhughNagumoCell1(): @@ -26,28 +20,22 @@ def test_fitzhughNagumoCell1(): name="a", n_units=1, tau_m=1., resist_m=5., v_thr=2.1, key=subkeys[0] ) - #""" - advance_process = (Process("advance_proc") + # """ + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - # ctx.wrap_and_add_command(advance_process.pure, name="run") - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") + # ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - #""" - - """ - reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - """ - + # ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") + # """ ## set up non-compiled utility commands - @Context.dynamicCommand - def clamp(x): - a.j.set(x) + # @Context.dynamicCommand + # def clamp(x): + # a.j.set(x) + + def clamp(x): + a.j.set(x) ## input spike train x_seq = jnp.asarray([[0., 0., 1., 1., 1., 1., 0., 0., 0., 0.]], dtype=jnp.float32) @@ -55,14 +43,16 @@ def clamp(x): y_seq = jnp.asarray([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]], dtype=jnp.float32) outs = [] - ctx.reset() + reset_process.run() # ctx.reset() for ts in range(x_seq.shape[1]): - x_t = x_seq[:, ts:ts+1] ## get data at time t - ctx.clamp(x_t) - ctx.run(t=ts * 1., dt=dt) - outs.append(a.s.value) + x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t + clamp(x_t) # ctx.clamp(x_t) + advance_process.run(t=ts * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt) + outs.append(a.s.get()) + outs = jnp.concatenate(outs, axis=1) - #print(outs) + # print(outs) + # print(y_seq) ## output should equal input assert_array_equal(outs, y_seq) From 37ecdcdd3993ac90e78d1bf2d0b31d32ac0b0de0 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Tue, 4 Nov 2025 18:52:24 -0500 Subject: [PATCH 018/121] cleaned up lif, raf, wtas, fn, and quad-lif cells repr method --- ngclearn/components/neurons/spiking/LIFCell.py | 14 ++++++++++++++ ngclearn/components/neurons/spiking/RAFCell.py | 2 +- ngclearn/components/neurons/spiking/WTASCell.py | 4 ++-- ngclearn/components/neurons/spiking/adExCell.py | 2 +- .../neurons/spiking/fitzhughNagumoCell.py | 2 +- ngclearn/components/neurons/spiking/quadLIFCell.py | 2 +- 6 files changed, 20 insertions(+), 6 deletions(-) diff --git a/ngclearn/components/neurons/spiking/LIFCell.py b/ngclearn/components/neurons/spiking/LIFCell.py index c205b7ad..30435271 100644 --- a/ngclearn/components/neurons/spiking/LIFCell.py +++ b/ngclearn/components/neurons/spiking/LIFCell.py @@ -1,5 +1,6 @@ from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, random, nn, Array +from ngclearn.utils import tensorstats from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ step_euler, step_rk2 from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator, @@ -267,6 +268,19 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info + def __repr__(self): + comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] + maxlen = max(len(c) for c in comps) + 5 + lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" + for c in comps: + stats = tensorstats(getattr(self, c).value) + if stats is not None: + line = [f"{k}: {v}" for k, v in stats.items()] + line = ", ".join(line) + else: + line = "None" + lines += f" {f'({c})'.ljust(maxlen)}{line}\n" + return lines if __name__ == '__main__': from ngcsimlib.context import Context diff --git a/ngclearn/components/neurons/spiking/RAFCell.py b/ngclearn/components/neurons/spiking/RAFCell.py index 5a807c2c..b478c2dc 100755 --- a/ngclearn/components/neurons/spiking/RAFCell.py +++ b/ngclearn/components/neurons/spiking/RAFCell.py @@ -216,7 +216,7 @@ def help(cls): ## component help function return info def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] maxlen = max(len(c) for c in comps) + 5 lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" for c in comps: diff --git a/ngclearn/components/neurons/spiking/WTASCell.py b/ngclearn/components/neurons/spiking/WTASCell.py index 8b4f368c..835e3d03 100755 --- a/ngclearn/components/neurons/spiking/WTASCell.py +++ b/ngclearn/components/neurons/spiking/WTASCell.py @@ -160,11 +160,11 @@ def help(cls): ## component help function return info def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] maxlen = max(len(c) for c in comps) + 5 lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" for c in comps: - stats = tensorstats(getattr(self, c).get()) + stats = tensorstats(getattr(self, c).value) if stats is not None: line = [f"{k}: {v}" for k, v in stats.items()] line = ", ".join(line) diff --git a/ngclearn/components/neurons/spiking/adExCell.py b/ngclearn/components/neurons/spiking/adExCell.py index 8c7575dd..32ff02df 100755 --- a/ngclearn/components/neurons/spiking/adExCell.py +++ b/ngclearn/components/neurons/spiking/adExCell.py @@ -213,7 +213,7 @@ def help(cls): ## component help function return info def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] maxlen = max(len(c) for c in comps) + 5 lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" for c in comps: diff --git a/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py b/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py index e3549da7..64cb9f1d 100755 --- a/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py +++ b/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py @@ -215,7 +215,7 @@ def help(cls): ## component help function return info def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] maxlen = max(len(c) for c in comps) + 5 lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" for c in comps: diff --git a/ngclearn/components/neurons/spiking/quadLIFCell.py b/ngclearn/components/neurons/spiking/quadLIFCell.py index 7d68ffa3..f724bcc0 100755 --- a/ngclearn/components/neurons/spiking/quadLIFCell.py +++ b/ngclearn/components/neurons/spiking/quadLIFCell.py @@ -242,7 +242,7 @@ def help(cls): ## component help function return info def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] maxlen = max(len(c) for c in comps) + 5 lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" for c in comps: From 947d2cffe9e1cfd5ea503bf05008c2257f6ffe3d Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Tue, 4 Nov 2025 19:17:32 -0500 Subject: [PATCH 019/121] refactored and tests passed for izh and h-h cells --- .../components/neurons/spiking/adExCell.py | 3 +- .../neurons/spiking/hodgkinHuxleyCell.py | 120 ++++++++++-------- .../neurons/spiking/izhikevichCell.py | 79 ++++++------ .../neurons/spiking/test_hodgkinHuxleyCell.py | 49 +++---- .../neurons/spiking/test_izhikevichCell.py | 58 ++++----- 5 files changed, 150 insertions(+), 159 deletions(-) diff --git a/ngclearn/components/neurons/spiking/adExCell.py b/ngclearn/components/neurons/spiking/adExCell.py index 32ff02df..af34161b 100755 --- a/ngclearn/components/neurons/spiking/adExCell.py +++ b/ngclearn/components/neurons/spiking/adExCell.py @@ -4,8 +4,7 @@ from ngclearn.utils import tensorstats from ngcsimlib import deprecate_args from ngcsimlib.logger import info, warn -from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ - step_euler, step_rk2 +from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2 from ngcsimlib.parser import compilable from ngcsimlib.compartment import Compartment diff --git a/ngclearn/components/neurons/spiking/hodgkinHuxleyCell.py b/ngclearn/components/neurons/spiking/hodgkinHuxleyCell.py index 29ab648e..76e2e5ef 100644 --- a/ngclearn/components/neurons/spiking/hodgkinHuxleyCell.py +++ b/ngclearn/components/neurons/spiking/hodgkinHuxleyCell.py @@ -2,13 +2,11 @@ from jax import numpy as jnp, random, jit, nn from functools import partial from ngclearn.utils import tensorstats -from ngcsimlib.deprecators import deprecate_args +from ngcsimlib import deprecate_args from ngcsimlib.logger import info, warn -from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ - step_euler, step_rk2, step_rk4 +from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2, step_rk4 -from ngcsimlib.compilers.process import transition -#from ngcsimlib.component import Component +from ngcsimlib.parser import compilable from ngcsimlib.compartment import Compartment @@ -113,7 +111,6 @@ class HodgkinHuxleyCell(JaxComponent): ## Hodgkin-Huxley spiking cell at an increase in computational cost (and simulation time) """ - # Define Functions def __init__( self, name, n_units, tau_v, resist_m=1., v_Na=115., v_K=-35., v_L=10.6, g_Na=100., g_K=5., g_L=0.3, thr=4., spike_reset=False, v_reset=0., integration_type="euler", **kwargs @@ -126,7 +123,7 @@ def __init__( ## cell properties / biophysical parameter setup (affects ODE integration) self.tau_v = tau_v ## membrane time constant - self.R_m = resist_m ## resistance value + self.resist_m = resist_m ## resistance value R_m self.spike_reset = spike_reset self.thr = thr # mV ## base value for threshold self.v_reset = v_reset ## base value to reset voltage to (if spike_reset = True) @@ -151,38 +148,49 @@ def __init__( self.s = Compartment(restVals, display_name="Spike pulse") self.tols = Compartment(restVals, display_name="Time-of-last-spike") ## time-of-last-spike - @transition(output_compartments=["v", "m", "n", "h", "s", "tols"]) - @staticmethod - def advance_state( - t, dt, spike_reset, v_reset, thr, tau_v, R_m, g_Na, g_K, g_L, v_Na, v_K, v_L, j, v, m, n, h, tols, intgFlag - ): - _j = j * R_m - alpha_n_of_v, beta_n_of_v, alpha_m_of_v, beta_m_of_v, alpha_h_of_v, beta_h_of_v = _calc_biophysical_constants(v) + #@transition(output_compartments=["v", "m", "n", "h", "s", "tols"]) + #@staticmethod + @compilable + def advance_state(self, t, dt): #t, dt, spike_reset, v_reset, thr, tau_v, R_m, g_Na, g_K, g_L, v_Na, v_K, v_L, j, v, m, n, h, tols, intgFlag + _j = self.j.get() * self.resist_m + alpha_n_of_v, beta_n_of_v, alpha_m_of_v, beta_m_of_v, alpha_h_of_v, beta_h_of_v = _calc_biophysical_constants(self.v.get()) ## integrate voltage / membrane potential - if intgFlag == 1: ## midpoint method - _, _v = step_rk2(0., v, dv_dt, dt, (_j, m + 0., n + 0., h + 0., tau_v, g_Na, g_K, g_L, v_Na, v_K, v_L)) + if self.intgFlag == 1: ## midpoint method + _, _v = step_rk2( + 0., self.v.get(), dv_dt, dt, + (_j, self.m.get() + 0., self.n.get() + 0., self.h.get() + 0., self.tau_v, self.g_Na, self.g_K, + self.g_L, self.v_Na, self.v_K, self.v_L) + ) ## next, integrate different channels - _, _n = step_rk2(0., n, dx_dt, dt, (alpha_n_of_v, beta_n_of_v)) - _, _m = step_rk2(0., m, dx_dt, dt, (alpha_m_of_v, beta_m_of_v)) - _, _h = step_rk2(0., h, dx_dt, dt, (alpha_h_of_v, beta_h_of_v)) - elif intgFlag == 4: ## Runge-Kutta 4th order - _, _v = step_rk4(0., v, dv_dt, dt, (_j, m + 0., n + 0., h + 0., tau_v, g_Na, g_K, g_L, v_Na, v_K, v_L)) + _, _n = step_rk2(0., self.n.get(), dx_dt, dt, (alpha_n_of_v, beta_n_of_v)) + _, _m = step_rk2(0., self.m.get(), dx_dt, dt, (alpha_m_of_v, beta_m_of_v)) + _, _h = step_rk2(0., self.h.get(), dx_dt, dt, (alpha_h_of_v, beta_h_of_v)) + elif self.intgFlag == 4: ## Runge-Kutta 4th order + _, _v = step_rk4( + 0., self.v.get(), dv_dt, dt, + (_j, self.m.get() + 0., self.n.get() + 0., self.h.get() + 0., self.tau_v, self.g_Na, self.g_K, + self.g_L, self.v_Na, self.v_K, self.v_L) + ) ## next, integrate different channels - _, _n = step_rk4(0., n, dx_dt, dt, (alpha_n_of_v, beta_n_of_v)) - _, _m = step_rk4(0., m, dx_dt, dt, (alpha_m_of_v, beta_m_of_v)) - _, _h = step_rk4(0., h, dx_dt, dt, (alpha_h_of_v, beta_h_of_v)) + _, _n = step_rk4(0., self.n.get(), dx_dt, dt, (alpha_n_of_v, beta_n_of_v)) + _, _m = step_rk4(0., self.m.get(), dx_dt, dt, (alpha_m_of_v, beta_m_of_v)) + _, _h = step_rk4(0., self.h.get(), dx_dt, dt, (alpha_h_of_v, beta_h_of_v)) else: # integType == 0 (default -- Euler) - _, _v = step_euler(0., v, dv_dt, dt, (_j, m + 0., n + 0., h + 0., tau_v, g_Na, g_K, g_L, v_Na, v_K, v_L)) + _, _v = step_euler( + 0., self.v.get(), dv_dt, dt, + (_j, self.m.get() + 0., self.n.get() + 0., self.h.get() + 0., self.tau_v, self.g_Na, self.g_K, + self.g_L, self.v_Na, self.v_K, self.v_L) + ) ## next, integrate different channels - _, _n = step_euler(0., n, dx_dt, dt, (alpha_n_of_v, beta_n_of_v)) - _, _m = step_euler(0., m, dx_dt, dt, (alpha_m_of_v, beta_m_of_v)) - _, _h = step_euler(0., h, dx_dt, dt, (alpha_h_of_v, beta_h_of_v)) + _, _n = step_euler(0., self.n.get(), dx_dt, dt, (alpha_n_of_v, beta_n_of_v)) + _, _m = step_euler(0., self.m.get(), dx_dt, dt, (alpha_m_of_v, beta_m_of_v)) + _, _h = step_euler(0., self.h.get(), dx_dt, dt, (alpha_h_of_v, beta_h_of_v)) ## obtain action potentials/spikes/pulses - s = (_v > thr) * 1. - if spike_reset: ## if spike-reset used, variables snapped back to initial conditions + s = (_v > self.thr) * 1. + if self.spike_reset: ## if spike-reset used, variables snapped back to initial conditions alpha_n_of_v, beta_n_of_v, alpha_m_of_v, beta_m_of_v, alpha_h_of_v, beta_h_of_v = ( - _calc_biophysical_constants(v * 0 + v_reset)) - _v = _v * (1. - s) + s * v_reset + _calc_biophysical_constants(self.v.get() * 0 + self.v_reset)) + _v = _v * (1. - s) + s * self.v_reset _n = _n * (1. - s) + s * (alpha_n_of_v / (alpha_n_of_v + beta_n_of_v)) _m = _m * (1. - s) + s * (alpha_m_of_v / (alpha_m_of_v + beta_m_of_v)) _h = _h * (1. - s) + s * (alpha_h_of_v / (alpha_h_of_v + beta_h_of_v)) @@ -191,32 +199,40 @@ def advance_state( m = _m n = _n h = _h - tols = (1. - s) * tols + (s * t) ## update tols + ## update time-of-last spike variable(s) + self.tols.set((1. - s) * self.tols.get() + (s * t)) - return v, m, n, h, s, tols + self.v.set(v) + self.m.set(m) + self.n.set(n) + self.h.set(h) + self.s.set(s) - @transition(output_compartments=["j", "v", "m", "n", "h", "s", "tols"]) - @staticmethod - def reset(batch_size, n_units): - restVals = jnp.zeros((batch_size, n_units)) + @compilable + def reset(self): + restVals = jnp.zeros((self.batch_size, self.n_units)) v = restVals # + 0 alpha_n_of_v, beta_n_of_v, alpha_m_of_v, beta_m_of_v, alpha_h_of_v, beta_h_of_v = _calc_biophysical_constants(v) - j = restVals #+ 0 + if not self.j.targeted: + self.j.set(restVals) n = alpha_n_of_v / (alpha_n_of_v + beta_n_of_v) m = alpha_m_of_v / (alpha_m_of_v + beta_m_of_v) h = alpha_h_of_v / (alpha_h_of_v + beta_h_of_v) - s = restVals #+ 0 - tols = restVals #+ 0 - return j, v, m, n, h, s, tols - - def save(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - #jnp.savez(file_name, threshold=self.thr.value) - - def load(self, directory, seeded=False, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - #self.thr.set( data['threshold'] ) + self.v.set(v) + self.n.set(n) + self.m.set(m) + self.h.set(h) + self.s.set(restVals) + self.tols.set(restVals) + + # def save(self, directory, **kwargs): + # file_name = directory + "/" + self.name + ".npz" + # #jnp.savez(file_name, threshold=self.thr.value) + # + # def load(self, directory, seeded=False, **kwargs): + # file_name = directory + "/" + self.name + ".npz" + # data = jnp.load(file_name) + # #self.thr.set( data['threshold'] ) @classmethod def help(cls): ## component help function @@ -258,7 +274,7 @@ def help(cls): ## component help function return info def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] maxlen = max(len(c) for c in comps) + 5 lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" for c in comps: diff --git a/ngclearn/components/neurons/spiking/izhikevichCell.py b/ngclearn/components/neurons/spiking/izhikevichCell.py index 0027f314..38fbab4c 100755 --- a/ngclearn/components/neurons/spiking/izhikevichCell.py +++ b/ngclearn/components/neurons/spiking/izhikevichCell.py @@ -2,16 +2,13 @@ from jax import numpy as jnp, random, jit, nn from functools import partial from ngclearn.utils import tensorstats -from ngcsimlib.deprecators import deprecate_args +from ngcsimlib import deprecate_args from ngcsimlib.logger import info, warn -from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ - step_euler, step_rk2 +from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2 -from ngcsimlib.compilers.process import transition -#from ngcsimlib.component import Component +from ngcsimlib.parser import compilable from ngcsimlib.compartment import Compartment - @jit def _dfv_internal(j, v, w, b, tau_m): ## raw voltage dynamics ## (v^2 * 0.04 + v * 5 + 140 - u + j) * a, where a = (1./tau_m) (w = u) @@ -119,17 +116,16 @@ class IzhikevichCell(JaxComponent): ## Izhikevich neuronal cell at an increase in computational cost (and simulation time) """ - # Define Functions def __init__(self, name, n_units, tau_m=1., resist_m=1., v_thr=30., v_reset=-65., tau_w=50., w_reset=8., coupling_factor=0.2, v0=-65., w0=-14., integration_type="euler", **kwargs): super().__init__(name, **kwargs) ## Cell properties - self.R_m = resist_m + self.resist_m = resist_m ## resistance R_m self.tau_m = tau_m self.tau_w = tau_w - self.coupling = coupling_factor + self.coupling_factor = coupling_factor self.v_reset = v_reset self.w_reset = w_reset @@ -153,45 +149,47 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., v_thr=30., v_reset=-65. self.s = Compartment(restVals) self.tols = Compartment(restVals) ## time-of-last-spike - @transition(output_compartments=["j", "v", "w", "s", "tols"]) - @staticmethod - def advance_state(t, dt, tau_m, tau_w, v_thr, coupling, v_reset, w_reset, R_m, - intgFlag, j, v, w, s, tols): + @compilable + def advance_state(self, t, dt): ## note: a = 0.1 --> fast spikes, a = 0.02 --> regular spikes - a = 1. / tau_w ## we map time constant to variable "a" (a = 1/tau_w) - _j = j * R_m + a = 1. / self.tau_w ## we map time constant to variable "a" (a = 1/tau_w) + _j = self.j.get() * self.resist_m # _j = jnp.maximum(-30.0, _j) ## lower-bound/clip input current ## check for spikes - s = (v > v_thr) * 1. + s = (self.v.get() > self.v_thr) * 1. ## for non-spikes, evolve according to dynamics - if intgFlag == 1: - v_params = (_j, w, coupling, tau_m) - _, _v = step_rk2(0., v, _dfv, dt, v_params) # _v = step_rk2(v, v_params, _dfv, dt) - w_params = (_j, v, coupling, tau_w) - _, _w = step_rk2(0., w, _dfw, dt, w_params) # _w = step_rk2(w, w_params, _dfw, dt) + if self.intgFlag == 1: + v_params = (_j, self.w.get(), self.coupling_factor, self.tau_m) + _, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params) # _v = step_rk2(v, v_params, _dfv, dt) + w_params = (_j, self.v.get(), self.coupling_factor, self.tau_w) + _, _w = step_rk2(0., self.w.get(), _dfw, dt, w_params) # _w = step_rk2(w, w_params, _dfw, dt) else: # integType == 0 (default -- Euler) - v_params = (_j, w, coupling, tau_m) - _, _v = step_euler(0., v, _dfv, dt, v_params) # _v = step_euler(v, v_params, _dfv, dt) - w_params = (_j, v, coupling, tau_w) - _, _w = step_euler(0., w, _dfw, dt, w_params) # _w = step_euler(w, w_params, _dfw, dt) + v_params = (_j, self.w.get(), self.coupling_factor, self.tau_m) + _, _v = step_euler(0., self.v.get(), _dfv, dt, v_params) # _v = step_euler(v, v_params, _dfv, dt) + w_params = (_j, self.v.get(), self.coupling_factor, self.tau_w) + _, _w = step_euler(0., self.w.get(), _dfw, dt, w_params) # _w = step_euler(w, w_params, _dfw, dt) ## for spikes, snap to particular states - _v, _w = _post_process(s, _v, _w, v, w, v_reset, w_reset) + _v, _w = _post_process(s, _v, _w, self.v.get(), self.w.get(), self.v_reset, self.w_reset) v = _v w = _w - tols = (1. - s) * tols + (s * t) ## update tols - return j, v, w, s, tols + ## update time-of-last spike variable(s) + self.tols.set((1. - s) * self.tols.get() + (s * t)) + + # self.j.set(j) ## j is not getting modified in these dynamics + self.v.set(v) + self.w.set(w) + self.s.set(s) - @transition(output_compartments=["j", "v", "w", "s", "tols"]) - @staticmethod - def reset(batch_size, n_units, v0, w0): - restVals = jnp.zeros((batch_size, n_units)) - j = restVals # None - v = restVals + v0 - w = restVals + w0 - s = restVals #+ 0 - tols = restVals #+ 0 - return j, v, w, s, tols + @compilable + def reset(self): + restVals = jnp.zeros((self.batch_size, self.n_units)) + if not self.j.targeted: + self.j.set(restVals) + self.v.set(restVals + self.v0) + self.w.set(restVals + self.w0) + self.s.set(restVals) + self.tols.set(restVals) @classmethod def help(cls): ## component help function @@ -219,8 +217,7 @@ def help(cls): ## component help function "v_rest": "Resting membrane potential value", "v_reset": "Reset membrane potential value", "w_reset": "Reset recover variable value", - "coupling_factor": "Degree to which recovery variable is sensitive to " - "subthreshold voltage fluctuations", + "coupling_factor": "Degree to which recovery variable is sensitive to subthreshold voltage fluctuations", "v0": "Initial condition for membrane potential/voltage", "w0": "Initial condition for recovery variable", "integration_type": "Type of numerical integration to use for the cell dynamics" @@ -233,7 +230,7 @@ def help(cls): ## component help function return info def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] maxlen = max(len(c) for c in comps) + 5 lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" for c in comps: diff --git a/tests/components/neurons/spiking/test_hodgkinHuxleyCell.py b/tests/components/neurons/spiking/test_hodgkinHuxleyCell.py index d86c3fd0..aeb80c48 100644 --- a/tests/components/neurons/spiking/test_hodgkinHuxleyCell.py +++ b/tests/components/neurons/spiking/test_hodgkinHuxleyCell.py @@ -1,17 +1,11 @@ from jax import numpy as jnp, random, jit from ngcsimlib.context import Context import numpy as np - np.random.seed(42) -from ngclearn.components import HodgkinHuxleyCell -from ngcsimlib.compilers import compile_command, wrap_command -from numpy.testing import assert_array_almost_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context -from ngcsimlib.utils.compartment import Get_Compartment_Batch +from ngclearn import Context, MethodProcess +from ngclearn.components.neurons.spiking.hodgkinHuxleyCell import HodgkinHuxleyCell +from numpy.testing import assert_array_almost_equal import matplotlib.pyplot as plt @@ -30,27 +24,21 @@ def test_hodgkinHuxleyCell1(): ) # """ - advance_process = (Process("advance_proc") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - # ctx.wrap_and_add_command(advance_process.pure, name="run") - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") + # ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") + # ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") # """ - - """ - reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - """ - ## set up non-compiled utility commands - @Context.dynamicCommand - def clamp(x): - a.j.set(x) + # @Context.dynamicCommand + # def clamp(x): + # a.j.set(x) + + def clamp(x): + a.j.set(x) ## input spike train x_seq = jnp.zeros((1, 20)) @@ -61,12 +49,15 @@ def clamp(x): 0.40085957, 0.42394499, 0.44698984, 0.46999594]], dtype=jnp.float32) v = [] - ctx.reset() + reset_process.run() # ctx.reset() for ts in range(x_seq.shape[1]): x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t - ctx.clamp(x_t) - ctx.run(t=ts * 1., dt=dt) - v.append(a.v.value[0, 0]) + clamp(x_t) # ctx.clamp(x_t) + advance_process.run(t=ts * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt) + v.append(a.v.get()[0, 0]) + # print(outs) + # print(y_seq) + outs = jnp.array(v) diff = np.abs(outs - y_seq) ## delta/error should be approximately zero diff --git a/tests/components/neurons/spiking/test_izhikevichCell.py b/tests/components/neurons/spiking/test_izhikevichCell.py index 165752d9..04ec6bcb 100644 --- a/tests/components/neurons/spiking/test_izhikevichCell.py +++ b/tests/components/neurons/spiking/test_izhikevichCell.py @@ -1,17 +1,11 @@ from jax import numpy as jnp, random, jit from ngcsimlib.context import Context import numpy as np - np.random.seed(42) -from ngclearn.components import IzhikevichCell -from ngcsimlib.compilers import compile_command, wrap_command -from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context -from ngcsimlib.utils.compartment import Get_Compartment_Batch +from ngclearn import Context, MethodProcess +from ngclearn.components.neurons.spiking.izhikevichCell import IzhikevichCell +from numpy.testing import assert_array_equal def test_izhikevichCell1(): @@ -26,28 +20,22 @@ def test_izhikevichCell1(): name="a", n_units=1, tau_m=1., resist_m=4., v_thr=30., key=subkeys[0] ) - #""" - advance_process = (Process("advance_proc") + # """ + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - # ctx.wrap_and_add_command(advance_process.pure, name="run") - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") + # ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - #""" - - """ - reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - """ - + # ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") + # """ ## set up non-compiled utility commands - @Context.dynamicCommand - def clamp(x): - a.j.set(x) + # @Context.dynamicCommand + # def clamp(x): + # a.j.set(x) + + def clamp(x): + a.j.set(x) ## input spike train x_seq = jnp.asarray([[0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]], dtype=jnp.float32) @@ -55,16 +43,16 @@ def clamp(x): y_seq = jnp.asarray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]], dtype=jnp.float32) outs = [] - ctx.reset() + reset_process.run() # ctx.reset() for ts in range(x_seq.shape[1]): - x_t = x_seq[:, ts:ts+1] ## get data at time t - ctx.clamp(x_t) - ctx.run(t=ts * 1., dt=dt) - outs.append(a.s.value) - print(a.v.value) + x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t + clamp(x_t) # ctx.clamp(x_t) + advance_process.run(t=ts * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt) + outs.append(a.s.get()) outs = jnp.concatenate(outs, axis=1) - print(outs) - #exit() + # print(outs) + # print(y_seq) + ## output should equal input assert_array_equal(outs, y_seq) From 058bf90a2705a186ab4fc12e01ac618fb08e3d02 Mon Sep 17 00:00:00 2001 From: Will Gebhardt Date: Thu, 6 Nov 2025 09:22:02 -0500 Subject: [PATCH 020/121] JaxProcess update --- ngclearn/utils/JaxProcessesMixin.py | 41 +++++++ ngclearn/utils/jaxProcess.py | 171 ---------------------------- 2 files changed, 41 insertions(+), 171 deletions(-) create mode 100644 ngclearn/utils/JaxProcessesMixin.py delete mode 100644 ngclearn/utils/jaxProcess.py diff --git a/ngclearn/utils/JaxProcessesMixin.py b/ngclearn/utils/JaxProcessesMixin.py new file mode 100644 index 00000000..ae1a655c --- /dev/null +++ b/ngclearn/utils/JaxProcessesMixin.py @@ -0,0 +1,41 @@ +from ngcsimlib import JointProcess, MethodProcess +from ngcsimlib.global_state import stateManager +import jax +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ngcsimlib._src.process.baseProcess import BaseProcess + +class JaxProcessesMixin: + def __init__(self: "BaseProcess"): + self._previous_result = None + self._previous_state = None + + @property + def previous_result(self): + return self._previous_result + + @property + def previous_state(self): + return self._previous_state + + def clear(self): + self._previous_result = None + self._previous_state = None + + + def scan(self: "BaseProcess", inputs, current_state=None, save_state: bool = True, store_results: bool = True): + state = current_state or stateManager.state + final_state, result = jax.lax.scan(self.run.compiled, state, inputs) + if save_state: + self._previous_state = final_state + if store_results: + self._previous_result = result + return final_state, result + + + +class JaxJointProcess(JointProcess, JaxProcessesMixin): + pass + +class JaxMethodProcess(MethodProcess, JaxProcessesMixin): + pass diff --git a/ngclearn/utils/jaxProcess.py b/ngclearn/utils/jaxProcess.py deleted file mode 100644 index 8c3de576..00000000 --- a/ngclearn/utils/jaxProcess.py +++ /dev/null @@ -1,171 +0,0 @@ -from ngcsimlib.compartment import Compartment -from ngcsimlib import MethodProcess -from jax.lax import scan as _scan -from ngcsimlib.logger import warn -from jax import numpy as jnp - - -class JaxProcess(MethodProcess): - """ - The JaxProcess is a subclass of the ngcsimlib Process class. The - functionality added by this subclass is the use of the jax scanner to run a - process quickly through the use of jax's JIT compiler. - """ - - def __init__(self, name): - super().__init__(name) - self._process_scan_method = None - self._monitoring = [] - - def _make_scanner(self): - arg_order = self.get_required_args() - - def _pure(current_state, x): - v = self.pure(current_state, - **{key: value for key, value in zip(arg_order, x)}) - return v, [v[m] for m in self._monitoring] - - return _pure - - def watch(self, compartment): - """ - Adds a compartment to the process to watch during a scan - - Args: - compartment: the compartment to watch - """ - if not isinstance(compartment, Compartment): - warn( - "Jax Process trying to watch a value that is not a compartment") - - self._monitoring.append(compartment.path) - self._process_scan_method = self._make_scanner() - - def clear_watch_list(self): - """ - Clears the watch list so no values are watched - """ - self._monitoring = [] - self._process_scan_method = self._make_scanner() - - def transition(self, transition_call): - """ - Appends to the base transition call to create pure method for use by its - scanner - - Args: - transition_call: the transition being passed into the default process - - Returns: - this JaxProcess instance for chaining - """ - super().transition(transition_call) - self._process_scan_method = self._make_scanner() - return self - - def scan(self, save_state=True, scan_length=None, **kwargs): - """ - There a quite a few ways to initialize the scan method for the - jaxProcess. To start the straight forward arguments is "save_state". - The save_state flag is simply there to note if the state - of the model should reflect the final state of the model after the scan - is complete. - - This scan method can also watch and report intermediate compartment - values defined through calling the JaxProcess.watch() method watching a - compartment means at the end of each process cycle record the value of - the compartment and then at the end a tuple of concatenated values will - be returned that correspond to each compartment the process is watching. - - Where there are options for the arguments is when defining the keyword - arguments for the process. The process will do its best to broadcast all - the inputs to the largest size, so they can be scanned over. This means - that is one is a (2, 3) and the other is a constant, it will broadcast - constant to a (2, 3). This does mean that every keyword value that is - passed to a method in the process will be the same size. This is a - limitation of the jax scanner as all the values have to be concatenated - into a single jax array to be passed into the scanner. The accepted - types for arguments, are lists, tuples, numpy arrays, jax arrays, ints, - and floats. If all the keyword arguments are passed as ints or floats - the scan_length flag must be set so the scanner knows how many - iterations to run. If any of the arguments are iterable it will - automatically assume that the leading axis is the number of iterations - to run. - - - Args: - save_state: A boolean flag to indicate if the model state should be saved - - scan_length: a value to be used to denote the number of iterations of the scanner if all keyword - arguments are passed as ints or floats - - **kwargs: the required keyword arguments for the process to run - - Returns: the final state of the model, the stacked output of the scan method - - """ - arg_order = list(self.get_required_args()) - - args = [] - max_axis = 1 - max_next_axis = 0 - - for kwarg in arg_order: - if kwarg not in kwargs.keys(): - warn("Missing kwarg in Process", self.name) - return - - kval = kwargs.get(kwarg, None) - if isinstance(kval, (float, int, list, tuple)): - val = jnp.array(kval) - else: - val = kval - - max_axis = max(max_axis, len(val.shape)) - if max_axis == len(val.shape): - max_next_axis = max(max_next_axis, val.shape[0]) - args.append(val) - - # Check axis && get max_next_axis - - if max_next_axis == 0: - if scan_length is None: - warn("scan_length must be defined if all keyword arguments are " - "constants") - return - elif scan_length > 0: - max_next_axis = scan_length - else: - warn("scan_length must be greater than 0") - return - - for axis in range(max_axis): - current_axis = max_next_axis - max_next_axis = 0 - new_args = [] - for a in args: - if len(a.shape) >= axis + 1: - if a.shape[axis] == current_axis: - new_args.append(a) - else: - warn("Keyword arguments must all be able to be " - "broadcasted to the largest shape") - return - else: - new_args.append(jnp.zeros(list(a.shape) + [current_axis], - dtype=a.dtype) + a.reshape( - *a.shape, 1)) - - if len(a.shape) > axis + 1: - max_next_axis = max(max_next_axis, a.shape[axis + 1]) - - args = new_args - - args = jnp.array(args).transpose( - [1, 0] + [i for i in range(2, max_axis + 1)]) - state, stacked = _scan(self._process_scan_method, - init=self.get_required_state( - include_special_compartments=True), xs=args) - if save_state: - self.updated_modified_state(state) - return state, stacked From b86ae3da95421aac5d2805a7dfb43515bcd49572 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Thu, 6 Nov 2025 12:10:35 -0500 Subject: [PATCH 021/121] cleaned up dunder repr method, moved to JaxComponent parent; fixed __init__ pointer to tensorstats --- ngclearn/components/jaxComponent.py | 14 ++++++++++++++ ngclearn/components/neurons/spiking/IFCell.py | 15 --------------- ngclearn/components/neurons/spiking/LIFCell.py | 15 --------------- ngclearn/components/neurons/spiking/RAFCell.py | 16 ---------------- ngclearn/components/neurons/spiking/WTASCell.py | 16 ---------------- ngclearn/components/neurons/spiking/adExCell.py | 16 ---------------- .../neurons/spiking/fitzhughNagumoCell.py | 16 ---------------- .../neurons/spiking/hodgkinHuxleyCell.py | 16 ---------------- .../components/neurons/spiking/izhikevichCell.py | 16 ---------------- .../components/neurons/spiking/quadLIFCell.py | 16 ---------------- ngclearn/utils/__init__.py | 1 + 11 files changed, 15 insertions(+), 142 deletions(-) diff --git a/ngclearn/components/jaxComponent.py b/ngclearn/components/jaxComponent.py index 858a09c3..56247900 100755 --- a/ngclearn/components/jaxComponent.py +++ b/ngclearn/components/jaxComponent.py @@ -6,6 +6,7 @@ from jax import random from ngcsimlib.compartment import Compartment from ngcsimlib import Component +from ngclearn.utils import tensorstats class JaxComponent(Component): @@ -57,3 +58,16 @@ def load(self, directory: str): if d is not None: comp.set(d) + def __repr__(self): + comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] + maxlen = max(len(c) for c in comps) + 5 + lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" + for c in comps: + stats = tensorstats(getattr(self, c).value) + if stats is not None: + line = [f"{k}: {v}" for k, v in stats.items()] + line = ", ".join(line) + else: + line = "None" + lines += f" {f'({c})'.ljust(maxlen)}{line}\n" + return lines \ No newline at end of file diff --git a/ngclearn/components/neurons/spiking/IFCell.py b/ngclearn/components/neurons/spiking/IFCell.py index 1e7f26ba..cb94827c 100755 --- a/ngclearn/components/neurons/spiking/IFCell.py +++ b/ngclearn/components/neurons/spiking/IFCell.py @@ -1,6 +1,5 @@ from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, random, nn, Array, jit -from ngclearn.utils import tensorstats from ngcsimlib import deprecate_args from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ step_euler, step_rk2 @@ -231,17 +230,3 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - diff --git a/ngclearn/components/neurons/spiking/LIFCell.py b/ngclearn/components/neurons/spiking/LIFCell.py index 30435271..c0d049cf 100644 --- a/ngclearn/components/neurons/spiking/LIFCell.py +++ b/ngclearn/components/neurons/spiking/LIFCell.py @@ -1,6 +1,5 @@ from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, random, nn, Array -from ngclearn.utils import tensorstats from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ step_euler, step_rk2 from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator, @@ -268,20 +267,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: diff --git a/ngclearn/components/neurons/spiking/RAFCell.py b/ngclearn/components/neurons/spiking/RAFCell.py index b478c2dc..4b9f25dd 100755 --- a/ngclearn/components/neurons/spiking/RAFCell.py +++ b/ngclearn/components/neurons/spiking/RAFCell.py @@ -1,7 +1,5 @@ from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, random, jit, nn -from functools import partial -from ngclearn.utils import tensorstats from ngcsimlib import deprecate_args from ngcsimlib.logger import info, warn from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ @@ -214,17 +212,3 @@ def help(cls): ## component help function "tau_w * dw/dt = w * dampen_factor - v * omega + j", "hyperparameters": hyperparams} return info - - def __repr__(self): - comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines diff --git a/ngclearn/components/neurons/spiking/WTASCell.py b/ngclearn/components/neurons/spiking/WTASCell.py index 835e3d03..6ae97097 100755 --- a/ngclearn/components/neurons/spiking/WTASCell.py +++ b/ngclearn/components/neurons/spiking/WTASCell.py @@ -1,8 +1,6 @@ from jax import numpy as jnp, random, jit, nn from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, random, jit, nn -from functools import partial -from ngclearn.utils import tensorstats from ngcsimlib import deprecate_args from ngcsimlib.logger import info, warn @@ -159,20 +157,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: diff --git a/ngclearn/components/neurons/spiking/adExCell.py b/ngclearn/components/neurons/spiking/adExCell.py index af34161b..0b7b6792 100755 --- a/ngclearn/components/neurons/spiking/adExCell.py +++ b/ngclearn/components/neurons/spiking/adExCell.py @@ -1,7 +1,5 @@ from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, random, jit, nn -from functools import partial -from ngclearn.utils import tensorstats from ngcsimlib import deprecate_args from ngcsimlib.logger import info, warn from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2 @@ -211,20 +209,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: diff --git a/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py b/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py index 64cb9f1d..d666a2bf 100755 --- a/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py +++ b/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py @@ -1,7 +1,5 @@ from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, random, jit, nn -from functools import partial -from ngclearn.utils import tensorstats from ngcsimlib import deprecate_args from ngcsimlib.logger import info, warn from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ @@ -214,20 +212,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: diff --git a/ngclearn/components/neurons/spiking/hodgkinHuxleyCell.py b/ngclearn/components/neurons/spiking/hodgkinHuxleyCell.py index 76e2e5ef..3ee00ca5 100644 --- a/ngclearn/components/neurons/spiking/hodgkinHuxleyCell.py +++ b/ngclearn/components/neurons/spiking/hodgkinHuxleyCell.py @@ -1,7 +1,5 @@ from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, random, jit, nn -from functools import partial -from ngclearn.utils import tensorstats from ngcsimlib import deprecate_args from ngcsimlib.logger import info, warn from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2, step_rk4 @@ -273,20 +271,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: diff --git a/ngclearn/components/neurons/spiking/izhikevichCell.py b/ngclearn/components/neurons/spiking/izhikevichCell.py index 38fbab4c..07d89fc0 100755 --- a/ngclearn/components/neurons/spiking/izhikevichCell.py +++ b/ngclearn/components/neurons/spiking/izhikevichCell.py @@ -1,7 +1,5 @@ from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, random, jit, nn -from functools import partial -from ngclearn.utils import tensorstats from ngcsimlib import deprecate_args from ngcsimlib.logger import info, warn from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2 @@ -229,20 +227,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: diff --git a/ngclearn/components/neurons/spiking/quadLIFCell.py b/ngclearn/components/neurons/spiking/quadLIFCell.py index f724bcc0..af39434b 100755 --- a/ngclearn/components/neurons/spiking/quadLIFCell.py +++ b/ngclearn/components/neurons/spiking/quadLIFCell.py @@ -1,7 +1,5 @@ from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, random, jit, nn, Array -from functools import partial -from ngclearn.utils import tensorstats from ngcsimlib import deprecate_args from ngcsimlib.logger import info, warn from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ @@ -241,20 +239,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: diff --git a/ngclearn/utils/__init__.py b/ngclearn/utils/__init__.py index 7bba010f..1d8c114e 100644 --- a/ngclearn/utils/__init__.py +++ b/ngclearn/utils/__init__.py @@ -1,3 +1,4 @@ from .distribution_generator import DistributionGenerator from .JaxProcessesMixin import JaxJointProcess as JointProcess, JaxMethodProcess as MethodProcess +from .model_utils import tensorstats From 55756d0508bcd76edc48865d1fa892a19c463cac Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Thu, 6 Nov 2025 13:56:19 -0500 Subject: [PATCH 022/121] refactored alpha and exp-synapses, tests passed; minor edit to __init__ for synapses --- ngclearn/components/synapses/__init__.py | 62 +++++----- ngclearn/components/synapses/alphaSynapse.py | 116 ++++++++---------- ngclearn/components/synapses/denseSynapse.py | 3 +- .../components/synapses/exponentialSynapse.py | 111 ++++++++--------- .../synapses/test_exponentialSynapse.py | 25 ++-- 5 files changed, 144 insertions(+), 173 deletions(-) diff --git a/ngclearn/components/synapses/__init__.py b/ngclearn/components/synapses/__init__.py index 2c21c231..954fedf8 100644 --- a/ngclearn/components/synapses/__init__.py +++ b/ngclearn/components/synapses/__init__.py @@ -3,36 +3,36 @@ ## short-term plasticity components -from .STPDenseSynapse import STPDenseSynapse +# from .STPDenseSynapse import STPDenseSynapse from .exponentialSynapse import ExponentialSynapse -from .doubleExpSynapse import DoupleExpSynapse +# from .doubleExpSynapse import DoupleExpSynapse from .alphaSynapse import AlphaSynapse - -## dense synaptic components -from .hebbian.hebbianSynapse import HebbianSynapse -from .hebbian.traceSTDPSynapse import TraceSTDPSynapse -from .hebbian.expSTDPSynapse import ExpSTDPSynapse -from .hebbian.eventSTDPSynapse import EventSTDPSynapse -from .hebbian.BCMSynapse import BCMSynapse - - -## conv/deconv synaptic components -from .convolution.convSynapse import ConvSynapse -from .convolution.staticConvSynapse import StaticConvSynapse -from .convolution.hebbianConvSynapse import HebbianConvSynapse -from .convolution.traceSTDPConvSynapse import TraceSTDPConvSynapse -from .convolution.deconvSynapse import DeconvSynapse -from .convolution.staticDeconvSynapse import StaticDeconvSynapse -from .convolution.hebbianDeconvSynapse import HebbianDeconvSynapse -from .convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse - - -## modulated synaptic components -from .modulated.MSTDPETSynapse import MSTDPETSynapse -from .modulated.REINFORCESynapse import REINFORCESynapse - -## patched synaptic components -from .patched.patchedSynapse import PatchedSynapse -from .patched.staticPatchedSynapse import StaticPatchedSynapse -from .patched.hebbianPatchedSynapse import HebbianPatchedSynapse - +# +# ## dense synaptic components +# from .hebbian.hebbianSynapse import HebbianSynapse +# from .hebbian.traceSTDPSynapse import TraceSTDPSynapse +# from .hebbian.expSTDPSynapse import ExpSTDPSynapse +# from .hebbian.eventSTDPSynapse import EventSTDPSynapse +# from .hebbian.BCMSynapse import BCMSynapse +# +# +# ## conv/deconv synaptic components +# from .convolution.convSynapse import ConvSynapse +# from .convolution.staticConvSynapse import StaticConvSynapse +# from .convolution.hebbianConvSynapse import HebbianConvSynapse +# from .convolution.traceSTDPConvSynapse import TraceSTDPConvSynapse +# from .convolution.deconvSynapse import DeconvSynapse +# from .convolution.staticDeconvSynapse import StaticDeconvSynapse +# from .convolution.hebbianDeconvSynapse import HebbianDeconvSynapse +# from .convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse +# +# +# ## modulated synaptic components +# from .modulated.MSTDPETSynapse import MSTDPETSynapse +# from .modulated.REINFORCESynapse import REINFORCESynapse +# +# ## patched synaptic components +# from .patched.patchedSynapse import PatchedSynapse +# from .patched.staticPatchedSynapse import StaticPatchedSynapse +# from .patched.hebbianPatchedSynapse import HebbianPatchedSynapse +# diff --git a/ngclearn/components/synapses/alphaSynapse.py b/ngclearn/components/synapses/alphaSynapse.py index cf5f9543..d4a2e3c6 100644 --- a/ngclearn/components/synapses/alphaSynapse.py +++ b/ngclearn/components/synapses/alphaSynapse.py @@ -1,12 +1,12 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compilers.process import transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment - +from ngclearn.components.jaxComponent import JaxComponent +from ngclearn.utils import tensorstats from ngclearn.utils.weight_distribution import initialize_params from ngcsimlib.logger import info + from ngclearn.components.synapses import DenseSynapse -from ngclearn.utils import tensorstats +from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable class AlphaSynapse(DenseSynapse): ## dynamic alpha synapse cable """ @@ -64,8 +64,8 @@ class AlphaSynapse(DenseSynapse): ## dynamic alpha synapse cable # Define Functions def __init__( - self, name, shape, tau_decay, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., - is_nonplastic=True, **kwargs + self, name, shape, tau_decay, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., + p_conn=1., is_nonplastic=True, **kwargs ): super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs) ## dynamic synapse meta-parameters @@ -82,55 +82,55 @@ def __init__( self.g_syn = Compartment(postVals) ## conductance variable self.h_syn = Compartment(postVals) ## intermediate conductance variable if is_nonplastic: - self.weights.set(self.weights.value * 0 + 1.) + self.weights.set(self.weights.get() * 0 + 1.) - @transition(output_compartments=["outputs", "i_syn", "g_syn", "h_syn"]) - @staticmethod - def advance_state( - dt, tau_decay, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, h_syn, v - ): - s = inputs + @compilable + def advance_state(self, t, dt): + s = self.inputs.get() ## advance conductance variable(s) - _out = jnp.matmul(s, weights) ## sum all pre-syn spikes at t going into post-neuron) - dhsyn_dt = -h_syn/tau_decay + (_out * g_syn_bar) * (1./dt) - h_syn = h_syn + dhsyn_dt * dt ## run Euler step to move intermediate conductance h + _out = jnp.matmul(s, self.weights.get()) ## sum all pre-syn spikes at t going into post-neuron) + dhsyn_dt = -self.h_syn.get()/self.tau_decay + (_out * self.g_syn_bar) * (1./dt) + h_syn = self.h_syn.get() + dhsyn_dt * dt ## run Euler step to move intermediate conductance h - dgsyn_dt = -g_syn/tau_decay + h_syn * (1./dt) # or -g_syn/tau_decay + h_syn/tau_decay - g_syn = g_syn + dgsyn_dt * dt ## run Euler step to move conductance g + dgsyn_dt = -self.g_syn.get()/self.tau_decay + h_syn * (1./dt) # or -g_syn/tau_decay + h_syn/tau_decay + g_syn = self.g_syn.get() + dgsyn_dt * dt ## run Euler step to move conductance g ## compute derive electrical current variable - i_syn = -g_syn * Rscale - if syn_rest is not None: - i_syn = -(g_syn * Rscale) * (v - syn_rest) - outputs = i_syn #jnp.matmul(inputs, Wdyn * Rscale) + biases - return outputs, i_syn, g_syn, h_syn - - @transition(output_compartments=["inputs", "outputs", "i_syn", "g_syn", "h_syn", "v"]) - @staticmethod - def reset(batch_size, shape): - preVals = jnp.zeros((batch_size, shape[0])) - postVals = jnp.zeros((batch_size, shape[1])) - inputs = preVals - outputs = postVals - i_syn = postVals - g_syn = postVals - h_syn = postVals - v = postVals - return inputs, outputs, i_syn, g_syn, h_syn, v - - def save(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - if self.bias_init != None: - jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value) - else: - jnp.savez(file_name, weights=self.weights.value) - - def load(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - self.weights.set(data['weights']) - if "biases" in data.keys(): - self.biases.set(data['biases']) + i_syn = -g_syn * self.resist_scale + if self.syn_rest is not None: + i_syn = -(g_syn * self.resist_scale) * (self.v.get() - self.syn_rest) + outputs = i_syn #jnp.matmul(inputs, Wdyn * self.resist_scale) + biases + + self.outputs.set(outputs) + self.i_syn.set(i_syn) + self.g_syn.set(g_syn) + self.h_syn.set(h_syn) + + @compilable + def reset(self): + preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0])) + postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1])) + if not self.inputs.targeted: + self.inputs.set(preVals) + self.outputs.set(postVals) + self.i_syn.set(postVals) + self.g_syn.set(postVals) + self.h_syn.set(postVals) + self.v.set(postVals) + + # def save(self, directory, **kwargs): + # file_name = directory + "/" + self.name + ".npz" + # if self.bias_init != None: + # jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value) + # else: + # jnp.savez(file_name, weights=self.weights.value) + # + # def load(self, directory, **kwargs): + # file_name = directory + "/" + self.name + ".npz" + # data = jnp.load(file_name) + # self.weights.set(data['weights']) + # if "biases" in data.keys(): + # self.biases.set(data['biases']) @classmethod def help(cls): ## component help function @@ -170,17 +170,3 @@ def help(cls): ## component help function "dgsyn_dt = -g_syn/tau_decay + h_syn", "hyperparameters": hyperparams} return info - - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines diff --git a/ngclearn/components/synapses/denseSynapse.py b/ngclearn/components/synapses/denseSynapse.py index 91a4fda3..76a0778f 100755 --- a/ngclearn/components/synapses/denseSynapse.py +++ b/ngclearn/components/synapses/denseSynapse.py @@ -41,8 +41,7 @@ class DenseSynapse(JaxComponent): ## base dense synaptic cable # Define Functions def __init__( - self, name, shape, weight_init=None, bias_init=None, resist_scale=1., - p_conn=1., batch_size=1, **kwargs + self, name, shape, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., batch_size=1, **kwargs ): super().__init__(name, **kwargs) diff --git a/ngclearn/components/synapses/exponentialSynapse.py b/ngclearn/components/synapses/exponentialSynapse.py index a873baf9..7c7e4a5d 100644 --- a/ngclearn/components/synapses/exponentialSynapse.py +++ b/ngclearn/components/synapses/exponentialSynapse.py @@ -1,12 +1,12 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compilers.process import transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment - +from ngclearn.components.jaxComponent import JaxComponent +from ngclearn.utils import tensorstats from ngclearn.utils.weight_distribution import initialize_params from ngcsimlib.logger import info + from ngclearn.components.synapses import DenseSynapse -from ngclearn.utils import tensorstats +from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable class ExponentialSynapse(DenseSynapse): ## dynamic exponential synapse cable """ @@ -63,8 +63,8 @@ class ExponentialSynapse(DenseSynapse): ## dynamic exponential synapse cable # Define Functions def __init__( - self, name, shape, tau_decay, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., - is_nonplastic=True, **kwargs + self, name, shape, tau_decay, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., + p_conn=1., is_nonplastic=True, **kwargs ): super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs) ## dynamic synapse meta-parameters @@ -80,50 +80,51 @@ def __init__( self.i_syn = Compartment(postVals) ## electrical current output self.g_syn = Compartment(postVals) ## conductance variable if is_nonplastic: - self.weights.set(self.weights.value * 0 + 1.) + self.weights.set(self.weights.get() * 0 + 1.) - @transition(output_compartments=["outputs", "i_syn", "g_syn"]) - @staticmethod - def advance_state( - dt, tau_decay, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, v - ): - s = inputs + # @transition(output_compartments=["outputs", "i_syn", "g_syn"]) + # @staticmethod + @compilable + def advance_state(self, t, dt): #dt, tau_decay, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, v + s = self.inputs.get() ## advance conductance variable - _out = jnp.matmul(s, weights) ## sum all pre-syn spikes at t going into post-neuron) - dgsyn_dt = -g_syn/tau_decay + (_out * g_syn_bar) * (1./dt) - g_syn = g_syn + dgsyn_dt * dt ## run Euler step to move conductance + _out = jnp.matmul(s, self.weights.get()) ## sum all pre-syn spikes at t going into post-neuron) + dgsyn_dt = -self.g_syn.get()/self.tau_decay + (_out * self.g_syn_bar) * (1./dt) + g_syn = self.g_syn.get() + dgsyn_dt * dt ## run Euler step to move conductance ## compute derive electrical current variable - i_syn = -g_syn * Rscale - if syn_rest is not None: - i_syn = -(g_syn * Rscale) * (v - syn_rest) - outputs = i_syn #jnp.matmul(inputs, Wdyn * Rscale) + biases - return outputs, i_syn, g_syn - - @transition(output_compartments=["inputs", "outputs", "i_syn", "g_syn", "v"]) - @staticmethod - def reset(batch_size, shape): - preVals = jnp.zeros((batch_size, shape[0])) - postVals = jnp.zeros((batch_size, shape[1])) - inputs = preVals - outputs = postVals - i_syn = postVals - g_syn = postVals - v = postVals - return inputs, outputs, i_syn, g_syn, v - - def save(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - if self.bias_init != None: - jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value) - else: - jnp.savez(file_name, weights=self.weights.value) - - def load(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - self.weights.set(data['weights']) - if "biases" in data.keys(): - self.biases.set(data['biases']) + i_syn = -g_syn * self.resist_scale + if self.syn_rest is not None: + i_syn = -(g_syn * self.resist_scale) * (self.v.get() - self.syn_rest) + outputs = i_syn #jnp.matmul(inputs, Wdyn * self.resist_scale) + biases + + self.outputs.set(outputs) + self.i_syn.set(i_syn) + self.g_syn.set(g_syn) + + @compilable + def reset(self): + preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0])) + postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1])) + if not self.inputs.targeted: + self.inputs.set(preVals) + self.outputs.set(postVals) + self.i_syn.set(postVals) + self.g_syn.set(postVals) + self.v.set(postVals) + + # def save(self, directory, **kwargs): + # file_name = directory + "/" + self.name + ".npz" + # if self.bias_init != None: + # jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value) + # else: + # jnp.savez(file_name, weights=self.weights.value) + # + # def load(self, directory, **kwargs): + # file_name = directory + "/" + self.name + ".npz" + # data = jnp.load(file_name) + # self.weights.set(data['weights']) + # if "biases" in data.keys(): + # self.biases.set(data['biases']) @classmethod def help(cls): ## component help function @@ -162,17 +163,3 @@ def help(cls): ## component help function "dgsyn_dt = (W * inputs) * g_syn_bar - g_syn/tau_decay ", "hyperparameters": hyperparams} return info - - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines diff --git a/tests/components/synapses/test_exponentialSynapse.py b/tests/components/synapses/test_exponentialSynapse.py index 83ad19ee..8fdcd732 100644 --- a/tests/components/synapses/test_exponentialSynapse.py +++ b/tests/components/synapses/test_exponentialSynapse.py @@ -1,11 +1,12 @@ from jax import numpy as jnp, random, jit +from ngcsimlib.context import Context import numpy as np np.random.seed(42) -from ngclearn.components import ExponentialSynapse -from ngcsimlib.compilers.process import Process -from ngcsimlib.context import Context +from ngclearn import Context, MethodProcess import ngclearn.utils.weight_distribution as dist +from ngclearn.components.synapses.exponentialSynapse import ExponentialSynapse +from numpy.testing import assert_array_equal def test_exponentialSynapse1(): name = "expsyn_ctx" @@ -23,14 +24,11 @@ def test_exponentialSynapse1(): key=subkeys[0] ) - advance_process = (Process("advance_proc") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - # ctx.wrap_and_add_command(advance_process.pure, name="run") - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") sp_train = jnp.array([1., 0., 1.], dtype=jnp.float32) post_syn_neuron_volt = jnp.ones((1, 1)) * -65. ## post-syn neuron is at rest @@ -38,15 +36,16 @@ def test_exponentialSynapse1(): outs_truth = jnp.array([[156., 78., 195.]]) outs = [] - ctx.reset() + reset_process.run() # ctx.reset() for t in range(3): in_pulse = jnp.expand_dims(sp_train[t], axis=0) a.inputs.set(in_pulse) a.v.set(post_syn_neuron_volt) - ctx.run(t=t * dt, dt=dt) - #print("g: ",a.g_syn.value) - #print("i: ", a.i_syn.value) - outs.append(a.outputs.value) + advance_process.run(t=t * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt) + # print("in: ", a.inputs.get()) + # print("g: ",a.g_syn.get()) + # print("i: ", a.i_syn.get()) + outs.append(a.outputs.get()) outs = jnp.concatenate(outs, axis=1) #print(outs) From 0d1a35fd90fe9f824326e6738527a3751ac3b56b Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Thu, 6 Nov 2025 14:33:36 -0500 Subject: [PATCH 023/121] refactored short-term syn, tests passed - including stp-dense-syn and minor cleanup/edit to synapse __init__ --- .../components/synapses/STPDenseSynapse.py | 130 ++++++++---------- ngclearn/components/synapses/__init__.py | 4 +- ngclearn/components/synapses/alphaSynapse.py | 2 - .../components/synapses/doubleExpSynapse.py | 115 +++++++--------- .../components/synapses/exponentialSynapse.py | 6 +- .../synapses/test_STPDenseSynapse.py | 35 ++--- .../synapses/test_exponentialSynapse.py | 1 - 7 files changed, 122 insertions(+), 171 deletions(-) diff --git a/ngclearn/components/synapses/STPDenseSynapse.py b/ngclearn/components/synapses/STPDenseSynapse.py index 4fc1a81b..c523ec74 100755 --- a/ngclearn/components/synapses/STPDenseSynapse.py +++ b/ngclearn/components/synapses/STPDenseSynapse.py @@ -1,12 +1,10 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compilers.process import transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment - from ngclearn.utils.weight_distribution import initialize_params from ngcsimlib.logger import info + from ngclearn.components.synapses import DenseSynapse -from ngclearn.utils import tensorstats +from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable class STPDenseSynapse(DenseSynapse): ## short-term plastic synaptic cable """ @@ -56,10 +54,10 @@ class STPDenseSynapse(DenseSynapse): ## short-term plastic synaptic cable resources_int: initialization kernel for synaptic resources matrix """ - # Define Functions - def __init__(self, name, shape, weight_init=None, bias_init=None, - resist_scale=1., p_conn=1., tau_f=750., tau_d=50., - resources_init=None, **kwargs): + def __init__( + self, name, shape, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., tau_f=750., tau_d=50., + resources_init=None, **kwargs + ): super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs) ## STP meta-parameters self.resources_init = resources_init @@ -67,11 +65,11 @@ def __init__(self, name, shape, weight_init=None, bias_init=None, self.tau_d = tau_d ## Set up short-term plasticity / dynamic synapse compartment values - tmp_key, *subkeys = random.split(self.key.value, 4) + tmp_key, *subkeys = random.split(self.key.get(), 4) preVals = jnp.zeros((self.batch_size, shape[0])) self.u = Compartment(preVals) ## release prob variables self.x = Compartment(preVals + 1) ## resource availability variables - self.Wdyn = Compartment(self.weights.value * 0) ## dynamic synapse values + self.Wdyn = Compartment(self.weights.get() * 0) ## dynamic synapse values if self.resources_init is None: info(self.name, "is using default resources value initializer!") self.resources_init = {"dist": "uniform", "amin": 0.125, "amax": 0.175} # 0.15 @@ -79,57 +77,59 @@ def __init__(self, name, shape, weight_init=None, bias_init=None, initialize_params(subkeys[2], self.resources_init, shape) ) ## matrix U - synaptic resources matrix - @transition(output_compartments=["outputs", "u", "x", "Wdyn"]) - @staticmethod - def advance_state( - tau_f, tau_d, Rscale, inputs, weights, biases, resources, u, x, Wdyn - ): - s = inputs + @compilable + def advance_state(self, t, dt): + s = self.inputs.get() ## compute short-term facilitation #u = u - u * (1./tau_f) + (resources * (1. - u)) * s - if tau_f > 0.: ## compute short-term facilitation - u = u - u * (1./tau_f) + (resources * (1. - u)) * s + if self.tau_f > 0.: ## compute short-term facilitation + u = self.u.get() - self.u.get() * (1./self.tau_f) + (self.resources.get() * (1. - self.u.get())) * s else: - u = resources ## disabling STF yields fixed resource u variables + u = self.resources.get() ## disabling STF yields fixed resource u variables ## compute dynamic synaptic values/conductances - Wdyn = (weights * u * x) * s + Wdyn * (1. - s) ## OR: -W/tau_w + W * u * x - if tau_d > 0.: - ## compute short-term depression - x = x + (1. - x) * (1./tau_d) - u * x * s - outputs = jnp.matmul(inputs, Wdyn * Rscale) + biases - return outputs, u, x, Wdyn - - @transition(output_compartments=["inputs", "outputs", "u", "x", "Wdyn"]) - @staticmethod - def reset(batch_size, shape): - preVals = jnp.zeros((batch_size, shape[0])) - postVals = jnp.zeros((batch_size, shape[1])) - inputs = preVals - outputs = postVals - u = preVals - x = preVals + 1 - Wdyn = jnp.zeros(shape) - return inputs, outputs, u, x, Wdyn - - def save(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - if self.bias_init != None: - jnp.savez(file_name, - weights=self.weights.value, - biases=self.biases.value, - resources=self.resources.value) - else: - jnp.savez(file_name, - weights=self.weights.value, - resources=self.resources.value) - - def load(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - self.weights.set(data['weights']) - self.resources.set(data['resources']) - if "biases" in data.keys(): - self.biases.set(data['biases']) + Wdyn = (self.weights.get() * u * self.x.get()) * s + self.Wdyn.get() * (1. - s) ## OR: -W/tau_w + W * u * x + ## compute short-term depression + x = self.x.get() + if self.tau_d > 0.: + x = x + (1. - x) * (1./self.tau_d) - u * x * s + ## else, do nothing with x (keep it pointing to current x compartment) + outputs = jnp.matmul(self.inputs.get(), Wdyn * self.resist_scale) + self.biases.get() + + self.outputs.set(outputs) + self.u.set(u) + self.x.set(x) + self.Wdyn.set(Wdyn) + + @compilable + def reset(self): + preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0])) + postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1])) + if not self.inputs.targeted: + self.inputs.set(preVals) + self.outputs.set(postVals) + self.u.set(preVals) + self.x.set(preVals + 1) + self.Wdyn.set(jnp.zeros(self.shape.get())) + + # def save(self, directory, **kwargs): + # file_name = directory + "/" + self.name + ".npz" + # if self.bias_init != None: + # jnp.savez(file_name, + # weights=self.weights.value, + # biases=self.biases.value, + # resources=self.resources.value) + # else: + # jnp.savez(file_name, + # weights=self.weights.value, + # resources=self.resources.value) + # + # def load(self, directory, **kwargs): + # file_name = directory + "/" + self.name + ".npz" + # data = jnp.load(file_name) + # self.weights.set(data['weights']) + # self.resources.set(data['resources']) + # if "biases" in data.keys(): + # self.biases.set(data['biases']) @classmethod def help(cls): ## component help function @@ -166,17 +166,3 @@ def help(cls): ## component help function "dW/dt = W_full * u * x * inputs", "hyperparameters": hyperparams} return info - - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines diff --git a/ngclearn/components/synapses/__init__.py b/ngclearn/components/synapses/__init__.py index 954fedf8..d646001e 100644 --- a/ngclearn/components/synapses/__init__.py +++ b/ngclearn/components/synapses/__init__.py @@ -3,9 +3,9 @@ ## short-term plasticity components -# from .STPDenseSynapse import STPDenseSynapse +from .STPDenseSynapse import STPDenseSynapse from .exponentialSynapse import ExponentialSynapse -# from .doubleExpSynapse import DoupleExpSynapse +from .doubleExpSynapse import DoupleExpSynapse from .alphaSynapse import AlphaSynapse # # ## dense synaptic components diff --git a/ngclearn/components/synapses/alphaSynapse.py b/ngclearn/components/synapses/alphaSynapse.py index d4a2e3c6..8d639b4b 100644 --- a/ngclearn/components/synapses/alphaSynapse.py +++ b/ngclearn/components/synapses/alphaSynapse.py @@ -1,6 +1,4 @@ from jax import random, numpy as jnp, jit -from ngclearn.components.jaxComponent import JaxComponent -from ngclearn.utils import tensorstats from ngclearn.utils.weight_distribution import initialize_params from ngcsimlib.logger import info diff --git a/ngclearn/components/synapses/doubleExpSynapse.py b/ngclearn/components/synapses/doubleExpSynapse.py index 86225a68..03135f8c 100644 --- a/ngclearn/components/synapses/doubleExpSynapse.py +++ b/ngclearn/components/synapses/doubleExpSynapse.py @@ -1,12 +1,10 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compilers.process import transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment - from ngclearn.utils.weight_distribution import initialize_params from ngcsimlib.logger import info + from ngclearn.components.synapses import DenseSynapse -from ngclearn.utils import tensorstats +from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable class DoupleExpSynapse(DenseSynapse): ## dynamic double-exponential synapse cable """ @@ -66,8 +64,8 @@ class DoupleExpSynapse(DenseSynapse): ## dynamic double-exponential synapse cabl # Define Functions def __init__( - self, name, shape, tau_decay, tau_rise, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., - is_nonplastic=True, **kwargs + self, name, shape, tau_decay, tau_rise, g_syn_bar, syn_rest, weight_init=None, bias_init=None, + resist_scale=1., p_conn=1., is_nonplastic=True, **kwargs ): super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs) ## dynamic synapse meta-parameters @@ -85,57 +83,58 @@ def __init__( self.g_syn = Compartment(postVals) ## conductance variable self.h_syn = Compartment(postVals) ## intermediate conductance variable if is_nonplastic: - self.weights.set(self.weights.value * 0 + 1.) + self.weights.set(self.weights.get() * 0 + 1.) - @transition(output_compartments=["outputs", "i_syn", "g_syn", "h_syn"]) - @staticmethod - def advance_state( - dt, tau_decay, tau_rise, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, h_syn, v - ): - s = inputs + @compilable + def advance_state(self, t, dt): #dt, tau_decay, tau_rise, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, h_syn, v + s = self.inputs.get() #A = tau_decay/(tau_decay - tau_rise) * jnp.power((tau_rise/tau_decay), tau_rise/(tau_rise - tau_decay)) - A = 1. + A = 1. ## FIXME: scale factor to use? ## advance conductance variable(s) - _out = jnp.matmul(s, weights) ## sum all pre-syn spikes at t going into post-neuron) - dhsyn_dt = -h_syn/tau_rise + ((_out * g_syn_bar) * (1. / tau_rise - 1. / tau_decay) * A) * (1./dt) - h_syn = h_syn + dhsyn_dt * dt ## run Euler step to move intermediate conductance h + _out = jnp.matmul(s, self.weights.get()) ## sum all pre-syn spikes at t going into post-neuron) + dhsyn_dt = (-self.h_syn.get()/self.tau_rise + + ((_out * self.g_syn_bar) * (1. / self.tau_rise - 1. / self.tau_decay) * A) * (1./dt)) + h_syn = self.h_syn.get() + dhsyn_dt * dt ## run Euler step to move intermediate conductance h - dgsyn_dt = -g_syn/tau_decay + h_syn * (1./dt) - g_syn = g_syn + dgsyn_dt * dt ## run Euler step to move conductance g + dgsyn_dt = -self.g_syn.get()/self.tau_decay + h_syn * (1./dt) + g_syn = self.g_syn.get() + dgsyn_dt * dt ## run Euler step to move conductance g ## compute derive electrical current variable - i_syn = -g_syn * Rscale - if syn_rest is not None: - i_syn = -(g_syn * Rscale) * (v - syn_rest) + i_syn = -g_syn * self.resist_scale + if self.syn_rest is not None: + i_syn = -(g_syn * self.resist_scale) * (self.v.get() - self.syn_rest) outputs = i_syn #jnp.matmul(inputs, Wdyn * Rscale) + biases - return outputs, i_syn, g_syn, h_syn - - @transition(output_compartments=["inputs", "outputs", "i_syn", "g_syn", "h_syn", "v"]) - @staticmethod - def reset(batch_size, shape): - preVals = jnp.zeros((batch_size, shape[0])) - postVals = jnp.zeros((batch_size, shape[1])) - inputs = preVals - outputs = postVals - i_syn = postVals - g_syn = postVals - h_syn = postVals - v = postVals - return inputs, outputs, i_syn, g_syn, h_syn, v - - def save(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - if self.bias_init != None: - jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value) - else: - jnp.savez(file_name, weights=self.weights.value) - - def load(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - self.weights.set(data['weights']) - if "biases" in data.keys(): - self.biases.set(data['biases']) + + self.outputs.set(outputs) + self.i_syn.set(i_syn) + self.g_syn.set(g_syn) + self.h_syn.set(h_syn) + + @compilable + def reset(self): + preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0])) + postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1])) + if not self.inputs.targeted: + self.inputs.set(preVals) + self.outputs.set(postVals) + self.i_syn.set(postVals) + self.g_syn.set(postVals) + self.h_syn.set(postVals) + self.v.set(postVals) + + # def save(self, directory, **kwargs): + # file_name = directory + "/" + self.name + ".npz" + # if self.bias_init != None: + # jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value) + # else: + # jnp.savez(file_name, weights=self.weights.value) + # + # def load(self, directory, **kwargs): + # file_name = directory + "/" + self.name + ".npz" + # data = jnp.load(file_name) + # self.weights.set(data['weights']) + # if "biases" in data.keys(): + # self.biases.set(data['biases']) @classmethod def help(cls): ## component help function @@ -176,17 +175,3 @@ def help(cls): ## component help function "dgsyn_dt = -g_syn/tau_decay + h_syn", "hyperparameters": hyperparams} return info - - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines diff --git a/ngclearn/components/synapses/exponentialSynapse.py b/ngclearn/components/synapses/exponentialSynapse.py index 7c7e4a5d..d29ec9da 100644 --- a/ngclearn/components/synapses/exponentialSynapse.py +++ b/ngclearn/components/synapses/exponentialSynapse.py @@ -1,6 +1,4 @@ from jax import random, numpy as jnp, jit -from ngclearn.components.jaxComponent import JaxComponent -from ngclearn.utils import tensorstats from ngclearn.utils.weight_distribution import initialize_params from ngcsimlib.logger import info @@ -82,10 +80,8 @@ def __init__( if is_nonplastic: self.weights.set(self.weights.get() * 0 + 1.) - # @transition(output_compartments=["outputs", "i_syn", "g_syn"]) - # @staticmethod @compilable - def advance_state(self, t, dt): #dt, tau_decay, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, v + def advance_state(self, t, dt): s = self.inputs.get() ## advance conductance variable _out = jnp.matmul(s, self.weights.get()) ## sum all pre-syn spikes at t going into post-neuron) diff --git a/tests/components/synapses/test_STPDenseSynapse.py b/tests/components/synapses/test_STPDenseSynapse.py index 78ac2e12..d7780c02 100644 --- a/tests/components/synapses/test_STPDenseSynapse.py +++ b/tests/components/synapses/test_STPDenseSynapse.py @@ -2,15 +2,10 @@ from ngcsimlib.context import Context import numpy as np np.random.seed(42) -from ngclearn.components import STPDenseSynapse -from ngcsimlib.compilers import compile_command, wrap_command -from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context +from ngclearn import Context, MethodProcess import ngclearn.utils.weight_distribution as dist +from ngclearn.components.synapses.STPDenseSynapse import STPDenseSynapse def test_STPDenseSynapse1(): name = "stp_ctx" @@ -24,23 +19,12 @@ def test_STPDenseSynapse1(): name="a", shape=(1,1), resources_init=dist.constant(value=1.),key=subkeys[0] ) - #""" - advance_process = (Process("advance_proc") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - # ctx.wrap_and_add_command(advance_process.pure, name="run") - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - #""" - """ - reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - """ a.weights.set(jnp.ones((1, 1))) in_pulse = jnp.ones((1, 1)) * 0.425 @@ -49,16 +33,19 @@ def test_STPDenseSynapse1(): outs = [] Wdyn = [] - ctx.reset() + reset_process.run() # ctx.reset() for t in range(3): a.inputs.set(in_pulse) - ctx.run(t=t * dt, dt=dt) - outs.append(a.outputs.value) - Wdyn.append(a.Wdyn.value) + advance_process.run(t=t * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt) + outs.append(a.outputs.get()) + Wdyn.append(a.Wdyn.get()) outs = jnp.concatenate(outs, axis=1) Wdyn = jnp.concatenate(Wdyn, axis=1) # print(outs) + # print(outs_truth) + # print("...") # print(Wdyn) + # print(Wdyn_truth) np.testing.assert_allclose(outs, outs_truth, atol=1e-8) np.testing.assert_allclose(Wdyn, Wdyn_truth, atol=1e-8) diff --git a/tests/components/synapses/test_exponentialSynapse.py b/tests/components/synapses/test_exponentialSynapse.py index 8fdcd732..bcaf7a01 100644 --- a/tests/components/synapses/test_exponentialSynapse.py +++ b/tests/components/synapses/test_exponentialSynapse.py @@ -6,7 +6,6 @@ from ngclearn import Context, MethodProcess import ngclearn.utils.weight_distribution as dist from ngclearn.components.synapses.exponentialSynapse import ExponentialSynapse -from numpy.testing import assert_array_equal def test_exponentialSynapse1(): name = "expsyn_ctx" From d26b417b36996c125820f9f09b261b404fda0f53 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Thu, 6 Nov 2025 15:02:18 -0500 Subject: [PATCH 024/121] refactored bcm-syn and test passed --- .../components/synapses/hebbian/BCMSynapse.py | 103 ++++++++---------- .../components/synapses/hebbian/__init__.py | 6 +- .../synapses/hebbian/test_BCMSynapse.py | 47 +++----- 3 files changed, 61 insertions(+), 95 deletions(-) diff --git a/ngclearn/components/synapses/hebbian/BCMSynapse.py b/ngclearn/components/synapses/hebbian/BCMSynapse.py index 6b391335..e4f8ddc4 100755 --- a/ngclearn/components/synapses/hebbian/BCMSynapse.py +++ b/ngclearn/components/synapses/hebbian/BCMSynapse.py @@ -1,10 +1,8 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compilers.process import transition -from ngcsimlib.component import Component from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable -from ngclearn.components.synapses import DenseSynapse -from ngclearn.utils import tensorstats +from ngclearn.components.synapses.denseSynapse import DenseSynapse class BCMSynapse(DenseSynapse): # BCM-adjusted synaptic cable """ @@ -71,8 +69,7 @@ def __init__( self, name, shape, tau_w, tau_theta, theta0=-1., w_bound=0., w_decay=0., weight_init=None, resist_scale=1., p_conn=1., batch_size=1, **kwargs ): - super().__init__(name, shape, weight_init, None, resist_scale, p_conn, - batch_size=batch_size, **kwargs) + super().__init__(name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs) ## Synapse and BCM hyper-parameters self.shape = shape ## shape of synaptic efficacy matrix @@ -90,48 +87,51 @@ def __init__( self.post = Compartment(postVals) ## post-synaptic statistic self.post_term = Compartment(postVals) self.theta = Compartment(postVals + self.theta0) ## synaptic modification thresholds - self.dWeights = Compartment(self.weights.value * 0) + self.dWeights = Compartment(self.weights.get() * 0) - @transition(output_compartments=["weights", "theta", "dWeights", "post_term"]) - @staticmethod - def evolve(t, dt, tau_w, tau_theta, w_bound, w_decay, pre, post, theta, weights): + @compilable + def evolve(self, t, dt): #t, dt, tau_w, tau_theta, w_bound, w_decay, pre, post, theta, weights): eps = 1e-7 - post_term = post * (post - theta) # post - theta - post_term = post_term * (1. / (theta + eps)) - dWeights = jnp.matmul(pre.T, post_term) - if w_bound > 0.: - dWeights = dWeights * (w_bound - jnp.abs(weights)) + post_term = self.post.get() * (self.post.get() - self.theta.get()) # post - theta + post_term = post_term * (1. / (self.theta.get() + eps)) + dWeights = jnp.matmul(self.pre.get().T, post_term) + if self.w_bound > 0.: + dWeights = dWeights * (self.w_bound - jnp.abs(self.weights.get())) ## update synaptic efficacies according to a leaky ODE - dWeights = -weights * w_decay + dWeights - _W = weights + dWeights * dt / tau_w + dWeights = -self.weights.get() * self.w_decay + dWeights + _W = self.weights.get() + dWeights * dt / self.tau_w ## update synaptic modification threshold as a leaky ODE - dtheta = jnp.mean(jnp.square(post), axis=0, keepdims=True) ## batch avg - theta = theta + (-theta + dtheta) * dt / tau_theta - return weights, theta, dWeights, post_term - - @transition(output_compartments=["inputs", "outputs", "pre", "post", "dWeights", "post_term"]) - @staticmethod - def reset(batch_size, shape): - preVals = jnp.zeros((batch_size, shape[0])) - postVals = jnp.zeros((batch_size, shape[1])) - inputs = preVals - outputs = postVals - pre = preVals - post = postVals - dWeights = jnp.zeros(shape) - post_term = postVals - return inputs, outputs, pre, post, dWeights, post_term - - def save(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - jnp.savez(file_name, - weights=self.weights.value, theta=self.theta.value) - - def load(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - self.weights.set(data['weights']) - self.theta.set(data['theta']) + dtheta = jnp.mean(jnp.square(self.post.get()), axis=0, keepdims=True) ## batch avg + theta = self.theta.get() + (-self.theta.get() + dtheta) * dt / self.tau_theta + + #self.weights.set(weights) + self.theta.set(theta) + self.dWeights.set(dWeights) + self.post_term.set(post_term) + + @compilable + def reset(self): + preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0])) + postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1])) + + if not self.inputs.targeted: + self.inputs.set(preVals) + self.outputs.set(postVals) + self.pre.set(preVals) + self.post.set(postVals) + self.dWeights.set(jnp.zeros(self.shape.get())) + self.post_term.set(postVals) + + # def save(self, directory, **kwargs): + # file_name = directory + "/" + self.name + ".npz" + # jnp.savez(file_name, + # weights=self.weights.value, theta=self.theta.value) + # + # def load(self, directory, **kwargs): + # file_name = directory + "/" + self.name + ".npz" + # data = jnp.load(file_name) + # self.weights.set(data['weights']) + # self.theta.set(data['theta']) @classmethod def help(cls): ## component help function @@ -175,21 +175,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - - if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: diff --git a/ngclearn/components/synapses/hebbian/__init__.py b/ngclearn/components/synapses/hebbian/__init__.py index f39d556f..572bd9e6 100644 --- a/ngclearn/components/synapses/hebbian/__init__.py +++ b/ngclearn/components/synapses/hebbian/__init__.py @@ -1,6 +1,6 @@ -from .hebbianSynapse import HebbianSynapse +#from .hebbianSynapse import HebbianSynapse from .traceSTDPSynapse import TraceSTDPSynapse -from .expSTDPSynapse import ExpSTDPSynapse -from .eventSTDPSynapse import EventSTDPSynapse +#from .expSTDPSynapse import ExpSTDPSynapse +#from .eventSTDPSynapse import EventSTDPSynapse from .BCMSynapse import BCMSynapse diff --git a/tests/components/synapses/hebbian/test_BCMSynapse.py b/tests/components/synapses/hebbian/test_BCMSynapse.py index 7597f549..12d28329 100644 --- a/tests/components/synapses/hebbian/test_BCMSynapse.py +++ b/tests/components/synapses/hebbian/test_BCMSynapse.py @@ -2,14 +2,11 @@ from ngcsimlib.context import Context import numpy as np np.random.seed(42) -from ngclearn.components import BCMSynapse -from ngcsimlib.compilers import compile_command, wrap_command -from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context +from ngclearn import Context, MethodProcess +import ngclearn.utils.weight_distribution as dist +from ngclearn.components.synapses.hebbian.BCMSynapse import BCMSynapse +from numpy.testing import assert_array_equal def test_BCMSynapse1(): name = "bcm_stdp_ctx" @@ -23,42 +20,26 @@ def test_BCMSynapse1(): name="a", shape=(1,1), tau_w=40., tau_theta=20., key=subkeys[0] ) - #""" - evolve_process = (Process("evolve_proc") + evolve_process = (MethodProcess("evolve_process") >> a.evolve) - #ctx.wrap_and_add_command(evolve_process.pure, name="run") - ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt") - advance_process = (Process("advance_proc") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - # ctx.wrap_and_add_command(advance_process.pure, name="run") - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - #""" - - """ - reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve") - ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt") - """ pre_value = jnp.ones((1, 1)) * 0.425 post_value = jnp.ones((1, 1)) * 1.55 truth = jnp.array([[-1.6798127]]) - ctx.reset() + reset_process.run() # ctx.reset() a.pre.set(pre_value) a.post.set(post_value) - ctx.run(t=1., dt=dt) - ctx.adapt(t=1., dt=dt) - #print(a.dWeights.value) - assert_array_equal(a.dWeights.value, truth) - + advance_process.run(t=1., dt=dt) # ctx.run(t=1., dt=dt) + evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1., dt=dt) + # print(a.dWeights.get()) + # print(truth) + assert_array_equal(a.dWeights.get(), truth) -#test_BCMSynapse1() +test_BCMSynapse1() From 99b3c436215322dea64273761159cc83fe1c7703 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Thu, 6 Nov 2025 18:54:51 -0500 Subject: [PATCH 025/121] refactored exp-stdp-syn and passed tests for exp-stdp-syn and trace-stdp-syn --- .../components/synapses/hebbian/__init__.py | 2 +- .../synapses/hebbian/expSTDPSynapse.py | 122 ++++++++---------- .../synapses/hebbian/traceSTDPSynapse.py | 12 +- .../synapses/hebbian/test_expSTDPSynapse.py | 61 ++++----- .../synapses/hebbian/test_traceSTDPSynapse.py | 60 ++++----- 5 files changed, 112 insertions(+), 145 deletions(-) diff --git a/ngclearn/components/synapses/hebbian/__init__.py b/ngclearn/components/synapses/hebbian/__init__.py index 572bd9e6..a2247bc8 100644 --- a/ngclearn/components/synapses/hebbian/__init__.py +++ b/ngclearn/components/synapses/hebbian/__init__.py @@ -1,6 +1,6 @@ #from .hebbianSynapse import HebbianSynapse from .traceSTDPSynapse import TraceSTDPSynapse -#from .expSTDPSynapse import ExpSTDPSynapse +from .expSTDPSynapse import ExpSTDPSynapse #from .eventSTDPSynapse import EventSTDPSynapse from .BCMSynapse import BCMSynapse diff --git a/ngclearn/components/synapses/hebbian/expSTDPSynapse.py b/ngclearn/components/synapses/hebbian/expSTDPSynapse.py index ff184b9c..2a44c3f9 100644 --- a/ngclearn/components/synapses/hebbian/expSTDPSynapse.py +++ b/ngclearn/components/synapses/hebbian/expSTDPSynapse.py @@ -1,10 +1,8 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compilers.process import transition -from ngcsimlib.component import Component from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable -from ngclearn.components.synapses import DenseSynapse -from ngclearn.utils import tensorstats +from ngclearn.components.synapses.denseSynapse import DenseSynapse class ExpSTDPSynapse(DenseSynapse): """ @@ -61,16 +59,20 @@ class ExpSTDPSynapse(DenseSynapse): this to < 1. will result in a sparser synaptic structure w_bound: maximum value/magnitude any synaptic efficacy can be (default: 1) + + tau_w: synaptic weight decay coefficient to apply to STDP update + + weight_mask: synaptic binary masking matrix to apply (to enforce a constant sparse structure; default: None) """ # Define Functions def __init__( self, name, shape, A_plus, A_minus, exp_beta, eta=1., pretrace_target=0., weight_init=None, resist_scale=1., - p_conn=1., w_bound=1., batch_size=1, **kwargs + p_conn=1., w_bound=1., tau_w=0., weight_mask=None, batch_size=1, **kwargs ): - super().__init__(name, shape, weight_init, None, resist_scale, - p_conn, batch_size=batch_size, **kwargs) + super().__init__(name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs) + self.tau_w = tau_w ## Exp-STDP meta-parameters self.shape = shape ## shape of synaptic efficacy matrix self.eta = eta ## global learning rate governing plasticity @@ -81,6 +83,12 @@ def __init__( self.Rscale = resist_scale ## post-transformation scale factor self.w_bound = w_bound #1. ## soft weight constraint + if weight_mask is None: + self.weight_mask = jnp.ones((1, 1)) + else: + self.weight_mask = weight_mask + self.weights.set(self.weights.get() * self.weight_mask) + ## Compartment setup preVals = jnp.zeros((self.batch_size, shape[0])) postVals = jnp.zeros((self.batch_size, shape[1])) @@ -88,64 +96,61 @@ def __init__( self.postSpike = Compartment(postVals) self.preTrace = Compartment(preVals) self.postTrace = Compartment(postVals) - self.dWeights = Compartment(self.weights.value * 0) + self.dWeights = Compartment(self.weights.get() * 0) self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate governing plasticity - @staticmethod - def _compute_update( - dt, w_bound, preTrace_target, exp_beta, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights - ): - pre = preSpike - x_pre = preTrace - post = postSpike - x_post = postTrace - W = weights - x_tar = preTrace_target + def _compute_update(self): # dt, w_bound, preTrace_target, exp_beta, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights + pre = self.preSpike.get() + x_pre = self.preTrace.get() + post = self.postSpike.get() + x_post = self.postTrace.get() + W = self.weights.get() + x_tar = self.preTrace_target ## equations 4 from Diehl and Cook - full exponential weight-dependent STDP ## calculate post-synaptic term - post_term1 = jnp.exp(-exp_beta * W) * jnp.matmul(x_pre.T, post) + post_term1 = jnp.exp(-self.exp_beta * W) * jnp.matmul(x_pre.T, post) x_tar_vec = x_pre * 0 + x_tar # need to broadcast scalar x_tar to mat/vec form - post_term2 = jnp.exp(-exp_beta * (w_bound - W)) * jnp.matmul(x_tar_vec.T, - post) - dWpost = (post_term1 - post_term2) * Aplus + post_term2 = jnp.exp(-self.exp_beta * (self.w_bound - W)) * jnp.matmul(x_tar_vec.T, post) + dWpost = (post_term1 - post_term2) * self.Aplus ## calculate pre-synaptic term dWpre = 0. - if Aminus > 0.: - dWpre = -jnp.exp(-exp_beta * W) * jnp.matmul(pre.T, x_post) * Aminus + if self.Aminus > 0.: + dWpre = -jnp.exp(-self.exp_beta * W) * jnp.matmul(pre.T, x_post) * self.Aminus ## calc final weighted adjustment dW = (dWpost + dWpre) return dW - @transition(output_compartments=["weights", "dWeights"]) - @staticmethod - def evolve( - dt, w_bound, preTrace_target, exp_beta, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, - weights, eta - ): - dW = ExpSTDPSynapse._compute_update( - dt, w_bound, preTrace_target, exp_beta, Aplus, Aminus, - preSpike, postSpike, preTrace, postTrace, weights - ) + @compilable + def evolve(self): + dWeights = self._compute_update() + if self.tau_w > 0.: + decayTerm = self.weights.get() / self.tau_w + else: + decayTerm = 0. + ## do a gradient ascent update/shift - _W = weights + dW * eta + _W = self.weights.get() + (dWeights * self.eta) #- decayTerm ## enforce non-negativity eps = 0.01 - _W = jnp.clip(_W, eps, w_bound - eps) - return weights, dW - - @transition(output_compartments=["inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace", "dWeights"]) - @staticmethod - def reset(batch_size, shape): - preVals = jnp.zeros((batch_size, shape[0])) - postVals = jnp.zeros((batch_size, shape[1])) - inputs = preVals - outputs = postVals - preSpike = preVals - postSpike = postVals - preTrace = preVals - postTrace = postVals - dWeights = jnp.zeros(shape) - return inputs, outputs, preSpike, postSpike, preTrace, postTrace, dWeights + _W = jnp.clip(_W, eps, self.w_bound - eps) + _W = jnp.where(self.weight_mask != 0., _W, 0.) + + self.weights.set(_W) + self.dWeights.set(dWeights) + + @compilable + def reset(self): + preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0])) + postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1])) + + if not self.inputs.targeted: + self.inputs.set(preVals) + self.outputs.set(postVals) + self.preSpike.set(preVals) + self.postSpike.set(postVals) + self.preTrace.set(preVals) + self.postTrace.set(postVals) + self.dWeights.set(jnp.zeros(self.shape.get())) @classmethod def help(cls): ## component help function @@ -183,6 +188,7 @@ def help(cls): ## component help function "exp_beta": "Controls effect of exponential Hebbian shift / dependency (B)", "eta": "Global learning rate initial condition", "pretrace_target": "Pre-synaptic disconnecting/decay factor (x_tar)", + "weight_mask" : "Binary synaptic weight mask to apply to enforce a sparsity structure" } info = {cls.__name__: properties, "compartments": compartment_props, @@ -192,20 +198,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: diff --git a/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py b/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py index 66d3137c..dd51ecf5 100755 --- a/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py +++ b/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py @@ -3,7 +3,6 @@ from ngcsimlib.parser import compilable from ngclearn.components.synapses.denseSynapse import DenseSynapse -from ngclearn.utils import tensorstats class TraceSTDPSynapse(DenseSynapse): # power-law / trace-based STDP @@ -56,12 +55,16 @@ class TraceSTDPSynapse(DenseSynapse): # power-law / trace-based STDP initialization to use resist_scale: a fixed scaling factor to apply to synaptic transform - (Default: 1.), i.e., yields: out = ((W * Rscale) * in) + (Default: 1.), i.e., yields: out = ((W * resistance) * in) p_conn: probability of a connection existing (default: 1); setting this to < 1. will result in a sparser synaptic structure w_bound: maximum value/magnitude any synaptic efficacy can be (default: 1) + + tau_w: synaptic weight decay coefficient to apply to STDP update + + weight_mask: synaptic binary masking matrix to apply (to enforce a constant sparse structure; default: None) """ # Define Functions @@ -69,8 +72,7 @@ def __init__( self, name, shape, A_plus, A_minus, eta=1., mu=0., pretrace_target=0., weight_init=None, resist_scale=1., p_conn=1., w_bound=1., tau_w=0., weight_mask=None, batch_size=1, **kwargs ): - super().__init__(name, shape, weight_init, None, resist_scale, - p_conn, batch_size=batch_size, **kwargs) + super().__init__(name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs) self.tau_w = tau_w self.mu = mu ## controls power-scaling of STDP rule @@ -84,7 +86,6 @@ def __init__( self.weight_mask = jnp.ones((1, 1)) else: self.weight_mask = weight_mask - self.weights.set(self.weights.get() * self.weight_mask) ## Compartment setup @@ -184,6 +185,7 @@ def help(cls): ## component help function "eta": "Global learning rate initial condition", "mu": "Power factor for STDP adjustment", "pretrace_target": "Pre-synaptic disconnecting/decay factor (x_tar)", + "weight_mask" : "Binary synaptic weight mask to apply to enforce a sparsity structure" } info = {cls.__name__: properties, "compartments": compartment_props, diff --git a/tests/components/synapses/hebbian/test_expSTDPSynapse.py b/tests/components/synapses/hebbian/test_expSTDPSynapse.py index 9765315d..679443a7 100644 --- a/tests/components/synapses/hebbian/test_expSTDPSynapse.py +++ b/tests/components/synapses/hebbian/test_expSTDPSynapse.py @@ -1,15 +1,13 @@ + from jax import numpy as jnp, random, jit from ngcsimlib.context import Context import numpy as np np.random.seed(42) -from ngclearn.components import ExpSTDPSynapse -from ngcsimlib.compilers import compile_command, wrap_command -from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context +from ngclearn import Context, MethodProcess +import ngclearn.utils.weight_distribution as dist +from ngclearn.components.synapses.hebbian.expSTDPSynapse import ExpSTDPSynapse +from numpy.testing import assert_array_equal def test_expSTDPSynapse1(): name = "exp_stdp_ctx" @@ -20,33 +18,18 @@ def test_expSTDPSynapse1(): # ---- build a simple Poisson cell system ---- with Context(name) as ctx: a = ExpSTDPSynapse( - name="a", shape=(1,1), A_plus=1., A_minus=1., exp_beta=1.25, key=subkeys[0] + name="a", shape=(1,1), A_plus=1., A_minus=1., exp_beta=1.25, eta=0., key=subkeys[0] ) - #""" - evolve_process = (Process("evolve_proc") - >> a.evolve) - #ctx.wrap_and_add_command(evolve_process.pure, name="run") - ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt") + evolve_process = (MethodProcess("evolve_process") + >> a.evolve) - advance_process = (Process("advance_proc") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - # ctx.wrap_and_add_command(advance_process.pure, name="run") - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - #""" - """ - reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve") - ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt") - """ a.weights.set(jnp.ones((1, 1)) * 0.1) in_spike = jnp.ones((1, 1)) @@ -56,26 +39,30 @@ def test_expSTDPSynapse1(): ## check pre-synaptic STDP only truth = jnp.array([[1.1031212]]) - ctx.reset() + reset_process.run() # ctx.reset() a.preSpike.set(in_spike * 0) a.preTrace.set(in_trace) a.postSpike.set(out_spike) a.postTrace.set(out_trace) - ctx.run(t=1., dt=dt) - ctx.adapt(t=1., dt=dt) - #print(a.dWeights.value) - assert_array_equal(a.dWeights.value, truth) + advance_process.run(t=1., dt=dt) # ctx.run(t=1., dt=dt) + evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1., dt=dt) + # print("W: ",a.weights.get()) + # print(a.dWeights.get()) + # print(truth) + assert_array_equal(a.dWeights.get(), truth) truth = jnp.array([[-0.57362294]]) - ctx.reset() + reset_process.run() # ctx.reset() a.preSpike.set(in_spike) a.preTrace.set(in_trace) a.postSpike.set(out_spike * 0) a.postTrace.set(out_trace) - ctx.run(t=1., dt=dt) - ctx.adapt(t=1., dt=dt) - #print(a.dWeights.value) - assert_array_equal(a.dWeights.value, truth) + advance_process.run(t=1., dt=dt) # ctx.run(t=1., dt=dt) + evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1., dt=dt) + # print("W: ", a.weights.get()) + # print(a.dWeights.get()) + # print(truth) + assert_array_equal(a.dWeights.get(), truth) #test_expSTDPSynapse1() diff --git a/tests/components/synapses/hebbian/test_traceSTDPSynapse.py b/tests/components/synapses/hebbian/test_traceSTDPSynapse.py index 4e1e42de..41452f61 100644 --- a/tests/components/synapses/hebbian/test_traceSTDPSynapse.py +++ b/tests/components/synapses/hebbian/test_traceSTDPSynapse.py @@ -2,14 +2,11 @@ from ngcsimlib.context import Context import numpy as np np.random.seed(42) -from ngclearn.components import TraceSTDPSynapse -from ngcsimlib.compilers import compile_command, wrap_command -from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context +from ngclearn import Context, MethodProcess +import ngclearn.utils.weight_distribution as dist +from ngclearn.components.synapses.hebbian.traceSTDPSynapse import TraceSTDPSynapse +from numpy.testing import assert_array_equal def test_traceSTDPSynapse1(): name = "trace_stdp_ctx" @@ -20,33 +17,18 @@ def test_traceSTDPSynapse1(): # ---- build a simple Poisson cell system ---- with Context(name) as ctx: a = TraceSTDPSynapse( - name="a", shape=(1,1), A_plus=1., A_minus=1., key=subkeys[0] + name="a", shape=(1,1), A_plus=1., A_minus=1., eta=0., key=subkeys[0] ) - #""" - evolve_process = (Process("evolve_proc") - >> a.evolve) - #ctx.wrap_and_add_command(evolve_process.pure, name="run") - ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt") + evolve_process = (MethodProcess("evolve_process") + >> a.evolve) - advance_process = (Process("advance_proc") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - # ctx.wrap_and_add_command(advance_process.pure, name="run") - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - #""" - """ - reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve") - ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt") - """ a.weights.set(jnp.ones((1, 1)) * 0.1) in_spike = jnp.ones((1, 1)) @@ -56,25 +38,29 @@ def test_traceSTDPSynapse1(): ## check pre-synaptic STDP only truth = jnp.array([[1.25]]) - ctx.reset() + reset_process.run() # ctx.reset() a.preSpike.set(in_spike * 0) a.preTrace.set(in_trace) a.postSpike.set(out_spike) a.postTrace.set(out_trace) - ctx.run(t=1., dt=dt) - ctx.adapt(t=1., dt=dt) - #print(a.dWeights.value) - assert_array_equal(a.dWeights.value, truth) + advance_process.run(t=1., dt=dt) # ctx.run(t=1., dt=dt) + evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1., dt=dt) + # print("W: ", a.weights.get()) + # print(a.dWeights.get()) + # print(truth) + assert_array_equal(a.dWeights.get(), truth) truth = jnp.array([[-0.65]]) - ctx.reset() + reset_process.run() # ctx.reset() a.preSpike.set(in_spike) a.preTrace.set(in_trace) a.postSpike.set(out_spike * 0) a.postTrace.set(out_trace) - ctx.run(t=1., dt=dt) - ctx.adapt(t=1., dt=dt) - #print(a.dWeights.value) - assert_array_equal(a.dWeights.value, truth) + advance_process.run(t=1., dt=dt) # ctx.run(t=1., dt=dt) + evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1., dt=dt) + # print("W: ", a.weights.get()) + # print(a.dWeights.get()) + # print(truth) + assert_array_equal(a.dWeights.get(), truth) #test_traceSTDPSynapse1() From ebdea3efa803c65320cd7c2cd1f170c3d54db63c Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Thu, 6 Nov 2025 19:37:36 -0500 Subject: [PATCH 026/121] refactored event-stdp-syn and test passed --- .../components/synapses/hebbian/__init__.py | 2 +- .../synapses/hebbian/eventSTDPSynapse.py | 100 +++++++----------- .../synapses/hebbian/test_eventSTDPSynapse.py | 58 ++++------ 3 files changed, 61 insertions(+), 99 deletions(-) diff --git a/ngclearn/components/synapses/hebbian/__init__.py b/ngclearn/components/synapses/hebbian/__init__.py index a2247bc8..99ebec99 100644 --- a/ngclearn/components/synapses/hebbian/__init__.py +++ b/ngclearn/components/synapses/hebbian/__init__.py @@ -1,6 +1,6 @@ #from .hebbianSynapse import HebbianSynapse from .traceSTDPSynapse import TraceSTDPSynapse from .expSTDPSynapse import ExpSTDPSynapse -#from .eventSTDPSynapse import EventSTDPSynapse +from .eventSTDPSynapse import EventSTDPSynapse from .BCMSynapse import BCMSynapse diff --git a/ngclearn/components/synapses/hebbian/eventSTDPSynapse.py b/ngclearn/components/synapses/hebbian/eventSTDPSynapse.py index fde8758a..b92522fe 100755 --- a/ngclearn/components/synapses/hebbian/eventSTDPSynapse.py +++ b/ngclearn/components/synapses/hebbian/eventSTDPSynapse.py @@ -1,10 +1,8 @@ -from jax import numpy as jnp, jit -from ngcsimlib.compilers.process import transition -from ngcsimlib.component import Component +from jax import random, numpy as jnp, jit from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable -from ngclearn.components.synapses import DenseSynapse -from ngclearn.utils import tensorstats +from ngclearn.components.synapses.denseSynapse import DenseSynapse class EventSTDPSynapse(DenseSynapse): # event-driven, post-synaptic STDP """ @@ -57,11 +55,11 @@ class EventSTDPSynapse(DenseSynapse): # event-driven, post-synaptic STDP """ # Define Functions - def __init__(self, name, shape, eta, lmbda=0.01, A_plus=1., A_minus=1., - presyn_win_len=2., w_bound=1., weight_init=None, resist_scale=1., - p_conn=1., batch_size=1, **kwargs): - super().__init__(name, shape, weight_init, None, resist_scale, p_conn, - batch_size=batch_size, **kwargs) + def __init__( + self, name, shape, eta, lmbda=0.01, A_plus=1., A_minus=1., presyn_win_len=2., w_bound=1., weight_init=None, + resist_scale=1., p_conn=1., batch_size=1, **kwargs + ): + super().__init__(name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs) ## Synaptic hyper-parameters self.eta = eta ## global learning rate governing plasticity @@ -78,53 +76,47 @@ def __init__(self, name, shape, eta, lmbda=0.01, A_plus=1., A_minus=1., postVals = jnp.zeros((self.batch_size, shape[1])) self.pre_tols = Compartment(preVals) self.postSpike = Compartment(postVals) - self.dWeights = Compartment(self.weights.value * 0) + self.dWeights = Compartment(self.weights.get() * 0) self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate governing plasticity - @staticmethod - def _compute_update( - t, lmbda, presyn_win_len, Aminus, Aplus, w_bound, pre_tols, postSpike, weights - ): ## synaptic adjustment calculation co-routine + def _compute_update(self, t, dt): ## synaptic adjustment calculation co-routine ## check if a spike occurred in window of (t - presyn_win_len, t] - m = (pre_tols > 0.) * 1. ## ignore default value of tols = 0 ms - if presyn_win_len > 0.: - lbound = ((t - presyn_win_len) < pre_tols) * 1. + m = (self.pre_tols.get() > 0.) * 1. ## ignore default value of tols = 0 ms + if self.presyn_win_len > 0.: + lbound = ((t - self.presyn_win_len) < self.pre_tols.get()) * 1. preSpike = lbound * m else: - check_spike = (pre_tols == t) * 1. + check_spike = (self.pre_tols.get() == t) * 1. preSpike = check_spike * m ## this implements a generalization of the rule in eqn 18 of the paper - pos_shift = w_bound - (weights * (1. + lmbda)) - pos_shift = pos_shift * Aplus - neg_shift = -weights * (1. + lmbda) - neg_shift = neg_shift * Aminus + pos_shift = self.w_bound - (self.weights.get() * (1. + self.lmbda)) + pos_shift = pos_shift * self.Aplus + neg_shift = -self.weights.get() * (1. + self.lmbda) + neg_shift = neg_shift * self.Aminus dW = jnp.where(preSpike.T, pos_shift, neg_shift) # at pre-spikes => LTP, else decay - dW = (dW * postSpike) ## gate to make sure only post-spikes trigger updates + dW = (dW * self.postSpike.get()) ## gate to make sure only post-spikes trigger updates return dW - @transition(output_compartments=["weights", "dWeights"]) - @staticmethod - def evolve( - t, lmbda, presyn_win_len, Aminus, Aplus, w_bound, pre_tols, postSpike, weights, eta - ): - dWeights = EventSTDPSynapse._compute_update( - t, lmbda, presyn_win_len, Aminus, Aplus, w_bound, pre_tols, postSpike, weights - ) - weights = weights + dWeights * eta # * (1. - w) * eta - weights = jnp.clip(weights, 0.01, w_bound) ## Note: this step not in source paper - return weights, dWeights - - @transition(output_compartments=["inputs", "outputs", "pre_tols", "postSpike", "dWeights"]) - @staticmethod - def reset(batch_size, shape): - preVals = jnp.zeros((batch_size, shape[0])) - postVals = jnp.zeros((batch_size, shape[1])) - inputs = preVals - outputs = postVals - pre_tols = preVals ## pre-synaptic time-of-last-spike(s) record - postSpike = postVals - dWeights = jnp.zeros(shape) - return inputs, outputs, pre_tols, postSpike, dWeights + @compilable + def evolve(self, t, dt): + dWeights = self._compute_update(t, dt) + weights = self.weights.get() + dWeights * self.eta # * (1. - w) * eta + weights = jnp.clip(weights, 0.01, self.w_bound) ## Note: this step not in source paper + + self.weights.set(weights) + self.dWeights.set(dWeights) + + @compilable + def reset(self): + preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0])) + postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1])) + + if not self.inputs.targeted: + self.inputs.set(preVals) + self.outputs.set(postVals) + self.pre_tols.set(preVals) ## pre-synaptic time-of-last-spike(s) record + self.postSpike.set(postVals) + self.dWeights.set(jnp.zeros(self.shape.get())) @classmethod def help(cls): ## component help function @@ -166,20 +158,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: diff --git a/tests/components/synapses/hebbian/test_eventSTDPSynapse.py b/tests/components/synapses/hebbian/test_eventSTDPSynapse.py index b51c16de..48e3dca5 100644 --- a/tests/components/synapses/hebbian/test_eventSTDPSynapse.py +++ b/tests/components/synapses/hebbian/test_eventSTDPSynapse.py @@ -2,14 +2,11 @@ from ngcsimlib.context import Context import numpy as np np.random.seed(42) -from ngclearn.components import EventSTDPSynapse -from ngcsimlib.compilers import compile_command, wrap_command -from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context +from ngclearn import Context, MethodProcess +import ngclearn.utils.weight_distribution as dist +from ngclearn.components.synapses.hebbian.eventSTDPSynapse import EventSTDPSynapse +from numpy.testing import assert_array_equal def test_eventSTDPSynapse1(): name = "event_stdp_ctx" @@ -24,46 +21,32 @@ def test_eventSTDPSynapse1(): name="a", shape=(1,1), eta=0., presyn_win_len=2., key=subkeys[0] ) - #""" - evolve_process = (Process("evolve_proc") - >> a.evolve) - #ctx.wrap_and_add_command(evolve_process.pure, name="run") - ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt") + evolve_process = (MethodProcess("evolve_process") + >> a.evolve) - advance_process = (Process("advance_proc") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - # ctx.wrap_and_add_command(advance_process.pure, name="run") - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - #""" - """ - reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve") - ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt") - """ a.weights.set(jnp.ones((1, 1)) * 0.1) t = 12. ## fake out current time - ## Case 1: outside of pre-syn time window + ## Case 1: outside pre-syn time window input_tols = jnp.ones((1, 1,)) * 9. out_spike = jnp.ones((1, 1)) ## check pre-synaptic STDP only truth = jnp.array([[-0.101]]) - ctx.reset() + reset_process.run() # ctx.reset() a.pre_tols.set(input_tols) a.postSpike.set(out_spike) - ctx.run(t=t, dt=dt) - ctx.adapt(t=t, dt=dt) - #print(a.dWeights.value) - assert_array_equal(a.dWeights.value, truth) + advance_process.run(t=t, dt=dt) # ctx.run(t=t, dt=dt) + evolve_process.run(t=t, dt=dt) # ctx.adapt(t=t, dt=dt) + # print(a.dWeights.get()) + # print(truth) + assert_array_equal(a.dWeights.get(), truth) ## Case 2: within pre-syn time window input_tols = jnp.ones((1, 1,)) * 11. @@ -71,13 +54,14 @@ def test_eventSTDPSynapse1(): ## check pre-synaptic STDP only truth = jnp.array([[0.899]]) - ctx.reset() + reset_process.run() # ctx.reset() a.pre_tols.set(input_tols) a.postSpike.set(out_spike) - ctx.run(t=t, dt=dt) - ctx.adapt(t=t, dt=dt) - #print(a.dWeights.value) - assert_array_equal(a.dWeights.value, truth) + advance_process.run(t=t, dt=dt) # ctx.run(t=t, dt=dt) + evolve_process.run(t=t, dt=dt) # ctx.adapt(t=t, dt=dt) + # print(a.dWeights.get()) + # print(truth) + assert_array_equal(a.dWeights.get(), truth) #test_eventSTDPSynapse1() From 94477b8f3391a110c75e089524bd227b9418ac66 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Fri, 7 Nov 2025 13:02:11 -0500 Subject: [PATCH 027/121] refactored mstdpet-syn and test passed --- .../synapses/modulated/MSTDPETSynapse.py | 104 +++++++----------- .../synapses/modulated/test_MSTDPETSynapse.py | 60 ++++------ 2 files changed, 59 insertions(+), 105 deletions(-) diff --git a/ngclearn/components/synapses/modulated/MSTDPETSynapse.py b/ngclearn/components/synapses/modulated/MSTDPETSynapse.py index 6e5dd8c4..45c99a7b 100755 --- a/ngclearn/components/synapses/modulated/MSTDPETSynapse.py +++ b/ngclearn/components/synapses/modulated/MSTDPETSynapse.py @@ -1,12 +1,9 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compilers.process import transition -from ngcsimlib.component import Component from ngcsimlib.compartment import Compartment - +from ngcsimlib.parser import compilable from ngclearn.utils.weight_distribution import initialize_params -from ngcsimlib.logger import info + from ngclearn.components.synapses.hebbian import TraceSTDPSynapse -from ngclearn.utils import tensorstats class MSTDPETSynapse(TraceSTDPSynapse): # modulated trace-based STDP w/ eligility traces """ @@ -72,78 +69,69 @@ class MSTDPETSynapse(TraceSTDPSynapse): # modulated trace-based STDP w/ eligilit p_conn: probability of a connection existing (default: 1.); setting this to < 1. will result in a sparser synaptic structure + + w_bound: maximum value/magnitude any synaptic efficacy can be (default: 1) """ - # Define Functions def __init__( self, name, shape, A_plus, A_minus, eta=1., mu=0., pretrace_target=0., tau_elg=0., elg_decay=1., tau_w=0., weight_init=None, resist_scale=1., p_conn=1., w_bound=1., batch_size=1, **kwargs ): - super().__init__( + super().__init__( # call to parent trace-stdp component name, shape, A_plus, A_minus, eta=eta, mu=mu, pretrace_target=pretrace_target, weight_init=weight_init, resist_scale=resist_scale, p_conn=p_conn, w_bound=w_bound, batch_size=batch_size, **kwargs ) self.w_eps = 0. self.tau_w = tau_w ## MSTDP/MSTDP-ET meta-parameters - self.tau_elg = tau_elg - self.elg_decay = elg_decay + self.tau_elg = tau_elg ## time constant for eligibility trace + self.elg_decay = elg_decay ## decay factor eligibility trace ## MSTDP/MSTDP-ET compartments self.modulator = Compartment(jnp.zeros((self.batch_size, 1))) self.eligibility = Compartment(jnp.zeros(shape)) self.outmask = Compartment(jnp.zeros((1, shape[1]))) - @transition(output_compartments=["weights", "dWeights", "eligibility"]) - @staticmethod - def evolve( - dt, w_bound, w_eps, preTrace_target, mu, Aplus, Aminus, tau_elg, elg_decay, tau_w, preSpike, postSpike, - preTrace, postTrace, weights, dWeights, eta, modulator, eligibility, outmask - ): - # dW_dt = TraceSTDPSynapse._compute_update( ## use Hebbian/STDP rule to obtain a non-modulated update - # dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights - # ) + @compilable + def evolve(self, dt, t): + # dW_dt = self._compute_update() # dWeights = dW_dt ## can think of this as eligibility at time t - if tau_elg > 0.: ## perform dynamics of M-STDP-ET - eligibility = eligibility * jnp.exp(-dt / tau_elg) * elg_decay + dWeights/tau_elg + if self.tau_elg > 0.: ## perform dynamics of M-STDP-ET + eligibility = self.eligibility.get() * jnp.exp(-dt / self.tau_elg) * self.elg_decay + self.dWeights.get()/self.tau_elg else: ## otherwise, just do M-STDP - eligibility = dWeights ## dynamics of M-STDP had no eligibility tracing + eligibility = self.dWeights.get() ## dynamics of M-STDP had no eligibility tracing ## do a gradient ascent update/shift decayTerm = 0. - if tau_w > 0.: - decayTerm = weights * (1. / tau_w) - weights = weights + (eligibility * modulator * eta) * outmask - decayTerm ## do modulated update + if self.tau_w > 0.: + decayTerm = self.weights.get() * (1. / self.tau_w) + ## do modulated update + weights = self.weights.get() + (eligibility * self.modulator.get() * self.eta) * self.outmask.get() - decayTerm - dW_dt = TraceSTDPSynapse._compute_update( ## use Hebbian/STDP rule to obtain a non-modulated update - dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights - ) + dW_dt = self._compute_update() ## apply a Hebbian/STDP rule to obtain a non-modulated update dWeights = dW_dt ## can think of this as eligibility at time t #w_eps = 0.01 - weights = jnp.clip(weights, w_eps, w_bound - w_eps) # jnp.abs(w_bound)) - - return weights, dWeights, eligibility - - @transition( - output_compartments=[ - "inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace", "dWeights", "eligibility", "outmask" - ] - ) - @staticmethod - def reset(batch_size, shape): - preVals = jnp.zeros((batch_size, shape[0])) - postVals = jnp.zeros((batch_size, shape[1])) - synVals = jnp.zeros(shape) - inputs = preVals - outputs = postVals - preSpike = preVals - postSpike = postVals - preTrace = preVals - postTrace = postVals - dWeights = synVals - eligibility = synVals - outmask = postVals + 1. - return inputs, outputs, preSpike, postSpike, preTrace, postTrace, dWeights, eligibility, outmask + weights = jnp.clip(weights, self.w_eps, self.w_bound - self.w_eps) # jnp.abs(w_bound)) + self.weights.set(weights) + self.dWeights.set(dWeights) + self.eligibility.set(eligibility) + + @compilable + def reset(self): + preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0])) + postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1])) + synVals = jnp.zeros(self.shape.get()) + + if not self.inputs.targeted: + self.inputs.set(preVals) + self.outputs.set(postVals) + self.preSpike.set(preVals) + self.postSpike.set(postVals) + self.preTrace.set(preVals) + self.postTrace.set(postVals) + self.dWeights.set(synVals) + self.eligibility.set(synVals) + self.outmask.set(postVals + 1.) @classmethod def help(cls): ## component help function @@ -195,17 +183,3 @@ def help(cls): ## component help function "dW^{stdp}_{ij}/dt = A_plus * (z_j - x_tar) * s_i - A_minus * s_j * z_i", "hyperparameters": hyperparams} return info - - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines diff --git a/tests/components/synapses/modulated/test_MSTDPETSynapse.py b/tests/components/synapses/modulated/test_MSTDPETSynapse.py index e1c7ce36..748f5848 100644 --- a/tests/components/synapses/modulated/test_MSTDPETSynapse.py +++ b/tests/components/synapses/modulated/test_MSTDPETSynapse.py @@ -2,15 +2,11 @@ from ngcsimlib.context import Context import numpy as np np.random.seed(42) -from ngclearn.components import MSTDPETSynapse -from ngcsimlib.compilers import compile_command, wrap_command -from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context +from ngclearn import Context, MethodProcess import ngclearn.utils.weight_distribution as dist +from ngclearn.components.synapses.modulated.MSTDPETSynapse import MSTDPETSynapse +from numpy.testing import assert_array_equal def test_MSTDPETSynapse1(): name = "mstdpet_ctx" @@ -24,30 +20,14 @@ def test_MSTDPETSynapse1(): name="a", shape=(1,1), A_plus=1., A_minus=1., eta=0.1, key=subkeys[0] ) - #""" - advance_process = (Process("advance_proc") + evolve_process = (MethodProcess("evolve_process") + >> a.evolve) + + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - # ctx.wrap_and_add_command(advance_process.pure, name="run") - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - evolve_process = (Process("evolve_proc") - >> a.evolve) - #ctx.wrap_and_add_command(evolve_process.pure, name="run") - ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt") - - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - #""" - - """ - reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve") - ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt") - """ a.weights.set(jnp.ones((1, 1)) * 0.75) @@ -59,28 +39,28 @@ def test_MSTDPETSynapse1(): r_pos = jnp.ones((1, 1)) #print(a.weights.value) - ctx.reset() + reset_process.run() # ctx.reset() a.preSpike.set(in_spike * 0) a.preTrace.set(in_trace) a.postSpike.set(out_spike) a.postTrace.set(out_trace) a.modulator.set(r_pos) - ctx.run(t=1. * dt, dt=dt) - ctx.adapt(t=1. * dt, dt=dt) - ctx.adapt(t=1. * dt, dt=dt) - #print(a.weights.value) - assert_array_equal(a.weights.value, jnp.array([[0.875]])) + advance_process.run(t=1., dt=dt) # ctx.run(t=1. * dt, dt=dt) + evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1. * dt, dt=dt) + evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1. * dt, dt=dt) + #print(a.weights.get()) + assert_array_equal(a.weights.get(), jnp.array([[0.875]])) - ctx.reset() + reset_process.run() # ctx.reset() a.preSpike.set(in_spike * 0) a.preTrace.set(in_trace) a.postSpike.set(out_spike) a.postTrace.set(out_trace) a.modulator.set(r_neg) - ctx.run(t=1. * dt, dt=dt) - ctx.adapt(t=1. * dt, dt=dt) - ctx.adapt(t=1. * dt, dt=dt) - #print(a.weights.value) - assert_array_equal(a.weights.value, jnp.array([[0.75]])) + advance_process.run(t=1., dt=dt) # ctx.run(t=1. * dt, dt=dt) + evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1. * dt, dt=dt) + evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1. * dt, dt=dt) + #print(a.weights.get()) + assert_array_equal(a.weights.get(), jnp.array([[0.75]])) #test_MSTDPETSynapse1() From f72db76e1f4894c2b35d94d7cf1f0904ce3ac547 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Fri, 7 Nov 2025 14:00:25 -0500 Subject: [PATCH 028/121] refactored stdp-conv-syn/conv-syn and test passed --- .../synapses/convolution/convSynapse.py | 90 ++++++-------- .../convolution/traceSTDPConvSynapse.py | 115 +++++++++--------- .../convolution/test_traceSTDPConvSynapse.py | 63 ++++------ 3 files changed, 120 insertions(+), 148 deletions(-) diff --git a/ngclearn/components/synapses/convolution/convSynapse.py b/ngclearn/components/synapses/convolution/convSynapse.py index 12c5e674..9da1412f 100755 --- a/ngclearn/components/synapses/convolution/convSynapse.py +++ b/ngclearn/components/synapses/convolution/convSynapse.py @@ -1,15 +1,13 @@ from jax import random, numpy as jnp, jit -from ngclearn.components.jaxComponent import JaxComponent -from ngcsimlib.compilers.process import transition -from ngcsimlib.component import Component from ngcsimlib.compartment import Compartment - +from ngcsimlib.parser import compilable from ngclearn.utils.weight_distribution import initialize_params from ngcsimlib.logger import info -from ngclearn.utils import tensorstats import ngclearn.utils.weight_distribution as dist from ngclearn.components.synapses.convolution.ngcconv import conv2d +from ngclearn.components.jaxComponent import JaxComponent + class ConvSynapse(JaxComponent): ## base-level convolutional cable """ A base convolutional synaptic cable. @@ -61,7 +59,7 @@ def __init__( self.shape = shape ## shape of synaptic filter tensor x_size, x_size = x_shape self.x_size = x_size - self.Rscale = resist_scale ## post-transformation scale factor + self.resist_scale = resist_scale ## post-transformation scale factor self.padding = padding self.stride = stride @@ -69,7 +67,7 @@ def __init__( k_size, k_size, n_in_chan, n_out_chan = shape self.pad_args = None if self.padding is not None and self.padding == "SAME": - if (x_size % stride == 0): + if x_size % stride == 0: pad_along_height = max(k_size - stride, 0) else: pad_along_height = max(k_size - (x_size % stride), 0) @@ -83,7 +81,7 @@ def __init__( self.pad_args = ((0, 0), (0, 0)) ######################### set up compartments ########################## - tmp_key, *subkeys = random.split(self.key.value, 4) + tmp_key, *subkeys = random.split(self.key.get(), 4) weights = dist.initialize_params(subkeys[0], filter_init, shape) ## filter tensor self.batch_size = batch_size # 1 ## Compartment setup and shape computation @@ -101,36 +99,38 @@ def __init__( dist.initialize_params(subkeys[2], bias_init, (1, shape[1])) if bias_init else 0.0 ) - @transition(output_compartments=["outputs"]) - @staticmethod - def advance_state(Rscale, padding, stride, weights, biases, inputs): - _x = inputs - outputs = conv2d(_x, weights, stride_size=stride, padding=padding) * Rscale + biases - return outputs - - @transition(output_compartments=["inputs", "outputs"]) - @staticmethod - def reset(in_shape, out_shape): - preVals = jnp.zeros(in_shape) - postVals = jnp.zeros(out_shape) - inputs = preVals - outputs = postVals - return inputs, outputs - - def save(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - if self.bias_init != None: - jnp.savez(file_name, weights=self.weights.value, - biases=self.biases.value) - else: - jnp.savez(file_name, weights=self.weights.value) - - def load(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - self.weights.set(data['weights']) - if "biases" in data.keys(): - self.biases.set(data['biases']) + # @transition(output_compartments=["outputs"]) + # @staticmethod + @compilable + def advance_state(self): #Rscale, padding, stride, weights, biases, inputs): + _x = self.inputs.get() + ## FIXME: does resist_scale affect update rules? + outputs = conv2d(_x, self.weights.get(), stride_size=self.stride, padding=self.padding) * self.resist_scale + self.biases.get() + self.outputs.set(outputs) + + # @transition(output_compartments=["inputs", "outputs"]) + # @staticmethod + @compilable + def reset(self): #in_shape, out_shape): + preVals = jnp.zeros(self.in_shape) + postVals = jnp.zeros(self.out_shape) + self.inputs.set(preVals) + self.outputs.set(postVals) + + # def save(self, directory, **kwargs): + # file_name = directory + "/" + self.name + ".npz" + # if self.bias_init != None: + # jnp.savez(file_name, weights=self.weights.get(), + # biases=self.biases.get()) + # else: + # jnp.savez(file_name, weights=self.weights.get()) + # + # def load(self, directory, **kwargs): + # file_name = directory + "/" + self.name + ".npz" + # data = jnp.load(file_name) + # self.weights.set(data['weights']) + # if "biases" in data.keys(): + # self.biases.set(data['biases']) @classmethod def help(cls): ## component help function @@ -163,17 +163,3 @@ def help(cls): ## component help function "dynamics": "outputs = [K @ inputs] * R + b", "hyperparameters": hyperparams} return info - - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines diff --git a/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py b/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py index 7fbb5021..0ee67009 100755 --- a/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py +++ b/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py @@ -1,13 +1,11 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compilers.process import transition -from ngcsimlib.component import Component from ngcsimlib.compartment import Compartment - -from .convSynapse import ConvSynapse +from ngcsimlib.parser import compilable from ngclearn.utils.weight_distribution import initialize_params -from ngcsimlib.logger import info -from ngclearn.utils import tensorstats import ngclearn.utils.weight_distribution as dist + +from ngclearn.components.synapses.convolution.convSynapse import ConvSynapse + from ngclearn.components.synapses.convolution.ngcconv import (_conv_same_transpose_padding, _conv_valid_transpose_padding) from ngclearn.components.synapses.convolution.ngcconv import (conv2d, _calc_dX_conv, @@ -93,7 +91,7 @@ def __init__( ######################### set up compartments ########################## ## Compartment setup and shape computation - self.dWeights = Compartment(self.weights.value * 0) + self.dWeights = Compartment(self.weights.get() * 0) self.dInputs = Compartment(jnp.zeros(self.in_shape)) self.preSpike = Compartment(jnp.zeros(self.in_shape)) self.preTrace = Compartment(jnp.zeros(self.in_shape)) @@ -108,72 +106,76 @@ def __init__( k_size, k_size, n_in_chan, n_out_chan = self.shape if padding == "SAME": self.antiPad = _conv_same_transpose_padding( - self.postSpike.value.shape[1], + self.postSpike.get().shape[1], self.x_size, k_size, stride) elif padding == "VALID": self.antiPad = _conv_valid_transpose_padding( - self.postSpike.value.shape[1], + self.postSpike.get().shape[1], self.x_size, k_size, stride) ######################################################################## def _init(self, batch_size, x_size, shape, stride, padding, pad_args, weights): k_size, k_size, n_in_chan, n_out_chan = shape _x = jnp.zeros((batch_size, x_size, x_size, n_in_chan)) - _d = conv2d(_x, weights.value, stride_size=stride, padding=padding) * 0 + _d = conv2d(_x, weights.get(), stride_size=stride, padding=padding) * 0 _dK = _calc_dK_conv(_x, _d, stride_size=stride, padding=pad_args) ## get filter update correction - dx = _dK.shape[0] - weights.value.shape[0] - dy = _dK.shape[1] - weights.value.shape[1] + dx = _dK.shape[0] - weights.get().shape[0] + dy = _dK.shape[1] - weights.get().shape[1] #self.delta_shape = (dx, dy) self.delta_shape = (max(dx, 0), max(dy, 0)) ## get input update correction - _dx = _calc_dX_conv(weights.value, _d, stride_size=stride, + _dx = _calc_dX_conv(weights.get(), _d, stride_size=stride, anti_padding=pad_args) dx = (_dx.shape[1] - _x.shape[1]) dy = (_dx.shape[2] - _x.shape[2]) self.x_delta_shape = (dx, dy) - @staticmethod - def _compute_update( - pretrace_target, Aplus, Aminus, stride, pad_args, delta_shape, preSpike, preTrace, postSpike, postTrace - ): + #@staticmethod + def _compute_update(self): #pretrace_target, Aplus, Aminus, stride, pad_args, delta_shape, preSpike, preTrace, postSpike, postTrace ## Compute long-term potentiation to filters dW_ltp = calc_dK_conv( - preTrace - pretrace_target, postSpike * Aplus, delta_shape=delta_shape, stride_size=stride, padding=pad_args + self.preTrace.get() - self.pretrace_target, self.postSpike.get() * self.Aplus, delta_shape=self.delta_shape, + stride_size=self.stride, padding=self.pad_args ) ## Compute long-term depression to filters dW_ltd = -calc_dK_conv( - preSpike, postTrace * Aminus, delta_shape=delta_shape, stride_size=stride, padding=pad_args + self.preSpike.get(), self.postTrace.get() * self.Aminus, delta_shape=self.delta_shape, + stride_size=self.stride, padding=self.pad_args ) dWeights = (dW_ltp + dW_ltd) return dWeights - @transition(output_compartments=["weights", "dWeights"]) - @staticmethod - def evolve( - pretrace_target, Aplus, Aminus, w_decay, w_bound, stride, pad_args, delta_shape, preSpike, preTrace, - postSpike, postTrace, weights, eta - ): - dWeights = TraceSTDPConvSynapse._compute_update( - pretrace_target, Aplus, Aminus, stride, pad_args, delta_shape, preSpike, preTrace, postSpike, postTrace - ) - if w_decay > 0.: ## apply synaptic decay - weights = weights + dWeights * eta - weights * w_decay ## conduct decayed STDP-ascent + # @transition(output_compartments=["weights", "dWeights"]) + # @staticmethod + @compilable + def evolve(self): + # pretrace_target, Aplus, Aminus, w_decay, w_bound, stride, pad_args, delta_shape, preSpike, preTrace, + # postSpike, postTrace, weights, eta + + dWeights = self._compute_update() + # dWeights = TraceSTDPConvSynapse._compute_update( + # pretrace_target, Aplus, Aminus, stride, pad_args, delta_shape, preSpike, preTrace, postSpike, postTrace + # ) + if self.w_decay > 0.: ## apply synaptic decay + weights = self.weights.get() + dWeights * self.eta - self.weights.get() * self.w_decay ## conduct decayed STDP-ascent else: - weights = weights + dWeights * eta ## conduct STDP-ascent + weights = self.weights.get() + dWeights * self.eta ## conduct STDP-ascent ## Apply any enforced filter constraints - if w_bound > 0.: ## enforce non-negativity + if self.w_bound > 0.: ## enforce non-negativity eps = 0.01 # 0.001 - weights = jnp.clip(weights, eps, w_bound - eps) - return weights, dWeights - - @transition(output_compartments=["dInputs"]) - @staticmethod - def backtransmit( - x_size, shape, stride, padding, x_delta_shape, antiPad, postSpike, weights - ): ## action-backpropagating routine + weights = jnp.clip(weights, eps, self.w_bound - eps) + + self.weights.set(weights) + self.dWeights.set(dWeights) + + # @transition(output_compartments=["dInputs"]) + # @staticmethod + @compilable + def backtransmit(self): # x_size, shape, stride, padding, x_delta_shape, antiPad, postSpike, weights + ## action-backpropagating routine ## calc dInputs - adjustment w.r.t. input signal - k_size, k_size, n_in_chan, n_out_chan = shape + k_size, k_size, n_in_chan, n_out_chan = self.shape # antiPad = None # if padding == "SAME": # antiPad = _conv_same_transpose_padding(postSpike.shape[1], x_size, @@ -181,21 +183,22 @@ def backtransmit( # elif padding == "VALID": # antiPad = _conv_valid_transpose_padding(postSpike.shape[1], x_size, # k_size, stride) - dInputs = calc_dX_conv(weights, postSpike, delta_shape=x_delta_shape, stride_size=stride, anti_padding=antiPad) - return dInputs - - @transition(output_compartments=["inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace"]) - @staticmethod - def reset(in_shape, out_shape): - preVals = jnp.zeros(in_shape) - postVals = jnp.zeros(out_shape) - inputs = preVals - outputs = postVals - preSpike = preVals - postSpike = postVals - preTrace = preVals - postTrace = postVals - return inputs, outputs, preSpike, postSpike, preTrace, postTrace + dInputs = calc_dX_conv( + self.weights.get(), self.postSpike.get(), delta_shape=self.x_delta_shape, stride_size=self.stride, + anti_padding=self.antiPad + ) + self.dInputs.set(dInputs) + + @compilable + def reset(self): # in_shape, out_shape): + preVals = jnp.zeros(self.in_shape.get()) + postVals = jnp.zeros(self.out_shape.get()) + self.inputs.set(preVals) + self.outputs.set(postVals) + self.preSpike.set(preVals) + self.postSpike.set(postVals) + self.preTrace.set(preVals) + self.postTrace.set(postVals) @classmethod def help(cls): ## component help function diff --git a/tests/components/synapses/convolution/test_traceSTDPConvSynapse.py b/tests/components/synapses/convolution/test_traceSTDPConvSynapse.py index bf113760..8b179704 100644 --- a/tests/components/synapses/convolution/test_traceSTDPConvSynapse.py +++ b/tests/components/synapses/convolution/test_traceSTDPConvSynapse.py @@ -2,16 +2,12 @@ from ngcsimlib.context import Context import numpy as np np.random.seed(42) -from ngclearn.components import TraceSTDPConvSynapse + +from ngclearn import Context, MethodProcess import ngclearn.utils.weight_distribution as dist -from ngcsimlib.compilers import compile_command, wrap_command +from ngclearn.components.synapses.convolution.traceSTDPConvSynapse import TraceSTDPConvSynapse from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context - def test_TraceSTDPConvSynapse1(): name = "stdp_conv_ctx" ## create seeding keys @@ -36,34 +32,17 @@ def test_TraceSTDPConvSynapse1(): stride=stride, padding=padding_style, batch_size=batch_size, key=subkeys[0] ) - #""" - evolve_process = (Process("evolve_proc") + evolve_process = (MethodProcess("evolve_process") >> a.evolve) - ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt") - backtransmit_process = (Process("btransmit_proc") + backtransmit_process = (MethodProcess("backtransmit_process") >> a.backtransmit) - ctx.wrap_and_add_command(jit(backtransmit_process.pure), name="backtransmit") - advance_process = (Process("advance_proc") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - #""" - - """ - reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve") - ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt") - backpass_cmd, backpass_args = ctx.compile_by_key(a, compile_key="backtransmit") - ctx.add_command(wrap_command(jit(ctx.backtransmit)), name="backtransmit") - """ ## fake out a mix of pre-synaptic spikes/no-spikes x = np.ones(x_shape) @@ -75,25 +54,25 @@ def test_TraceSTDPConvSynapse1(): [[1.], [0.]]]] ) - ctx.reset() + reset_process.run() # a.inputs.set(x) - ctx.run(t=1., dt=dt) - y = (a.outputs.value > 0.) * 1. ## fake out post-syn spikes + advance_process.run(t=1., dt=dt) # ctx.run(t=1., dt=dt) + y = (a.outputs.get() > 0.) * 1. ## fake out post-syn spikes assert_array_equal(y, y_truth) - #print(y) - #print("======") + print(y) + print("y.Tr:\n", y_truth) + print("======") - # print("NGC-Learn.shape = ", node.outputs.value.shape) + # print("NGC-Learn.shape = ", node.outputs.get().shape) a.preSpike.set(x) a.postSpike.set(y) a.preTrace.set(x * 0.4) ## fake out pre-syn trace values a.postTrace.set(y * 1.3) ## fake out post-syn trace values - ctx.adapt(t=1., dt=dt) - dK = a.dWeights.value - #print(dK) - ctx.backtransmit(t=1., dt=dt) - dx = a.dInputs.value - #print(dx) + evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1., dt=dt) + dK = a.dWeights.get() + + backtransmit_process.run(t=1., dt=dt) # ctx.backtransmit(t=1., dt=dt) + dx = a.dInputs.get() dK_truth = jnp.array( [[[[-1.8]], [[-0.9]]], @@ -106,6 +85,10 @@ def test_TraceSTDPConvSynapse1(): [[2.], [3.]]]] ) + # print(dK) + # print("dK.Tr:\n", dK_truth) + # print(dx) + # print("dx.Tr:\n", dx_truth) assert_array_equal(dK, dK_truth) assert_array_equal(dx, dx_truth) From 6cb319a94289ca46a36203db76551f070b776449 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Fri, 7 Nov 2025 14:17:51 -0500 Subject: [PATCH 029/121] refactored and passed test for deconv/stdp-deconv-syn and other minor cleanup for conv/deconv support --- .../synapses/convolution/__init__.py | 4 +- .../synapses/convolution/convSynapse.py | 8 +- .../synapses/convolution/deconvSynapse.py | 86 ++++++-------- .../convolution/traceSTDPConvSynapse.py | 12 +- .../convolution/traceSTDPDeconvSynapse.py | 110 +++++++++--------- .../convolution/test_traceSTDPConvSynapse.py | 8 +- .../test_traceSTDPDeconvSynapse.py | 66 ++++------- 7 files changed, 120 insertions(+), 174 deletions(-) diff --git a/ngclearn/components/synapses/convolution/__init__.py b/ngclearn/components/synapses/convolution/__init__.py index ed305c38..01f3bced 100755 --- a/ngclearn/components/synapses/convolution/__init__.py +++ b/ngclearn/components/synapses/convolution/__init__.py @@ -2,7 +2,7 @@ from .staticConvSynapse import StaticConvSynapse from .deconvSynapse import DeconvSynapse from .staticDeconvSynapse import StaticDeconvSynapse -from .hebbianConvSynapse import HebbianConvSynapse -from .hebbianDeconvSynapse import HebbianDeconvSynapse +#from .hebbianConvSynapse import HebbianConvSynapse +# from .hebbianDeconvSynapse import HebbianDeconvSynapse from .traceSTDPConvSynapse import TraceSTDPConvSynapse from .traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse diff --git a/ngclearn/components/synapses/convolution/convSynapse.py b/ngclearn/components/synapses/convolution/convSynapse.py index 9da1412f..ed6b83de 100755 --- a/ngclearn/components/synapses/convolution/convSynapse.py +++ b/ngclearn/components/synapses/convolution/convSynapse.py @@ -99,17 +99,15 @@ def __init__( dist.initialize_params(subkeys[2], bias_init, (1, shape[1])) if bias_init else 0.0 ) - # @transition(output_compartments=["outputs"]) - # @staticmethod @compilable def advance_state(self): #Rscale, padding, stride, weights, biases, inputs): _x = self.inputs.get() ## FIXME: does resist_scale affect update rules? - outputs = conv2d(_x, self.weights.get(), stride_size=self.stride, padding=self.padding) * self.resist_scale + self.biases.get() + outputs = conv2d( + _x, self.weights.get(), stride_size=self.stride, padding=self.padding + ) * self.resist_scale + self.biases.get() self.outputs.set(outputs) - # @transition(output_compartments=["inputs", "outputs"]) - # @staticmethod @compilable def reset(self): #in_shape, out_shape): preVals = jnp.zeros(self.in_shape) diff --git a/ngclearn/components/synapses/convolution/deconvSynapse.py b/ngclearn/components/synapses/convolution/deconvSynapse.py index 13d78c6b..a81563b1 100755 --- a/ngclearn/components/synapses/convolution/deconvSynapse.py +++ b/ngclearn/components/synapses/convolution/deconvSynapse.py @@ -1,15 +1,14 @@ from jax import random, numpy as jnp, jit -from ngclearn.components.jaxComponent import JaxComponent -from ngcsimlib.compilers.process import transition -from ngcsimlib.component import Component from ngcsimlib.compartment import Compartment - +from ngcsimlib.parser import compilable from ngclearn.utils.weight_distribution import initialize_params from ngcsimlib.logger import info -from ngclearn.utils import tensorstats import ngclearn.utils.weight_distribution as dist from ngclearn.components.synapses.convolution.ngcconv import deconv2d +from ngclearn.components.jaxComponent import JaxComponent + + class DeconvSynapse(JaxComponent): ## base-level deconvolutional cable """ A base deconvolutional (transposed convolutional) synaptic cable. @@ -61,7 +60,7 @@ def __init__( self.shape = shape ## shape of synaptic filter tensor x_size, x_size = x_shape self.x_size = x_size - self.Rscale = resist_scale ## post-transformation scale factor + self.resist_scale = resist_scale ## post-transformation scale factor self.padding = padding self.stride = stride @@ -70,7 +69,7 @@ def __init__( self.pad_args = None ######################### set up compartments ########################## - tmp_key, *subkeys = random.split(self.key.value, 4) + tmp_key, *subkeys = random.split(self.key.get(), 4) weights = dist.initialize_params(subkeys[0], filter_init, shape) ## filter tensor self.batch_size = batch_size # 1 @@ -89,36 +88,35 @@ def __init__( (1, shape[1])) if bias_init else 0.0) - @transition(output_compartments=["outputs"]) - @staticmethod - def advance_state(Rscale, padding, stride, weights, biases, inputs): - _x = inputs - out = deconv2d(_x, weights, stride_size=stride, padding=padding) * Rscale + biases - return out - - @transition(output_compartments=["inputs", "outputs"]) - @staticmethod - def reset(in_shape, out_shape): - preVals = jnp.zeros(in_shape) - postVals = jnp.zeros(out_shape) - inputs = preVals - outputs = postVals - return inputs, outputs - - def save(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - if self.bias_init != None: - jnp.savez(file_name, weights=self.weights.value, - biases=self.biases.value) - else: - jnp.savez(file_name, weights=self.weights.value) - - def load(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - self.weights.set(data['weights']) - if "biases" in data.keys(): - self.biases.set(data['biases']) + @compilable + def advance_state(self): + _x = self.inputs.get() + out = deconv2d( + _x, self.weights.get(), stride_size=self.stride, padding=self.padding + ) * self.resist_scale + self.biases.get() + self.outputs.set(out) + + @compilable + def reset(self): #in_shape, out_shape): + preVals = jnp.zeros(self.in_shape) + postVals = jnp.zeros(self.out_shape) + self.inputs.set(preVals) + self.outputs.set(postVals) + + # def save(self, directory, **kwargs): + # file_name = directory + "/" + self.name + ".npz" + # if self.bias_init != None: + # jnp.savez(file_name, weights=self.weights.get(), + # biases=self.biases.get()) + # else: + # jnp.savez(file_name, weights=self.weights.get()) + # + # def load(self, directory, **kwargs): + # file_name = directory + "/" + self.name + ".npz" + # data = jnp.load(file_name) + # self.weights.set(data['weights']) + # if "biases" in data.keys(): + # self.biases.set(data['biases']) @classmethod def help(cls): ## component help function @@ -151,17 +149,3 @@ def help(cls): ## component help function "dynamics": "outputs = [K @.T inputs] * R + b", "hyperparameters": hyperparams} return info - - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines diff --git a/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py b/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py index 0ee67009..3f62b119 100755 --- a/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py +++ b/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py @@ -146,17 +146,9 @@ def _compute_update(self): #pretrace_target, Aplus, Aminus, stride, pad_args, de dWeights = (dW_ltp + dW_ltd) return dWeights - # @transition(output_compartments=["weights", "dWeights"]) - # @staticmethod @compilable def evolve(self): - # pretrace_target, Aplus, Aminus, w_decay, w_bound, stride, pad_args, delta_shape, preSpike, preTrace, - # postSpike, postTrace, weights, eta - dWeights = self._compute_update() - # dWeights = TraceSTDPConvSynapse._compute_update( - # pretrace_target, Aplus, Aminus, stride, pad_args, delta_shape, preSpike, preTrace, postSpike, postTrace - # ) if self.w_decay > 0.: ## apply synaptic decay weights = self.weights.get() + dWeights * self.eta - self.weights.get() * self.w_decay ## conduct decayed STDP-ascent else: @@ -169,10 +161,8 @@ def evolve(self): self.weights.set(weights) self.dWeights.set(dWeights) - # @transition(output_compartments=["dInputs"]) - # @staticmethod @compilable - def backtransmit(self): # x_size, shape, stride, padding, x_delta_shape, antiPad, postSpike, weights + def backtransmit(self): ## action-backpropagating routine ## calc dInputs - adjustment w.r.t. input signal k_size, k_size, n_in_chan, n_out_chan = self.shape diff --git a/ngclearn/components/synapses/convolution/traceSTDPDeconvSynapse.py b/ngclearn/components/synapses/convolution/traceSTDPDeconvSynapse.py index 0e5d76b4..f0c6cedf 100755 --- a/ngclearn/components/synapses/convolution/traceSTDPDeconvSynapse.py +++ b/ngclearn/components/synapses/convolution/traceSTDPDeconvSynapse.py @@ -1,17 +1,14 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compilers.process import transition -from ngcsimlib.component import Component from ngcsimlib.compartment import Compartment - -from .deconvSynapse import DeconvSynapse +from ngcsimlib.parser import compilable from ngclearn.utils.weight_distribution import initialize_params -from ngcsimlib.logger import info -from ngclearn.utils import tensorstats import ngclearn.utils.weight_distribution as dist + +from ngclearn.components.synapses.convolution.deconvSynapse import DeconvSynapse + from ngclearn.components.synapses.convolution.ngcconv import (deconv2d, _calc_dX_deconv, _calc_dK_deconv, calc_dX_deconv, calc_dK_deconv) -from ngclearn.utils.optim import get_opt_init_fn, get_opt_step_fn class TraceSTDPDeconvSynapse(DeconvSynapse): ## trace-based STDP deconvolutional cable """ @@ -92,7 +89,7 @@ def __init__( ######################### set up compartments ########################## ## Compartment setup and shape computation - self.dWeights = Compartment(self.weights.value * 0) + self.dWeights = Compartment(self.weights.get() * 0) self.dInputs = Compartment(jnp.zeros(self.in_shape)) self.preSpike = Compartment(jnp.zeros(self.in_shape)) self.preTrace = Compartment(jnp.zeros(self.in_shape)) @@ -108,76 +105,73 @@ def __init__( def _init(self, batch_size, x_size, shape, stride, padding, pad_args, weights): k_size, k_size, n_in_chan, n_out_chan = shape _x = jnp.zeros((batch_size, x_size, x_size, n_in_chan)) - _d = deconv2d(_x, self.weights.value, stride_size=self.stride, + _d = deconv2d(_x, self.weights.get(), stride_size=self.stride, padding=self.padding) * 0 _dK = _calc_dK_deconv(_x, _d, stride_size=self.stride, out_size=k_size) ## get filter update correction - dx = _dK.shape[0] - self.weights.value.shape[0] - dy = _dK.shape[1] - self.weights.value.shape[1] + dx = _dK.shape[0] - self.weights.get().shape[0] + dy = _dK.shape[1] - self.weights.get().shape[1] self.delta_shape = (abs(dx), abs(dy)) ## get input update correction - _dx = _calc_dX_deconv(self.weights.value, _d, stride_size=self.stride, + _dx = _calc_dX_deconv(self.weights.get(), _d, stride_size=self.stride, padding=self.padding) dx = (_dx.shape[1] - _x.shape[1]) # abs() dy = (_dx.shape[2] - _x.shape[2]) self.x_delta_shape = (dx, dy) - @staticmethod - def _compute_update( - pretrace_target, Aplus, Aminus, shape, stride, padding, delta_shape, preSpike, preTrace, postSpike, postTrace - ): - k_size, k_size, n_in_chan, n_out_chan = shape + def _compute_update(self): + k_size, k_size, n_in_chan, n_out_chan = self.shape ## calc dFilters - dW_ltp = calc_dK_deconv(preTrace - pretrace_target, postSpike * Aplus, - delta_shape=delta_shape, stride_size=stride, - out_size=k_size, padding=padding) - dW_ltd = -calc_dK_deconv(preSpike, postTrace * Aminus, - delta_shape=delta_shape, stride_size=stride, - out_size=k_size, padding=padding) + dW_ltp = calc_dK_deconv( + self.preTrace.get() - self.pretrace_target, self.postSpike.get() * self.Aplus, + delta_shape=self.delta_shape, stride_size=self.stride, out_size=k_size, padding=self.padding + ) + dW_ltd = -calc_dK_deconv( + self.preSpike.get(), self.postTrace.get() * self.Aminus, delta_shape=self.delta_shape, + stride_size=self.stride, out_size=k_size, padding=self.padding + ) dWeights = (dW_ltp + dW_ltd) return dWeights - @transition(output_compartments=["weights", "dWeights"]) - @staticmethod - def evolve( - pretrace_target, Aplus, Aminus, w_decay, w_bound, shape, stride, padding, delta_shape, preSpike, preTrace, - postSpike, postTrace, weights, eta - ): - dWeights = TraceSTDPDeconvSynapse._compute_update( - pretrace_target, Aplus, Aminus, shape, stride, padding, delta_shape, - preSpike, preTrace, postSpike, postTrace - ) - if w_decay > 0.: ## apply synaptic decay - weights = weights + dWeights * eta - weights * w_decay ## conduct decayed STDP-ascent + @compilable + def evolve(self): + dWeights = self._compute_update() + # dWeights = TraceSTDPDeconvSynapse._compute_update( + # pretrace_target, Aplus, Aminus, shape, stride, padding, delta_shape, + # preSpike, preTrace, postSpike, postTrace + # ) + if self.w_decay > 0.: ## apply synaptic decay and conduct decayed STDP-ascent + weights = self.weights.get() + dWeights * self.eta - self.weights.get() * self.w_decay else: - weights = weights + dWeights * eta ## conduct STDP-ascent + weights = self.weights.get() + dWeights * self.eta ## conduct STDP-ascent ## Apply any enforced filter constraints - if w_bound > 0.: ## enforce non-negativity + if self.w_bound > 0.: ## enforce non-negativity eps = 0.01 # 0.001 - weights = jnp.clip(weights, eps, w_bound - eps) - return weights, dWeights + weights = jnp.clip(weights, eps, self.w_bound - eps) - @transition(output_compartments=["dInputs"]) - @staticmethod - def backtransmit(stride, padding, x_delta_shape, preSpike, postSpike, weights): ## action-backpropagating routine + self.weights.set(weights) + self.dWeights.set(dWeights) + + @compilable + def backtransmit(self): ## action-backpropagating routine ## calc dInputs - dInputs = calc_dX_deconv(weights, postSpike, delta_shape=x_delta_shape, - stride_size=stride, padding=padding) - return dInputs - - @transition(output_compartments=["inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace"]) - @staticmethod - def reset(in_shape, out_shape): - preVals = jnp.zeros(in_shape) - postVals = jnp.zeros(out_shape) - inputs = preVals - outputs = postVals - preSpike = preVals - postSpike = postVals - preTrace = preVals - postTrace = postVals - return inputs, outputs, preSpike, postSpike, preTrace, postTrace + dInputs = calc_dX_deconv( + self.weights.get(), self.postSpike.get(), delta_shape=self.x_delta_shape, stride_size=self.stride, + padding=self.padding + ) + self.dInputs.set(dInputs) + + @compilable + def reset(self): # in_shape, out_shape): + preVals = jnp.zeros(self.in_shape.get()) + postVals = jnp.zeros(self.out_shape.get()) + self.inputs.set(preVals) + self.outputs.set(postVals) + self.preSpike.set(preVals) + self.postSpike.set(postVals) + self.preTrace.set(preVals) + self.postTrace.set(postVals) @classmethod def help(cls): ## component help function diff --git a/tests/components/synapses/convolution/test_traceSTDPConvSynapse.py b/tests/components/synapses/convolution/test_traceSTDPConvSynapse.py index 8b179704..ea02daf7 100644 --- a/tests/components/synapses/convolution/test_traceSTDPConvSynapse.py +++ b/tests/components/synapses/convolution/test_traceSTDPConvSynapse.py @@ -54,14 +54,14 @@ def test_TraceSTDPConvSynapse1(): [[1.], [0.]]]] ) - reset_process.run() # + reset_process.run() # ctx.reset() a.inputs.set(x) advance_process.run(t=1., dt=dt) # ctx.run(t=1., dt=dt) y = (a.outputs.get() > 0.) * 1. ## fake out post-syn spikes assert_array_equal(y, y_truth) - print(y) - print("y.Tr:\n", y_truth) - print("======") + # print(y) + # print("y.Tr:\n", y_truth) + # print("======") # print("NGC-Learn.shape = ", node.outputs.get().shape) a.preSpike.set(x) diff --git a/tests/components/synapses/convolution/test_traceSTDPDeconvSynapse.py b/tests/components/synapses/convolution/test_traceSTDPDeconvSynapse.py index 76be1c2a..a414ac08 100644 --- a/tests/components/synapses/convolution/test_traceSTDPDeconvSynapse.py +++ b/tests/components/synapses/convolution/test_traceSTDPDeconvSynapse.py @@ -2,16 +2,12 @@ from ngcsimlib.context import Context import numpy as np np.random.seed(42) -from ngclearn.components import TraceSTDPDeconvSynapse + +from ngclearn import Context, MethodProcess import ngclearn.utils.weight_distribution as dist -from ngcsimlib.compilers import compile_command, wrap_command +from ngclearn.components.synapses.convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context - def test_TraceSTDPDeconvSynapse1(): name = "stdp_deconv_ctx" ## create seeding keys @@ -37,36 +33,17 @@ def test_TraceSTDPDeconvSynapse1(): stride=stride, padding=padding_style, batch_size=batch_size, key=subkeys[0] ) - #""" - evolve_process = (Process("evolve_proc") - >> a.evolve) - #ctx.wrap_and_add_command(evolve_process.pure, name="run") - ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt") + evolve_process = (MethodProcess("evolve_process") + >> a.evolve) - backtransmit_process = (Process("btransmit_proc") + backtransmit_process = (MethodProcess("backtransmit_process") >> a.backtransmit) - ctx.wrap_and_add_command(jit(backtransmit_process.pure), name="backtransmit") - advance_process = (Process("advance_proc") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - # ctx.wrap_and_add_command(advance_process.pure, name="run") - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - #""" - - """ - reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve") - ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt") - backpass_cmd, backpass_args = ctx.compile_by_key(a, compile_key="backtransmit") - ctx.add_command(wrap_command(jit(ctx.backtransmit)), name="backtransmit") - """ ## fake out a mix of pre-synaptic spikes/no-spikes x = np.ones(x_shape) @@ -78,25 +55,24 @@ def test_TraceSTDPDeconvSynapse1(): [[1.], [1.]]]] ) - ctx.reset() + reset_process.run() #ctx.reset() a.inputs.set(x) - ctx.run(t=1., dt=dt) - y = (a.outputs.value > 0.) * 1. ## fake out post-syn spikes + advance_process.run(t=1., dt=dt) # ctx.run(t=1., dt=dt) + y = (a.outputs.get() > 0.) * 1. ## fake out post-syn spikes assert_array_equal(y, y_truth) - #print(y) - #print("======") + # print(y) + # print("y.Tr:\n", y_truth) + # print("======") - # print("NGC-Learn.shape = ", node.outputs.value.shape) + # print("NGC-Learn.shape = ", node.outputs.get().shape) a.preSpike.set(x) a.postSpike.set(y) a.preTrace.set(x * 0.4) ## fake out pre-syn trace values a.postTrace.set(y * 1.3) ## fake out post-syn trace values - ctx.adapt(t=1., dt=dt) - dK = a.dWeights.value - #print(dK) - ctx.backtransmit(t=1., dt=dt) - dx = a.dInputs.value - #print(dx) + evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1., dt=dt) + dK = a.dWeights.get() + backtransmit_process.run(t=1., dt=dt) # ctx.backtransmit(t=1., dt=dt) + dx = a.dInputs.get() dK_truth = jnp.array( [[[[0.]], [[-0.9]]], @@ -109,6 +85,10 @@ def test_TraceSTDPDeconvSynapse1(): [[2.], [1.]]]] ) + # print(dK) + # print("dK.Tr:\n", dK_truth) + # print(dx) + # print("dx.Tr:\n", dx_truth) assert_array_equal(dK, dK_truth) assert_array_equal(dx, dx_truth) From 10ef0e05439f683bb1362a749541f65d50971dc3 Mon Sep 17 00:00:00 2001 From: Viet Dung Nguyen <60036798+rxng8@users.noreply.github.com> Date: Mon, 10 Nov 2025 14:33:18 -0500 Subject: [PATCH 030/121] Refactoring neuronal and synaptic components (#123) - merge from fork to v3 * refactoring graded cells * update refactored models * update sLIF cell --------- Co-authored-by: Alex Ororbia --- .../neurons/graded/bernoulliErrorCell.py | 62 +++++++++---- .../neurons/graded/gaussianErrorCell.py | 47 +++++----- .../neurons/graded/laplacianErrorCell.py | 48 ++++++----- .../components/neurons/graded/rateCell.py | 57 ++++++------ .../neurons/graded/rewardErrorCell.py | 69 +++++++++------ .../components/neurons/spiking/sLIFCell.py | 86 ++++++++++++------- ngclearn/components/other/expKernel.py | 30 +++---- .../synapses/hebbian/hebbianSynapse.py | 52 +++++------ .../synapses/patched/hebbianPatchedSynapse.py | 77 ++++++++--------- .../synapses/patched/patchedSynapse.py | 53 +++++------- ngclearn/utils/optim/adam.py | 6 +- 11 files changed, 316 insertions(+), 271 deletions(-) diff --git a/ngclearn/components/neurons/graded/bernoulliErrorCell.py b/ngclearn/components/neurons/graded/bernoulliErrorCell.py index 6bf0ebe6..376fa41f 100755 --- a/ngclearn/components/neurons/graded/bernoulliErrorCell.py +++ b/ngclearn/components/neurons/graded/bernoulliErrorCell.py @@ -1,9 +1,13 @@ -from ngclearn import resolver, Component, Compartment +# %% + from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, jit from ngclearn.utils import tensorstats from ngclearn.utils.model_utils import sigmoid, d_sigmoid -from ngcsimlib.compilers.process import transition + +from ngcsimlib.logger import info +from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable class BernoulliErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell """ @@ -59,14 +63,20 @@ def __init__(self, name, n_units, batch_size=1, input_logits=False, shape=None, self.modulator = Compartment(restVals + 1.0) # to be set/consumed self.mask = Compartment(restVals + 1.0) - @transition(output_compartments=["dp", "dtarget", "L", "mask"]) - @staticmethod - def advance_state(dt, p, target, modulator, mask, input_logits): ## compute Bernoulli error cell output + # @transition(output_compartments=["dp", "dtarget", "L", "mask"]) + @compilable + def advance_state(self, dt): ## compute Bernoulli error cell output + # Get the variables + p = self.p.get() + target = self.target.get() + modulator = self.modulator.get() + mask = self.mask.get() + # Moves Bernoulli error cell dynamics one step forward. Specifically, this routine emulates the error unit # behavior of the local cost functional eps = 0.0001 _p = p - if input_logits: ## convert from "logits" to probs via sigmoidal link function + if self.input_logits: ## convert from "logits" to probs via sigmoidal link function _p = sigmoid(p) _p = jnp.clip(_p, eps, 1. - eps) ## post-process to prevent div by 0 x = target @@ -78,7 +88,7 @@ def advance_state(dt, p, target, modulator, mask, input_logits): ## compute Bern log_p = jnp.log(_p) ## ln(p) log_one_min_p = jnp.log(one_min_p) ## ln(1 - p) L = jnp.sum(log_p * x + log_one_min_p * one_min_x) ## Bern LL - if input_logits: + if self.input_logits: dL_dp = x - _p ## d(Bern LL)/dp where _p = sigmoid(p) else: dL_dp = x/(_p) - one_min_x/one_min_p ## d(Bern LL)/dp @@ -89,14 +99,21 @@ def advance_state(dt, p, target, modulator, mask, input_logits): ## compute Bern dp = dp * modulator * mask ## NOTE: how does mask apply to a multivariate Bernoulli? dtarget = dL_dx * modulator * mask mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t - return dp, dtarget, jnp.squeeze(L), mask - - @transition(output_compartments=["dp", "dtarget", "target", "p", "modulator", "L", "mask"]) - @staticmethod - def reset(batch_size, shape): ## reset core components/statistics - _shape = (batch_size, shape[0]) - if len(shape) > 1: - _shape = (batch_size, shape[0], shape[1], shape[2]) + + # Set state + # dp, dtarget, jnp.squeeze(L), mask + self.dp.set(dp) + self.dtarget.set(dtarget) + self.L.set(jnp.squeeze(L)) + self.mask.set(mask) + + + # @transition(output_compartments=["dp", "dtarget", "target", "p", "modulator", "L", "mask"]) + @compilable + def reset(self, batch_size): ## reset core components/statistics + _shape = (batch_size, self.shape[0]) + if len(self.shape) > 1: + _shape = (batch_size, self.shape[0], self.shape[1], self.shape[2]) restVals = jnp.zeros(_shape) ## "rest"/reset values dp = restVals dtarget = restVals @@ -105,7 +122,16 @@ def reset(batch_size, shape): ## reset core components/statistics modulator = restVals + 1. ## reset modulator signal L = 0. #jnp.zeros((1, 1)) ## rest loss mask = jnp.ones(_shape) ## reset mask - return dp, dtarget, target, p, modulator, L, mask + + # Set compartment + self.dp.set(dp) + self.dtarget.set(dtarget) + self.target.set(target) + self.p.set(p) + self.modulator.set(modulator) + self.L.set(L) + self.mask.set(mask) + @classmethod def help(cls): ## component help function @@ -136,11 +162,11 @@ def help(cls): ## component help function return info def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] maxlen = max(len(c) for c in comps) + 5 lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" for c in comps: - stats = tensorstats(getattr(self, c).value) + stats = tensorstats(getattr(self, c).get()) if stats is not None: line = [f"{k}: {v}" for k, v in stats.items()] line = ", ".join(line) diff --git a/ngclearn/components/neurons/graded/gaussianErrorCell.py b/ngclearn/components/neurons/graded/gaussianErrorCell.py index 29b5f267..63e10a65 100755 --- a/ngclearn/components/neurons/graded/gaussianErrorCell.py +++ b/ngclearn/components/neurons/graded/gaussianErrorCell.py @@ -1,8 +1,12 @@ -from ngclearn import resolver, Component, Compartment +# %% + from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, jit from ngclearn.utils import tensorstats -from ngcsimlib.compilers.process import transition + +from ngcsimlib.logger import info +from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell """ @@ -71,9 +75,15 @@ def eval_log_density(target, mu, Sigma): log_density = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma) return log_density - @transition(output_compartments=["dmu", "dtarget", "dSigma", "L", "mask"]) - @staticmethod - def advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian error cell output + @compilable + def advance_state(self, dt): ## compute Gaussian error cell output + # Get the variables + mu = self.mu.get() + target = self.target.get() + Sigma = self.Sigma.get() + modulator = self.modulator.get() + mask = self.mask.get() + # Moves Gaussian cell dynamics one step forward. Specifically, this routine emulates the error unit # behavior of the local cost functional: # FIXME: Currently, below does: L(targ, mu) = -(1/(2*sigma)) * ||targ - mu||^2_2 @@ -90,24 +100,13 @@ def advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian e dmu = dmu * modulator * mask ## not sure how mask will apply to a full covariance... dtarget = dtarget * modulator * mask mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t - return dmu, dtarget, dSigma, jnp.squeeze(L), mask - @transition(output_compartments=["dmu", "dtarget", "dSigma", "target", "mu", "modulator", "L", "mask"]) - @staticmethod - def reset(batch_size, shape, sigma_shape): ## reset core components/statistics - _shape = (batch_size, shape[0]) - if len(shape) > 1: - _shape = (batch_size, shape[0], shape[1], shape[2]) - restVals = jnp.zeros(_shape) - dmu = restVals - dtarget = restVals - dSigma = jnp.zeros(sigma_shape) - target = restVals - mu = restVals - modulator = mu + 1. - L = 0. #jnp.zeros((1, 1)) - mask = jnp.ones(_shape) - return dmu, dtarget, dSigma, target, mu, modulator, L, mask + # Update compartments + self.dmu.set(dmu) + self.dtarget.set(dtarget) + self.dSigma.set(dSigma) + self.L.set(jnp.squeeze(L)) + self.mask.set(mask) @classmethod def help(cls): ## component help function @@ -140,11 +139,11 @@ def help(cls): ## component help function return info def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] maxlen = max(len(c) for c in comps) + 5 lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" for c in comps: - stats = tensorstats(getattr(self, c).value) + stats = tensorstats(getattr(self, c).get()) if stats is not None: line = [f"{k}: {v}" for k, v in stats.items()] line = ", ".join(line) diff --git a/ngclearn/components/neurons/graded/laplacianErrorCell.py b/ngclearn/components/neurons/graded/laplacianErrorCell.py index 6d825fe0..e3717d1c 100755 --- a/ngclearn/components/neurons/graded/laplacianErrorCell.py +++ b/ngclearn/components/neurons/graded/laplacianErrorCell.py @@ -1,8 +1,12 @@ -from ngclearn import resolver, Component, Compartment +# %% + from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, jit from ngclearn.utils import tensorstats -from ngcsimlib.compilers.process import transition + +from ngcsimlib.logger import info +from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable class LaplacianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell """ @@ -44,7 +48,7 @@ def __init__(self, name, n_units, batch_size=1, scale=1., shape=None, **kwargs): else: _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor scale_shape = (1, 1) - if not isinstance(scale, float) and not isinstance(sigma, int): + if not isinstance(scale, float) and not isinstance(scale, int): scale_shape = jnp.array(scale).shape self.scale_shape = scale_shape ## Layer Size setup @@ -67,9 +71,15 @@ def __init__(self, name, n_units, batch_size=1, scale=1., shape=None, **kwargs): self.modulator = Compartment(restVals + 1.0) ## to be set/consumed self.mask = Compartment(restVals + 1.0) - @transition(output_compartments=["dshift", "dtarget", "dScale", "L", "mask"]) - @staticmethod - def advance_state(dt, shift, target, Scale, modulator, mask): ## compute Laplacian error cell output + @compilable + def advance_state(self, dt): ## compute Laplacian error cell output + # Get the variables + shift = self.shift.get() + target = self.target.get() + Scale = self.Scale.get() + modulator = self.modulator.get() + mask = self.mask.get() + # Moves Laplacian cell dynamics one step forward. Specifically, this routine emulates the error unit # behavior of the local cost functional: # FIXME: Currently, below does: L(targ, shift) = -||targ - shift||_1/scale @@ -85,21 +95,13 @@ def advance_state(dt, shift, target, Scale, modulator, mask): ## compute Laplaci dshift = dshift * modulator * mask dtarget = dtarget * modulator * mask mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t - return dshift, dtarget, dScale, jnp.squeeze(L), mask - - @transition(output_compartments=["dshift", "dtarget", "dScale", "target", "shift", "modulator", "L", "mask"]) - @staticmethod - def reset(batch_size, n_units, scale_shape): - restVals = jnp.zeros((batch_size, n_units)) - dshift = restVals - dtarget = restVals - dScale = jnp.zeros(scale_shape) - target = restVals - shift = restVals - modulator = shift + 1. - L = 0. - mask = jnp.ones((batch_size, n_units)) - return dshift, dtarget, dScale, target, shift, modulator, L, mask + + # Update compartments + self.dshift.set(dshift) + self.dtarget.set(dtarget) + self.dScale.set(dScale) + self.L.set(jnp.squeeze(L)) + self.mask.set(mask) @classmethod def help(cls): ## component help function @@ -131,11 +133,11 @@ def help(cls): ## component help function return info def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] maxlen = max(len(c) for c in comps) + 5 lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" for c in comps: - stats = tensorstats(getattr(self, c).value) + stats = tensorstats(getattr(self, c).get()) if stats is not None: line = [f"{k}: {v}" for k, v in stats.items()] line = ", ".join(line) diff --git a/ngclearn/components/neurons/graded/rateCell.py b/ngclearn/components/neurons/graded/rateCell.py index 63a9fe3b..a76346b4 100755 --- a/ngclearn/components/neurons/graded/rateCell.py +++ b/ngclearn/components/neurons/graded/rateCell.py @@ -11,6 +11,9 @@ from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ step_euler, step_rk2, step_rk4 +from ngcsimlib.logger import info +from ngcsimlib.parser import compilable + def _dfz_internal_laplace(z, j, j_td, tau_m, leak_gamma): ## raw dynamics z_leak = jnp.sign(z) ## d/dx of Laplace is signum dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) @@ -198,7 +201,6 @@ def __init__( self.n_units = n_units self.batch_size = batch_size - omega_0 = None if act_fx == "sine": omega_0 = kwargs["omega_0"] @@ -211,46 +213,43 @@ def __init__( self.j_td = Compartment(restVals, display_name="Modulatory Stimulus Current", units="mA") # top-down electrical current - pressure self.z = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity - @transition(output_compartments=["j", "j_td", "z", "zF"]) - @staticmethod - def advance_state( - dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType, resist_scale, thresholdType, thr_lmbda, is_stateful, - output_scale, j, j_td, z): + @compilable + def advance_state(self, dt): + # Get the compartment values + j = self.j.get() + j_td = self.j_td.get() + z = self.z.get() + #if tau_m > 0.: - if is_stateful: + if self.is_stateful: ### run a step of integration over neuronal dynamics ## Notes: ## self.pressure <-- "top-down" expectation / contextual pressure ## self.current <-- "bottom-up" data-dependent signal - dfx_val = dfx(z) + dfx_val = self.dfx(z) j = _modulate(j, dfx_val) - j = j * resist_scale + j = j * self.resist_scale tmp_z = _run_cell(dt, j, j_td, z, - tau_m, leak_gamma=priorLeakRate, - integType=intgFlag, priorType=priorType) + self.tau_m, leak_gamma=self.priorLeakRate, + integType=self.intgFlag, priorType=self.priorType) ## apply optional thresholding sub-dynamics - if thresholdType == "soft_threshold": - tmp_z = threshold_soft(tmp_z, thr_lmbda) - elif thresholdType == "cauchy_threshold": - tmp_z = threshold_cauchy(tmp_z, thr_lmbda) + if self.thresholdType == "soft_threshold": + tmp_z = threshold_soft(tmp_z, self.thr_lmbda) + elif self.thresholdType == "cauchy_threshold": + tmp_z = threshold_cauchy(tmp_z, self.thr_lmbda) z = tmp_z ## pre-activation function value(s) - zF = fx(z) * output_scale ## post-activation function value(s) + zF = self.fx(z) * self.output_scale ## post-activation function value(s) else: ## run in "stateless" mode (when no membrane time constant provided) j_total = j + j_td z = _run_cell_stateless(j_total) - zF = fx(z) * output_scale - return j, j_td, z, zF - - @transition(output_compartments=["j", "j_td", "z", "zF"]) - @staticmethod - def reset(batch_size, shape): #n_units - _shape = (batch_size, shape[0]) - if len(shape) > 1: - _shape = (batch_size, shape[0], shape[1], shape[2]) - restVals = jnp.zeros(_shape) - return tuple([restVals for _ in range(4)]) + zF = self.fx(z) * self.output_scale + # Update compartments + self.j.set(j) + self.j_td.set(j_td) + self.z.set(z) + self.zF.set(zF) def save(self, directory, **kwargs): ## do a protected save of constants, depending on whether they are floats or arrays @@ -308,11 +307,11 @@ def help(cls): ## component help function return info def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] maxlen = max(len(c) for c in comps) + 5 lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" for c in comps: - stats = tensorstats(getattr(self, c).value) + stats = tensorstats(getattr(self, c).get()) if stats is not None: line = [f"{k}: {v}" for k, v in stats.items()] line = ", ".join(line) diff --git a/ngclearn/components/neurons/graded/rewardErrorCell.py b/ngclearn/components/neurons/graded/rewardErrorCell.py index fe9670c3..a9d43fac 100755 --- a/ngclearn/components/neurons/graded/rewardErrorCell.py +++ b/ngclearn/components/neurons/graded/rewardErrorCell.py @@ -1,9 +1,13 @@ -from ngclearn import resolver, Component, Compartment +# %% + from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, jit -from ngcsimlib.compilers.process import transition from ngclearn.utils import tensorstats +from ngcsimlib.logger import info +from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable + class RewardErrorCell(JaxComponent): ## Reward prediction error cell """ A reward prediction error (RPE) cell. @@ -51,38 +55,41 @@ def __init__(self, name, n_units, alpha, ema_window_len=10, self.accum_reward = Compartment(restVals) ## accumulated reward signal(s) self.n_ep_steps = Compartment(jnp.zeros((self.batch_size, 1))) ## number of episode steps taken - @transition(output_compartments=["mu", "rpe", "n_ep_steps", "accum_reward"]) - @staticmethod - def advance_state(dt, use_online_predictor, alpha, mu, rpe, reward, - n_ep_steps, accum_reward): + @compilable + def advance_state(self, dt): + # Get the variables + mu = self.mu.get() + reward = self.reward.get() + n_ep_steps = self.n_ep_steps.get() + accum_reward = self.accum_reward.get() + ## compute/update RPE and predictor values accum_reward = accum_reward + reward rpe = reward - mu - if use_online_predictor: - mu = mu * (1. - alpha) + reward * alpha + if self.use_online_predictor: + mu = mu * (1. - self.alpha) + reward * self.alpha n_ep_steps = n_ep_steps + 1 - return mu, rpe, n_ep_steps, accum_reward - @transition(output_compartments=["mu"]) - @staticmethod - def evolve(dt, use_online_predictor, ema_window_len, n_ep_steps, mu, - accum_reward): - if use_online_predictor: + # Update compartments + self.mu.set(mu) + self.rpe.set(rpe) + self.n_ep_steps.set(n_ep_steps) + self.accum_reward.set(accum_reward) + + @compilable + def evolve(self, dt): + # Get the variables + mu = self.mu.get() + n_ep_steps = self.n_ep_steps.get() + accum_reward = self.accum_reward.get() + + if self.use_online_predictor: ## total episodic reward signal r = accum_reward/n_ep_steps - mu = (1. - 1./ema_window_len) * mu + (1./ema_window_len) * r - return mu - - @transition(output_compartments=["mu", "rpe", "accum_reward", "n_ep_steps"]) - @staticmethod - def reset(batch_size, n_units): - restVals = jnp.zeros((batch_size, n_units)) - mu = restVals - rpe = restVals - accum_reward = restVals - n_ep_steps = jnp.zeros((batch_size, 1)) - return mu, rpe, accum_reward, n_ep_steps + mu = (1. - 1./self.ema_window_len) * mu + (1./self.ema_window_len) * r + # Update compartment + self.mu.set(mu) @classmethod def help(cls): ## component help function @@ -116,11 +123,11 @@ def help(cls): ## component help function return info def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] maxlen = max(len(c) for c in comps) + 5 lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" for c in comps: - stats = tensorstats(getattr(self, c).value) + stats = tensorstats(getattr(self, c).get()) if stats is not None: line = [f"{k}: {v}" for k, v in stats.items()] line = ", ".join(line) @@ -128,3 +135,9 @@ def __repr__(self): line = "None" lines += f" {f'({c})'.ljust(maxlen)}{line}\n" return lines + +if __name__ == '__main__': + from ngcsimlib.context import Context + with Context("Bar") as bar: + X = RewardErrorCell("X", 9, 0.03) + print(X) diff --git a/ngclearn/components/neurons/spiking/sLIFCell.py b/ngclearn/components/neurons/spiking/sLIFCell.py index 76736aec..b77a9c4f 100644 --- a/ngclearn/components/neurons/spiking/sLIFCell.py +++ b/ngclearn/components/neurons/spiking/sLIFCell.py @@ -1,15 +1,16 @@ +# %% + from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, random, jit from functools import partial from ngclearn.utils import tensorstats -from ngcsimlib.deprecators import deprecate_args from ngcsimlib.logger import info, warn from ngclearn.utils.diffeq.ode_utils import step_euler from ngclearn.utils.surrogate_fx import secant_lif_estimator -from ngcsimlib.compilers.process import transition -#from ngcsimlib.component import Component +from ngcsimlib.logger import info from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable @jit def _dfv_internal(j, v, rfr, tau_m, refract_T): ## raw voltage dynamics @@ -132,7 +133,7 @@ def __init__( ## create simple recurrent inhibitory pressure self.inh_R = resist_inh ## lateral inhibitory magnitude - key, subkey = random.split(self.key.value) + key, subkey = random.split(self.key.get()) self.inh_weights = random.uniform(subkey, (n_units, n_units), minval=0.025, maxval=1.) self.inh_weights = self.inh_weights * (1. - jnp.eye(n_units)) @@ -162,12 +163,8 @@ def __init__( self.rfr = Compartment(restVals + self.refract_T) ## refractory variable(s) self.surrogate = Compartment(restVals + 1.) ## surrogate signal - @transition(output_compartments=["j", "s", "tols", "v", "thr", "rfr", "surrogate"]) - @staticmethod - def advance_state( - t, dt, inh_weights, R_m, inh_R, d_spike_fx, tau_m, spike_fx, refract_T, thrGain, - thrLeak, rho_b, sticky_spikes, v_min, j, s, v, thr, rfr, tols - ): + @compilable + def advance_state(self, t, dt): ##################################################################################### #The following 3 lines of code modify electrical current j via application of a #scalar membrane resistance value and an approximate form of lateral inhibition. @@ -180,20 +177,31 @@ def advance_state( #| R_m: membrane resistance (to multiply/scale j by), #| inh_R: inhibitory resistance to scale lateral inhibitory current by; if inh_R = 0, # NO lateral inhibitory pressure will be applied - j = j * R_m - if inh_R > 0.: ## if inh_R > 0, then lateral inhibition is applied - j = j - (jnp.matmul(spikes, inh_weights) * inh_R) + + # First, get the relevant compartment values + j = self.j.get() + # s = self.s.get() # NOTE: This is unused + tols = self.tols.get() + v = self.v.get() + thr = self.thr.get() + rfr = self.rfr.get() + surrogate = self.surrogate.get() + ## modify electrical current j via membrane resistance and lateral inhibition + + j = j * self.R_m + if self.inh_R > 0.: ## if inh_R > 0, then lateral inhibition is applied + j = j - (jnp.matmul(spikes, self.inh_weights) * self.inh_R) ##################################################################################### - surrogate = d_spike_fx(j, c1=0.82, c2=0.08) ## calc surrogate deriv of spikes + surrogate = self.d_spike_fx(j, c1=0.82, c2=0.08) ## calc surrogate deriv of spikes ## transition to: voltage(t+dt), spikes, threshold(t+dt), refractory_variables(t+dt) - v_params = (j, rfr, tau_m, refract_T) + v_params = (j, rfr, self.tau_m, self.refract_T) _, _v = step_euler(0., v, _dfv, dt, v_params) - spikes = spike_fx(_v, thr) + spikes = self.spike_fx(_v, thr) #_v = _hyperpolarize(_v, spikes) _v = (1. - spikes) * _v ## hyper-polarize cells - new_thr = _update_threshold(dt, thr, spikes, thrGain, thrLeak, rho_b) - _rfr, spikes = _update_refract_and_spikes(dt, rfr, spikes, refract_T, sticky_spikes) + new_thr = _update_threshold(dt, thr, spikes, self.thrGain, self.thrLeak, self.rho_b) + _rfr, spikes = _update_refract_and_spikes(dt, rfr, spikes, self.refract_T, self.sticky_spikes) v = _v s = spikes thr = new_thr @@ -201,34 +209,48 @@ def advance_state( ## update tols tols = (1. - s) * tols + (s * t) - return j, s, tols, v, thr, rfr, surrogate - - @transition(output_compartments=["j", "s", "tols", "v", "thr", "rfr", "surrogate"]) - @staticmethod - def reset(refract_T, thr_persist, threshold0, batch_size, n_units, thr): - restVals = jnp.zeros((batch_size, n_units)) + # return j, s, tols, v, thr, rfr, surrogate + self.j.set(j) + self.s.set(s) + self.tols.set(tols) + self.v.set(v) + self.thr.set(thr) + self.rfr.set(rfr) + self.surrogate.set(surrogate) + + @compilable + def reset(self): + # refract_T, thr_persist, threshold0, batch_size, n_units, thr + restVals = jnp.zeros((self.batch_size, self.n_units)) voltage = restVals - refract = restVals + refract_T + refract = restVals + self.refract_T current = restVals surrogate = restVals + 1. timeOfLastSpike = restVals spikes = restVals - if not thr_persist: ## if thresh non-persistent, reset to base value - thr = threshold0 + 0 - return current, spikes, timeOfLastSpike, voltage, thr, refract, surrogate + if not self.thr_persist: ## if thresh non-persistent, reset to base value + thr = self.threshold0 + 0 + # return current, spikes, timeOfLastSpike, voltage, thr, refract, surrogate + self.j.set(current) + self.s.set(spikes) + self.tols.set(timeOfLastSpike) + self.v.set(voltage) + self.thr.set(thr) + self.rfr.set(refract) + self.surrogate.set(surrogate) def save(self, directory, **kwargs): file_name = directory + "/" + self.name + ".npz" if self.thr_persist == False: jnp.savez(file_name, threshold=self.threshold0) # save threshold0 else: - jnp.savez(file_name, threshold=self.thr.value) # save the actual threshold param/compartment + jnp.savez(file_name, threshold=self.thr.get()) # save the actual threshold param/compartment def load(self, directory, **kwargs): file_name = directory + "/" + self.name + ".npz" data = jnp.load(file_name) self.thr.set(data['threshold']) - self.threshold0 = self.thr.value + 0 + self.threshold0 = self.thr.get() + 0 @classmethod def help(cls): ## component help function @@ -270,11 +292,11 @@ def help(cls): ## component help function return info def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] maxlen = max(len(c) for c in comps) + 5 lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" for c in comps: - stats = tensorstats(getattr(self, c).value) + stats = tensorstats(getattr(self, c).get()) if stats is not None: line = [f"{k}: {v}" for k, v in stats.items()] line = ", ".join(line) diff --git a/ngclearn/components/other/expKernel.py b/ngclearn/components/other/expKernel.py index a074c30e..5f95848b 100644 --- a/ngclearn/components/other/expKernel.py +++ b/ngclearn/components/other/expKernel.py @@ -3,11 +3,10 @@ from functools import partial from ngclearn.utils import tensorstats from ngcsimlib.deprecators import deprecate_args -from ngcsimlib.logger import info, warn -from ngcsimlib.compilers.process import transition -#from ngcsimlib.component import Component +from ngcsimlib.logger import info, warn from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable @partial(jit, static_argnums=[5,6]) def _apply_kernel(tf_curr, s, t, tau_w, win_len, krn_start, krn_end): @@ -67,21 +66,20 @@ def __init__(self, name, n_units, dt, tau_w=500., nu=4., batch_size=1, **kwargs) ## window of spike times self.tf = Compartment(jnp.zeros((self.win_len, self.batch_size, self.n_units))) - @transition(output_compartments=["epsp", "tf"]) - @staticmethod - def advance_state(t, tau_w, win_len, inputs, tf): + @compilable + def advance_state(self, t): + # Get the variables + inputs = self.inputs.get() + tf = self.tf.get() + s = inputs ## update spike time window and corresponding window volume - tf, epsp = _apply_kernel(tf, s, t, tau_w, win_len, krn_start=0, - krn_end=win_len-1) #0:win_len-1) - return epsp, tf - - @transition(output_compartments=["inputs", "epsp", "tf"]) - @staticmethod - def reset(batch_size, n_units, win_len): - restVals = jnp.zeros((batch_size, n_units)) - restTensor = jnp.zeros([win_len, batch_size, n_units], jnp.float32) # tf - return restVals, restVals, restTensor # inputs, epsp, tf + tf, epsp = _apply_kernel(tf, s, t, self.tau_w, self.win_len, krn_start=0, + krn_end=self.win_len-1) #0:win_len-1) + + # Update compartments + self.epsp.set(epsp) + self.tf.set(tf) @classmethod def help(cls): ## component help function diff --git a/ngclearn/components/synapses/hebbian/hebbianSynapse.py b/ngclearn/components/synapses/hebbian/hebbianSynapse.py index faaee5c9..cb221320 100644 --- a/ngclearn/components/synapses/hebbian/hebbianSynapse.py +++ b/ngclearn/components/synapses/hebbian/hebbianSynapse.py @@ -1,8 +1,11 @@ from jax import random, numpy as jnp, jit from functools import partial from ngclearn.utils.optim import get_opt_init_fn, get_opt_step_fn -from ngclearn import resolver, Component, Compartment -from ngcsimlib.compilers.process import transition + +from ngcsimlib.logger import info +from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable + from ngclearn.components.synapses import DenseSynapse from ngclearn.utils import tensorstats from ngcsimlib.deprecators import deprecate_args @@ -218,38 +221,35 @@ def _compute_update(w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda post_wght=post_wght) return dW, db - @transition(output_compartments=["opt_params", "weights", "biases", "dWeights", "dBiases"]) - @staticmethod - def evolve(opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, - post_wght, bias_init, pre, post, weights, biases, opt_params): + @compilable + def evolve(self): + # Get the variables + pre = self.pre.get() + post = self.post.get() + weights = self.weights.get() + biases = self.biases.get() + opt_params = self.opt_params.get() + ## calculate synaptic update values dWeights, dBiases = HebbianSynapse._compute_update( - w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, post_wght, + self.w_bound, self.is_nonnegative, self.sign_value, self.prior_type, self.prior_lmbda, self.pre_wght, self.post_wght, pre, post, weights ) ## conduct a step of optimization - get newly evolved synaptic weight value matrix - if bias_init != None: - opt_params, [weights, biases] = opt(opt_params, [weights, biases], [dWeights, dBiases]) + if self.bias_init != None: + opt_params, [weights, biases] = self.opt(opt_params, [weights, biases], [dWeights, dBiases]) else: # ignore db since no biases configured - opt_params, [weights] = opt(opt_params, [weights], [dWeights]) + opt_params, [weights] = self.opt(opt_params, [weights], [dWeights]) ## ensure synaptic efficacies adhere to constraints - weights = _enforce_constraints(weights, w_bound, is_nonnegative=is_nonnegative) - return opt_params, weights, biases, dWeights, dBiases - - @transition(output_compartments=["inputs", "outputs", "pre", "post", "dWeights", "dBiases"]) - @staticmethod - def reset(batch_size, shape): - preVals = jnp.zeros((batch_size, shape[0])) - postVals = jnp.zeros((batch_size, shape[1])) - return ( - preVals, # inputs - postVals, # outputs - preVals, # pre - postVals, # post - jnp.zeros(shape), # dW - jnp.zeros(shape[1]), # db - ) + weights = _enforce_constraints(weights, self.w_bound, is_nonnegative=self.is_nonnegative) + + # Update compartments + self.opt_params.set(opt_params) + self.weights.set(weights) + self.biases.set(biases) + self.dWeights.set(dWeights) + self.dBiases.set(dBiases) @classmethod def help(cls): ## component help function diff --git a/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py b/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py index 1415f51a..364ad3d8 100644 --- a/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py +++ b/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py @@ -1,13 +1,18 @@ +# %% + import matplotlib.pyplot as plt from jax import random, numpy as jnp, jit from functools import partial from ngclearn.utils.optim import get_opt_init_fn, get_opt_step_fn -from ngclearn import resolver, Component, Compartment + +from ngcsimlib.logger import info +from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable + from ngclearn.components.synapses.patched import PatchedSynapse from ngclearn.utils import tensorstats -from ngcsimlib.compilers.process import transition -@partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9]) +# @partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9]) def _calc_update(pre, post, W, mask, w_bound, is_nonnegative=True, signVal=1., prior_type=None, prior_lmbda=0., pre_wght=1., post_wght=1.): @@ -69,7 +74,7 @@ def _calc_update(pre, post, W, mask, w_bound, is_nonnegative=True, signVal=1., return dW * signVal, db * signVal -@partial(jit, static_argnums=[1,2, 3]) +# @partial(jit, static_argnums=[1,2, 3]) def _enforce_constraints(W, block_mask, w_bound, is_nonnegative=True): """ Enforces constraints that the (synaptic) efficacies/values within matrix @@ -225,10 +230,10 @@ def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), eta=0., weig self.dWeights = Compartment(jnp.zeros(self.shape)) self.dBiases = Compartment(jnp.zeros(self.shape[1])) - #key, subkey = random.split(self.key.value) + #key, subkey = random.split(self.key.get()) self.opt_params = Compartment(get_opt_init_fn(optim_type)( - [self.weights.value, self.biases.value] - if bias_init else [self.weights.value])) + [self.weights.get(), self.biases.get()] + if bias_init else [self.weights.get()])) @staticmethod def _compute_update(block_mask, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, @@ -241,39 +246,35 @@ def _compute_update(block_mask, w_bound, is_nonnegative, sign_value, prior_type, return dW * jnp.where(0 != jnp.abs(weights), 1, 0) , db - @transition(output_compartments=["opt_params", "weights", "biases", "dWeights", "dBiases"]) - @staticmethod - def evolve(block_mask, opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, - post_wght, bias_init, pre, post, weights, biases, opt_params): + @compilable + def evolve(self): + # Get the variables + pre = self.pre.get() + post = self.post.get() + weights = self.weights.get() + biases = self.biases.get() + opt_params = self.opt_params.get() + ## calculate synaptic update values dWeights, dBiases = HebbianPatchedSynapse._compute_update( - block_mask, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, - pre_wght, post_wght, pre, post, weights + self.block_mask, self.w_bound, self.is_nonnegative, self.sign_value, self.prior_type, self.prior_lmbda, + self.pre_wght, self.post_wght, pre, post, weights ) ## conduct a step of optimization - get newly evolved synaptic weight value matrix - if bias_init != None: - opt_params, [weights, biases] = opt(opt_params, [weights, biases], [dWeights, dBiases]) + if self.bias_init != None: + opt_params, [weights, biases] = self.opt(opt_params, [weights, biases], [dWeights, dBiases]) else: # ignore db since no biases configured - opt_params, [weights] = opt(opt_params, [weights], [dWeights]) + opt_params, [weights] = self.opt(opt_params, [weights], [dWeights]) ## ensure synaptic efficacies adhere to constraints - weights = _enforce_constraints(weights, block_mask, w_bound, is_nonnegative=is_nonnegative) - return opt_params, weights, biases, dWeights, dBiases - - @transition(output_compartments=["inputs", "outputs", "pre", "post", "dWeights", "dBiases"]) - @staticmethod - def reset(batch_size, shape): - preVals = jnp.zeros((batch_size, shape[0])) - postVals = jnp.zeros((batch_size, shape[1])) - return ( - preVals, # inputs - postVals, # outputs - preVals, # pre - postVals, # post - jnp.zeros(shape), # dW - jnp.zeros(shape[1]), # db - ) + weights = _enforce_constraints(weights, self.block_mask, self.w_bound, is_nonnegative=self.is_nonnegative) + # Update compartments + self.opt_params.set(opt_params) + self.weights.set(weights) + self.biases.set(biases) + self.dWeights.set(dWeights) + self.dBiases.set(dBiases) @classmethod def help(cls): ## component help function @@ -326,11 +327,11 @@ def help(cls): ## component help function def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] maxlen = max(len(c) for c in comps) + 5 lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" for c in comps: - stats = tensorstats(getattr(self, c).value) + stats = tensorstats(getattr(self, c).get()) if stats is not None: line = [f"{k}: {v}" for k, v in stats.items()] line = ", ".join(line) @@ -340,18 +341,12 @@ def __repr__(self): return lines - - - - - - if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: Wab = HebbianPatchedSynapse("Wab", (9, 30), 3, (0, 0), optim_type='adam', sign_value=-1.0, prior=("l1l2", 0.001)) print(Wab) - plt.imshow(Wab.weights.value, cmap='gray') + plt.imshow(Wab.weights.get(), cmap='gray') plt.show() diff --git a/ngclearn/components/synapses/patched/patchedSynapse.py b/ngclearn/components/synapses/patched/patchedSynapse.py index 43d1dc16..bbf23d78 100644 --- a/ngclearn/components/synapses/patched/patchedSynapse.py +++ b/ngclearn/components/synapses/patched/patchedSynapse.py @@ -1,11 +1,15 @@ +# %% + import matplotlib.pyplot as plt from jax import random, numpy as jnp, jit -from ngclearn import resolver, Component, Compartment from ngclearn.components.jaxComponent import JaxComponent from ngclearn.utils import tensorstats -from ngcsimlib.compilers.process import transition from ngclearn.utils.weight_distribution import initialize_params + from ngcsimlib.logger import info +from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable + import math @@ -66,7 +70,7 @@ class PatchedSynapse(JaxComponent): ## base patched synaptic cable with number of inputs by number of outputs) n_sub_models: The number of submodels in each layer (Default: 1 similar functionality as DenseSynapse) - + stride_shape: Stride shape of overlapping synaptic weight value matrix (Default: (0, 0)) @@ -104,7 +108,7 @@ def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), block_mask=N self.n_sub_models = n_sub_models self.sub_stride = stride_shape - tmp_key, *subkeys = random.split(self.key.value, 4) + tmp_key, *subkeys = random.split(self.key.get(), 4) if self.weight_init is None: info(self.name, "is using default weight initializer!") self.weight_init = {"dist": "fan_in_gaussian"} @@ -137,20 +141,17 @@ def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), block_mask=N (1, self.shape[1])) if bias_init else 0.0) - @transition(output_compartments=["outputs"]) - @staticmethod - def advance_state(Rscale, inputs, weights, biases): - outputs = (jnp.matmul(inputs, weights) * Rscale) + biases - return outputs - - @transition(output_compartments=["inputs", "outputs"]) - @staticmethod - def reset(batch_size, shape): - preVals = jnp.zeros((batch_size, shape[0])) - postVals = jnp.zeros((batch_size, shape[1])) - inputs = preVals - outputs = postVals - return inputs, outputs + @compilable + def advance_state(self): + # Get the variables + inputs = self.inputs.get() + weights = self.weights.get() + biases = self.biases.get() + + outputs = (jnp.matmul(inputs, weights) * self.Rscale) + biases + + # Update compartment + self.outputs.set(outputs) def save(self, directory, **kwargs): file_name = directory + "/" + self.name + ".npz" @@ -202,11 +203,11 @@ def help(cls): ## component help function return info def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] maxlen = max(len(c) for c in comps) + 5 lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" for c in comps: - stats = tensorstats(getattr(self, c).value) + stats = tensorstats(getattr(self, c).get()) if stats is not None: line = [f"{k}: {v}" for k, v in stats.items()] line = ", ".join(line) @@ -216,21 +217,11 @@ def __repr__(self): return lines - - - - if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: Wab = PatchedSynapse("Wab", (9, 30), 3) print(Wab) - plt.imshow(Wab.weights.value, cmap='gray') + plt.imshow(Wab.weights.get(), cmap='gray') plt.show() - - - - - - diff --git a/ngclearn/utils/optim/adam.py b/ngclearn/utils/optim/adam.py index 12b1d756..8fa9cba6 100644 --- a/ngclearn/utils/optim/adam.py +++ b/ngclearn/utils/optim/adam.py @@ -1,8 +1,8 @@ # %% -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.resolver import resolver +# from ngcsimlib.component import Component +# from ngcsimlib.compartment import Compartment +# from ngcsimlib.resolver import resolver import numpy as np from jax import jit, numpy as jnp, random, nn, lax From 1ac6f2c17d55e82b098554755f12cf06e2ba250b Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Mon, 10 Nov 2025 14:38:41 -0500 Subject: [PATCH 031/121] commented out deprecator in hebb-syn and exp-kernel --- ngclearn/components/other/expKernel.py | 2 +- .../synapses/convolution/traceSTDPConvSynapse.py | 10 ++++------ .../synapses/convolution/traceSTDPDeconvSynapse.py | 6 +++--- ngclearn/components/synapses/hebbian/hebbianSynapse.py | 2 +- 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/ngclearn/components/other/expKernel.py b/ngclearn/components/other/expKernel.py index 5f95848b..21a10c37 100644 --- a/ngclearn/components/other/expKernel.py +++ b/ngclearn/components/other/expKernel.py @@ -2,7 +2,7 @@ from jax import numpy as jnp, random, jit from functools import partial from ngclearn.utils import tensorstats -from ngcsimlib.deprecators import deprecate_args +from ngcsimlib import deprecate_args from ngcsimlib.logger import info, warn from ngcsimlib.compartment import Compartment diff --git a/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py b/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py index 3f62b119..a0f74537 100755 --- a/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py +++ b/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py @@ -14,8 +14,8 @@ class TraceSTDPConvSynapse(ConvSynapse): ## trace-based STDP convolutional cable """ - A synaptic convolutional cable that adjusts its filter efficacies via a - trace-based form of spike-timing-dependent plasticity (STDP). + A specialized synaptic convolutional cable that adjusts its filter efficacies via a trace-based form of + spike-timing-dependent plasticity (STDP). | --- Synapse Compartments: --- | inputs - input (takes in external signals) @@ -131,8 +131,7 @@ def _init(self, batch_size, x_size, shape, stride, padding, pad_args, weights): dy = (_dx.shape[2] - _x.shape[2]) self.x_delta_shape = (dx, dy) - #@staticmethod - def _compute_update(self): #pretrace_target, Aplus, Aminus, stride, pad_args, delta_shape, preSpike, preTrace, postSpike, postTrace + def _compute_update(self): ## Compute long-term potentiation to filters dW_ltp = calc_dK_conv( self.preTrace.get() - self.pretrace_target, self.postSpike.get() * self.Aplus, delta_shape=self.delta_shape, @@ -162,8 +161,7 @@ def evolve(self): self.dWeights.set(dWeights) @compilable - def backtransmit(self): - ## action-backpropagating routine + def backtransmit(self): ## action-backpropagating co-routine ## calc dInputs - adjustment w.r.t. input signal k_size, k_size, n_in_chan, n_out_chan = self.shape # antiPad = None diff --git a/ngclearn/components/synapses/convolution/traceSTDPDeconvSynapse.py b/ngclearn/components/synapses/convolution/traceSTDPDeconvSynapse.py index f0c6cedf..514f8611 100755 --- a/ngclearn/components/synapses/convolution/traceSTDPDeconvSynapse.py +++ b/ngclearn/components/synapses/convolution/traceSTDPDeconvSynapse.py @@ -12,8 +12,8 @@ class TraceSTDPDeconvSynapse(DeconvSynapse): ## trace-based STDP deconvolutional cable """ - A synaptic deconvolutional (transposed convolutional) cable that adjusts its - filter efficacies via a trace-based form of spike-timing-dependent plasticity (STDP). + A specialized synaptic deconvolutional (transposed convolutional) cable that adjusts its filter efficacies via a + trace-based form of spike-timing-dependent plasticity (STDP). | --- Synapse Compartments: --- | inputs - input (takes in external signals) @@ -154,7 +154,7 @@ def evolve(self): self.dWeights.set(dWeights) @compilable - def backtransmit(self): ## action-backpropagating routine + def backtransmit(self): ## action-backpropagating co-routine ## calc dInputs dInputs = calc_dX_deconv( self.weights.get(), self.postSpike.get(), delta_shape=self.x_delta_shape, stride_size=self.stride, diff --git a/ngclearn/components/synapses/hebbian/hebbianSynapse.py b/ngclearn/components/synapses/hebbian/hebbianSynapse.py index cb221320..9447019f 100644 --- a/ngclearn/components/synapses/hebbian/hebbianSynapse.py +++ b/ngclearn/components/synapses/hebbian/hebbianSynapse.py @@ -8,7 +8,7 @@ from ngclearn.components.synapses import DenseSynapse from ngclearn.utils import tensorstats -from ngcsimlib.deprecators import deprecate_args +from ngcsimlib import deprecate_args @partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9]) def _calc_update(pre, post, W, w_bound, is_nonnegative=True, signVal=1., From 5a58b87d9f30d991ebafb212a7172f0abfbbb92b Mon Sep 17 00:00:00 2001 From: Viet Nguyen Date: Mon, 10 Nov 2025 14:51:16 -0500 Subject: [PATCH 032/121] update hebbian synapse --- ngclearn/components/synapses/hebbian/hebbianSynapse.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ngclearn/components/synapses/hebbian/hebbianSynapse.py b/ngclearn/components/synapses/hebbian/hebbianSynapse.py index 9447019f..245392dd 100644 --- a/ngclearn/components/synapses/hebbian/hebbianSynapse.py +++ b/ngclearn/components/synapses/hebbian/hebbianSynapse.py @@ -1,3 +1,5 @@ +# %% + from jax import random, numpy as jnp, jit from functools import partial from ngclearn.utils.optim import get_opt_init_fn, get_opt_step_fn @@ -297,11 +299,11 @@ def help(cls): ## component help function return info def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] maxlen = max(len(c) for c in comps) + 5 lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" for c in comps: - stats = tensorstats(getattr(self, c).value) + stats = tensorstats(getattr(self, c).get()) if stats is not None: line = [f"{k}: {v}" for k, v in stats.items()] line = ", ".join(line) From cac52079267acc9e9926b8c35edd09f3674c5702 Mon Sep 17 00:00:00 2001 From: Viet Nguyen Date: Mon, 10 Nov 2025 14:59:07 -0500 Subject: [PATCH 033/121] update hebbian synapse --- ngclearn/components/synapses/hebbian/hebbianSynapse.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ngclearn/components/synapses/hebbian/hebbianSynapse.py b/ngclearn/components/synapses/hebbian/hebbianSynapse.py index 245392dd..0007366f 100644 --- a/ngclearn/components/synapses/hebbian/hebbianSynapse.py +++ b/ngclearn/components/synapses/hebbian/hebbianSynapse.py @@ -166,7 +166,7 @@ class HebbianSynapse(DenseSynapse): """ # Define Functions - @deprecate_args(_rebind=False, w_decay='prior') + # @deprecate_args(_rebind=False, w_decay='prior') def __init__( self, name, shape, eta=0., weight_init=None, bias_init=None, w_bound=1., is_nonnegative=False, prior=("constant", 0.), w_decay=0., sign_value=1., optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1., @@ -210,8 +210,8 @@ def __init__( #key, subkey = random.split(self.key.value) self.opt_params = Compartment(get_opt_init_fn(optim_type)( - [self.weights.value, self.biases.value] - if bias_init else [self.weights.value])) + [self.weights.get(), self.biases.get()] + if bias_init else [self.weights.get()])) @staticmethod def _compute_update(w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, From 0d7c24b75cd1770ae396e539111d3961a2a04748 Mon Sep 17 00:00:00 2001 From: Viet Nguyen Date: Mon, 10 Nov 2025 15:57:25 -0500 Subject: [PATCH 034/121] working reinforce synapse --- .../synapses/modulated/REINFORCESynapse.py | 182 +++++++++++------- 1 file changed, 116 insertions(+), 66 deletions(-) diff --git a/ngclearn/components/synapses/modulated/REINFORCESynapse.py b/ngclearn/components/synapses/modulated/REINFORCESynapse.py index 92b72d88..378becf0 100644 --- a/ngclearn/components/synapses/modulated/REINFORCESynapse.py +++ b/ngclearn/components/synapses/modulated/REINFORCESynapse.py @@ -1,7 +1,9 @@ +# %% + from jax import random, numpy as jnp, jit -from ngcsimlib.compilers.process import transition -from ngcsimlib.component import Component +from ngcsimlib.logger import info from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable from ngclearn.utils.model_utils import clip, d_clip import jax import jax.numpy as jnp @@ -17,11 +19,59 @@ def gaussian_logpdf(event, mean, stddev): quadratic = (event - mean)**2 / scale_sqrd return - 0.5 * (log_normalizer + quadratic) + +def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max, scalar_stddev): + learning_stddev_mask = jnp.asarray(scalar_stddev <= 0.0, dtype=jnp.float32) + # (input_dim, output_dim * 2) => (input_dim, output_dim), (input_dim, output_dim) + W_mu, W_logstd = jnp.split(weights, 2, axis=-1) + # Forward pass + activation = act_fx(inputs) + mean = activation @ W_mu + fx_mean = mu_act_fx(mean) + logstd = activation @ W_logstd + clip_logstd = clip(logstd, -10.0, 2.0) + std = jnp.exp(clip_logstd) + std = learning_stddev_mask * std + (1.0 - learning_stddev_mask) * scalar_stddev # masking trick + # Sample using reparameterization trick + epsilon = jax.random.normal(seed, fx_mean.shape) + sample = epsilon * std + fx_mean + sample = jnp.clip(sample, mu_out_min, mu_out_max) + outputs = sample # the actual action that we take + # Compute log probability density of the Gaussian + log_prob = gaussian_logpdf(sample, fx_mean, std).sum(-1) + # Compute objective (negative REINFORCE objective) + objective = (-log_prob * rewards).mean() * 1e-2 + + # Backward pass + batch_size = inputs.shape[0] # B + dL_dlogp = -rewards[:, None] * 1e-2 / batch_size # (B, 1) + + # Compute gradients manually based on the derivation + # dL/dmu = -(r-r_hat) * dlog_prob/dmu = -(r-r_hat) * -(sample-mu)/sigma^2 + dlog_prob_dfxmean = (sample - fx_mean) / (std ** 2) + dL_dmean = dL_dlogp * dlog_prob_dfxmean * dmu_act_fx(mean) # (B, A) + dL_dWmu = activation.T @ dL_dmean + + # dL/dlog(sigma) = -(r-r_hat) * dlog_prob/dlog(sigma) = -(r-r_hat) * (((sample-mu)/sigma)^2 - 1) + dlog_prob_dlogstd = - 1.0 / std + (sample - fx_mean)**2 / std**3 + dL_dstd = dL_dlogp * dlog_prob_dlogstd + # Apply gradient clipping for logstd + dL_dlogstd = d_clip(logstd, -10.0, 2.0) * dL_dstd * std + dL_dWlogstd = activation.T @ dL_dlogstd # (I, B) @ (B, A) = (I, A) + dL_dWlogstd = dL_dWlogstd * learning_stddev_mask # there is no learning for the scalar stddev + + # Update weights, negate the gradient because gradient ascent in ngc-learn + dW = jnp.concatenate([-dL_dWmu, -dL_dWlogstd], axis=-1) + # Finally, return metrics if needed + return dW, objective, outputs + + + class REINFORCESynapse(DenseSynapse): """ A stochastic synapse implementing the REINFORCE algorithm (policy gradient method). This synapse uses Gaussian distributions for generating actions and performs gradient-based updates. - + | --- Synapse Compartments: --- | inputs - input (takes in external signals) | outputs - output signals (sampled actions from Gaussian distribution) @@ -89,7 +139,7 @@ def __init__( self.scalar_stddev = scalar_stddev ## Compartment setup - self.dWeights = Compartment(self.weights.value * 0) + self.dWeights = Compartment(self.weights.get() * 0) # self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate # For eligiblity traces later self.objective = Compartment(jnp.zeros(())) self.outputs = Compartment(jnp.zeros((batch_size, output_dim))) @@ -101,72 +151,50 @@ def __init__( self.learning_mask = Compartment(jnp.zeros(())) self.seed = Compartment(jax.random.PRNGKey(seed if seed is not None else 42)) - @staticmethod - def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max, scalar_stddev): - learning_stddev_mask = jnp.asarray(scalar_stddev <= 0.0, dtype=jnp.float32) - # (input_dim, output_dim * 2) => (input_dim, output_dim), (input_dim, output_dim) - W_mu, W_logstd = jnp.split(weights, 2, axis=-1) - # Forward pass - activation = act_fx(inputs) - mean = activation @ W_mu - fx_mean = mu_act_fx(mean) - logstd = activation @ W_logstd - clip_logstd = clip(logstd, -10.0, 2.0) - std = jnp.exp(clip_logstd) - std = learning_stddev_mask * std + (1.0 - learning_stddev_mask) * scalar_stddev # masking trick - # Sample using reparameterization trick - epsilon = jax.random.normal(seed, fx_mean.shape) - sample = epsilon * std + fx_mean - sample = jnp.clip(sample, mu_out_min, mu_out_max) - outputs = sample # the actual action that we take - # Compute log probability density of the Gaussian - log_prob = gaussian_logpdf(sample, fx_mean, std).sum(-1) - # Compute objective (negative REINFORCE objective) - objective = (-log_prob * rewards).mean() * 1e-2 - - # Backward pass - batch_size = inputs.shape[0] # B - dL_dlogp = -rewards[:, None] * 1e-2 / batch_size # (B, 1) - - # Compute gradients manually based on the derivation - # dL/dmu = -(r-r_hat) * dlog_prob/dmu = -(r-r_hat) * -(sample-mu)/sigma^2 - dlog_prob_dfxmean = (sample - fx_mean) / (std ** 2) - dL_dmean = dL_dlogp * dlog_prob_dfxmean * dmu_act_fx(mean) # (B, A) - dL_dWmu = activation.T @ dL_dmean - - # dL/dlog(sigma) = -(r-r_hat) * dlog_prob/dlog(sigma) = -(r-r_hat) * (((sample-mu)/sigma)^2 - 1) - dlog_prob_dlogstd = - 1.0 / std + (sample - fx_mean)**2 / std**3 - dL_dstd = dL_dlogp * dlog_prob_dlogstd - # Apply gradient clipping for logstd - dL_dlogstd = d_clip(logstd, -10.0, 2.0) * dL_dstd * std - dL_dWlogstd = activation.T @ dL_dlogstd # (I, B) @ (B, A) = (I, A) - dL_dWlogstd = dL_dWlogstd * learning_stddev_mask # there is no learning for the scalar stddev - - # Update weights, negate the gradient because gradient ascent in ngc-learn - dW = jnp.concatenate([-dL_dWmu, -dL_dWlogstd], axis=-1) - # Finally, return metrics if needed - return dW, objective, outputs - - @transition(output_compartments=["weights", "dWeights", "objective", "outputs", "accumulated_gradients", "step_count", "seed"]) - @staticmethod - def evolve(dt, w_bound, inputs, rewards, act_fx, weights, eta, learning_mask, decay, accumulated_gradients, step_count, seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max, scalar_stddev): + + # @transition(output_compartments=["weights", "dWeights", "objective", "outputs", "accumulated_gradients", "step_count", "seed"]) + # @staticmethod + @compilable + def evolve(self, dt): + + # Get compartment values + weights = self.weights.get() + dWeights = self.dWeights.get() + objective = self.objective.get() + outputs = self.outputs.get() + accumulated_gradients = self.accumulated_gradients.get() + step_count = self.step_count.get() + seed = self.seed.get() + inputs = self.inputs.get() + rewards = self.rewards.get() + + # Main logic main_seed, sub_seed = jax.random.split(seed) - dWeights, objective, outputs = REINFORCESynapse._compute_update( - dt, inputs, rewards, act_fx, weights, sub_seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max, scalar_stddev + dWeights, objective, outputs = _compute_update( + dt, inputs, rewards, self.act_fx, weights, sub_seed, self.mu_act_fx, self.dmu_act_fx, self.mu_out_min, self.mu_out_max, self.scalar_stddev ) ## do a gradient ascent update/shift - weights = (weights + dWeights * eta) * learning_mask + weights * (1.0 - learning_mask) # update the weights only where learning_mask is 1.0 + weights = (weights + dWeights * self.eta) * self.learning_mask + weights * (1.0 - self.learning_mask) # update the weights only where learning_mask is 1.0 ## enforce non-negativity eps = 0.0 # 0.01 # 0.001 - weights = jnp.clip(weights, eps, w_bound - eps) # jnp.abs(w_bound)) + weights = jnp.clip(weights, eps, self.w_bound - eps) # jnp.abs(w_bound)) step_count += 1 - accumulated_gradients = (step_count - 1) / step_count * accumulated_gradients * decay + 1.0 / step_count * dWeights # EMA update of accumulated gradients - step_count = step_count * (1 - learning_mask) # reset the step count to 0 when we have learned - return weights, dWeights, objective, outputs, accumulated_gradients, step_count, main_seed + accumulated_gradients = (step_count - 1) / step_count * accumulated_gradients * self.decay + 1.0 / step_count * dWeights # EMA update of accumulated gradients + step_count = step_count * (1 - self.learning_mask) # reset the step count to 0 when we have learned + + # Set updated compartment values + self.weights.set(weights) + self.dWeights.set(dWeights) + self.objective.set(objective) + self.outputs.set(outputs) + self.accumulated_gradients.set(accumulated_gradients) + self.step_count.set(step_count) + self.seed.set(main_seed) - @transition(output_compartments=["inputs", "outputs", "objective", "rewards", "dWeights", "accumulated_gradients", "step_count", "seed"]) - @staticmethod - def reset(batch_size, shape): + # @transition(output_compartments=["inputs", "outputs", "objective", "rewards", "dWeights", "accumulated_gradients", "step_count", "seed"]) + # @staticmethod + @compilable + def reset(self, batch_size, shape): preVals = jnp.zeros((batch_size, shape[0])) postVals = jnp.zeros((batch_size, shape[1])) inputs = preVals @@ -177,7 +205,17 @@ def reset(batch_size, shape): accumulated_gradients = jnp.zeros((shape[0], shape[1] * 2)) step_count = jnp.zeros(()) seed = jax.random.PRNGKey(42) - return inputs, outputs, objective, rewards, dWeights, accumulated_gradients, step_count, seed + + + self.inputs.set(inputs) + self.outputs.set(outputs) + self.objective.set(objective) + self.rewards.set(rewards) + self.dWeights.set(dWeights) + self.accumulated_gradients.set(accumulated_gradients) + self.step_count.set(step_count) + self.seed.set(seed) + @classmethod def help(cls): ## component help function @@ -223,11 +261,11 @@ def help(cls): ## component help function return info def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] maxlen = max(len(c) for c in comps) + 5 lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" for c in comps: - stats = tensorstats(getattr(self, c).value) + stats = tensorstats(getattr(self, c).get()) if stats is not None: line = [f"{k}: {v}" for k, v in stats.items()] line = ", ".join(line) @@ -235,3 +273,15 @@ def __repr__(self): line = "None" lines += f" {f'({c})'.ljust(maxlen)}{line}\n" return lines + + +if __name__ == '__main__': + from ngcsimlib.context import Context + with Context("Bar") as bar: + syn = REINFORCESynapse( + name="reinforce_syn", + shape=(3, 2) + ) + # Wab = syn.weights.get() + print(syn) + From 68f663b375187c5bd85c5f894a5e0ad7fc0a2484 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Mon, 10 Nov 2025 18:20:22 -0500 Subject: [PATCH 035/121] minor edits to exp-kernel/wtas-cell --- .../components/neurons/spiking/WTASCell.py | 2 +- ngclearn/components/other/expKernel.py | 13 +++++++++++-- ngclearn/components/synapses/denseSynapse.py | 19 +++++++++++-------- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/ngclearn/components/neurons/spiking/WTASCell.py b/ngclearn/components/neurons/spiking/WTASCell.py index 6ae97097..16a7e4e2 100755 --- a/ngclearn/components/neurons/spiking/WTASCell.py +++ b/ngclearn/components/neurons/spiking/WTASCell.py @@ -50,7 +50,7 @@ class WTASCell(JaxComponent): ## winner-take-all spiking cell thr_jitter: scale of uniform jitter to add to initialization of thresholds """ - # Define Functions + #@deprecate_args(thr_base="thrBase") def __init__( self, name, n_units, tau_m, resist_m=1., thr_base=0.4, thr_gain=0.002, refract_time=0., thr_jitter=0.05, **kwargs diff --git a/ngclearn/components/other/expKernel.py b/ngclearn/components/other/expKernel.py index 21a10c37..28cf77b8 100644 --- a/ngclearn/components/other/expKernel.py +++ b/ngclearn/components/other/expKernel.py @@ -74,13 +74,22 @@ def advance_state(self, t): s = inputs ## update spike time window and corresponding window volume - tf, epsp = _apply_kernel(tf, s, t, self.tau_w, self.win_len, krn_start=0, - krn_end=self.win_len-1) #0:win_len-1) + tf, epsp = _apply_kernel( + tf, s, t, self.tau_w, self.win_len, krn_start=0, krn_end=self.win_len-1 + ) #0:win_len-1) # Update compartments self.epsp.set(epsp) self.tf.set(tf) + @compilable + def reset(self): + restVals = jnp.zeros((self.batch_size, self.n_units)) ## inputs, epsp + restTensor = jnp.zeros([self.win_len, self.batch_size, self.n_units], jnp.float32) ## tf + self.inputs.set(restVals) + self.epsp.set(restVals) + self.tf.set(restTensor) + @classmethod def help(cls): ## component help function properties = { diff --git a/ngclearn/components/synapses/denseSynapse.py b/ngclearn/components/synapses/denseSynapse.py index 76a0778f..5b0e71b9 100755 --- a/ngclearn/components/synapses/denseSynapse.py +++ b/ngclearn/components/synapses/denseSynapse.py @@ -1,7 +1,6 @@ from jax import random, numpy as jnp, jit from ngclearn.components.jaxComponent import JaxComponent -from ngclearn.utils import tensorstats -from ngclearn.utils.weight_distribution import initialize_params +from ngclearn.utils.distribution_generator import DistributionGenerator from ngcsimlib.logger import info from ngcsimlib.compartment import Compartment @@ -58,10 +57,13 @@ def __init__( if self.weight_init is None: info(self.name, "is using default weight initializer!") - self.weight_init = {"dist": "uniform", "amin": 0.025, "amax": 0.8} - weights = initialize_params(subkeys[0], self.weight_init, shape) + # self.weight_init = {"dist": "uniform", "amin": 0.025, "amax": 0.8} + # weights = initialize_params(subkeys[0], self.weight_init, shape) + self.weight_init = DistributionGenerator.uniform(0.025, 0.8) + #weights = initialize_params(subkeys[0], self.weight_init, shape) + weights = self.weight_init(shape, subkeys[0]) - if 0. < p_conn < 1.: ## only non-zero and <1 probs allowed + if 0. < p_conn < 1.: ## Modifier/constraint: only non-zero and <1 probs allowed p_mask = random.bernoulli(subkeys[1], p=p_conn, shape=shape) weights = weights * p_mask ## sparsify matrix @@ -76,9 +78,10 @@ def __init__( if self.bias_init is None: info(self.name, "is using default bias value of zero (no bias " "kernel provided)!") - self.biases = Compartment(initialize_params(subkeys[2], bias_init, - (1, shape[1])) - if bias_init else 0.0) + self.biases = Compartment(self.bias_init((1, shape[1]), subkeys[2]) if bias_init else 0.0) + # self.biases = Compartment(initialize_params(subkeys[2], bias_init, + # (1, shape[1])) + # if bias_init else 0.0) @compilable def advance_state(self): From 55b0219b06955c27807b241239bd66216c3585b7 Mon Sep 17 00:00:00 2001 From: Viet Dung Nguyen Date: Tue, 11 Nov 2025 09:30:00 -0500 Subject: [PATCH 036/121] update requirements --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e689ce87..abf22239 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ numpy>=1.22.0 scikit-learn>=1.6.1 scipy>=1.7.0 matplotlib>=3.8.0 -patchify +# patchify # patchify has issues with pip installation jax>=0.4.28 jaxlib>=0.4.28 ngcsimlib>=1.0.0 From 56a059ff8d1ea8a1bfd5061eed07b55ad47c1ade Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Tue, 11 Nov 2025 11:36:27 -0500 Subject: [PATCH 037/121] refactored conv/deconv-hebb-syn and tests passed --- .../components/neurons/spiking/__init__.py | 3 +- ngclearn/components/synapses/__init__.py | 48 +++--- .../synapses/convolution/__init__.py | 5 +- .../convolution/hebbianConvSynapse.py | 138 +++++++++--------- .../convolution/hebbianDeconvSynapse.py | 119 ++++++++------- .../convolution/test_hebbianConvSynapse.py | 61 +++----- .../convolution/test_hebbianDeconvSynapse.py | 67 +++------ 7 files changed, 197 insertions(+), 244 deletions(-) diff --git a/ngclearn/components/neurons/spiking/__init__.py b/ngclearn/components/neurons/spiking/__init__.py index 690087b7..1466af9a 100644 --- a/ngclearn/components/neurons/spiking/__init__.py +++ b/ngclearn/components/neurons/spiking/__init__.py @@ -1,5 +1,5 @@ ## point to standard spiking cell component types -from .sLIFCell import SLIFCell +# from .sLIFCell import SLIFCell from .LIFCell import LIFCell from .IFCell import IFCell from .WTASCell import WTASCell @@ -9,3 +9,4 @@ from .izhikevichCell import IzhikevichCell from .RAFCell import RAFCell from .hodgkinHuxleyCell import HodgkinHuxleyCell + diff --git a/ngclearn/components/synapses/__init__.py b/ngclearn/components/synapses/__init__.py index d646001e..4c060ac9 100644 --- a/ngclearn/components/synapses/__init__.py +++ b/ngclearn/components/synapses/__init__.py @@ -7,32 +7,32 @@ from .exponentialSynapse import ExponentialSynapse from .doubleExpSynapse import DoupleExpSynapse from .alphaSynapse import AlphaSynapse -# -# ## dense synaptic components + +## dense synaptic components # from .hebbian.hebbianSynapse import HebbianSynapse -# from .hebbian.traceSTDPSynapse import TraceSTDPSynapse -# from .hebbian.expSTDPSynapse import ExpSTDPSynapse -# from .hebbian.eventSTDPSynapse import EventSTDPSynapse -# from .hebbian.BCMSynapse import BCMSynapse -# -# -# ## conv/deconv synaptic components -# from .convolution.convSynapse import ConvSynapse -# from .convolution.staticConvSynapse import StaticConvSynapse -# from .convolution.hebbianConvSynapse import HebbianConvSynapse -# from .convolution.traceSTDPConvSynapse import TraceSTDPConvSynapse -# from .convolution.deconvSynapse import DeconvSynapse -# from .convolution.staticDeconvSynapse import StaticDeconvSynapse -# from .convolution.hebbianDeconvSynapse import HebbianDeconvSynapse -# from .convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse -# -# -# ## modulated synaptic components -# from .modulated.MSTDPETSynapse import MSTDPETSynapse +from .hebbian.traceSTDPSynapse import TraceSTDPSynapse +from .hebbian.expSTDPSynapse import ExpSTDPSynapse +from .hebbian.eventSTDPSynapse import EventSTDPSynapse +from .hebbian.BCMSynapse import BCMSynapse + + +## conv/deconv synaptic components +from .convolution.convSynapse import ConvSynapse +from .convolution.staticConvSynapse import StaticConvSynapse +from .convolution.hebbianConvSynapse import HebbianConvSynapse +from .convolution.traceSTDPConvSynapse import TraceSTDPConvSynapse +from .convolution.deconvSynapse import DeconvSynapse +from .convolution.staticDeconvSynapse import StaticDeconvSynapse +from .convolution.hebbianDeconvSynapse import HebbianDeconvSynapse +from .convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse + + +## modulated synaptic components +from .modulated.MSTDPETSynapse import MSTDPETSynapse # from .modulated.REINFORCESynapse import REINFORCESynapse -# -# ## patched synaptic components + +## patched synaptic components # from .patched.patchedSynapse import PatchedSynapse # from .patched.staticPatchedSynapse import StaticPatchedSynapse # from .patched.hebbianPatchedSynapse import HebbianPatchedSynapse -# + diff --git a/ngclearn/components/synapses/convolution/__init__.py b/ngclearn/components/synapses/convolution/__init__.py index 01f3bced..0724f25f 100755 --- a/ngclearn/components/synapses/convolution/__init__.py +++ b/ngclearn/components/synapses/convolution/__init__.py @@ -2,7 +2,8 @@ from .staticConvSynapse import StaticConvSynapse from .deconvSynapse import DeconvSynapse from .staticDeconvSynapse import StaticDeconvSynapse -#from .hebbianConvSynapse import HebbianConvSynapse -# from .hebbianDeconvSynapse import HebbianDeconvSynapse +from .hebbianConvSynapse import HebbianConvSynapse +from .hebbianDeconvSynapse import HebbianDeconvSynapse from .traceSTDPConvSynapse import TraceSTDPConvSynapse from .traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse + diff --git a/ngclearn/components/synapses/convolution/hebbianConvSynapse.py b/ngclearn/components/synapses/convolution/hebbianConvSynapse.py index ff45f76b..9ffdf49b 100755 --- a/ngclearn/components/synapses/convolution/hebbianConvSynapse.py +++ b/ngclearn/components/synapses/convolution/hebbianConvSynapse.py @@ -1,13 +1,11 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compilers.process import transition -from ngcsimlib.component import Component from ngcsimlib.compartment import Compartment - -from .convSynapse import ConvSynapse +from ngcsimlib.parser import compilable from ngclearn.utils.weight_distribution import initialize_params -from ngcsimlib.logger import info -from ngclearn.utils import tensorstats import ngclearn.utils.weight_distribution as dist + +from ngclearn.components.synapses.convolution.convSynapse import ConvSynapse + from ngclearn.components.synapses.convolution.ngcconv import (_conv_same_transpose_padding, _conv_valid_transpose_padding) from ngclearn.components.synapses.convolution.ngcconv import (conv2d, _calc_dX_conv, @@ -17,8 +15,7 @@ class HebbianConvSynapse(ConvSynapse): ## Hebbian-evolved convolutional cable """ - A synaptic convolutional cable that adjusts its efficacies via a two-factor - Hebbian adjustment rule. + A specialized synaptic convolutional cable that adjusts its efficacies via a two-factor Hebbian adjustment rule. | --- Synapse Compartments: --- | inputs - input (takes in external signals) @@ -88,10 +85,11 @@ class HebbianConvSynapse(ConvSynapse): ## Hebbian-evolved convolutional cable """ # Define Functions - def __init__(self, name, shape, x_shape, eta=0., filter_init=None, bias_init=None, - stride=1, padding=None, resist_scale=1., w_bound=0., - is_nonnegative=False, w_decay=0., sign_value=1., optim_type="sgd", - batch_size=1, **kwargs): + def __init__( + self, name, shape, x_shape, eta=0., filter_init=None, bias_init=None, stride=1, padding=None, + resist_scale=1., w_bound=0., is_nonnegative=False, w_decay=0., sign_value=1., optim_type="sgd", + batch_size=1, **kwargs + ): super().__init__( name, shape, x_shape=x_shape, filter_init=filter_init, bias_init=bias_init, resist_scale=resist_scale, stride=stride, padding=padding, batch_size=batch_size, **kwargs @@ -107,9 +105,9 @@ def __init__(self, name, shape, x_shape, eta=0., filter_init=None, bias_init=Non ######################### set up compartments ########################## ## Compartment setup and shape computation - self.dWeights = Compartment(self.weights.value * 0) + self.dWeights = Compartment(self.weights.get() * 0) self.dInputs = Compartment(jnp.zeros(self.in_shape)) - self.dBiases = Compartment(self.biases.value * 0) + self.dBiases = Compartment(self.biases.get() * 0) self.pre = Compartment(jnp.zeros(self.in_shape)) self.post = Compartment(jnp.zeros(self.out_shape)) @@ -120,80 +118,76 @@ def __init__(self, name, shape, x_shape, eta=0., filter_init=None, bias_init=Non self.antiPad = None k_size, k_size, n_in_chan, n_out_chan = self.shape if padding == "SAME": - self.antiPad = _conv_same_transpose_padding(self.post.value.shape[1], + self.antiPad = _conv_same_transpose_padding(self.post.get().shape[1], self.x_size, k_size, stride) elif padding == "VALID": - self.antiPad = _conv_valid_transpose_padding(self.post.value.shape[1], + self.antiPad = _conv_valid_transpose_padding(self.post.get().shape[1], self.x_size, k_size, stride) ######################################################################## ## set up outer optimization compartments self.opt_params = Compartment(get_opt_init_fn(optim_type)( - [self.weights.value, self.biases.value] - if bias_init else [self.weights.value])) + [self.weights.get(), self.biases.get()] + if bias_init else [self.weights.get()]) + ) def _init(self, batch_size, x_size, shape, stride, padding, pad_args, weights): k_size, k_size, n_in_chan, n_out_chan = shape _x = jnp.zeros((batch_size, x_size, x_size, n_in_chan)) - _d = conv2d(_x, weights.value, stride_size=stride, padding=padding) * 0 + _d = conv2d(_x, weights.get(), stride_size=stride, padding=padding) * 0 _dK = _calc_dK_conv(_x, _d, stride_size=stride, padding=pad_args) ## get filter update correction - dx = _dK.shape[0] - weights.value.shape[0] - dy = _dK.shape[1] - weights.value.shape[1] + dx = _dK.shape[0] - weights.get().shape[0] + dy = _dK.shape[1] - weights.get().shape[1] self.delta_shape = (max(dx, 0), max(dy, 0)) ## get input update correction - _dx = _calc_dX_conv(weights.value, _d, stride_size=stride, - anti_padding=pad_args) + _dx = _calc_dX_conv(weights.get(), _d, stride_size=stride, anti_padding=pad_args) dx = (_dx.shape[1] - _x.shape[1]) dy = (_dx.shape[2] - _x.shape[2]) self.x_delta_shape = (dx, dy) - @staticmethod - def _compute_update( - sign_value, w_decay, bias_init, stride, pad_args, delta_shape, pre, post, weights - ): ## synaptic kernel adjustment calculation co-routine + def _compute_update(self): #sign_value, w_decay, bias_init, stride, pad_args, delta_shape, pre, post, weights + ## synaptic kernel adjustment calculation co-routine ## compute adjustment to filters - dWeights = calc_dK_conv(pre, post, delta_shape=delta_shape, stride_size=stride, padding=pad_args) - dWeights = dWeights * sign_value - if w_decay > 0.: ## apply synaptic decay - dWeights = dWeights - weights * w_decay + dWeights = calc_dK_conv( + self.pre.get(), self.post.get(), delta_shape=self.delta_shape, stride_size=self.stride, padding=self.pad_args + ) + dWeights = dWeights * self.sign_value + if self.w_decay > 0.: ## apply synaptic decay + dWeights = dWeights - self.weights.get() * self.w_decay ## compute adjustment to base-rates (if applicable) dBiases = 0. # jnp.zeros((1,1)) - if bias_init != None: - dBiases = jnp.sum(post, axis=0, keepdims=True) * sign_value + if self.bias_init != None: + dBiases = jnp.sum(self.post.get(), axis=0, keepdims=True) * self.sign_value return dWeights, dBiases - @transition(output_compartments=["opt_params", "weights", "biases", "dWeights", "dBiases"]) - @staticmethod - def evolve( - opt, sign_value, w_decay, w_bounds, is_nonnegative, bias_init, stride, pad_args, delta_shape, pre, post, - weights, biases, opt_params - ): + @compilable + def evolve(self): ## calc dFilters / dBiases - update to filters and biases - dWeights, dBiases = HebbianConvSynapse._compute_update( - sign_value, w_decay, bias_init, stride, pad_args, delta_shape, pre, post, weights - ) - if bias_init != None: - opt_params, [weights, biases] = opt(opt_params, [weights, biases], [dWeights, dBiases]) + dWeights, dBiases = self._compute_update() + if self.bias_init is not None: + opt_params, [weights, biases] = self.opt(self.opt_params.get(), [self.weights.get(), self.biases.get()], [dWeights, dBiases]) else: ## ignore dBiases since no biases configured - opt_params, [weights] = opt(opt_params, [weights], [dWeights]) - + opt_params, [weights] = self.opt(self.opt_params.get(), [self.weights.get()], [dWeights]) + biases = None ## apply any enforced filter constraints - if w_bounds > 0.: - if is_nonnegative: - weights = jnp.clip(weights, 0., w_bounds) + if self.w_bounds > 0.: + if self.is_nonnegative: + weights = jnp.clip(weights, 0., self.w_bounds) else: - weights = jnp.clip(weights, -w_bounds, w_bounds) - return opt_params, weights, biases, dWeights, dBiases - - @transition(output_compartments=["dInputs"]) - @staticmethod - def backtransmit( - sign_value, x_size, shape, stride, padding, x_delta_shape, antiPad, post, weights - ): ## action-backpropagating routine + weights = jnp.clip(weights, -self.w_bounds, self.w_bounds) + + self.opt_params.set(opt_params) + self.weights.set(weights) + self.biases.set(biases) + self.dWeights.set(dWeights) + self.dBiases.set(dBiases) + + @compilable + def backtransmit(self): ## action-backpropagating co-routine ## calc dInputs - adjustment w.r.t. input signal - k_size, k_size, n_in_chan, n_out_chan = shape + k_size, k_size, n_in_chan, n_out_chan = self.shape # antiPad = None # if padding == "SAME": # antiPad = _conv_same_transpose_padding(post.shape[1], x_size, @@ -201,22 +195,20 @@ def backtransmit( # elif padding == "VALID": # antiPad = _conv_valid_transpose_padding(post.shape[1], x_size, # k_size, stride) - dInputs = calc_dX_conv(weights, post, delta_shape=x_delta_shape, stride_size=stride, anti_padding=antiPad) + dInputs = calc_dX_conv(self.weights.get(), self.post.get(), delta_shape=self.x_delta_shape, stride_size=self.stride, anti_padding=self.antiPad) ## flip sign of back-transmitted signal (if applicable) - dInputs = dInputs * sign_value - return dInputs - - @transition(output_compartments=["inputs", "outputs", "pre", "post", "dInputs"]) - @staticmethod - def reset(in_shape, out_shape): - preVals = jnp.zeros(in_shape) - postVals = jnp.zeros(out_shape) - inputs = preVals - outputs = postVals - pre = preVals - post = postVals - dInputs = preVals - return inputs, outputs, pre, post, dInputs + dInputs = dInputs * self.sign_value + self.dInputs.set(dInputs) + + @compilable + def reset(self): #in_shape, out_shape): + preVals = jnp.zeros(self.in_shape.get()) + postVals = jnp.zeros(self.out_shape.get()) + self.inputs.set(preVals) + self.outputs.set(postVals) + self.pre.set(preVals) + self.post.set(postVals) + self.dInputs.set(preVals) @classmethod def help(cls): ## component help function diff --git a/ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py b/ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py index f203400a..86b7ec6e 100755 --- a/ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py +++ b/ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py @@ -1,13 +1,11 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compilers.process import transition -from ngcsimlib.component import Component from ngcsimlib.compartment import Compartment - -from .deconvSynapse import DeconvSynapse +from ngcsimlib.parser import compilable from ngclearn.utils.weight_distribution import initialize_params -from ngcsimlib.logger import info -from ngclearn.utils import tensorstats import ngclearn.utils.weight_distribution as dist + +from ngclearn.components.synapses.convolution.deconvSynapse import DeconvSynapse + from ngclearn.components.synapses.convolution.ngcconv import (deconv2d, _calc_dX_deconv, _calc_dK_deconv, calc_dX_deconv, calc_dK_deconv) @@ -15,8 +13,8 @@ class HebbianDeconvSynapse(DeconvSynapse): ## Hebbian-evolved deconvolutional cable """ - A synaptic deconvolutional (transposed convolutional) cable that adjusts its - efficacies via a two-factor Hebbian adjustment rule. + A specialized synaptic deconvolutional (transposed convolutional) cable that adjusts its efficacies via a + two-factor Hebbian adjustment rule. | --- Synapse Compartments: --- | inputs - input (takes in external signals) @@ -104,11 +102,11 @@ def __init__( ## optimization / adjustment properties (given learning dynamics above) self.opt = get_opt_step_fn(optim_type, eta=self.eta) - self.dWeights = Compartment(self.weights.value * 0) + self.dWeights = Compartment(self.weights.get() * 0) self.dInputs = Compartment(jnp.zeros(self.in_shape)) self.pre = Compartment(jnp.zeros(self.in_shape)) self.post = Compartment(jnp.zeros(self.out_shape)) - self.dBiases = Compartment(self.biases.value * 0) + self.dBiases = Compartment(self.biases.get() * 0) ######################################################################## ## Shape error correction -- do shape correction inference (for local updates) @@ -118,84 +116,85 @@ def __init__( ## set up outer optimization compartments self.opt_params = Compartment(get_opt_init_fn(optim_type)( - [self.weights.value, self.biases.value] - if bias_init else [self.weights.value])) + [self.weights.get(), self.biases.get()] + if bias_init else [self.weights.get()]) + ) def _init(self, batch_size, x_size, shape, stride, padding, pad_args, weights): k_size, k_size, n_in_chan, n_out_chan = shape _x = jnp.zeros((batch_size, x_size, x_size, n_in_chan)) - _d = deconv2d(_x, self.weights.value, stride_size=self.stride, + _d = deconv2d(_x, self.weights.get(), stride_size=self.stride, padding=self.padding) * 0 _dK = _calc_dK_deconv(_x, _d, stride_size=self.stride, out_size=k_size) ## get filter update correction - dx = _dK.shape[0] - self.weights.value.shape[0] - dy = _dK.shape[1] - self.weights.value.shape[1] + dx = _dK.shape[0] - self.weights.get().shape[0] + dy = _dK.shape[1] - self.weights.get().shape[1] self.delta_shape = (abs(dx), abs(dy)) ## get input update correction - _dx = _calc_dX_deconv(self.weights.value, _d, stride_size=self.stride, + _dx = _calc_dX_deconv(self.weights.get(), _d, stride_size=self.stride, padding=self.padding) dx = (_dx.shape[1] - _x.shape[1]) # abs() dy = (_dx.shape[2] - _x.shape[2]) self.x_delta_shape = (dx, dy) - @staticmethod - def _compute_update(sign_value, w_decay, bias_init, shape, stride, padding, delta_shape, pre, post, weights): - k_size, k_size, n_in_chan, n_out_chan = shape + def _compute_update(self): + k_size, k_size, n_in_chan, n_out_chan = self.shape ## compute adjustment to filters dWeights = calc_dK_deconv( - pre, post, delta_shape=delta_shape, stride_size=stride, out_size=k_size, padding=padding + self.pre.get(), self.post.get(), delta_shape=self.delta_shape, stride_size=self.stride, out_size=k_size, + padding=self.padding ) - dWeights = dWeights * sign_value - if w_decay > 0.: ## apply synaptic decay - dWeights = dWeights - weights * w_decay + dWeights = dWeights * self.sign_value + if self.w_decay > 0.: ## apply synaptic decay + dWeights = dWeights - self.weights.get() * self.w_decay ## compute adjustment to base-rates (if applicable) dBiases = 0. # jnp.zeros((1,1)) - if bias_init != None: - dBiases = jnp.sum(post, axis=0, keepdims=True) * sign_value + if self.bias_init != None: + dBiases = jnp.sum(self.post.get(), axis=0, keepdims=True) * self.sign_value return dWeights, dBiases - @transition(output_compartments=["opt_params", "weights", "biases", "dWeights", "dBiases"]) - @staticmethod - def evolve( - opt, sign_value, w_decay, w_bounds, is_nonnegative, bias_init, shape, stride, padding, delta_shape, - pre, post, weights, biases, opt_params - ): - dWeights, dBiases = HebbianDeconvSynapse._compute_update( - sign_value, w_decay, bias_init, shape, stride, padding, delta_shape, pre, post, weights - ) - if bias_init != None: - opt_params, [weights, biases] = opt(opt_params, [weights, biases], [dWeights, dBiases]) + @compilable + def evolve(self): + dWeights, dBiases = self._compute_update() + if self.bias_init != None: + opt_params, [weights, biases] = self.opt(self.opt_params.get(), [self.weights.get(), self.biases.get()], [dWeights, dBiases]) else: ## ignore dBiases since no biases configured - opt_params, [weights] = opt(opt_params, [weights], [dWeights]) + opt_params, [weights] = self.opt(self.opt_params.get(), [self.weights.get()], [dWeights]) + biases = None ## apply any enforced filter constraints - if w_bounds > 0.: - if is_nonnegative: - weights = jnp.clip(weights, 0., w_bounds) + if self.w_bounds > 0.: + if self.is_nonnegative: + weights = jnp.clip(weights, 0., self.w_bounds) else: - weights = jnp.clip(weights, -w_bounds, w_bounds) - return opt_params, weights, biases, dWeights, dBiases + weights = jnp.clip(weights, -self.w_bounds, self.w_bounds) - @transition(output_compartments=["dInputs"]) - @staticmethod - def backtransmit(sign_value, stride, padding, x_delta_shape, pre, post, weights): ## action-backpropagating routine + self.opt_params.set(opt_params) + self.weights.set(weights) + self.biases.set(biases) + self.dWeights.set(dWeights) + self.dBiases.set(dBiases) + + @compilable + def backtransmit(self): ## action-backpropagating co-routine ## calc dInputs - dInputs = calc_dX_deconv(weights, post, delta_shape=x_delta_shape, stride_size=stride, padding=padding) + dInputs = calc_dX_deconv( + self.weights.get(), self.post.get(), delta_shape=self.x_delta_shape, stride_size=self.stride, + padding=self.padding + ) ## flip sign of back-transmitted signal (if applicable) - dInputs = dInputs * sign_value - return dInputs - - @transition(output_compartments=["inputs", "outputs", "pre", "post", "dInputs"]) - @staticmethod - def reset(in_shape, out_shape): - preVals = jnp.zeros(in_shape) - postVals = jnp.zeros(out_shape) - inputs = preVals - outputs = postVals - pre = preVals - post = postVals - dInputs = preVals - return inputs, outputs, pre, post, dInputs + dInputs = dInputs * self.sign_value + self.dInputs.set(dInputs) + + @compilable + def reset(self): #in_shape, out_shape): + preVals = jnp.zeros(self.in_shape.get()) + postVals = jnp.zeros(self.out_shape.get()) + self.inputs.set(preVals) + self.outputs.set(postVals) + self.pre.set(preVals) + self.post.set(postVals) + self.dInputs.set(preVals) @classmethod def help(cls): ## component help function diff --git a/tests/components/synapses/convolution/test_hebbianConvSynapse.py b/tests/components/synapses/convolution/test_hebbianConvSynapse.py index db6cd662..02d9fcad 100644 --- a/tests/components/synapses/convolution/test_hebbianConvSynapse.py +++ b/tests/components/synapses/convolution/test_hebbianConvSynapse.py @@ -1,17 +1,12 @@ from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context import numpy as np np.random.seed(42) -from ngclearn.components import HebbianConvSynapse + +from ngclearn import Context, MethodProcess import ngclearn.utils.weight_distribution as dist -from ngcsimlib.compilers import compile_command, wrap_command +from ngclearn.components.synapses.convolution.hebbianConvSynapse import HebbianConvSynapse from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context - def test_HebbianConvSynapse1(): name = "hebb_conv_ctx" ## create seeding keys @@ -36,41 +31,24 @@ def test_HebbianConvSynapse1(): stride=stride, padding=padding_style, batch_size=batch_size, key=subkeys[0] ) - #""" - evolve_process = (Process("evolve_proc") + evolve_process = (MethodProcess("evolve_process") >> a.evolve) - ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt") - backtransmit_process = (Process("btransmit_proc") - >> a.backtransmit) - ctx.wrap_and_add_command(jit(backtransmit_process.pure), name="backtransmit") + backtransmit_process = (MethodProcess("backtransmit_process") + >> a.backtransmit) - advance_process = (Process("advance_proc") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - #""" - - """ - reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve") - ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt") - backpass_cmd, backpass_args = ctx.compile_by_key(a, compile_key="backtransmit") - ctx.add_command(wrap_command(jit(ctx.backtransmit)), name="backtransmit") - """ x = jnp.ones(x_shape) - ctx.reset() + reset_process.run() # ctx.reset() a.inputs.set(x) - ctx.run(t=1., dt=dt) - y = a.outputs.value + advance_process.run(t=1., dt=dt) # ctx.run(t=1., dt=dt) + y = a.outputs.get() y_truth = jnp.array( [[[[4.],[2.]], @@ -79,17 +57,16 @@ def test_HebbianConvSynapse1(): assert_array_equal(y, y_truth) # print(y) + # print("y.Tr:\n", y_truth) # print("======") - # print("NGC-Learn.shape = ", node.outputs.value.shape) + # print("NGC-Learn.shape = ", node.outputs.get().shape) a.pre.set(x) a.post.set(y) - ctx.adapt(t=1., dt=dt) - dK = a.dWeights.value - #print(dK) - ctx.backtransmit(t=1., dt=dt) - dx = a.dInputs.value - #print(dx) + evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1., dt=dt) + dK = a.dWeights.get() + backtransmit_process.run(t=1., dt=dt) # ctx.backtransmit(t=1., dt=dt) + dx = a.dInputs.get() dK_truth = jnp.array( [[[[9.]], [[6.]]], @@ -102,6 +79,10 @@ def test_HebbianConvSynapse1(): [[6.], [9.]]]] ) + # print(dK) + # print("dK.Tr:\n", dK_truth) + # print(dx) + # print("dx.Tr:\n", dx_truth) assert_array_equal(dK, dK_truth) assert_array_equal(dx, dx_truth) diff --git a/tests/components/synapses/convolution/test_hebbianDeconvSynapse.py b/tests/components/synapses/convolution/test_hebbianDeconvSynapse.py index a91e69d4..5adf7884 100644 --- a/tests/components/synapses/convolution/test_hebbianDeconvSynapse.py +++ b/tests/components/synapses/convolution/test_hebbianDeconvSynapse.py @@ -1,17 +1,12 @@ from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context import numpy as np np.random.seed(42) -from ngclearn.components import HebbianDeconvSynapse + +from ngclearn import Context, MethodProcess import ngclearn.utils.weight_distribution as dist -from ngcsimlib.compilers import compile_command, wrap_command +from ngclearn.components.synapses.convolution.hebbianDeconvSynapse import HebbianDeconvSynapse from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context - def test_HebbianDeconvSynapse1(): name = "hebb_deconv_ctx" ## create seeding keys @@ -36,43 +31,24 @@ def test_HebbianDeconvSynapse1(): stride=stride, padding=padding_style, batch_size=batch_size, key=subkeys[0] ) - #""" - evolve_process = (Process("evolve_proc") - >> a.evolve) - #ctx.wrap_and_add_command(evolve_process.pure, name="run") - ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt") + evolve_process = (MethodProcess("evolve_process") + >> a.evolve) - backtransmit_process = (Process("btransmit_proc") + backtransmit_process = (MethodProcess("backtransmit_process") >> a.backtransmit) - ctx.wrap_and_add_command(jit(backtransmit_process.pure), name="backtransmit") - advance_process = (Process("advance_proc") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - # ctx.wrap_and_add_command(advance_process.pure, name="run") - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - #""" - - """ - reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve") - ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt") - backpass_cmd, backpass_args = ctx.compile_by_key(a, compile_key="backtransmit") - ctx.add_command(wrap_command(jit(ctx.backtransmit)), name="backtransmit") - """ x = jnp.ones(x_shape) - ctx.reset() + reset_process.run() # ctx.reset() a.inputs.set(x) - ctx.run(t=1., dt=dt) - y = a.outputs.value + advance_process.run(t=1., dt=dt) # ctx.run(t=1., dt=dt) + y = a.outputs.get() y_truth = jnp.array( [[[[1.],[2.]], @@ -80,18 +56,17 @@ def test_HebbianDeconvSynapse1(): ) assert_array_equal(y, y_truth) - #print(y) - #print("======") + # print(y) + # print("y.Tr:\n", y_truth) + # print("======") - # print("NGC-Learn.shape = ", node.outputs.value.shape) + # print("NGC-Learn.shape = ", node.outputs.get().shape) a.pre.set(x) a.post.set(y) - ctx.adapt(t=1., dt=dt) - dK = a.dWeights.value - #print(dK) - ctx.backtransmit(t=1., dt=dt) - dx = a.dInputs.value - #print(dx) + evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1., dt=dt) + dK = a.dWeights.get() + backtransmit_process.run(t=1., dt=dt) # ctx.backtransmit(t=1., dt=dt) + dx = a.dInputs.get() dK_truth = jnp.array( [[[[4.]], [[6.]]], @@ -104,6 +79,10 @@ def test_HebbianDeconvSynapse1(): [[6.], [4.]]]] ) + # print(dK) + # print("dK.Tr:\n", dK_truth) + # print(dx) + # print("dx.Tr:\n", dx_truth) assert_array_equal(dK, dK_truth) assert_array_equal(dx, dx_truth) From 01454f051dc8da37fbceeae624069acc78e7b042 Mon Sep 17 00:00:00 2001 From: Viet Nguyen Date: Tue, 11 Nov 2025 13:51:33 -0500 Subject: [PATCH 038/121] update hebbian synapse reset bug --- .../components/synapses/hebbian/hebbianSynapse.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/ngclearn/components/synapses/hebbian/hebbianSynapse.py b/ngclearn/components/synapses/hebbian/hebbianSynapse.py index 0007366f..2d333fa8 100644 --- a/ngclearn/components/synapses/hebbian/hebbianSynapse.py +++ b/ngclearn/components/synapses/hebbian/hebbianSynapse.py @@ -253,6 +253,17 @@ def evolve(self): self.dWeights.set(dWeights) self.dBiases.set(dBiases) + @compilable + def reset(self, batch_size, shape): + preVals = jnp.zeros((batch_size, shape[0])) + postVals = jnp.zeros((batch_size, shape[1])) + self.inputs.set(preVals) # inputs + self.outputs.set(postVals) # outputs + self.pre.set(preVals) # pre + self.post.set(postVals) # post + self.dWeights.set(jnp.zeros(shape)) # dW + self.dBiases.set(jnp.zeros(shape[1])) # db + @classmethod def help(cls): ## component help function properties = { From a685fcd3d8bbc32f800638b841016c89c811bb4c Mon Sep 17 00:00:00 2001 From: Viet Nguyen Date: Tue, 11 Nov 2025 14:03:22 -0500 Subject: [PATCH 039/121] update reset methods --- .../neurons/graded/bernoulliErrorCell.py | 2 +- .../neurons/graded/gaussianErrorCell.py | 26 +++++++++++++++++++ .../neurons/graded/laplacianErrorCell.py | 20 ++++++++++++++ .../components/neurons/graded/rateCell.py | 12 +++++++++ .../neurons/graded/rewardErrorCell.py | 12 +++++++++ ngclearn/components/other/expKernel.py | 4 +-- 6 files changed, 73 insertions(+), 3 deletions(-) diff --git a/ngclearn/components/neurons/graded/bernoulliErrorCell.py b/ngclearn/components/neurons/graded/bernoulliErrorCell.py index 376fa41f..9b63e8e7 100755 --- a/ngclearn/components/neurons/graded/bernoulliErrorCell.py +++ b/ngclearn/components/neurons/graded/bernoulliErrorCell.py @@ -12,7 +12,7 @@ class BernoulliErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell """ A simple (non-spiking) Bernoulli error cell - this is a fixed-point solution - of a mismatch signal. Specifically, this cell operates as a factorized multivariate + of a mismatch signal. Specifically, this cell operates as a factorized multivariate Bernoulli distribution. | --- Cell Input Compartments: --- diff --git a/ngclearn/components/neurons/graded/gaussianErrorCell.py b/ngclearn/components/neurons/graded/gaussianErrorCell.py index 63e10a65..0d78bddf 100755 --- a/ngclearn/components/neurons/graded/gaussianErrorCell.py +++ b/ngclearn/components/neurons/graded/gaussianErrorCell.py @@ -108,6 +108,32 @@ def advance_state(self, dt): ## compute Gaussian error cell output self.L.set(jnp.squeeze(L)) self.mask.set(mask) + # @transition(output_compartments=["dmu", "dtarget", "dSigma", "target", "mu", "modulator", "L", "mask"]) + # @staticmethod + @compilable + def reset(self, batch_size, shape, sigma_shape): ## reset core components/statistics + _shape = (batch_size, shape[0]) + if len(shape) > 1: + _shape = (batch_size, shape[0], shape[1], shape[2]) + restVals = jnp.zeros(_shape) + dmu = restVals + dtarget = restVals + dSigma = jnp.zeros(sigma_shape) + target = restVals + mu = restVals + modulator = mu + 1. + L = 0. #jnp.zeros((1, 1)) + mask = jnp.ones(_shape) + + self.dmu.set(dmu) + self.dtarget.set(dtarget) + self.dSigma.set(dSigma) + self.target.set(target) + self.mu.set(mu) + self.modulator.set(modulator) + self.L.set(L) + self.mask.set(mask) + @classmethod def help(cls): ## component help function properties = { diff --git a/ngclearn/components/neurons/graded/laplacianErrorCell.py b/ngclearn/components/neurons/graded/laplacianErrorCell.py index e3717d1c..04d3b0b8 100755 --- a/ngclearn/components/neurons/graded/laplacianErrorCell.py +++ b/ngclearn/components/neurons/graded/laplacianErrorCell.py @@ -103,6 +103,26 @@ def advance_state(self, dt): ## compute Laplacian error cell output self.L.set(jnp.squeeze(L)) self.mask.set(mask) + def reset(self, batch_size, n_units, scale_shape): + restVals = jnp.zeros((batch_size, n_units)) + dshift = restVals + dtarget = restVals + dScale = jnp.zeros(scale_shape) + target = restVals + shift = restVals + modulator = shift + 1. + L = 0. + mask = jnp.ones((batch_size, n_units)) + + self.dshift.set(dshift) + self.dtarget.set(dtarget) + self.dScale.set(dScale) + self.target.set(target) + self.shift.set(shift) + self.modulator.set(modulator) + self.L.set(L) + self.mask.set(mask) + @classmethod def help(cls): ## component help function properties = { diff --git a/ngclearn/components/neurons/graded/rateCell.py b/ngclearn/components/neurons/graded/rateCell.py index a76346b4..31a52ad4 100755 --- a/ngclearn/components/neurons/graded/rateCell.py +++ b/ngclearn/components/neurons/graded/rateCell.py @@ -251,6 +251,18 @@ def advance_state(self, dt): self.z.set(z) self.zF.set(zF) + @compilable + def reset(self, batch_size, shape): #n_units + _shape = (batch_size, shape[0]) + if len(shape) > 1: + _shape = (batch_size, shape[0], shape[1], shape[2]) + restVals = jnp.zeros(_shape) + self.j.set(restVals) + self.j_td.set(restVals) + self.z.set(restVals) + self.zF.set(restVals) + + def save(self, directory, **kwargs): ## do a protected save of constants, depending on whether they are floats or arrays tau_m = (self.tau_m if isinstance(self.tau_m, float) diff --git a/ngclearn/components/neurons/graded/rewardErrorCell.py b/ngclearn/components/neurons/graded/rewardErrorCell.py index a9d43fac..f31ee332 100755 --- a/ngclearn/components/neurons/graded/rewardErrorCell.py +++ b/ngclearn/components/neurons/graded/rewardErrorCell.py @@ -91,6 +91,18 @@ def evolve(self, dt): # Update compartment self.mu.set(mu) + @compilable + def reset(self, batch_size, n_units): + restVals = jnp.zeros((batch_size, n_units)) + mu = restVals + rpe = restVals + accum_reward = restVals + n_ep_steps = jnp.zeros((batch_size, 1)) + self.mu.set(mu) + self.rpe.set(rpe) + self.accum_reward.set(accum_reward) + self.n_ep_steps.set(n_ep_steps) + @classmethod def help(cls): ## component help function properties = { diff --git a/ngclearn/components/other/expKernel.py b/ngclearn/components/other/expKernel.py index 28cf77b8..9901dade 100644 --- a/ngclearn/components/other/expKernel.py +++ b/ngclearn/components/other/expKernel.py @@ -71,13 +71,13 @@ def advance_state(self, t): # Get the variables inputs = self.inputs.get() tf = self.tf.get() - + s = inputs ## update spike time window and corresponding window volume tf, epsp = _apply_kernel( tf, s, t, self.tau_w, self.win_len, krn_start=0, krn_end=self.win_len-1 ) #0:win_len-1) - + # Update compartments self.epsp.set(epsp) self.tf.set(tf) From de100283c34ed8615f0bd4c844102954676e98d6 Mon Sep 17 00:00:00 2001 From: Viet Nguyen Date: Tue, 11 Nov 2025 14:15:04 -0500 Subject: [PATCH 040/121] update patched synapse reset --- .../synapses/patched/hebbianPatchedSynapse.py | 11 +++++++++++ .../components/synapses/patched/patchedSynapse.py | 9 +++++++++ 2 files changed, 20 insertions(+) diff --git a/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py b/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py index 364ad3d8..ac5d5b95 100644 --- a/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py +++ b/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py @@ -276,6 +276,17 @@ def evolve(self): self.dWeights.set(dWeights) self.dBiases.set(dBiases) + @compilable + def reset(self, batch_size, shape): + preVals = jnp.zeros((batch_size, shape[0])) + postVals = jnp.zeros((batch_size, shape[1])) + self.inputs.set(preVals) # inputs + self.outputs.set(postVals) # outputs + self.pre.set(preVals) # pre + self.post.set(postVals) # post + self.dWeights.set(jnp.zeros(shape)) # dW + self.dBiases.set(jnp.zeros(shape[1])) # db + @classmethod def help(cls): ## component help function properties = { diff --git a/ngclearn/components/synapses/patched/patchedSynapse.py b/ngclearn/components/synapses/patched/patchedSynapse.py index bbf23d78..d5872dd9 100644 --- a/ngclearn/components/synapses/patched/patchedSynapse.py +++ b/ngclearn/components/synapses/patched/patchedSynapse.py @@ -153,6 +153,15 @@ def advance_state(self): # Update compartment self.outputs.set(outputs) + @compilable + def reset(self, batch_size, shape): + preVals = jnp.zeros((batch_size, shape[0])) + postVals = jnp.zeros((batch_size, shape[1])) + inputs = preVals + outputs = postVals + self.inputs.set(inputs) + self.outputs.set(outputs) + def save(self, directory, **kwargs): file_name = directory + "/" + self.name + ".npz" if self.bias_init != None: From 80f24171bf160ee9b86fb6330225fa956b65addd Mon Sep 17 00:00:00 2001 From: Viet Nguyen Date: Tue, 11 Nov 2025 15:16:41 -0500 Subject: [PATCH 041/121] add `not self.inputs.targeted and ` to required components. Fixing general `__repr__` bug in `jaxcomponent` --- ngclearn/components/input_encoders/bernoulliCell.py | 2 +- ngclearn/components/input_encoders/latencyCell.py | 2 +- ngclearn/components/input_encoders/phasorCell.py | 2 +- ngclearn/components/jaxComponent.py | 2 +- ngclearn/components/other/expKernel.py | 2 +- ngclearn/components/other/varTrace.py | 4 +++- ngclearn/components/synapses/hebbian/hebbianSynapse.py | 2 +- ngclearn/components/synapses/modulated/REINFORCESynapse.py | 2 +- ngclearn/components/synapses/patched/hebbianPatchedSynapse.py | 2 +- ngclearn/components/synapses/patched/patchedSynapse.py | 2 +- 10 files changed, 12 insertions(+), 10 deletions(-) diff --git a/ngclearn/components/input_encoders/bernoulliCell.py b/ngclearn/components/input_encoders/bernoulliCell.py index 6d1ca713..8661dd9d 100755 --- a/ngclearn/components/input_encoders/bernoulliCell.py +++ b/ngclearn/components/input_encoders/bernoulliCell.py @@ -48,7 +48,7 @@ def advance_state(self, t): @compilable def reset(self): restVals = jnp.zeros((self.batch_size.get(), self.n_units.get())) - self.inputs.set(restVals) + not self.inputs.targeted and self.inputs.set(restVals) self.outputs.set(restVals) self.tols.set(restVals) diff --git a/ngclearn/components/input_encoders/latencyCell.py b/ngclearn/components/input_encoders/latencyCell.py index 374bea78..d0087b07 100755 --- a/ngclearn/components/input_encoders/latencyCell.py +++ b/ngclearn/components/input_encoders/latencyCell.py @@ -211,7 +211,7 @@ def advance_state(self, t): @compilable def reset(self): restVals = jnp.zeros((self.batch_size.get(), self.n_units.get())) - self.inputs.set(restVals) + not self.inputs.targeted and self.inputs.set(restVals) self.outputs.set(restVals) self.tols.set(restVals) self.mask.set(restVals) diff --git a/ngclearn/components/input_encoders/phasorCell.py b/ngclearn/components/input_encoders/phasorCell.py index a9ca1425..9a2175fa 100755 --- a/ngclearn/components/input_encoders/phasorCell.py +++ b/ngclearn/components/input_encoders/phasorCell.py @@ -88,7 +88,7 @@ def advance_state(self, t, dt): @compilable def reset(self): restVals = jnp.zeros((self.batch_size.get(), self.n_units.get())) - self.inputs.set(restVals) + not self.inputs.targeted and self.inputs.set(restVals) self.outputs.set(restVals) self.tols.set(restVals) self.angles.set(restVals) diff --git a/ngclearn/components/jaxComponent.py b/ngclearn/components/jaxComponent.py index 56247900..afa680bc 100755 --- a/ngclearn/components/jaxComponent.py +++ b/ngclearn/components/jaxComponent.py @@ -63,7 +63,7 @@ def __repr__(self): maxlen = max(len(c) for c in comps) + 5 lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" for c in comps: - stats = tensorstats(getattr(self, c).value) + stats = tensorstats(getattr(self, c).get()) if stats is not None: line = [f"{k}: {v}" for k, v in stats.items()] line = ", ".join(line) diff --git a/ngclearn/components/other/expKernel.py b/ngclearn/components/other/expKernel.py index 9901dade..1295fd62 100644 --- a/ngclearn/components/other/expKernel.py +++ b/ngclearn/components/other/expKernel.py @@ -86,7 +86,7 @@ def advance_state(self, t): def reset(self): restVals = jnp.zeros((self.batch_size, self.n_units)) ## inputs, epsp restTensor = jnp.zeros([self.win_len, self.batch_size, self.n_units], jnp.float32) ## tf - self.inputs.set(restVals) + not self.inputs.targeted and self.inputs.set(restVals) self.epsp.set(restVals) self.tf.set(restTensor) diff --git a/ngclearn/components/other/varTrace.py b/ngclearn/components/other/varTrace.py index 8e83bc2d..1e624ba9 100644 --- a/ngclearn/components/other/varTrace.py +++ b/ngclearn/components/other/varTrace.py @@ -1,3 +1,5 @@ +# %% + from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, random, jit from functools import partial @@ -124,7 +126,7 @@ def advance_state(self, dt): @compilable def reset(self): restVals = jnp.zeros((self.batch_size, self.n_units)) - self.inputs.set(restVals) + not self.inputs.targeted and self.inputs.set(restVals) self.outputs.set(restVals) self.trace.set(restVals) diff --git a/ngclearn/components/synapses/hebbian/hebbianSynapse.py b/ngclearn/components/synapses/hebbian/hebbianSynapse.py index 2d333fa8..287b35d6 100644 --- a/ngclearn/components/synapses/hebbian/hebbianSynapse.py +++ b/ngclearn/components/synapses/hebbian/hebbianSynapse.py @@ -257,7 +257,7 @@ def evolve(self): def reset(self, batch_size, shape): preVals = jnp.zeros((batch_size, shape[0])) postVals = jnp.zeros((batch_size, shape[1])) - self.inputs.set(preVals) # inputs + not self.inputs.targeted and self.inputs.set(preVals) # inputs self.outputs.set(postVals) # outputs self.pre.set(preVals) # pre self.post.set(postVals) # post diff --git a/ngclearn/components/synapses/modulated/REINFORCESynapse.py b/ngclearn/components/synapses/modulated/REINFORCESynapse.py index 378becf0..f64b0dc8 100644 --- a/ngclearn/components/synapses/modulated/REINFORCESynapse.py +++ b/ngclearn/components/synapses/modulated/REINFORCESynapse.py @@ -207,7 +207,7 @@ def reset(self, batch_size, shape): seed = jax.random.PRNGKey(42) - self.inputs.set(inputs) + not self.inputs.targeted and self.inputs.set(inputs) self.outputs.set(outputs) self.objective.set(objective) self.rewards.set(rewards) diff --git a/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py b/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py index ac5d5b95..86f55f87 100644 --- a/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py +++ b/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py @@ -280,7 +280,7 @@ def evolve(self): def reset(self, batch_size, shape): preVals = jnp.zeros((batch_size, shape[0])) postVals = jnp.zeros((batch_size, shape[1])) - self.inputs.set(preVals) # inputs + not self.inputs.targeted and self.inputs.set(preVals) # inputs self.outputs.set(postVals) # outputs self.pre.set(preVals) # pre self.post.set(postVals) # post diff --git a/ngclearn/components/synapses/patched/patchedSynapse.py b/ngclearn/components/synapses/patched/patchedSynapse.py index d5872dd9..3960aee2 100644 --- a/ngclearn/components/synapses/patched/patchedSynapse.py +++ b/ngclearn/components/synapses/patched/patchedSynapse.py @@ -159,7 +159,7 @@ def reset(self, batch_size, shape): postVals = jnp.zeros((batch_size, shape[1])) inputs = preVals outputs = postVals - self.inputs.set(inputs) + not self.inputs.targeted and self.inputs.set(inputs) self.outputs.set(outputs) def save(self, directory, **kwargs): From e163c37e5f536300b350e13aa9755c9ca79852e7 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 12 Nov 2025 14:38:11 -0500 Subject: [PATCH 042/121] minor edit to lif/modulated-syn init file --- ngclearn/components/neurons/spiking/LIFCell.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ngclearn/components/neurons/spiking/LIFCell.py b/ngclearn/components/neurons/spiking/LIFCell.py index c0d049cf..850f24a8 100644 --- a/ngclearn/components/neurons/spiking/LIFCell.py +++ b/ngclearn/components/neurons/spiking/LIFCell.py @@ -224,7 +224,6 @@ def reset(self): self.rfr.set(restVals + self.refract_T) self.tols.set(restVals) - @classmethod def help(cls): ## component help function properties = { From 03371ec8d803545eee92494b2556d21c9c4e04c4 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 12 Nov 2025 17:26:49 -0500 Subject: [PATCH 043/121] fixed some minor bugs in rate-coded cells/hebb-syn --- docs/index.rst | 4 +- .../neurons/graded/bernoulliErrorCell.py | 20 +------ .../neurons/graded/gaussianErrorCell.py | 24 ++------ .../neurons/graded/laplacianErrorCell.py | 22 ++------ .../components/neurons/graded/rateCell.py | 55 +++++++++---------- .../neurons/graded/rewardErrorCell.py | 6 +- .../synapses/hebbian/hebbianSynapse.py | 13 +++-- 7 files changed, 51 insertions(+), 93 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 969753d7..bace89f2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -6,8 +6,8 @@ Welcome to ngc-learn's documentation! ===================================== **ngc-learn** is a Python library for building, simulating, and analyzing -biomimetic computational models, arbitrary predictive processing/coding models, -and spiking neural networks. This toolkit is built on top of +biomimetic and NeuroAI computational models, arbitrary predictive processing/coding models, +spiking neural networks, and general dynamical systems. This toolkit is built on top of `JAX `_ and is distributed under the 3-Clause BSD license. .. toctree:: diff --git a/ngclearn/components/neurons/graded/bernoulliErrorCell.py b/ngclearn/components/neurons/graded/bernoulliErrorCell.py index 9b63e8e7..f6666015 100755 --- a/ngclearn/components/neurons/graded/bernoulliErrorCell.py +++ b/ngclearn/components/neurons/graded/bernoulliErrorCell.py @@ -110,10 +110,10 @@ def advance_state(self, dt): ## compute Bernoulli error cell output # @transition(output_compartments=["dp", "dtarget", "target", "p", "modulator", "L", "mask"]) @compilable - def reset(self, batch_size): ## reset core components/statistics - _shape = (batch_size, self.shape[0]) + def reset(self): ## reset core components/statistics + _shape = (self.batch_size, self.shape[0]) if len(self.shape) > 1: - _shape = (batch_size, self.shape[0], self.shape[1], self.shape[2]) + _shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2]) restVals = jnp.zeros(_shape) ## "rest"/reset values dp = restVals dtarget = restVals @@ -161,20 +161,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).get()) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: diff --git a/ngclearn/components/neurons/graded/gaussianErrorCell.py b/ngclearn/components/neurons/graded/gaussianErrorCell.py index 0d78bddf..d757800c 100755 --- a/ngclearn/components/neurons/graded/gaussianErrorCell.py +++ b/ngclearn/components/neurons/graded/gaussianErrorCell.py @@ -111,14 +111,14 @@ def advance_state(self, dt): ## compute Gaussian error cell output # @transition(output_compartments=["dmu", "dtarget", "dSigma", "target", "mu", "modulator", "L", "mask"]) # @staticmethod @compilable - def reset(self, batch_size, shape, sigma_shape): ## reset core components/statistics - _shape = (batch_size, shape[0]) - if len(shape) > 1: - _shape = (batch_size, shape[0], shape[1], shape[2]) + def reset(self): ## reset core components/statistics + _shape = (self.batch_size, self.shape[0]) + if len(self.shape) > 1: + _shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2]) restVals = jnp.zeros(_shape) dmu = restVals dtarget = restVals - dSigma = jnp.zeros(sigma_shape) + dSigma = jnp.zeros(self.sigma_shape) target = restVals mu = restVals modulator = mu + 1. @@ -164,20 +164,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).get()) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: diff --git a/ngclearn/components/neurons/graded/laplacianErrorCell.py b/ngclearn/components/neurons/graded/laplacianErrorCell.py index 04d3b0b8..251b5061 100755 --- a/ngclearn/components/neurons/graded/laplacianErrorCell.py +++ b/ngclearn/components/neurons/graded/laplacianErrorCell.py @@ -103,16 +103,16 @@ def advance_state(self, dt): ## compute Laplacian error cell output self.L.set(jnp.squeeze(L)) self.mask.set(mask) - def reset(self, batch_size, n_units, scale_shape): - restVals = jnp.zeros((batch_size, n_units)) + def reset(self): ## reset core components/statistics + restVals = jnp.zeros((self.batch_size, self.n_units)) dshift = restVals dtarget = restVals - dScale = jnp.zeros(scale_shape) + dScale = jnp.zeros(self.scale_shape) target = restVals shift = restVals modulator = shift + 1. L = 0. - mask = jnp.ones((batch_size, n_units)) + mask = jnp.ones((self.batch_size, self.n_units)) self.dshift.set(dshift) self.dtarget.set(dtarget) @@ -152,20 +152,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).get()) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: diff --git a/ngclearn/components/neurons/graded/rateCell.py b/ngclearn/components/neurons/graded/rateCell.py index 31a52ad4..cb88c430 100755 --- a/ngclearn/components/neurons/graded/rateCell.py +++ b/ngclearn/components/neurons/graded/rateCell.py @@ -252,40 +252,39 @@ def advance_state(self, dt): self.zF.set(zF) @compilable - def reset(self, batch_size, shape): #n_units - _shape = (batch_size, shape[0]) - if len(shape) > 1: - _shape = (batch_size, shape[0], shape[1], shape[2]) + def reset(self): #, batch_size, shape): #n_units + _shape = (self.batch_size, self.shape[0]) + if len(self.shape) > 1: + _shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2]) restVals = jnp.zeros(_shape) self.j.set(restVals) self.j_td.set(restVals) self.z.set(restVals) self.zF.set(restVals) - - def save(self, directory, **kwargs): - ## do a protected save of constants, depending on whether they are floats or arrays - tau_m = (self.tau_m if isinstance(self.tau_m, float) - else jnp.ones([[self.tau_m]])) - priorLeakRate = (self.priorLeakRate if isinstance(self.priorLeakRate, float) - else jnp.ones([[self.priorLeakRate]])) - resist_scale = (self.resist_scale if isinstance(self.resist_scale, float) - else jnp.ones([[self.resist_scale]])) - - file_name = directory + "/" + self.name + ".npz" - jnp.savez(file_name, - tau_m=tau_m, priorLeakRate=priorLeakRate, - resist_scale=resist_scale) #, key=self.key.value) - - def load(self, directory, seeded=False, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - ## constants loaded in - self.tau_m = data['tau_m'] - self.priorLeakRate = data['priorLeakRate'] - self.resist_scale = data['resist_scale'] - #if seeded: - # self.key.set(data['key']) + # def save(self, directory, **kwargs): + # ## do a protected save of constants, depending on whether they are floats or arrays + # tau_m = (self.tau_m if isinstance(self.tau_m, float) + # else jnp.ones([[self.tau_m]])) + # priorLeakRate = (self.priorLeakRate if isinstance(self.priorLeakRate, float) + # else jnp.ones([[self.priorLeakRate]])) + # resist_scale = (self.resist_scale if isinstance(self.resist_scale, float) + # else jnp.ones([[self.resist_scale]])) + # + # file_name = directory + "/" + self.name + ".npz" + # jnp.savez(file_name, + # tau_m=tau_m, priorLeakRate=priorLeakRate, + # resist_scale=resist_scale) #, key=self.key.value) + # + # def load(self, directory, seeded=False, **kwargs): + # file_name = directory + "/" + self.name + ".npz" + # data = jnp.load(file_name) + # ## constants loaded in + # self.tau_m = data['tau_m'] + # self.priorLeakRate = data['priorLeakRate'] + # self.resist_scale = data['resist_scale'] + # #if seeded: + # # self.key.set(data['key']) @classmethod def help(cls): ## component help function diff --git a/ngclearn/components/neurons/graded/rewardErrorCell.py b/ngclearn/components/neurons/graded/rewardErrorCell.py index f31ee332..f2bc7b1a 100755 --- a/ngclearn/components/neurons/graded/rewardErrorCell.py +++ b/ngclearn/components/neurons/graded/rewardErrorCell.py @@ -92,12 +92,12 @@ def evolve(self, dt): self.mu.set(mu) @compilable - def reset(self, batch_size, n_units): - restVals = jnp.zeros((batch_size, n_units)) + def reset(self): ## reset core components/statistics + restVals = jnp.zeros((self.batch_size, self.n_units)) mu = restVals rpe = restVals accum_reward = restVals - n_ep_steps = jnp.zeros((batch_size, 1)) + n_ep_steps = jnp.zeros((self.batch_size, 1)) self.mu.set(mu) self.rpe.set(rpe) self.accum_reward.set(accum_reward) diff --git a/ngclearn/components/synapses/hebbian/hebbianSynapse.py b/ngclearn/components/synapses/hebbian/hebbianSynapse.py index 287b35d6..e94fb279 100644 --- a/ngclearn/components/synapses/hebbian/hebbianSynapse.py +++ b/ngclearn/components/synapses/hebbian/hebbianSynapse.py @@ -254,15 +254,16 @@ def evolve(self): self.dBiases.set(dBiases) @compilable - def reset(self, batch_size, shape): - preVals = jnp.zeros((batch_size, shape[0])) - postVals = jnp.zeros((batch_size, shape[1])) - not self.inputs.targeted and self.inputs.set(preVals) # inputs + def reset(self): #, batch_size, shape): + preVals = jnp.zeros((self.batch_size, self.shape[0])) + postVals = jnp.zeros((self.batch_size, self.shape[1])) + #not self.inputs.targeted and self.inputs.set(preVals) # inputs + self.inputs.set(preVals) self.outputs.set(postVals) # outputs self.pre.set(preVals) # pre self.post.set(postVals) # post - self.dWeights.set(jnp.zeros(shape)) # dW - self.dBiases.set(jnp.zeros(shape[1])) # db + self.dWeights.set(jnp.zeros(self.shape)) # dW + self.dBiases.set(jnp.zeros(self.shape[1])) # db @classmethod def help(cls): ## component help function From c01a619716c2269a8b55230c426c2913ab8538c0 Mon Sep 17 00:00:00 2001 From: Viet Dung Nguyen Date: Wed, 12 Nov 2025 17:34:23 -0500 Subject: [PATCH 044/121] update code --- .../neurons/graded/test_RateCell.py | 24 ++------- .../synapses/hebbian/test_hebbianSynapse.py | 53 +++++++------------ 2 files changed, 23 insertions(+), 54 deletions(-) diff --git a/tests/components/neurons/graded/test_RateCell.py b/tests/components/neurons/graded/test_RateCell.py index bbd91d2b..95260c95 100644 --- a/tests/components/neurons/graded/test_RateCell.py +++ b/tests/components/neurons/graded/test_RateCell.py @@ -1,18 +1,12 @@ # %% from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context import numpy as np np.random.seed(42) from ngclearn.components import RateCell -from ngcsimlib.compilers import compile_command, wrap_command from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context -from ngcsimlib.utils.compartment import Get_Compartment_Batch +from ngclearn import Context, MethodProcess def test_RateCell1(): @@ -26,17 +20,9 @@ def test_RateCell1(): threshold=("none", 0.), integration_type="euler", batch_size=1, resist_scale=1., shape=None, is_stateful=True ) - advance_process = (Process("advance_proc") >> a.advance_state) - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) + reset_process = (MethodProcess("reset_proc") >> a.reset) - # reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - # ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - # advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - # ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - - @Context.dynamicCommand def clamp(x): a.j.set(x) @@ -46,11 +32,11 @@ def clamp(x): y_seq = jnp.asarray([[0.02, 0.04, 0.06, 0.08, 0.09999999999999999, 0.11999999999999998, 0.13999999999999999, 0.15999999999999998, 0.17999999999999998, 0.19999999999999998]], dtype=jnp.float32) outs = [] - ctx.reset() + reset_process.run() for ts in range(x_seq.shape[1]): x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t ctx.clamp(x_t) - ctx.run(t=ts * 1., dt=dt) + advance_process.run(t=ts * 1., dt=dt) outs.append(a.z.value) outs = jnp.concatenate(outs, axis=1) # print(outs) diff --git a/tests/components/synapses/hebbian/test_hebbianSynapse.py b/tests/components/synapses/hebbian/test_hebbianSynapse.py index 35a2b191..1b39ff5a 100644 --- a/tests/components/synapses/hebbian/test_hebbianSynapse.py +++ b/tests/components/synapses/hebbian/test_hebbianSynapse.py @@ -1,18 +1,13 @@ # %% from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context + import numpy as np np.random.seed(42) from ngclearn.components import HebbianSynapse -from ngcsimlib.compilers import compile_command, wrap_command -from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context -from ngcsimlib.utils.compartment import Get_Compartment_Batch +from numpy.testing import assert_array_equal +from ngclearn import Context, MethodProcess def test_hebbianSynapse(): @@ -29,37 +24,23 @@ def test_hebbianSynapse(): with Context(name) as ctx: a = HebbianSynapse( - name="a", - shape=shape, + name="a", + shape=shape, resist_scale=resist_scale, batch_size=batch_size, prior = ("gaussian", 0.01) ) - advance_process = (Process("advance_proc") >> a.advance_state) - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - evolve_process = (Process("evolve_proc") >> a.evolve) - ctx.wrap_and_add_command(jit(evolve_process.pure), name="evolve") - - # Compile and add commands - # reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - # ctx.add_command(wrap_command(jit(reset_cmd)), name="reset") - # advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - # ctx.add_command(wrap_command(jit(advance_cmd)), name="run") - # evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve") - # ctx.add_command(wrap_command(jit(evolve_cmd)), name="evolve") - - @Context.dynamicCommand + advance_process = (MethodProcess("advance_proc") >> a.advance_state) + reset_process = (MethodProcess("reset_proc") >> a.reset) + evolve_process = (MethodProcess("evolve_proc") >> a.evolve) + def clamp_inputs(x): a.inputs.set(x) - @Context.dynamicCommand def clamp_pre(x): a.pre.set(x) - @Context.dynamicCommand def clamp_post(x): a.post.set(x) @@ -70,16 +51,18 @@ def clamp_post(x): in_pre = jnp.ones((1, 10)) * 1.0 in_post = jnp.ones((1, 5)) * 0.75 - ctx.reset() + reset_process.run() clamp_pre(in_pre) clamp_post(in_post) - ctx.run(t=1. * dt, dt=dt) - ctx.evolve(t=1. * dt, dt=dt) + advance_process.run(t=1. * dt, dt=dt) + evolve_process.run(t=1. * dt, dt=dt) - print(a.weights.value) + print(a.weights.get()) # Basic assertions to check learning dynamics - assert a.weights.value.shape == (10, 5), "" - assert a.weights.value[0, 0] == 0.5, "" + assert a.weights.get().shape == (10, 5), "" + assert a.weights.get()[0, 0] == 0.5, "" + +test_hebbianSynapse() + -# test_hebbianSynapse() \ No newline at end of file From 36c56a66010359bf59e74751d035ace591eacaf2 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 12 Nov 2025 18:08:31 -0500 Subject: [PATCH 045/121] minor patches to components, including hebb-syn/conv/deconv and reward-cell --- ngclearn/components/neurons/__init__.py | 2 +- .../components/neurons/graded/rateCell.py | 14 -------- .../neurons/graded/rewardErrorCell.py | 14 -------- .../convolution/hebbianConvSynapse.py | 5 ++- .../convolution/hebbianDeconvSynapse.py | 5 ++- .../components/synapses/hebbian/__init__.py | 2 +- .../synapses/hebbian/hebbianSynapse.py | 32 ++++++------------- 7 files changed, 15 insertions(+), 59 deletions(-) diff --git a/ngclearn/components/neurons/__init__.py b/ngclearn/components/neurons/__init__.py index e7165d7e..1d8bb919 100644 --- a/ngclearn/components/neurons/__init__.py +++ b/ngclearn/components/neurons/__init__.py @@ -5,7 +5,7 @@ from .graded.bernoulliErrorCell import BernoulliErrorCell from .graded.rewardErrorCell import RewardErrorCell ## point to standard spiking cell component types -from .spiking.sLIFCell import SLIFCell +#from .spiking.sLIFCell import SLIFCell from .spiking.IFCell import IFCell from .spiking.LIFCell import LIFCell from .spiking.WTASCell import WTASCell diff --git a/ngclearn/components/neurons/graded/rateCell.py b/ngclearn/components/neurons/graded/rateCell.py index cb88c430..bff095d2 100755 --- a/ngclearn/components/neurons/graded/rateCell.py +++ b/ngclearn/components/neurons/graded/rateCell.py @@ -317,20 +317,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).get()) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: diff --git a/ngclearn/components/neurons/graded/rewardErrorCell.py b/ngclearn/components/neurons/graded/rewardErrorCell.py index f2bc7b1a..479b5c74 100755 --- a/ngclearn/components/neurons/graded/rewardErrorCell.py +++ b/ngclearn/components/neurons/graded/rewardErrorCell.py @@ -134,20 +134,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).get()) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: diff --git a/ngclearn/components/synapses/convolution/hebbianConvSynapse.py b/ngclearn/components/synapses/convolution/hebbianConvSynapse.py index 9ffdf49b..16db48e5 100755 --- a/ngclearn/components/synapses/convolution/hebbianConvSynapse.py +++ b/ngclearn/components/synapses/convolution/hebbianConvSynapse.py @@ -127,9 +127,8 @@ def __init__( ######################################################################## ## set up outer optimization compartments - self.opt_params = Compartment(get_opt_init_fn(optim_type)( - [self.weights.get(), self.biases.get()] - if bias_init else [self.weights.get()]) + self.opt_params = Compartment( + get_opt_init_fn(optim_type)([self.weights.get(), self.biases.get()] if bias_init else [self.weights.get()]) ) def _init(self, batch_size, x_size, shape, stride, padding, pad_args, weights): diff --git a/ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py b/ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py index 86b7ec6e..64bb5313 100755 --- a/ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py +++ b/ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py @@ -115,9 +115,8 @@ def __init__( ######################################################################## ## set up outer optimization compartments - self.opt_params = Compartment(get_opt_init_fn(optim_type)( - [self.weights.get(), self.biases.get()] - if bias_init else [self.weights.get()]) + self.opt_params = Compartment( + get_opt_init_fn(optim_type)([self.weights.get(), self.biases.get()] if bias_init else [self.weights.get()]) ) def _init(self, batch_size, x_size, shape, stride, padding, pad_args, weights): diff --git a/ngclearn/components/synapses/hebbian/__init__.py b/ngclearn/components/synapses/hebbian/__init__.py index 99ebec99..f39d556f 100644 --- a/ngclearn/components/synapses/hebbian/__init__.py +++ b/ngclearn/components/synapses/hebbian/__init__.py @@ -1,4 +1,4 @@ -#from .hebbianSynapse import HebbianSynapse +from .hebbianSynapse import HebbianSynapse from .traceSTDPSynapse import TraceSTDPSynapse from .expSTDPSynapse import ExpSTDPSynapse from .eventSTDPSynapse import EventSTDPSynapse diff --git a/ngclearn/components/synapses/hebbian/hebbianSynapse.py b/ngclearn/components/synapses/hebbian/hebbianSynapse.py index e94fb279..a33616b7 100644 --- a/ngclearn/components/synapses/hebbian/hebbianSynapse.py +++ b/ngclearn/components/synapses/hebbian/hebbianSynapse.py @@ -172,8 +172,7 @@ def __init__( prior=("constant", 0.), w_decay=0., sign_value=1., optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1., resist_scale=1., batch_size=1, **kwargs ): - super().__init__(name, shape, weight_init, bias_init, resist_scale, - p_conn, batch_size=batch_size, **kwargs) + super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, batch_size=batch_size, **kwargs) if w_decay > 0.: prior = ('l2', w_decay) @@ -209,13 +208,14 @@ def __init__( self.dBiases = Compartment(jnp.zeros(shape[1])) #key, subkey = random.split(self.key.value) - self.opt_params = Compartment(get_opt_init_fn(optim_type)( - [self.weights.get(), self.biases.get()] - if bias_init else [self.weights.get()])) + self.opt_params = Compartment( + get_opt_init_fn(optim_type)([self.weights.get(), self.biases.get()] if bias_init else [self.weights.get()]) + ) @staticmethod - def _compute_update(w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, - post_wght, pre, post, weights): + def _compute_update( + w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, post_wght, pre, post, weights + ): ## calculate synaptic update values dW, db = _calc_update( pre, post, weights, w_bound, is_nonnegative=is_nonnegative, @@ -257,8 +257,8 @@ def evolve(self): def reset(self): #, batch_size, shape): preVals = jnp.zeros((self.batch_size, self.shape[0])) postVals = jnp.zeros((self.batch_size, self.shape[1])) - #not self.inputs.targeted and self.inputs.set(preVals) # inputs - self.inputs.set(preVals) + if not self.inputs.targeted: + self.inputs.set(preVals) self.outputs.set(postVals) # outputs self.pre.set(preVals) # pre self.post.set(postVals) # post @@ -310,20 +310,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).get()) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: From eeba01242282aecc6cb644c95bf5ee17636d2740 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 12 Nov 2025 18:10:57 -0500 Subject: [PATCH 046/121] minor patches to components, including hebb-syn/conv/deconv and reward-cell --- ngclearn/components/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ngclearn/components/__init__.py b/ngclearn/components/__init__.py index 96f8a2cf..3f8eda3f 100644 --- a/ngclearn/components/__init__.py +++ b/ngclearn/components/__init__.py @@ -8,7 +8,7 @@ from .neurons.graded.rewardErrorCell import RewardErrorCell ## point to standard spiking cell component types -from .neurons.spiking.sLIFCell import SLIFCell +#from .neurons.spiking.sLIFCell import SLIFCell from .neurons.spiking.IFCell import IFCell from .neurons.spiking.LIFCell import LIFCell from .neurons.spiking.WTASCell import WTASCell @@ -53,13 +53,13 @@ from .synapses.convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse ## point to modulated component types from .synapses.modulated.MSTDPETSynapse import MSTDPETSynapse -from .synapses.modulated.REINFORCESynapse import REINFORCESynapse +#from .synapses.modulated.REINFORCESynapse import REINFORCESynapse ## point to monitors from .monitor import Monitor ## point to patched component types -from .synapses.patched.patchedSynapse import PatchedSynapse -from .synapses.patched.staticPatchedSynapse import StaticPatchedSynapse -from .synapses.patched.hebbianPatchedSynapse import HebbianPatchedSynapse +# from .synapses.patched.patchedSynapse import PatchedSynapse +# from .synapses.patched.staticPatchedSynapse import StaticPatchedSynapse +# from .synapses.patched.hebbianPatchedSynapse import HebbianPatchedSynapse From b98fd1ae5f075e6296ad89bc79d6ad2e209997ca Mon Sep 17 00:00:00 2001 From: Viet Dung Nguyen Date: Fri, 14 Nov 2025 10:19:25 -0500 Subject: [PATCH 047/121] update testing for graded neurons and input encoders --- .../input_encoders/bernoulliCell.py | 4 +- .../components/input_encoders/latencyCell.py | 4 +- .../components/input_encoders/phasorCell.py | 4 +- .../components/input_encoders/poissonCell.py | 5 ++- .../input_encoders/test_bernoulliCell.py | 21 ++++------ .../input_encoders/test_latencyCell.py | 38 +++++++------------ .../input_encoders/test_phasorCell.py | 24 ++++-------- .../input_encoders/test_poissonCell.py | 23 ++++------- .../neurons/graded/test_bernoulliErrorCell.py | 32 ++++------------ .../neurons/graded/test_gaussianErrorCell.py | 30 ++++----------- .../neurons/graded/test_laplacianErrorCell.py | 35 ++++------------- .../neurons/graded/test_rewardErrorCell.py | 34 ++++------------- 12 files changed, 78 insertions(+), 176 deletions(-) diff --git a/ngclearn/components/input_encoders/bernoulliCell.py b/ngclearn/components/input_encoders/bernoulliCell.py index 8661dd9d..1a5c6dca 100755 --- a/ngclearn/components/input_encoders/bernoulliCell.py +++ b/ngclearn/components/input_encoders/bernoulliCell.py @@ -48,7 +48,9 @@ def advance_state(self, t): @compilable def reset(self): restVals = jnp.zeros((self.batch_size.get(), self.n_units.get())) - not self.inputs.targeted and self.inputs.set(restVals) + # BUG: the self.inputs here does not have the targeted field + # NOTE: Quick workaround is to check if targeted is in the input or not + hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals) self.outputs.set(restVals) self.tols.set(restVals) diff --git a/ngclearn/components/input_encoders/latencyCell.py b/ngclearn/components/input_encoders/latencyCell.py index d0087b07..c21c5e08 100755 --- a/ngclearn/components/input_encoders/latencyCell.py +++ b/ngclearn/components/input_encoders/latencyCell.py @@ -211,7 +211,9 @@ def advance_state(self, t): @compilable def reset(self): restVals = jnp.zeros((self.batch_size.get(), self.n_units.get())) - not self.inputs.targeted and self.inputs.set(restVals) + # BUG: the self.inputs here does not have the targeted field + # NOTE: Quick workaround is to check if targeted is in the input or not + hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals) self.outputs.set(restVals) self.tols.set(restVals) self.mask.set(restVals) diff --git a/ngclearn/components/input_encoders/phasorCell.py b/ngclearn/components/input_encoders/phasorCell.py index 9a2175fa..4db8e894 100755 --- a/ngclearn/components/input_encoders/phasorCell.py +++ b/ngclearn/components/input_encoders/phasorCell.py @@ -88,7 +88,9 @@ def advance_state(self, t, dt): @compilable def reset(self): restVals = jnp.zeros((self.batch_size.get(), self.n_units.get())) - not self.inputs.targeted and self.inputs.set(restVals) + # BUG: the self.inputs here does not have the targeted field + # NOTE: Quick workaround is to check if targeted is in the input or not + hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals) self.outputs.set(restVals) self.tols.set(restVals) self.angles.set(restVals) diff --git a/ngclearn/components/input_encoders/poissonCell.py b/ngclearn/components/input_encoders/poissonCell.py index 5eeb057b..47869f1a 100644 --- a/ngclearn/components/input_encoders/poissonCell.py +++ b/ngclearn/components/input_encoders/poissonCell.py @@ -60,8 +60,9 @@ def advance_state(self, t, dt): @compilable def reset(self): restVals = jnp.zeros((self.batch_size, self.n_units)) - if not self.inputs.targeted: - self.inputs.set(restVals) + # BUG: the self.inputs here does not have the targeted field + # NOTE: Quick workaround is to check if targeted is in the input or not + hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals) self.outputs.set(restVals) self.tols.set(restVals) diff --git a/tests/components/input_encoders/test_bernoulliCell.py b/tests/components/input_encoders/test_bernoulliCell.py index f73951b7..a3ba5a9f 100644 --- a/tests/components/input_encoders/test_bernoulliCell.py +++ b/tests/components/input_encoders/test_bernoulliCell.py @@ -1,15 +1,13 @@ +# %% + from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context import numpy as np np.random.seed(42) from ngclearn.components import BernoulliCell #from ngcsimlib.compilers import compile_command, wrap_command from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngclearn.utils import JaxProcess -from ngcsimlib.context import Context -#from ngcsimlib.utils.compartment import Get_Compartment_Batch +from ngclearn import MethodProcess, Context def test_bernoulliCell1(): @@ -23,16 +21,13 @@ def test_bernoulliCell1(): with Context(name) as ctx: a = BernoulliCell(name="a", n_units=1, key=subkeys[0]) - advance_process = (JaxProcess("advance_proc") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") ## set up non-compiled utility commands - @Context.dynamicCommand def clamp(x): a.inputs.set(x) @@ -40,11 +35,11 @@ def clamp(x): x_seq = jnp.asarray([[1., 1., 0., 0., 1.]], dtype=jnp.float32) outs = [] - ctx.reset() + reset_process.run() for ts in range(x_seq.shape[1]): x_t = jnp.array([[x_seq[0,ts]]]) ## get data at time t - ctx.clamp(x_t) - ctx.run(t=ts*1., dt=dt) + clamp(x_t) + advance_process.run(t=ts*1., dt=dt) outs.append(a.outputs.value) outs = jnp.concatenate(outs, axis=1) diff --git a/tests/components/input_encoders/test_latencyCell.py b/tests/components/input_encoders/test_latencyCell.py index 19843e54..ad45145e 100644 --- a/tests/components/input_encoders/test_latencyCell.py +++ b/tests/components/input_encoders/test_latencyCell.py @@ -1,17 +1,11 @@ +# %% + from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context import numpy as np np.random.seed(42) from ngclearn.components import LatencyCell -from ngcsimlib.compilers import compile_command, wrap_command from numpy.testing import assert_array_equal - -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context -from ngcsimlib.utils.compartment import Get_Compartment_Batch - +from ngclearn import MethodProcess, Context def test_latencyCell1(): name = "latency_ctx" @@ -29,23 +23,19 @@ def test_latencyCell1(): ) ## create and compile core simulation commands - advance_process = (Process("advance_proc") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - ctx.wrap_and_add_command(jit(advance_process.pure), name="advance") - calc_spike_times_process = (Process("calc_sptimes_proc") + calc_spike_times_process = (MethodProcess("calc_sptimes_proc") >> a.calc_spike_times) - ctx.wrap_and_add_command(jit(calc_spike_times_process.pure), name="calc_spike_times") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") ## set up non-compiled utility commands - @Context.dynamicCommand def clamp(x): a.inputs.set(x) ## input spike train - inputs = jnp.asarray([[0.02, 0.5, 1., 0.0]]) + x_t = jnp.asarray([[0.02, 0.5, 1., 0.0]]) targets = np.zeros((T, 4)) targets[0, 2] = 1. @@ -55,14 +45,14 @@ def clamp(x): targets = jnp.array(targets) ## gold-standard solution to check against outs = [] - ctx.reset() - ctx.clamp(inputs) - ctx.calc_spike_times() + reset_process.run() + clamp(x_t) + calc_spike_times_process.run() for ts in range(T): - ctx.clamp(inputs) - ctx.advance(t=ts * dt, dt=dt) + clamp(x_t) + advance_process.run(t=ts * dt, dt=dt) ## naively extract simple statistics at time ts and print them to I/O - s = a.outputs.value + s = a.outputs.get() outs.append(s) #print(" {}: s {} ".format(ts, jnp.squeeze(s))) outs = jnp.concatenate(outs, axis=0) @@ -70,4 +60,4 @@ def clamp(x): ## output should equal input assert_array_equal(outs, targets) -#test_latencyCell1() +test_latencyCell1() diff --git a/tests/components/input_encoders/test_phasorCell.py b/tests/components/input_encoders/test_phasorCell.py index 2f4735ac..d9091888 100644 --- a/tests/components/input_encoders/test_phasorCell.py +++ b/tests/components/input_encoders/test_phasorCell.py @@ -1,16 +1,9 @@ from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context import numpy as np np.random.seed(42) from ngclearn.components import PhasorCell -#from ngcsimlib.compilers import compile_command, wrap_command from numpy.testing import assert_array_equal - -from ngcsimlib.compilers.process import Process, transition -#from ngcsimlib.component import Component -#from ngcsimlib.compartment import Compartment -#from ngcsimlib.context import Context -#from ngcsimlib.utils.compartment import Get_Compartment_Batch +from ngclearn import MethodProcess, Context def test_phasorCell1(): @@ -24,16 +17,13 @@ def test_phasorCell1(): with Context(name) as ctx: a = PhasorCell(name="a", n_units=1, target_freq=1000., disable_phasor=True, key=subkeys[0]) - advance_process = (Process("advance_proc") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") ## set up non-compiled utility commands - @Context.dynamicCommand def clamp(x): a.inputs.set(x) @@ -41,12 +31,12 @@ def clamp(x): x_seq = jnp.asarray([[1., 1., 0., 0., 1.]], dtype=jnp.float32) outs = [] - ctx.reset() + reset_process.run() for ts in range(x_seq.shape[1]): x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t - ctx.clamp(x_t) - ctx.run(t=ts * 1., dt=dt) - outs.append(a.outputs.value) + clamp(x_t) + advance_process.run(t=ts * 1., dt=dt) + outs.append(a.outputs.get()) #print(a.outputs.value) outs = jnp.concatenate(outs, axis=1) #print(outs) diff --git a/tests/components/input_encoders/test_poissonCell.py b/tests/components/input_encoders/test_poissonCell.py index 10c05867..f21f062a 100644 --- a/tests/components/input_encoders/test_poissonCell.py +++ b/tests/components/input_encoders/test_poissonCell.py @@ -1,16 +1,10 @@ from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context import numpy as np np.random.seed(42) from ngclearn.components import PoissonCell -from ngcsimlib.compilers import compile_command, wrap_command from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context -from ngcsimlib.utils.compartment import Get_Compartment_Batch +from ngclearn import MethodProcess, Context def test_poissonCell1(): @@ -24,16 +18,13 @@ def test_poissonCell1(): with Context(name) as ctx: a = PoissonCell(name="a", n_units=1, target_freq=1000., key=subkeys[0]) - advance_process = (Process("advance_proc") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") ## set up non-compiled utility commands - @Context.dynamicCommand def clamp(x): a.inputs.set(x) @@ -41,12 +32,12 @@ def clamp(x): x_seq = jnp.asarray([[1., 1., 0., 0., 1.]], dtype=jnp.float32) outs = [] - ctx.reset() + reset_process.run() for ts in range(x_seq.shape[1]): x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t - ctx.clamp(x_t) - ctx.run(t=ts * 1., dt=dt) - outs.append(a.outputs.value) + clamp(x_t) + advance_process.run(t=ts * 1., dt=dt) + outs.append(a.outputs.get()) outs = jnp.concatenate(outs, axis=1) ## output should equal input diff --git a/tests/components/neurons/graded/test_bernoulliErrorCell.py b/tests/components/neurons/graded/test_bernoulliErrorCell.py index 897c6ef3..22e70e3d 100644 --- a/tests/components/neurons/graded/test_bernoulliErrorCell.py +++ b/tests/components/neurons/graded/test_bernoulliErrorCell.py @@ -1,19 +1,10 @@ # %% from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context import numpy as np np.random.seed(42) from ngclearn.components import BernoulliErrorCell -from ngcsimlib.compilers import compile_command, wrap_command -from numpy.testing import assert_array_equal - -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context -from ngcsimlib.utils.compartment import Get_Compartment_Batch - +from ngclearn import MethodProcess, Context def test_bernoulliErrorCell(): np.random.seed(42) @@ -25,21 +16,12 @@ def test_bernoulliErrorCell(): a = BernoulliErrorCell( name="a", n_units=1, batch_size=1, input_logits=False, shape=None ) - advance_process = (Process("advance_proc") >> a.advance_state) - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - - # reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - # ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - # advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - # ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) + reset_process = (MethodProcess("reset_proc") >> a.reset) - @Context.dynamicCommand def clamp(x): a.p.set(x) - @Context.dynamicCommand def clamp_target(x): a.target.set(x) @@ -50,13 +32,13 @@ def clamp_target(x): y_seq = jnp.asarray([[-2.8193381, -4976.9263, -2.1224928, -2939.0425, -1233.3916, -0.24662945, -708.30042, 0.28213939, 3550.8477, 1.3651246]], dtype=jnp.float32) outs = [] - ctx.reset() + reset_process.run() for ts in range(x_seq.shape[1]): x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t - ctx.clamp(x_t) + clamp(x_t) target_xt = jnp.array([[target_seq[0, ts]]]) - ctx.clamp_target(target_xt) - ctx.run(t=ts * 1., dt=dt) + clamp_target(target_xt) + advance_process.run(t=ts * 1., dt=dt) outs.append(a.dp.value) outs = jnp.concatenate(outs, axis=1) # print(outs) diff --git a/tests/components/neurons/graded/test_gaussianErrorCell.py b/tests/components/neurons/graded/test_gaussianErrorCell.py index 1dd2a2e1..b816a31d 100644 --- a/tests/components/neurons/graded/test_gaussianErrorCell.py +++ b/tests/components/neurons/graded/test_gaussianErrorCell.py @@ -1,18 +1,11 @@ # %% from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context import numpy as np np.random.seed(42) from ngclearn.components import GaussianErrorCell -from ngcsimlib.compilers import compile_command, wrap_command -from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context -from ngcsimlib.utils.compartment import Get_Compartment_Batch +from ngclearn import MethodProcess, Context def test_gaussianErrorCell(): @@ -25,21 +18,12 @@ def test_gaussianErrorCell(): a = GaussianErrorCell( name="a", n_units=1, batch_size=1, sigma=1.0, shape=None ) - advance_process = (Process("advance_proc") >> a.advance_state) - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) + reset_process = (MethodProcess("reset_proc") >> a.reset) - # reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - # ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - # advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - # ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - - @Context.dynamicCommand def clamp_mu(x): a.mu.set(x) - @Context.dynamicCommand def clamp_target(x): a.target.set(x) @@ -53,13 +37,13 @@ def clamp_target(x): dmu_outs = [] L_outs = [] - ctx.reset() + reset_process.run() for ts in range(mu_seq.shape[1]): mu_t = jnp.array([[mu_seq[0, ts]]]) ## get data at time t - ctx.clamp_mu(mu_t) + clamp_mu(mu_t) target_t = jnp.array([[target_seq[0, ts]]]) - ctx.clamp_target(target_t) - ctx.run(t=ts * 1., dt=dt) + clamp_target(target_t) + advance_process.run(t=ts * 1., dt=dt) dmu_outs.append(a.dmu.value) L_outs.append(a.L.value) diff --git a/tests/components/neurons/graded/test_laplacianErrorCell.py b/tests/components/neurons/graded/test_laplacianErrorCell.py index 4167bad9..3b19624e 100644 --- a/tests/components/neurons/graded/test_laplacianErrorCell.py +++ b/tests/components/neurons/graded/test_laplacianErrorCell.py @@ -1,19 +1,10 @@ # %% from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context import numpy as np np.random.seed(42) from ngclearn.components import LaplacianErrorCell -from ngcsimlib.compilers import compile_command, wrap_command -from numpy.testing import assert_array_equal - -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context -from ngcsimlib.utils.compartment import Get_Compartment_Batch - +from ngclearn import MethodProcess, Context def test_laplacianErrorCell(): np.random.seed(42) @@ -25,25 +16,15 @@ def test_laplacianErrorCell(): a = LaplacianErrorCell( name="a", n_units=1, batch_size=1, scale=1.0, shape=None ) - advance_process = (Process("advance_proc") >> a.advance_state) - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - - # reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - # ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - # advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - # ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) + reset_process = (MethodProcess("reset_proc") >> a.reset) - @Context.dynamicCommand def clamp_modulator(x): a.modulator.set(x) - @Context.dynamicCommand def clamp_shift(x): a.shift.set(x) - @Context.dynamicCommand def clamp_target(x): a.target.set(x) @@ -59,15 +40,15 @@ def clamp_target(x): dshift_outs = [] L_outs = [] - ctx.reset() + reset_process.run() for ts in range(shift_seq.shape[1]): shift_t = jnp.array([[shift_seq[0, ts]]]) ## get data at time t - ctx.clamp_shift(shift_t) + clamp_shift(shift_t) modulator_t = jnp.array([[modulator_seq[0, ts]]]) - ctx.clamp_modulator(modulator_t) + clamp_modulator(modulator_t) target_t = jnp.array([[target_seq[0, ts]]]) - ctx.clamp_target(target_t) - ctx.run(t=ts * 1., dt=dt) + clamp_target(target_t) + advance_process.run(t=ts * 1., dt=dt) dshift_outs.append(a.dshift.value) # print(f"a.L.value: {a.L.value}") # print(f"a.shift.value: {a.shift.value}") diff --git a/tests/components/neurons/graded/test_rewardErrorCell.py b/tests/components/neurons/graded/test_rewardErrorCell.py index 6ecb7710..6fa328ab 100644 --- a/tests/components/neurons/graded/test_rewardErrorCell.py +++ b/tests/components/neurons/graded/test_rewardErrorCell.py @@ -1,18 +1,11 @@ # %% from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context import numpy as np np.random.seed(42) from ngclearn.components import RewardErrorCell -from ngcsimlib.compilers import compile_command, wrap_command -from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context -from ngcsimlib.utils.compartment import Get_Compartment_Batch +from ngclearn import MethodProcess, Context def test_rewardErrorCell(): @@ -27,21 +20,10 @@ def test_rewardErrorCell(): name="a", n_units=1, alpha=alpha, ema_window_len=10, use_online_predictor=True, batch_size=1 ) - advance_process = (Process("advance_proc") >> a.advance_state) - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - evolve_process = (Process("evolve_proc") >> a.evolve) - ctx.wrap_and_add_command(jit(evolve_process.pure), name="evolve") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) + reset_process = (MethodProcess("reset_proc") >> a.reset) + evolve_process = (MethodProcess("evolve_proc") >> a.evolve) - # reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - # ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - # advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - # ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - # evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve") - # ctx.add_command(wrap_command(jit(ctx.evolve)), name="evolve") - - @Context.dynamicCommand def clamp_reward(x): a.reward.set(x) @@ -71,17 +53,17 @@ def clamp_reward(x): mu_outs = [] rpe_outs = [] accum_reward_outs = [] - ctx.reset() + reset_process.run() for ts in range(reward_seq.shape[1]): reward_t = jnp.array([[reward_seq[0, ts]]]) ## get reward at time t - ctx.clamp_reward(reward_t) - ctx.run(t=ts * 1., dt=dt) + clamp_reward(reward_t) + advance_process.run(t=ts * 1., dt=dt) mu_outs.append(a.mu.value) rpe_outs.append(a.rpe.value) accum_reward_outs.append(a.accum_reward.value) # Test evolve function - ctx.evolve(t=10 * 1., dt=dt) + evolve_process.run(t=10 * 1., dt=dt) final_mu = a.mu.value # print(f"final_mu: {final_mu}") From edc1803d779445d21f970205a6fb628de53d7213 Mon Sep 17 00:00:00 2001 From: Viet Dung Nguyen Date: Fri, 14 Nov 2025 10:30:36 -0500 Subject: [PATCH 048/121] update phasor cell --- .../components/input_encoders/phasorCell.py | 106 +++++++++++++----- .../input_encoders/test_phasorCell.py | 2 +- 2 files changed, 79 insertions(+), 29 deletions(-) diff --git a/ngclearn/components/input_encoders/phasorCell.py b/ngclearn/components/input_encoders/phasorCell.py index 4db8e894..77c7d9c1 100755 --- a/ngclearn/components/input_encoders/phasorCell.py +++ b/ngclearn/components/input_encoders/phasorCell.py @@ -3,10 +3,10 @@ import jax from typing import Union +from ngcsimlib.logger import info, warn from ngcsimlib.compartment import Compartment from ngcsimlib.parser import compilable - class PhasorCell(JaxComponent): """ A phasor cell that emits a pulse at a regular interval. @@ -33,25 +33,19 @@ class PhasorCell(JaxComponent): # Define Functions def __init__( - self, name: str, n_units: int, target_freq: float = 63.75, - batch_size: int = 1, key: Union[jax.Array, None] = None): - super().__init__(name=name, key=key) - - _key, subkey = random.split(self.key.get(), 2) - self.key.set(_key) + self, name, n_units, target_freq=63.75, batch_size=1, disable_phasor=False, **kwargs): + super().__init__(name, **kwargs) ## Phasor meta-parameters - self.target_freq = Compartment(target_freq, fixed=True) ## maximum frequency (in Hertz/Hz) - self.base_scale = Compartment(random.poisson(subkey[0], lam=target_freq, shape=(batch_size, n_units)) / target_freq, fixed=True) + self.target_freq = target_freq ## maximum frequency (in Hertz/Hz) ## Layer Size Setup - self.batch_size = Compartment(batch_size, fixed=True) - self.n_units = Compartment(n_units, fixed=True) - - - + self.batch_size = batch_size + self.n_units = n_units + _key, *subkey = random.split(self.key.get(), 3) + self.key.set(_key) ## Compartment setup - restVals = jnp.zeros((batch_size, n_units)) + restVals = jnp.zeros((self.batch_size, self.n_units)) self.inputs = Compartment(restVals, display_name="Input Stimulus") # input # compartment @@ -60,44 +54,100 @@ def __init__( self.tols = Compartment(initial_value=restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike self.angles = Compartment(restVals, display_name="Angles", units="deg") - + # self.base_scale = random.uniform(subkey, self.angles.value.shape, + # minval=0.75, maxval=1.25) + # self.base_scale = ((random.normal(subkey, self.angles.value.shape) * 0.15) + 1) + # alpha = ((random.normal(subkey, self.angles.value.shape) * (jnp.sqrt(target_freq) / target_freq)) + 1) + # beta = random.poisson(subkey, lam=target_freq, shape=self.angles.value.shape) / target_freq + + self.base_scale = random.poisson(subkey[0], lam=target_freq, shape=self.angles.get().shape) / target_freq + self.disable_phasor = disable_phasor + + def validate(self, dt=None, **validation_kwargs): + valid = super().validate(**validation_kwargs) + if dt is None: + warn(f"{self.name} requires a validation kwarg of `dt`") + return False + ## check for unstable combinations of dt and target-frequency + # meta-params + events_per_timestep = (dt / 1000.) * self.target_freq ## + # compute scaled probability + if events_per_timestep > 1.: + valid = False + warn( + f"{self.name} will be unable to make as many temporal events " + f"as " + f"requested! ({events_per_timestep} events/timestep) Unstable " + f"combination of dt = {dt} and target_freq = " + f"{self.target_freq} " + f"being used!" + ) + return valid + + # @transition(output_compartments=["outputs", "tols", "key", "angles"]) + # @staticmethod @compilable - def advance_state(self, t, dt): + def advance_state(self, t, dt, ): + + inputs = self.inputs.get() + angles = self.angles.get() + tols = self.tols.get() + ms_per_second = 1000 # ms/s - events_per_ms = self.target_freq.get() / ms_per_second # e/s s/ms -> e/ms + events_per_ms = self.target_freq / ms_per_second # e/s s/ms -> e/ms ms_per_event = 1 / events_per_ms # ms/e time_step_per_event = ms_per_event / dt # ms/e * ts/ms -> ts / e angle_per_event = 2 * jnp.pi # rad / e angle_per_timestep = angle_per_event / time_step_per_event # rad / e # * e/ts -> rad / ts key, *subkey = random.split(self.key.get(), 3) + # scatter = random.uniform(subkey, angles.shape, minval=0.5, + # maxval=1.5) * base_scale - scatter = ((random.normal(subkey[0], self.angles.get().shape) * 0.2) + 1) * self.base_scale.get() + scatter = ((random.normal(subkey[0], angles.shape) * 0.2) + 1) * self.base_scale scattered_update = angle_per_timestep * scatter - scaled_scattered_update = scattered_update * self.inputs.get() - - updated_angles = self.angles.get() + scaled_scattered_update - self.outputs.set(jnp.where(updated_angles > angle_per_event, 1., 0.)) + scaled_scattered_update = scattered_update * inputs - self.angles.set(jnp.where(updated_angles > angle_per_event, + updated_angles = angles + scaled_scattered_update + outputs = jnp.where(updated_angles > angle_per_event, 1., 0.) + updated_angles = jnp.where(updated_angles > angle_per_event, updated_angles - angle_per_event, - updated_angles)) + updated_angles) + if self.disable_phasor: + outputs = inputs + 0 + tols = tols * (1. - outputs) + t * outputs + + self.outputs.set(outputs) + self.tols.set(tols) + self.key.set(key) + self.angles.set(updated_angles) - self.tols.set(self.tols.get() * (1. - self.outputs.get()) + t * self.outputs.get()) + # @transition(output_compartments=["inputs", "outputs", "tols", "angles", "key"]) + # @staticmethod @compilable def reset(self): - restVals = jnp.zeros((self.batch_size.get(), self.n_units.get())) + restVals = jnp.zeros((self.batch_size, self.n_units)) + key, *subkey = random.split(self.key.get(), 3) + # BUG: the self.inputs here does not have the targeted field # NOTE: Quick workaround is to check if targeted is in the input or not hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals) self.outputs.set(restVals) self.tols.set(restVals) self.angles.set(restVals) - key, _ = random.split(self.key.get(), 2) self.key.set(key) + def save(self, directory, **kwargs): + file_name = directory + "/" + self.name + ".npz" + jnp.savez(file_name, key=self.key.value) + + def load(self, directory, **kwargs): + file_name = directory + "/" + self.name + ".npz" + data = jnp.load(file_name) + self.key.set(data['key']) + @classmethod def help(cls): ## component help function properties = { diff --git a/tests/components/input_encoders/test_phasorCell.py b/tests/components/input_encoders/test_phasorCell.py index d9091888..d170970b 100644 --- a/tests/components/input_encoders/test_phasorCell.py +++ b/tests/components/input_encoders/test_phasorCell.py @@ -44,4 +44,4 @@ def clamp(x): ## output should equal input assert_array_equal(outs, x_seq) -#test_phasorCell1() +test_phasorCell1() From b96139fb04ec72c9af0954a5db12f20921bc9758 Mon Sep 17 00:00:00 2001 From: Viet Dung Nguyen Date: Fri, 14 Nov 2025 10:33:10 -0500 Subject: [PATCH 049/121] update test bernoulli cell and poisson cell --- tests/components/input_encoders/test_bernoulliCell.py | 4 ++-- tests/components/input_encoders/test_poissonCell.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/components/input_encoders/test_bernoulliCell.py b/tests/components/input_encoders/test_bernoulliCell.py index a3ba5a9f..04349964 100644 --- a/tests/components/input_encoders/test_bernoulliCell.py +++ b/tests/components/input_encoders/test_bernoulliCell.py @@ -40,11 +40,11 @@ def clamp(x): x_t = jnp.array([[x_seq[0,ts]]]) ## get data at time t clamp(x_t) advance_process.run(t=ts*1., dt=dt) - outs.append(a.outputs.value) + outs.append(a.outputs.get()) outs = jnp.concatenate(outs, axis=1) ## output should equal input assert_array_equal(outs, x_seq) #print(outs) -#test_bernoulliCell1() +test_bernoulliCell1() diff --git a/tests/components/input_encoders/test_poissonCell.py b/tests/components/input_encoders/test_poissonCell.py index f21f062a..93c10d35 100644 --- a/tests/components/input_encoders/test_poissonCell.py +++ b/tests/components/input_encoders/test_poissonCell.py @@ -43,4 +43,4 @@ def clamp(x): ## output should equal input assert_array_equal(outs, x_seq) -#test_poissonCell1() +test_poissonCell1() From c80f2b5a93288a10f9af15518552aa70b4006247 Mon Sep 17 00:00:00 2001 From: Viet Dung Nguyen Date: Fri, 14 Nov 2025 10:49:46 -0500 Subject: [PATCH 050/121] update components and their related test cases --- ngclearn/components/other/expKernel.py | 4 +- ngclearn/components/other/varTrace.py | 4 +- .../synapses/patched/hebbianPatchedSynapse.py | 15 +++--- .../synapses/patched/patchedSynapse.py | 16 +++--- tests/components/other/test_expKernel.py | 28 +++------- tests/components/other/test_varTrace.py | 28 ++++------ .../patched/test_hebbianPatchedSynapse.py | 54 ++++++------------- .../synapses/patched/test_patchedSynapse.py | 38 +++++-------- 8 files changed, 72 insertions(+), 115 deletions(-) diff --git a/ngclearn/components/other/expKernel.py b/ngclearn/components/other/expKernel.py index 1295fd62..d434ced5 100644 --- a/ngclearn/components/other/expKernel.py +++ b/ngclearn/components/other/expKernel.py @@ -86,7 +86,9 @@ def advance_state(self, t): def reset(self): restVals = jnp.zeros((self.batch_size, self.n_units)) ## inputs, epsp restTensor = jnp.zeros([self.win_len, self.batch_size, self.n_units], jnp.float32) ## tf - not self.inputs.targeted and self.inputs.set(restVals) + # BUG: the self.inputs here does not have the targeted field + # NOTE: Quick workaround is to check if targeted is in the input or not + hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals) self.epsp.set(restVals) self.tf.set(restTensor) diff --git a/ngclearn/components/other/varTrace.py b/ngclearn/components/other/varTrace.py index 1e624ba9..e4e051f1 100644 --- a/ngclearn/components/other/varTrace.py +++ b/ngclearn/components/other/varTrace.py @@ -126,7 +126,9 @@ def advance_state(self, dt): @compilable def reset(self): restVals = jnp.zeros((self.batch_size, self.n_units)) - not self.inputs.targeted and self.inputs.set(restVals) + # BUG: the self.inputs here does not have the targeted field + # NOTE: Quick workaround is to check if targeted is in the input or not + hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals) self.outputs.set(restVals) self.trace.set(restVals) diff --git a/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py b/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py index 86f55f87..64aabd92 100644 --- a/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py +++ b/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py @@ -277,15 +277,18 @@ def evolve(self): self.dBiases.set(dBiases) @compilable - def reset(self, batch_size, shape): - preVals = jnp.zeros((batch_size, shape[0])) - postVals = jnp.zeros((batch_size, shape[1])) - not self.inputs.targeted and self.inputs.set(preVals) # inputs + def reset(self): + preVals = jnp.zeros((self.batch_size, self.shape[0])) + postVals = jnp.zeros((self.batch_size, self.shape[1])) + # BUG: the self.inputs here does not have the targeted field + # NOTE: Quick workaround is to check if targeted is in the input or not + hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(preVals) # inputs self.outputs.set(postVals) # outputs self.pre.set(preVals) # pre self.post.set(postVals) # post - self.dWeights.set(jnp.zeros(shape)) # dW - self.dBiases.set(jnp.zeros(shape[1])) # db + self.dWeights.set(jnp.zeros(self.shape)) # dW + self.dBiases.set(jnp.zeros(self.shape[1])) # db + @classmethod def help(cls): ## component help function diff --git a/ngclearn/components/synapses/patched/patchedSynapse.py b/ngclearn/components/synapses/patched/patchedSynapse.py index 3960aee2..3ea00475 100644 --- a/ngclearn/components/synapses/patched/patchedSynapse.py +++ b/ngclearn/components/synapses/patched/patchedSynapse.py @@ -154,21 +154,23 @@ def advance_state(self): self.outputs.set(outputs) @compilable - def reset(self, batch_size, shape): - preVals = jnp.zeros((batch_size, shape[0])) - postVals = jnp.zeros((batch_size, shape[1])) + def reset(self): + preVals = jnp.zeros((self.batch_size, self.shape[0])) + postVals = jnp.zeros((self.batch_size, self.shape[1])) inputs = preVals outputs = postVals - not self.inputs.targeted and self.inputs.set(inputs) + # BUG: the self.inputs here does not have the targeted field + # NOTE: Quick workaround is to check if targeted is in the input or not + hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(inputs) self.outputs.set(outputs) def save(self, directory, **kwargs): file_name = directory + "/" + self.name + ".npz" if self.bias_init != None: - jnp.savez(file_name, weights=self.weights.value, - biases=self.biases.value) + jnp.savez(file_name, weights=self.weights.get(), + biases=self.biases.get()) else: - jnp.savez(file_name, weights=self.weights.value) + jnp.savez(file_name, weights=self.weights.get()) def load(self, directory, **kwargs): file_name = directory + "/" + self.name + ".npz" diff --git a/tests/components/other/test_expKernel.py b/tests/components/other/test_expKernel.py index 0ece0bad..9375da66 100644 --- a/tests/components/other/test_expKernel.py +++ b/tests/components/other/test_expKernel.py @@ -1,16 +1,8 @@ from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context import numpy as np np.random.seed(42) from ngclearn.components import ExpKernel -from ngcsimlib.compilers import compile_command, wrap_command -from numpy.testing import assert_array_equal - -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context -from ngcsimlib.utils.compartment import Get_Compartment_Batch +from ngclearn import MethodProcess, Context def test_expKernel1(): name = "expKernel_ctx" @@ -25,16 +17,12 @@ def test_expKernel1(): name="a", n_units=1, dt=1., tau_w=500., nu=4., key=subkeys[0] ) - advance_process = (Process("advance_proc") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") ## set up non-compiled utility commands - @Context.dynamicCommand def clamp(x): a.inputs.set(x) @@ -44,16 +32,16 @@ def clamp(x): y_seq = jnp.asarray([[0., 1., 0.998002, 0.996008, 1.9940181]], dtype=jnp.float32) outs = [] - ctx.reset() + reset_process.run() for ts in range(x_seq.shape[1]): x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t - ctx.clamp(x_t) - ctx.run(t=ts * 1., dt=dt) - outs.append(a.epsp.value) + clamp(x_t) + advance_process.run(t=ts * 1., dt=dt) + outs.append(a.epsp.get()) outs = jnp.concatenate(outs, axis=1) #print(outs) ## output should equal input np.testing.assert_allclose(outs, y_seq, atol=1e-8) -#test_expKernel1() +test_expKernel1() diff --git a/tests/components/other/test_varTrace.py b/tests/components/other/test_varTrace.py index 88444588..8b8ba84d 100644 --- a/tests/components/other/test_varTrace.py +++ b/tests/components/other/test_varTrace.py @@ -1,16 +1,11 @@ from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context import numpy as np np.random.seed(42) from ngclearn.components import VarTrace -from ngcsimlib.compilers import compile_command, wrap_command from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context -from ngcsimlib.utils.compartment import Get_Compartment_Batch +from ngclearn import MethodProcess, Context + def test_varTrace1(): name = "trace_ctx" @@ -26,35 +21,32 @@ def test_varTrace1(): key=subkeys[0] ) - advance_process = (Process("advance_proc") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") ## set up non-compiled utility commands - @Context.dynamicCommand def clamp(x): a.inputs.set(x) ## input spike train x_seq = jnp.asarray([[1., 1., 0., 0., 1.]], dtype=jnp.float32) ## desired output pulses - y_seq = x_seq * trace_increment + y_seq = x_seq * trace_increment outs = [] - ctx.reset() + reset_process.run() for ts in range(x_seq.shape[1]): x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t - ctx.clamp(x_t) - ctx.run(t=ts * 1., dt=dt) - outs.append(a.outputs.value) + clamp(x_t) + advance_process.run(t=ts * 1., dt=dt) + outs.append(a.outputs.get()) outs = jnp.concatenate(outs, axis=1) #print(outs) ## output should equal input assert_array_equal(outs, y_seq) -#test_varTrace1() +test_varTrace1() diff --git a/tests/components/synapses/patched/test_hebbianPatchedSynapse.py b/tests/components/synapses/patched/test_hebbianPatchedSynapse.py index d0997c82..8c2e6396 100644 --- a/tests/components/synapses/patched/test_hebbianPatchedSynapse.py +++ b/tests/components/synapses/patched/test_hebbianPatchedSynapse.py @@ -1,19 +1,12 @@ # %% from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context import numpy as np np.random.seed(42) from ngclearn.components import HebbianPatchedSynapse -from ngcsimlib.compilers import compile_command, wrap_command from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context -from ngcsimlib.utils.compartment import Get_Compartment_Batch - +from ngclearn import MethodProcess, Context def test_hebbianPatchedSynapse(): np.random.seed(42) @@ -31,58 +24,45 @@ def test_hebbianPatchedSynapse(): with Context(name) as ctx: a = HebbianPatchedSynapse( - name="a", - shape=shape, - n_sub_models=n_sub_models, + name="a", + shape=shape, + n_sub_models=n_sub_models, stride_shape=stride_shape, resist_scale=resist_scale, batch_size=batch_size ) - advance_process = (Process("advance_proc") >> a.advance_state) - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - evolve_process = (Process("evolve_proc") >> a.evolve) - ctx.wrap_and_add_command(jit(evolve_process.pure), name="evolve") - - # Compile and add commands - # reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - # ctx.add_command(wrap_command(jit(reset_cmd)), name="reset") - # advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - # ctx.add_command(wrap_command(jit(advance_cmd)), name="run") - # evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve") - # ctx.add_command(wrap_command(jit(evolve_cmd)), name="evolve") - - @Context.dynamicCommand + advance_process = (MethodProcess("advance_proc") >> a.advance_state) + reset_process = (MethodProcess("reset_proc") >> a.reset) + evolve_process = (MethodProcess("evolve_proc") >> a.evolve) + def clamp_inputs(x): a.inputs.set(x) - @Context.dynamicCommand def clamp_pre(x): a.pre.set(x) - @Context.dynamicCommand def clamp_post(x): a.post.set(x) - a.weights.set(jnp.ones((12, 12)) * 0.5) + a.weights.set(jnp.ones((12, 12)) * 0.5) in_pre = jnp.ones((10, 12)) * 1.0 in_post = jnp.ones((10, 12)) * 0.75 - ctx.reset() + reset_process.run() clamp_pre(in_pre) clamp_post(in_post) - ctx.run(t=1. * dt, dt=dt) - ctx.evolve(t=1. * dt, dt=dt) + advance_process.run(t=1. * dt, dt=dt) + evolve_process.run(t=1. * dt, dt=dt) - print(a.weights.value) + print(a.weights.get()) # Basic assertions to check learning dynamics - assert a.weights.value.shape == (12, 12), "" - assert a.weights.value[0, 0] == 0.5, "" + assert a.weights.get().shape == (12, 12), "" + assert a.weights.get()[0, 0] == 0.5, "" + +test_hebbianPatchedSynapse() -# test_hebbianPatchedSynapse() \ No newline at end of file diff --git a/tests/components/synapses/patched/test_patchedSynapse.py b/tests/components/synapses/patched/test_patchedSynapse.py index 8dd99d06..8574819f 100644 --- a/tests/components/synapses/patched/test_patchedSynapse.py +++ b/tests/components/synapses/patched/test_patchedSynapse.py @@ -1,18 +1,12 @@ # %% from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context import numpy as np np.random.seed(42) from ngclearn.components import PatchedSynapse -from ngcsimlib.compilers import compile_command, wrap_command -from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context -from ngcsimlib.utils.compartment import Get_Compartment_Batch +from ngclearn import MethodProcess, Context + def test_patchedSynapse(): @@ -39,31 +33,25 @@ def test_patchedSynapse(): bias_init={"dist": "constant", "value": 0.0} ) - advance_process = (Process("advance_proc") >> a.advance_state) - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - - # Compile and add commands - # reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - # ctx.add_command(wrap_command(jit(reset_cmd)), name="reset") - # advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - # ctx.add_command(wrap_command(jit(advance_cmd)), name="run") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) + reset_process = (MethodProcess("reset_proc") >> a.reset) - @Context.dynamicCommand def clamp_inputs(x): a.inputs.set(x) inputs_seq = jnp.asarray(np.random.randn(1, 12)) - weights = a.weights.value - biases = a.biases.value + weights = a.weights.get() + biases = a.biases.get() expected_outputs = (jnp.matmul(inputs_seq, weights) * resist_scale) + biases outputs_outs = [] - ctx.reset() - ctx.clamp_inputs(inputs_seq) - ctx.run(t=0., dt=dt) - outputs_outs.append(a.outputs.value) + reset_process.run() + clamp_inputs(inputs_seq) + advance_process.run(t=0., dt=dt) + outputs_outs.append(a.outputs.get()) outputs_outs = jnp.concatenate(outputs_outs, axis=1) # Verify outputs match expected values np.testing.assert_allclose(outputs_outs, expected_outputs, atol=1e-5) + +test_patchedSynapse() + From ac2ec138c7dab78f66c9c5e15954bbcb8690793f Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Fri, 14 Nov 2025 12:45:53 -0500 Subject: [PATCH 051/121] fixed monitor bugs from v2, tweaked unit-tests for input-encoders/latency-cell --- ngclearn/components/base_monitor.py | 13 +++++++------ ngclearn/components/monitor.py | 2 +- .../components/input_encoders/test_bernoulliCell.py | 2 +- tests/components/input_encoders/test_latencyCell.py | 2 +- tests/components/input_encoders/test_phasorCell.py | 2 +- tests/components/input_encoders/test_poissonCell.py | 2 +- .../neurons/graded/test_rewardErrorCell.py | 10 +++++----- 7 files changed, 17 insertions(+), 16 deletions(-) diff --git a/ngclearn/components/base_monitor.py b/ngclearn/components/base_monitor.py index 8d7c71d0..b50435e3 100644 --- a/ngclearn/components/base_monitor.py +++ b/ngclearn/components/base_monitor.py @@ -1,8 +1,8 @@ import json -from ngclearn import Component, Compartment, transition +from ngclearn import Component, Compartment #, transition from ngclearn import numpy as np -from ngcsimlib.utils import get_current_path +#from ngcsimlib.utils import get_current_path from ngcsimlib.logger import warn, critical import matplotlib.pyplot as plt @@ -68,7 +68,7 @@ def _record_internal(compartments): "monitor found in ngclearn.components or " "ngclearn.components.lava (If using lava)") - @transition(None, True) + #@transition(None, True) @staticmethod def reset(component): """ @@ -95,7 +95,7 @@ def _reset(**kwargs): # pure func, output compartments, args, params, input compartments return _reset, output_compartments, [], [], output_compartments - @transition(None, True) + #@transition(None, True) @staticmethod def record(component): output_compartments = [] @@ -265,8 +265,9 @@ def load(self, directory, **kwargs): for comp_path, shape in vals["stores"].items(): compartment_path = comp_path.split("/")[-1] - new_path = get_current_path() + "/" + "/".join( - compartment_path.split("*")[-3:-1]) + new_path = "" + # new_path = get_current_path() + "/" + "/".join( + # compartment_path.split("*")[-3:-1]) cs, end = self._add_path(new_path) diff --git a/ngclearn/components/monitor.py b/ngclearn/components/monitor.py index 3b373cf3..d3916f7a 100644 --- a/ngclearn/components/monitor.py +++ b/ngclearn/components/monitor.py @@ -1,5 +1,5 @@ from ngclearn.components.base_monitor import Base_Monitor -from ngclearn import transition +#from ngclearn import transition class Monitor(Base_Monitor): """ diff --git a/tests/components/input_encoders/test_bernoulliCell.py b/tests/components/input_encoders/test_bernoulliCell.py index 04349964..43c616e7 100644 --- a/tests/components/input_encoders/test_bernoulliCell.py +++ b/tests/components/input_encoders/test_bernoulliCell.py @@ -47,4 +47,4 @@ def clamp(x): assert_array_equal(outs, x_seq) #print(outs) -test_bernoulliCell1() +#test_bernoulliCell1() diff --git a/tests/components/input_encoders/test_latencyCell.py b/tests/components/input_encoders/test_latencyCell.py index ad45145e..4abf0552 100644 --- a/tests/components/input_encoders/test_latencyCell.py +++ b/tests/components/input_encoders/test_latencyCell.py @@ -60,4 +60,4 @@ def clamp(x): ## output should equal input assert_array_equal(outs, targets) -test_latencyCell1() +#test_latencyCell1() diff --git a/tests/components/input_encoders/test_phasorCell.py b/tests/components/input_encoders/test_phasorCell.py index d170970b..d9091888 100644 --- a/tests/components/input_encoders/test_phasorCell.py +++ b/tests/components/input_encoders/test_phasorCell.py @@ -44,4 +44,4 @@ def clamp(x): ## output should equal input assert_array_equal(outs, x_seq) -test_phasorCell1() +#test_phasorCell1() diff --git a/tests/components/input_encoders/test_poissonCell.py b/tests/components/input_encoders/test_poissonCell.py index 93c10d35..f21f062a 100644 --- a/tests/components/input_encoders/test_poissonCell.py +++ b/tests/components/input_encoders/test_poissonCell.py @@ -43,4 +43,4 @@ def clamp(x): ## output should equal input assert_array_equal(outs, x_seq) -test_poissonCell1() +#test_poissonCell1() diff --git a/tests/components/neurons/graded/test_rewardErrorCell.py b/tests/components/neurons/graded/test_rewardErrorCell.py index 6fa328ab..e465d07c 100644 --- a/tests/components/neurons/graded/test_rewardErrorCell.py +++ b/tests/components/neurons/graded/test_rewardErrorCell.py @@ -58,13 +58,13 @@ def clamp_reward(x): reward_t = jnp.array([[reward_seq[0, ts]]]) ## get reward at time t clamp_reward(reward_t) advance_process.run(t=ts * 1., dt=dt) - mu_outs.append(a.mu.value) - rpe_outs.append(a.rpe.value) - accum_reward_outs.append(a.accum_reward.value) + mu_outs.append(a.mu.get()) + rpe_outs.append(a.rpe.get()) + accum_reward_outs.append(a.accum_reward.get()) # Test evolve function evolve_process.run(t=10 * 1., dt=dt) - final_mu = a.mu.value + final_mu = a.mu.get() # print(f"final_mu: {final_mu}") mu_outs = jnp.concatenate(mu_outs, axis=1) @@ -85,4 +85,4 @@ def clamp_reward(x): expected_final_mu = (1 - 1/10) * mu_outs[0, -1] + (1/10) * (accum_reward_outs[0, -1] / 10) np.testing.assert_allclose(final_mu, expected_final_mu, atol=1e-5) -# test_rewardErrorCell() \ No newline at end of file +#test_rewardErrorCell() From 0e3f674a959779523efdaf191a19e8bb241a058d Mon Sep 17 00:00:00 2001 From: Viet Dung Nguyen Date: Fri, 14 Nov 2025 12:58:20 -0500 Subject: [PATCH 052/121] update test case for test_sLIFCell.py --- .../neurons/spiking/test_sLIFCell.py | 37 +++++-------------- 1 file changed, 9 insertions(+), 28 deletions(-) diff --git a/tests/components/neurons/spiking/test_sLIFCell.py b/tests/components/neurons/spiking/test_sLIFCell.py index b1b5f517..6c63c028 100644 --- a/tests/components/neurons/spiking/test_sLIFCell.py +++ b/tests/components/neurons/spiking/test_sLIFCell.py @@ -1,16 +1,11 @@ from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context import numpy as np np.random.seed(42) from ngclearn.components import SLIFCell -from ngcsimlib.compilers import compile_command, wrap_command from numpy.testing import assert_array_equal -from ngcsimlib.compilers.process import Process, transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngcsimlib.context import Context -from ngcsimlib.utils.compartment import Get_Compartment_Batch +from ngclearn import MethodProcess, Context + def test_sLIFCell1(): name = "slif_ctx" @@ -25,26 +20,12 @@ def test_sLIFCell1(): name="a", n_units=1, tau_m=50., resist_m=10., thr=0.3, key=subkeys[0] ) - #""" - advance_process = (Process("advance_proc") + advance_process = (MethodProcess("advance_proc") >> a.advance_state) - #ctx.wrap_and_add_command(advance_process.pure, name="run") - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> a.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") - #""" - - """ - reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset") - ctx.add_command(wrap_command(jit(ctx.reset)), name="reset") - advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state") - ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run") - """ ## set up non-compiled utility commands - @Context.dynamicCommand def clamp(x): a.j.set(x) @@ -54,15 +35,15 @@ def clamp(x): y_seq = jnp.asarray([[0., 1., 0., 0., 0., 1., 0.]], dtype=jnp.float32) outs = [] - ctx.reset() + reset_process.run() for ts in range(x_seq.shape[1]): x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t - ctx.clamp(x_t) - ctx.run(t=ts * 1., dt=dt) - outs.append(a.s.value) + clamp(x_t) + advance_process.run(t=ts * 1., dt=dt) + outs.append(a.s.get()) outs = jnp.concatenate(outs, axis=1) ## output should equal input assert_array_equal(outs, y_seq) -#test_sLIFCell1() +test_sLIFCell1() From e0c75fab43fe96499e2eb7b0abca88e7fd08c030 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Fri, 14 Nov 2025 12:59:19 -0500 Subject: [PATCH 053/121] some cleanup --- README.md | 8 +- docs/source/ngclearn.commands.rst | 10 -- .../ngclearn.components.lava.neurons.rst | 21 --- docs/source/ngclearn.components.lava.rst | 31 ---- .../ngclearn.components.lava.synapses.rst | 37 ----- .../ngclearn.components.lava.traces.rst | 21 --- .../ngclearn.components.neurons.graded.rst | 8 ++ docs/source/ngclearn.components.rst | 1 - docs/source/ngclearn.rst | 1 - docs/source/ngclearn.utils.rst | 28 +++- docs/source/ngclearn.utils.viz.rst | 16 +++ ngclearn/__init__.py | 39 ++--- ngclearn/components/__init__.py | 10 +- ngclearn/components/neurons/__init__.py | 2 +- .../components/neurons/spiking/__init__.py | 2 +- .../components/neurons/spiking/sLIFCell.py | 18 +-- ngclearn/modules/regression/elastic_net.py | 134 +++++++++++------- .../neurons/spiking/test_sLIFCell.py | 2 +- 18 files changed, 164 insertions(+), 225 deletions(-) delete mode 100644 docs/source/ngclearn.commands.rst delete mode 100644 docs/source/ngclearn.components.lava.neurons.rst delete mode 100644 docs/source/ngclearn.components.lava.rst delete mode 100644 docs/source/ngclearn.components.lava.synapses.rst delete mode 100644 docs/source/ngclearn.components.lava.traces.rst diff --git a/README.md b/README.md index 7355b0bb..357b00a8 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ ngc-learn requires: 1) Python (>=3.10) 2) NumPy (>=1.22.0) 3) SciPy (>=1.7.0) -4) ngcsimlib (>=1.0.0), (visit official page here) +4) ngcsimlib (>=2.0.0), (visit official page here) 5) JAX (>=0.4.28) (to enable GPU use, make sure to install one of the CUDA variants) --- -ngc-learn 2.0.0 and later require Python 3.10 or newer as well as ngcsimlib >=1.0.0. +ngc-learn 3.0.0 and later require Python 3.10 or newer as well as ngcsimlib >=2.0.0. ngc-learn's plotting capabilities (routines within `ngclearn.utils.viz`) require Matplotlib (>=3.8.0) and imageio (>=2.31.5) and both plotting and density estimation tools (routines within ``ngclearn.utils.density``) will require Scikit-learn (>=0.24.2). @@ -75,7 +75,7 @@ Python 3.11.4 (main, MONTH DAY YEAR, TIME) [GCC XX.X.X] on linux Type "help", "copyright", "credits" or "license" for more information. >>> import ngclearn >>> ngclearn.__version__ -'2.0.0' +'3.0.0' ``` Note: For access to the previous Tensorflow-2 version of ngc-learn (of @@ -122,7 +122,7 @@ $ python install -e . **Version:**
-2.0.2 +3.0.0 Author: Alexander G. Ororbia II
diff --git a/docs/source/ngclearn.commands.rst b/docs/source/ngclearn.commands.rst deleted file mode 100644 index 7b0c40c1..00000000 --- a/docs/source/ngclearn.commands.rst +++ /dev/null @@ -1,10 +0,0 @@ -ngclearn.commands package -========================= - -Module contents ---------------- - -.. automodule:: ngclearn.commands - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/ngclearn.components.lava.neurons.rst b/docs/source/ngclearn.components.lava.neurons.rst deleted file mode 100644 index 9126f5e4..00000000 --- a/docs/source/ngclearn.components.lava.neurons.rst +++ /dev/null @@ -1,21 +0,0 @@ -ngclearn.components.lava.neurons package -======================================== - -Submodules ----------- - -ngclearn.components.lava.neurons.LIFCell module ------------------------------------------------ - -.. automodule:: ngclearn.components.lava.neurons.LIFCell - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: ngclearn.components.lava.neurons - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/ngclearn.components.lava.rst b/docs/source/ngclearn.components.lava.rst deleted file mode 100644 index 6b8be426..00000000 --- a/docs/source/ngclearn.components.lava.rst +++ /dev/null @@ -1,31 +0,0 @@ -ngclearn.components.lava package -================================ - -Subpackages ------------ - -.. toctree:: - :maxdepth: 4 - - ngclearn.components.lava.neurons - ngclearn.components.lava.synapses - ngclearn.components.lava.traces - -Submodules ----------- - -ngclearn.components.lava.monitor module ---------------------------------------- - -.. automodule:: ngclearn.components.lava.monitor - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: ngclearn.components.lava - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/ngclearn.components.lava.synapses.rst b/docs/source/ngclearn.components.lava.synapses.rst deleted file mode 100644 index 2babbb40..00000000 --- a/docs/source/ngclearn.components.lava.synapses.rst +++ /dev/null @@ -1,37 +0,0 @@ -ngclearn.components.lava.synapses package -========================================= - -Submodules ----------- - -ngclearn.components.lava.synapses.hebbianSynapse module -------------------------------------------------------- - -.. automodule:: ngclearn.components.lava.synapses.hebbianSynapse - :members: - :undoc-members: - :show-inheritance: - -ngclearn.components.lava.synapses.staticSynapse module ------------------------------------------------------- - -.. automodule:: ngclearn.components.lava.synapses.staticSynapse - :members: - :undoc-members: - :show-inheritance: - -ngclearn.components.lava.synapses.traceSTDPSynapse module ---------------------------------------------------------- - -.. automodule:: ngclearn.components.lava.synapses.traceSTDPSynapse - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: ngclearn.components.lava.synapses - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/ngclearn.components.lava.traces.rst b/docs/source/ngclearn.components.lava.traces.rst deleted file mode 100644 index e2dbe697..00000000 --- a/docs/source/ngclearn.components.lava.traces.rst +++ /dev/null @@ -1,21 +0,0 @@ -ngclearn.components.lava.traces package -======================================= - -Submodules ----------- - -ngclearn.components.lava.traces.gatedTrace module -------------------------------------------------- - -.. automodule:: ngclearn.components.lava.traces.gatedTrace - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: ngclearn.components.lava.traces - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/ngclearn.components.neurons.graded.rst b/docs/source/ngclearn.components.neurons.graded.rst index d62a5b7e..9eb532c8 100644 --- a/docs/source/ngclearn.components.neurons.graded.rst +++ b/docs/source/ngclearn.components.neurons.graded.rst @@ -28,6 +28,14 @@ ngclearn.components.neurons.graded.laplacianErrorCell module :undoc-members: :show-inheritance: +ngclearn.components.neurons.graded.leakyNoiseCell module +-------------------------------------------------------- + +.. automodule:: ngclearn.components.neurons.graded.leakyNoiseCell + :members: + :undoc-members: + :show-inheritance: + ngclearn.components.neurons.graded.rateCell module -------------------------------------------------- diff --git a/docs/source/ngclearn.components.rst b/docs/source/ngclearn.components.rst index e3209782..05c87f7d 100644 --- a/docs/source/ngclearn.components.rst +++ b/docs/source/ngclearn.components.rst @@ -8,7 +8,6 @@ Subpackages :maxdepth: 4 ngclearn.components.input_encoders - ngclearn.components.lava ngclearn.components.neurons ngclearn.components.other ngclearn.components.synapses diff --git a/docs/source/ngclearn.rst b/docs/source/ngclearn.rst index 814817bb..15e44840 100644 --- a/docs/source/ngclearn.rst +++ b/docs/source/ngclearn.rst @@ -7,7 +7,6 @@ Subpackages .. toctree:: :maxdepth: 4 - ngclearn.commands ngclearn.components ngclearn.modules ngclearn.operations diff --git a/docs/source/ngclearn.utils.rst b/docs/source/ngclearn.utils.rst index b442e626..82816960 100644 --- a/docs/source/ngclearn.utils.rst +++ b/docs/source/ngclearn.utils.rst @@ -16,6 +16,14 @@ Subpackages Submodules ---------- +ngclearn.utils.JaxProcessesMixin module +--------------------------------------- + +.. automodule:: ngclearn.utils.JaxProcessesMixin + :members: + :undoc-members: + :show-inheritance: + ngclearn.utils.data\_loader module ---------------------------------- @@ -24,18 +32,18 @@ ngclearn.utils.data\_loader module :undoc-members: :show-inheritance: -ngclearn.utils.io\_utils module -------------------------------- +ngclearn.utils.distribution\_generator module +--------------------------------------------- -.. automodule:: ngclearn.utils.io_utils +.. automodule:: ngclearn.utils.distribution_generator :members: :undoc-members: :show-inheritance: -ngclearn.utils.jaxProcess module --------------------------------- +ngclearn.utils.io\_utils module +------------------------------- -.. automodule:: ngclearn.utils.jaxProcess +.. automodule:: ngclearn.utils.io_utils :members: :undoc-members: :show-inheritance: @@ -56,6 +64,14 @@ ngclearn.utils.model\_utils module :undoc-members: :show-inheritance: +ngclearn.utils.patch module +--------------------------- + +.. automodule:: ngclearn.utils.patch + :members: + :undoc-members: + :show-inheritance: + ngclearn.utils.patch\_utils module ---------------------------------- diff --git a/docs/source/ngclearn.utils.viz.rst b/docs/source/ngclearn.utils.viz.rst index 0a48f7a8..4c926118 100644 --- a/docs/source/ngclearn.utils.viz.rst +++ b/docs/source/ngclearn.utils.viz.rst @@ -4,6 +4,22 @@ ngclearn.utils.viz package Submodules ---------- +ngclearn.utils.viz.compartment\_plot module +------------------------------------------- + +.. automodule:: ngclearn.utils.viz.compartment_plot + :members: + :undoc-members: + :show-inheritance: + +ngclearn.utils.viz.compartment\_raster module +--------------------------------------------- + +.. automodule:: ngclearn.utils.viz.compartment_raster + :members: + :undoc-members: + :show-inheritance: + ngclearn.utils.viz.dim\_reduce module ------------------------------------- diff --git a/ngclearn/__init__.py b/ngclearn/__init__.py index 8d0c7e10..fa1d4030 100644 --- a/ngclearn/__init__.py +++ b/ngclearn/__init__.py @@ -7,7 +7,7 @@ if sys.version_info.minor < 10: import warnings warnings.warn( - "Running ngclearn and jax in a python version prior to 3.10 may have unintended consequences. Compatability " + "Running ngclearn and jax in a python version prior to 3.10 may have unintended consequences. Compatibility " "with python 3.8 is maintained to allow for lava-nc components and should only be used with those") #required = {'ngcsimlib', 'jax', 'jaxlib'} ## list of core ngclearn dependencies @@ -30,25 +30,26 @@ from ngclearn.utils import JointProcess, MethodProcess from ngcsimlib.context import Context, ContextObjectTypes +from ngcsimlib import Component from ngcsimlib.compartment import Compartment -# from ngclearn.utils.jaxProcess import JaxProcess from ngcsimlib import logger -# if not Path(argv[0]).name == "sphinx-build" or Path(argv[0]).name == "build.py": -# if "readthedocs" not in argv[0]: ## prevent readthedocs execution of preload -# configure() -# logger.init_logging() -# from ngcsimlib.configManager import get_config -# pkg_config = get_config("packages") -# if pkg_config is not None: -# use_base_numpy = pkg_config.get("use_base_numpy", False) -# if use_base_numpy: -# import numpy as numpy -# else: -# from jax import numpy -# else: -# from jax import numpy -# -# -# preload_modules() +if not Path(argv[0]).name == "sphinx-build" or Path(argv[0]).name == "build.py": + if "readthedocs" not in argv[0]: ## prevent readthedocs execution of preload + # configure() + # logger.init_logging() + # from ngcsimlib.configManager import get_config + # pkg_config = get_config("packages") + # if pkg_config is not None: + # use_base_numpy = pkg_config.get("use_base_numpy", False) + # if use_base_numpy: + # import numpy as numpy + # else: + # from jax import numpy + # else: + # from jax import numpy + # + # + # preload_modules() + a = 2 diff --git a/ngclearn/components/__init__.py b/ngclearn/components/__init__.py index 3f8eda3f..96f8a2cf 100644 --- a/ngclearn/components/__init__.py +++ b/ngclearn/components/__init__.py @@ -8,7 +8,7 @@ from .neurons.graded.rewardErrorCell import RewardErrorCell ## point to standard spiking cell component types -#from .neurons.spiking.sLIFCell import SLIFCell +from .neurons.spiking.sLIFCell import SLIFCell from .neurons.spiking.IFCell import IFCell from .neurons.spiking.LIFCell import LIFCell from .neurons.spiking.WTASCell import WTASCell @@ -53,13 +53,13 @@ from .synapses.convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse ## point to modulated component types from .synapses.modulated.MSTDPETSynapse import MSTDPETSynapse -#from .synapses.modulated.REINFORCESynapse import REINFORCESynapse +from .synapses.modulated.REINFORCESynapse import REINFORCESynapse ## point to monitors from .monitor import Monitor ## point to patched component types -# from .synapses.patched.patchedSynapse import PatchedSynapse -# from .synapses.patched.staticPatchedSynapse import StaticPatchedSynapse -# from .synapses.patched.hebbianPatchedSynapse import HebbianPatchedSynapse +from .synapses.patched.patchedSynapse import PatchedSynapse +from .synapses.patched.staticPatchedSynapse import StaticPatchedSynapse +from .synapses.patched.hebbianPatchedSynapse import HebbianPatchedSynapse diff --git a/ngclearn/components/neurons/__init__.py b/ngclearn/components/neurons/__init__.py index 1d8bb919..e7165d7e 100644 --- a/ngclearn/components/neurons/__init__.py +++ b/ngclearn/components/neurons/__init__.py @@ -5,7 +5,7 @@ from .graded.bernoulliErrorCell import BernoulliErrorCell from .graded.rewardErrorCell import RewardErrorCell ## point to standard spiking cell component types -#from .spiking.sLIFCell import SLIFCell +from .spiking.sLIFCell import SLIFCell from .spiking.IFCell import IFCell from .spiking.LIFCell import LIFCell from .spiking.WTASCell import WTASCell diff --git a/ngclearn/components/neurons/spiking/__init__.py b/ngclearn/components/neurons/spiking/__init__.py index 1466af9a..b4c0b3db 100644 --- a/ngclearn/components/neurons/spiking/__init__.py +++ b/ngclearn/components/neurons/spiking/__init__.py @@ -1,5 +1,5 @@ ## point to standard spiking cell component types -# from .sLIFCell import SLIFCell +from .sLIFCell import SLIFCell from .LIFCell import LIFCell from .IFCell import IFCell from .WTASCell import WTASCell diff --git a/ngclearn/components/neurons/spiking/sLIFCell.py b/ngclearn/components/neurons/spiking/sLIFCell.py index b77a9c4f..9d8d1fae 100644 --- a/ngclearn/components/neurons/spiking/sLIFCell.py +++ b/ngclearn/components/neurons/spiking/sLIFCell.py @@ -190,7 +190,7 @@ def advance_state(self, t, dt): j = j * self.R_m if self.inh_R > 0.: ## if inh_R > 0, then lateral inhibition is applied - j = j - (jnp.matmul(spikes, self.inh_weights) * self.inh_R) + j = j - (jnp.matmul(self.s.get(), self.inh_weights) * self.inh_R) ##################################################################################### surrogate = self.d_spike_fx(j, c1=0.82, c2=0.08) ## calc surrogate deriv of spikes @@ -230,12 +230,12 @@ def reset(self): spikes = restVals if not self.thr_persist: ## if thresh non-persistent, reset to base value thr = self.threshold0 + 0 + self.thr.set(thr) # return current, spikes, timeOfLastSpike, voltage, thr, refract, surrogate self.j.set(current) self.s.set(spikes) self.tols.set(timeOfLastSpike) self.v.set(voltage) - self.thr.set(thr) self.rfr.set(refract) self.surrogate.set(surrogate) @@ -291,20 +291,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - def __repr__(self): - comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).get()) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines - if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: diff --git a/ngclearn/modules/regression/elastic_net.py b/ngclearn/modules/regression/elastic_net.py index 9cec8948..faf4fbea 100644 --- a/ngclearn/modules/regression/elastic_net.py +++ b/ngclearn/modules/regression/elastic_net.py @@ -1,13 +1,12 @@ -from jax import random, jit import numpy as np from ngclearn.utils import weight_distribution as dist -from ngclearn import Context, numpy as jnp -from ngclearn.components import (RateCell, - HebbianSynapse, - GaussianErrorCell, - StaticSynapse) -from ngclearn.utils.model_utils import scanner +from ngclearn import numpy as jnp +from jax import numpy as jnp, random, jit +from ngclearn import Context, MethodProcess +from ngclearn.components.synapses.hebbian.hebbianSynapse import HebbianSynapse +from ngclearn.components.neurons.graded.gaussianErrorCell import GaussianErrorCell +from ngcsimlib.global_state import stateManager class Iterative_ElasticNet(): """ @@ -87,43 +86,66 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l self.W.batch_size = batch_size self.err.batch_size = batch_size # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - self.err.mu << self.W.outputs - self.W.post << self.err.dmu + self.W.outputs >> self.err.mu + self.err.dmu >> self.W.post # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - advance_cmd, advance_args =self.circuit.compile_by_key(self.W, ## execute prediction synapses - self.err, ## finally, execute error neurons - compile_key="advance_state") - evolve_cmd, evolve_args =self.circuit.compile_by_key(self.W, compile_key="evolve") - reset_cmd, reset_args =self.circuit.compile_by_key(self.err, self.W, compile_key="reset") - # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - self.dynamic() - - def dynamic(self): ## create dynamic commands forself.circuit - W, err = self.circuit.get_components("W", "err") - self.self = W - self.err = err - - @Context.dynamicCommand - def batch_set(batch_size): - self.W.batch_size = batch_size - self.err.batch_size = batch_size - @Context.dynamicCommand - def clamps(y_scaled, X): - self.W.inputs.set(X) - self.W.pre.set(X) - self.err.target.set(y_scaled) - - self.circuit.wrap_and_add_command(jit(self.circuit.evolve), name="evolve") - self.circuit.wrap_and_add_command(jit(self.circuit.advance_state), name="advance") - self.circuit.wrap_and_add_command(jit(self.circuit.reset), name="reset") - - - @scanner - def _process(compartment_values, args): - _t, _dt = args - compartment_values = self.circuit.advance_state(compartment_values, t=_t, dt=_dt) - return compartment_values, compartment_values[self.W.weights.path] + advance = (MethodProcess(name="advance_state") + >> self.W.advance_state + >> self.err.advance_state) + self.advance = advance + + evolve = (MethodProcess(name="evolve") + >> self.W.evolve) + self.evolve = evolve + + reset = (MethodProcess(name="reset") + >> self.err.reset + >> self.W.reset) + self.reset = reset + + # advance_cmd, advance_args =self.circuit.compile_by_key(self.W, ## execute prediction synapses + # self.err, ## finally, execute error neurons + # compile_key="advance_state") + # evolve_cmd, evolve_args =self.circuit.compile_by_key(self.W, compile_key="evolve") + # reset_cmd, reset_args =self.circuit.compile_by_key(self.err, self.W, compile_key="reset") + # # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # self.dynamic() + + def batch_set(self, batch_size): + self.W.batch_size = batch_size + self.err.batch_size = batch_size + + def clamp(self, y_scaled, X): + self.W.inputs.set(X) + self.W.pre.set(X) + self.err.target.set(y_scaled) + + # def dynamic(self): ## create dynamic commands forself.circuit + # W, err = self.circuit.get_components("W", "err") + # self.self = W + # self.err = err + # + # @Context.dynamicCommand + # def batch_set(batch_size): + # self.W.batch_size = batch_size + # self.err.batch_size = batch_size + # + # @Context.dynamicCommand + # def clamps(y_scaled, X): + # self.W.inputs.set(X) + # self.W.pre.set(X) + # self.err.target.set(y_scaled) + # + # self.circuit.wrap_and_add_command(jit(self.circuit.evolve), name="evolve") + # self.circuit.wrap_and_add_command(jit(self.circuit.advance_state), name="advance") + # self.circuit.wrap_and_add_command(jit(self.circuit.reset), name="reset") + # + # @scanner + # def _process(compartment_values, args): + # _t, _dt = args + # compartment_values = self.circuit.advance_state(compartment_values, t=_t, dt=_dt) + # return compartment_values, compartment_values[self.W.weights.path] def thresholding(self, scale=1.): @@ -138,16 +160,28 @@ def thresholding(self, scale=1.): def fit(self, y, X): - self.circuit.reset() - self.circuit.clamps(y_scaled=y, X=X) + self.reset.run() + self.clamp(y_scaled=y, X=X) for i in range(self.epochs): - self.circuit._process(jnp.array([[self.dt * i, self.dt] for i in range(self.T)])) - self.circuit.evolve(t=self.T, dt=self.dt) - - self.coef_ = np.array(self.W.weights.value) - - return self.coef_, self.err.mu.value, self.err.L.value + inputs = jnp.array(self.advance.pack_rows(self.T, t=lambda x: x, dt=self.dt)) + stateManager.state, outputs = self.advance.scan(inputs) + self.evolve.run(t=self.T, dt=self.dt) + + self.coef_ = np.array(self.W.weights.get()) + + return self.coef_, self.err.mu.get(), self.err.L.get() + + # self.circuit.reset() + # self.circuit.clamps(y_scaled=y, X=X) + # + # for i in range(self.epochs): + # self.circuit._process(jnp.array([[self.dt * i, self.dt] for i in range(self.T)])) + # self.circuit.evolve(t=self.T, dt=self.dt) + # + # self.coef_ = np.array(self.W.weights.value) + # + # return self.coef_, self.err.mu.value, self.err.L.value diff --git a/tests/components/neurons/spiking/test_sLIFCell.py b/tests/components/neurons/spiking/test_sLIFCell.py index 6c63c028..697f1790 100644 --- a/tests/components/neurons/spiking/test_sLIFCell.py +++ b/tests/components/neurons/spiking/test_sLIFCell.py @@ -46,4 +46,4 @@ def clamp(x): ## output should equal input assert_array_equal(outs, y_seq) -test_sLIFCell1() +#test_sLIFCell1() From bdb0ce211235c93ea3a842d00f04fc4e7d0704c0 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Sat, 15 Nov 2025 18:10:06 -0500 Subject: [PATCH 054/121] made revisions to components/clean-up; added back in deprecators --- ngclearn/components/__init__.py | 3 -- .../components/input_encoders/latencyCell.py | 3 +- .../components/input_encoders/phasorCell.py | 1 - ngclearn/components/jaxComponent.py | 3 +- .../neurons/graded/laplacianErrorCell.py | 1 - .../components/neurons/graded/rateCell.py | 1 - ngclearn/components/neurons/spiking/IFCell.py | 2 +- .../components/neurons/spiking/RAFCell.py | 2 +- .../components/neurons/spiking/WTASCell.py | 2 +- .../components/neurons/spiking/adExCell.py | 2 +- .../components/neurons/spiking/quadLIFCell.py | 2 +- .../components/neurons/spiking/sLIFCell.py | 5 ++- ngclearn/components/other/expKernel.py | 1 - ngclearn/components/other/varTrace.py | 1 - ngclearn/components/synapses/__init__.py | 6 +-- ngclearn/components/synapses/alphaSynapse.py | 1 - .../synapses/convolution/convSynapse.py | 1 - .../synapses/convolution/deconvSynapse.py | 1 - .../convolution/hebbianConvSynapse.py | 1 - .../convolution/hebbianDeconvSynapse.py | 1 - .../convolution/traceSTDPConvSynapse.py | 1 - .../convolution/traceSTDPDeconvSynapse.py | 1 - ngclearn/components/synapses/denseSynapse.py | 1 - .../components/synapses/doubleExpSynapse.py | 1 - .../components/synapses/exponentialSynapse.py | 1 - .../components/synapses/hebbian/BCMSynapse.py | 1 - .../synapses/hebbian/eventSTDPSynapse.py | 5 +-- .../synapses/hebbian/expSTDPSynapse.py | 1 - .../synapses/hebbian/hebbianSynapse.py | 7 ++- .../synapses/hebbian/traceSTDPSynapse.py | 1 - .../synapses/modulated/MSTDPETSynapse.py | 4 +- ngclearn/utils/distribution_generator.py | 43 ++++++++++--------- .../neurons/spiking/test_quadLIFCell.py | 3 +- 33 files changed, 45 insertions(+), 65 deletions(-) diff --git a/ngclearn/components/__init__.py b/ngclearn/components/__init__.py index 96f8a2cf..9e7b5481 100644 --- a/ngclearn/components/__init__.py +++ b/ngclearn/components/__init__.py @@ -55,9 +55,6 @@ from .synapses.modulated.MSTDPETSynapse import MSTDPETSynapse from .synapses.modulated.REINFORCESynapse import REINFORCESynapse -## point to monitors -from .monitor import Monitor - ## point to patched component types from .synapses.patched.patchedSynapse import PatchedSynapse from .synapses.patched.staticPatchedSynapse import StaticPatchedSynapse diff --git a/ngclearn/components/input_encoders/latencyCell.py b/ngclearn/components/input_encoders/latencyCell.py index c21c5e08..82b16e52 100755 --- a/ngclearn/components/input_encoders/latencyCell.py +++ b/ngclearn/components/input_encoders/latencyCell.py @@ -143,7 +143,6 @@ class LatencyCell(JaxComponent): batch_size: batch size dimension of this cell (Default: 1) """ - # Define Functions def __init__( self, name: str, n_units: int, tau: float = 1., threshold: float = 0.01, first_spike_time: float = 0., linearize: bool = False, @@ -261,4 +260,4 @@ def help(cls): ## component help function X = LatencyCell("X", 9) print(X) print(X.calc_spike_times.compiled.code) - print(X.advance_state.compiled.code) \ No newline at end of file + print(X.advance_state.compiled.code) diff --git a/ngclearn/components/input_encoders/phasorCell.py b/ngclearn/components/input_encoders/phasorCell.py index 77c7d9c1..ccfbb15d 100755 --- a/ngclearn/components/input_encoders/phasorCell.py +++ b/ngclearn/components/input_encoders/phasorCell.py @@ -31,7 +31,6 @@ class PhasorCell(JaxComponent): batch_size: batch size dimension of this cell (Default: 1) """ - # Define Functions def __init__( self, name, n_units, target_freq=63.75, batch_size=1, disable_phasor=False, **kwargs): super().__init__(name, **kwargs) diff --git a/ngclearn/components/jaxComponent.py b/ngclearn/components/jaxComponent.py index afa680bc..6e2ccec7 100755 --- a/ngclearn/components/jaxComponent.py +++ b/ngclearn/components/jaxComponent.py @@ -70,4 +70,5 @@ def __repr__(self): else: line = "None" lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines \ No newline at end of file + return lines + diff --git a/ngclearn/components/neurons/graded/laplacianErrorCell.py b/ngclearn/components/neurons/graded/laplacianErrorCell.py index 251b5061..56bd5c12 100755 --- a/ngclearn/components/neurons/graded/laplacianErrorCell.py +++ b/ngclearn/components/neurons/graded/laplacianErrorCell.py @@ -37,7 +37,6 @@ class LaplacianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cel to a constant/fixed `scale` """ - # Define Functions def __init__(self, name, n_units, batch_size=1, scale=1., shape=None, **kwargs): super().__init__(name, **kwargs) diff --git a/ngclearn/components/neurons/graded/rateCell.py b/ngclearn/components/neurons/graded/rateCell.py index bff095d2..a03401fc 100755 --- a/ngclearn/components/neurons/graded/rateCell.py +++ b/ngclearn/components/neurons/graded/rateCell.py @@ -160,7 +160,6 @@ class RateCell(JaxComponent): ## Rate-coded/real-valued cell resist_scale: a scaling factor applied to incoming pressure `j` (default: 1) """ - # Define Functions def __init__( self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identity", output_scale=1., threshold=("none", 0.), integration_type="euler", batch_size=1, resist_scale=1., shape=None, is_stateful=True, **kwargs): diff --git a/ngclearn/components/neurons/spiking/IFCell.py b/ngclearn/components/neurons/spiking/IFCell.py index cb94827c..ec87053a 100755 --- a/ngclearn/components/neurons/spiking/IFCell.py +++ b/ngclearn/components/neurons/spiking/IFCell.py @@ -86,7 +86,7 @@ class IFCell(JaxComponent): ## integrate-and-fire cell the value of `v_rest` (default: True) """ - #@deprecate_args(thr_jitter=None) + @deprecate_args(thr_jitter=None) def __init__( self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., refract_time=0., integration_type="euler", surrogate_type="straight_through", lower_clamp_voltage=True, **kwargs diff --git a/ngclearn/components/neurons/spiking/RAFCell.py b/ngclearn/components/neurons/spiking/RAFCell.py index 4b9f25dd..102a97c4 100755 --- a/ngclearn/components/neurons/spiking/RAFCell.py +++ b/ngclearn/components/neurons/spiking/RAFCell.py @@ -97,7 +97,7 @@ class RAFCell(JaxComponent): at an increase in computational cost (and simulation time) """ - #@deprecate_args(resist_m="resist_v", tau_m="tau_v", b="dampen_factor") + @deprecate_args(resist_m="resist_v", tau_m="tau_v", b="dampen_factor") def __init__( self, name, n_units, tau_v=1., tau_w=1., thr=1., omega=10., dampen_factor=-1., v_reset=0., w_reset=0., v0=0., w0=0., resist_v=1., integration_type="euler", batch_size=1, **kwargs diff --git a/ngclearn/components/neurons/spiking/WTASCell.py b/ngclearn/components/neurons/spiking/WTASCell.py index 16a7e4e2..1d8f0a0e 100755 --- a/ngclearn/components/neurons/spiking/WTASCell.py +++ b/ngclearn/components/neurons/spiking/WTASCell.py @@ -50,7 +50,7 @@ class WTASCell(JaxComponent): ## winner-take-all spiking cell thr_jitter: scale of uniform jitter to add to initialization of thresholds """ - #@deprecate_args(thr_base="thrBase") + @deprecate_args(thrBase="thr_base") def __init__( self, name, n_units, tau_m, resist_m=1., thr_base=0.4, thr_gain=0.002, refract_time=0., thr_jitter=0.05, **kwargs diff --git a/ngclearn/components/neurons/spiking/adExCell.py b/ngclearn/components/neurons/spiking/adExCell.py index 0b7b6792..ef05d2c2 100755 --- a/ngclearn/components/neurons/spiking/adExCell.py +++ b/ngclearn/components/neurons/spiking/adExCell.py @@ -94,7 +94,7 @@ class AdExCell(JaxComponent): ## adaptive exponential integrate-and-fire cell at an increase in computational cost (and simulation time) """ - #@deprecate_args(v_thr="thr") + @deprecate_args(v_thr="thr") def __init__( self, name, n_units, tau_m=15., resist_m=1., tau_w=400., v_sharpness=2., intrinsic_mem_thr=-55., thr=5., v_rest=-72., v_reset=-75., a=0.1, b=0.75, v0=-70., w0=0., integration_type="euler", batch_size=1, **kwargs diff --git a/ngclearn/components/neurons/spiking/quadLIFCell.py b/ngclearn/components/neurons/spiking/quadLIFCell.py index af39434b..b8b93982 100755 --- a/ngclearn/components/neurons/spiking/quadLIFCell.py +++ b/ngclearn/components/neurons/spiking/quadLIFCell.py @@ -118,7 +118,7 @@ class QuadLIFCell(LIFCell): ## quadratic integrate-and-fire cell v_min: minimum voltage to clamp dynamics to (Default: None) """ ## batch_size arg? - #@deprecate_args(thr_jitter=None, critical_v="critical_V") + @deprecate_args(thr_jitter=None, critical_V="critical_v") def __init__( self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., v_scale=-41.6, critical_v=1., tau_theta=1e7, theta_plus=0.05, refract_time=5., one_spike=False, integration_type="euler", diff --git a/ngclearn/components/neurons/spiking/sLIFCell.py b/ngclearn/components/neurons/spiking/sLIFCell.py index 9d8d1fae..c644e6a2 100644 --- a/ngclearn/components/neurons/spiking/sLIFCell.py +++ b/ngclearn/components/neurons/spiking/sLIFCell.py @@ -102,7 +102,9 @@ class SLIFCell(JaxComponent): ## leaky integrate-and-fire cell refract_time: relative refractory period time (ms; Default: 1 ms) - rho_b: threshold sparsity factor (Default: 0) + rho_b: threshold sparsity factor (Default: 0); note that setting rho_b > 0 will + force the adaptive threshold to follow dynamics that ignore `thr_grain` and + `thr_leak` sticky_spikes: if True, spike variables will be pinned to action potential value (i.e, 1) throughout duration of the refractory period; this recovers @@ -113,7 +115,6 @@ class SLIFCell(JaxComponent): ## leaky integrate-and-fire cell batch_size: batch size dimension of this cell (Default: 1) """ - # Define Functions def __init__( self, name, n_units, tau_m, resist_m, thr, resist_inh=0., thr_persist=False, thr_gain=0.0, thr_leak=0.0, rho_b=0., refract_time=0., sticky_spikes=False, thr_jitter=0.05, batch_size=1, **kwargs diff --git a/ngclearn/components/other/expKernel.py b/ngclearn/components/other/expKernel.py index d434ced5..a7b25f6a 100644 --- a/ngclearn/components/other/expKernel.py +++ b/ngclearn/components/other/expKernel.py @@ -48,7 +48,6 @@ class ExpKernel(JaxComponent): ## exponential kernel batch_size: batch size dimension of this cell (Default: 1) """ - # Define Functions def __init__(self, name, n_units, dt, tau_w=500., nu=4., batch_size=1, **kwargs): super().__init__(name, **kwargs) diff --git a/ngclearn/components/other/varTrace.py b/ngclearn/components/other/varTrace.py index e4e051f1..f1ddc2bc 100644 --- a/ngclearn/components/other/varTrace.py +++ b/ngclearn/components/other/varTrace.py @@ -77,7 +77,6 @@ class VarTrace(JaxComponent): ## low-pass filter batch_size: batch size dimension of this cell (Default: 1) """ - # Define Functions def __init__(self, name, n_units, tau_tr, a_delta, P_scale=1., gamma_tr=1, decay_type="exp", n_nearest_spikes=0, batch_size=1, key=None): super().__init__(name, key) diff --git a/ngclearn/components/synapses/__init__.py b/ngclearn/components/synapses/__init__.py index 4c060ac9..2c9c9f70 100644 --- a/ngclearn/components/synapses/__init__.py +++ b/ngclearn/components/synapses/__init__.py @@ -32,7 +32,7 @@ # from .modulated.REINFORCESynapse import REINFORCESynapse ## patched synaptic components -# from .patched.patchedSynapse import PatchedSynapse -# from .patched.staticPatchedSynapse import StaticPatchedSynapse -# from .patched.hebbianPatchedSynapse import HebbianPatchedSynapse +from .patched.patchedSynapse import PatchedSynapse +from .patched.staticPatchedSynapse import StaticPatchedSynapse +from .patched.hebbianPatchedSynapse import HebbianPatchedSynapse diff --git a/ngclearn/components/synapses/alphaSynapse.py b/ngclearn/components/synapses/alphaSynapse.py index 8d639b4b..fc529b3f 100644 --- a/ngclearn/components/synapses/alphaSynapse.py +++ b/ngclearn/components/synapses/alphaSynapse.py @@ -60,7 +60,6 @@ class AlphaSynapse(DenseSynapse): ## dynamic alpha synapse cable """ - # Define Functions def __init__( self, name, shape, tau_decay, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., is_nonplastic=True, **kwargs diff --git a/ngclearn/components/synapses/convolution/convSynapse.py b/ngclearn/components/synapses/convolution/convSynapse.py index ed6b83de..62b6ee3a 100755 --- a/ngclearn/components/synapses/convolution/convSynapse.py +++ b/ngclearn/components/synapses/convolution/convSynapse.py @@ -45,7 +45,6 @@ class ConvSynapse(JaxComponent): ## base-level convolutional cable batch_size: batch size dimension of this component """ - # Define Functions def __init__( self, name, shape, x_shape, filter_init=None, bias_init=None, stride=1, padding=None, resist_scale=1., batch_size=1, **kwargs diff --git a/ngclearn/components/synapses/convolution/deconvSynapse.py b/ngclearn/components/synapses/convolution/deconvSynapse.py index a81563b1..32f1dfc8 100755 --- a/ngclearn/components/synapses/convolution/deconvSynapse.py +++ b/ngclearn/components/synapses/convolution/deconvSynapse.py @@ -46,7 +46,6 @@ class DeconvSynapse(JaxComponent): ## base-level deconvolutional cable batch_size: batch size dimension of this component """ - # Define Functions def __init__( self, name, shape, x_shape, filter_init=None, bias_init=None, stride=1, padding=None, resist_scale=1., batch_size=1, **kwargs diff --git a/ngclearn/components/synapses/convolution/hebbianConvSynapse.py b/ngclearn/components/synapses/convolution/hebbianConvSynapse.py index 16db48e5..4b45d2ce 100755 --- a/ngclearn/components/synapses/convolution/hebbianConvSynapse.py +++ b/ngclearn/components/synapses/convolution/hebbianConvSynapse.py @@ -84,7 +84,6 @@ class HebbianConvSynapse(ConvSynapse): ## Hebbian-evolved convolutional cable batch_size: batch size dimension of this component """ - # Define Functions def __init__( self, name, shape, x_shape, eta=0., filter_init=None, bias_init=None, stride=1, padding=None, resist_scale=1., w_bound=0., is_nonnegative=False, w_decay=0., sign_value=1., optim_type="sgd", diff --git a/ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py b/ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py index 64bb5313..35def788 100755 --- a/ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py +++ b/ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py @@ -83,7 +83,6 @@ class HebbianDeconvSynapse(DeconvSynapse): ## Hebbian-evolved deconvolutional ca batch_size: batch size dimension of this component """ - # Define Functions def __init__( self, name, shape, x_shape, eta=0., filter_init=None, bias_init=None, stride=1, padding=None, resist_scale=1., w_bound=0., is_nonnegative=False, w_decay=0., sign_value=1., optim_type="sgd", diff --git a/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py b/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py index a0f74537..c9b4e5f2 100755 --- a/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py +++ b/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py @@ -71,7 +71,6 @@ class TraceSTDPConvSynapse(ConvSynapse): ## trace-based STDP convolutional cable batch_size: batch size dimension of this component """ - # Define Functions def __init__( self, name, shape, x_shape, A_plus, A_minus, eta=0., pretrace_target=0., filter_init=None, stride=1, padding=None, resist_scale=1., w_bound=0., w_decay=0., batch_size=1, **kwargs diff --git a/ngclearn/components/synapses/convolution/traceSTDPDeconvSynapse.py b/ngclearn/components/synapses/convolution/traceSTDPDeconvSynapse.py index 514f8611..a6286cd8 100755 --- a/ngclearn/components/synapses/convolution/traceSTDPDeconvSynapse.py +++ b/ngclearn/components/synapses/convolution/traceSTDPDeconvSynapse.py @@ -69,7 +69,6 @@ class TraceSTDPDeconvSynapse(DeconvSynapse): ## trace-based STDP deconvolutional batch_size: batch size dimension of this component """ - # Define Functions def __init__( self, name, shape, x_shape, A_plus, A_minus, eta=0., pretrace_target=0., filter_init=None, stride=1, padding=None, resist_scale=1., w_bound=0., w_decay=0., batch_size=1, **kwargs diff --git a/ngclearn/components/synapses/denseSynapse.py b/ngclearn/components/synapses/denseSynapse.py index 5b0e71b9..99996e65 100755 --- a/ngclearn/components/synapses/denseSynapse.py +++ b/ngclearn/components/synapses/denseSynapse.py @@ -38,7 +38,6 @@ class DenseSynapse(JaxComponent): ## base dense synaptic cable (lower values yield sparse structure) """ - # Define Functions def __init__( self, name, shape, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., batch_size=1, **kwargs ): diff --git a/ngclearn/components/synapses/doubleExpSynapse.py b/ngclearn/components/synapses/doubleExpSynapse.py index 03135f8c..b5a9a3f0 100644 --- a/ngclearn/components/synapses/doubleExpSynapse.py +++ b/ngclearn/components/synapses/doubleExpSynapse.py @@ -62,7 +62,6 @@ class DoupleExpSynapse(DenseSynapse): ## dynamic double-exponential synapse cabl """ - # Define Functions def __init__( self, name, shape, tau_decay, tau_rise, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., is_nonplastic=True, **kwargs diff --git a/ngclearn/components/synapses/exponentialSynapse.py b/ngclearn/components/synapses/exponentialSynapse.py index d29ec9da..d8ba9b5f 100644 --- a/ngclearn/components/synapses/exponentialSynapse.py +++ b/ngclearn/components/synapses/exponentialSynapse.py @@ -59,7 +59,6 @@ class ExponentialSynapse(DenseSynapse): ## dynamic exponential synapse cable """ - # Define Functions def __init__( self, name, shape, tau_decay, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., is_nonplastic=True, **kwargs diff --git a/ngclearn/components/synapses/hebbian/BCMSynapse.py b/ngclearn/components/synapses/hebbian/BCMSynapse.py index e4f8ddc4..c31bba12 100755 --- a/ngclearn/components/synapses/hebbian/BCMSynapse.py +++ b/ngclearn/components/synapses/hebbian/BCMSynapse.py @@ -64,7 +64,6 @@ class BCMSynapse(DenseSynapse): # BCM-adjusted synaptic cable this to < 1. will result in a sparser synaptic structure """ - # Define Functions def __init__( self, name, shape, tau_w, tau_theta, theta0=-1., w_bound=0., w_decay=0., weight_init=None, resist_scale=1., p_conn=1., batch_size=1, **kwargs diff --git a/ngclearn/components/synapses/hebbian/eventSTDPSynapse.py b/ngclearn/components/synapses/hebbian/eventSTDPSynapse.py index b92522fe..265445e4 100755 --- a/ngclearn/components/synapses/hebbian/eventSTDPSynapse.py +++ b/ngclearn/components/synapses/hebbian/eventSTDPSynapse.py @@ -54,10 +54,9 @@ class EventSTDPSynapse(DenseSynapse): # event-driven, post-synaptic STDP this to < 1. will result in a sparser synaptic structure """ - # Define Functions def __init__( - self, name, shape, eta, lmbda=0.01, A_plus=1., A_minus=1., presyn_win_len=2., w_bound=1., weight_init=None, - resist_scale=1., p_conn=1., batch_size=1, **kwargs + self, name, shape, eta, lmbda=0.01, A_plus=1., A_minus=1., presyn_win_len=2., w_bound=1., + weight_init=None, resist_scale=1., p_conn=1., batch_size=1, **kwargs ): super().__init__(name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs) diff --git a/ngclearn/components/synapses/hebbian/expSTDPSynapse.py b/ngclearn/components/synapses/hebbian/expSTDPSynapse.py index 2a44c3f9..74312c6f 100644 --- a/ngclearn/components/synapses/hebbian/expSTDPSynapse.py +++ b/ngclearn/components/synapses/hebbian/expSTDPSynapse.py @@ -65,7 +65,6 @@ class ExpSTDPSynapse(DenseSynapse): weight_mask: synaptic binary masking matrix to apply (to enforce a constant sparse structure; default: None) """ - # Define Functions def __init__( self, name, shape, A_plus, A_minus, exp_beta, eta=1., pretrace_target=0., weight_init=None, resist_scale=1., p_conn=1., w_bound=1., tau_w=0., weight_mask=None, batch_size=1, **kwargs diff --git a/ngclearn/components/synapses/hebbian/hebbianSynapse.py b/ngclearn/components/synapses/hebbian/hebbianSynapse.py index a33616b7..ff3b796e 100644 --- a/ngclearn/components/synapses/hebbian/hebbianSynapse.py +++ b/ngclearn/components/synapses/hebbian/hebbianSynapse.py @@ -165,12 +165,11 @@ class HebbianSynapse(DenseSynapse): this to < 1. will result in a sparser synaptic structure """ - # Define Functions - # @deprecate_args(_rebind=False, w_decay='prior') + @deprecate_args(_rebind=False, w_decay='prior') def __init__( self, name, shape, eta=0., weight_init=None, bias_init=None, w_bound=1., is_nonnegative=False, - prior=("constant", 0.), w_decay=0., sign_value=1., optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1., - resist_scale=1., batch_size=1, **kwargs + prior=("constant", 0.), w_decay=0., sign_value=1., optim_type="sgd", pre_wght=1., post_wght=1., + p_conn=1., resist_scale=1., batch_size=1, **kwargs ): super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, batch_size=batch_size, **kwargs) diff --git a/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py b/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py index dd51ecf5..59098ed8 100755 --- a/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py +++ b/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py @@ -67,7 +67,6 @@ class TraceSTDPSynapse(DenseSynapse): # power-law / trace-based STDP weight_mask: synaptic binary masking matrix to apply (to enforce a constant sparse structure; default: None) """ - # Define Functions def __init__( self, name, shape, A_plus, A_minus, eta=1., mu=0., pretrace_target=0., weight_init=None, resist_scale=1., p_conn=1., w_bound=1., tau_w=0., weight_mask=None, batch_size=1, **kwargs diff --git a/ngclearn/components/synapses/modulated/MSTDPETSynapse.py b/ngclearn/components/synapses/modulated/MSTDPETSynapse.py index 45c99a7b..bbd7dae3 100755 --- a/ngclearn/components/synapses/modulated/MSTDPETSynapse.py +++ b/ngclearn/components/synapses/modulated/MSTDPETSynapse.py @@ -74,8 +74,8 @@ class MSTDPETSynapse(TraceSTDPSynapse): # modulated trace-based STDP w/ eligilit """ def __init__( - self, name, shape, A_plus, A_minus, eta=1., mu=0., pretrace_target=0., tau_elg=0., elg_decay=1., tau_w=0., - weight_init=None, resist_scale=1., p_conn=1., w_bound=1., batch_size=1, **kwargs + self, name, shape, A_plus, A_minus, eta=1., mu=0., pretrace_target=0., tau_elg=0., elg_decay=1., + tau_w=0., weight_init=None, resist_scale=1., p_conn=1., w_bound=1., batch_size=1, **kwargs ): super().__init__( # call to parent trace-stdp component name, shape, A_plus, A_minus, eta=eta, mu=mu, pretrace_target=pretrace_target, weight_init=weight_init, diff --git a/ngclearn/utils/distribution_generator.py b/ngclearn/utils/distribution_generator.py index f62c32b1..15845334 100644 --- a/ngclearn/utils/distribution_generator.py +++ b/ngclearn/utils/distribution_generator.py @@ -9,22 +9,18 @@ class DistributionParams(TypedDict, total=False): """ - Extra parameters to be used when generating distributions. - - Attributes: - amin: sets the lower bound of the distribution - amax: sets the upper bound of the distribution - lower_triangle: keeps the lower triangle, sets the rest to zero - upper_triangle: keeps the upper triangle, sets the rest to zero - hollow: produces a hollow distribution (zeros along the diagonal) - eye: produces an eye distribution (zeros the off-diagonal) - col_mask: - single value, keeps n random columns - list values, keeps the provided column indices - row_mask: - single value, keeps n random rows - list values, keeps the provided row indices - use_numpy: use default numpy + Extra parameters to be used when generating distributions. (Attributes listed below) + + Args: + amin: sets the lower bound of the distribution + amax: sets the upper bound of the distribution + lower_triangle: keeps the lower triangle, sets the rest to zero + upper_triangle: keeps the upper triangle, sets the rest to zero + hollow: produces a hollow distribution (zeros along the diagonal) + eye: produces an eye distribution (zeros the off-diagonal) + col_mask: single value, keeps n random columns; list values, keeps the provided column indices + row_mask: single value, keeps n random rows; list values, keeps the provided row indices + use_numpy: use default numpy """ amin: float amax: float @@ -54,7 +50,8 @@ def constant(value: float, **params: Unpack[ value: the constant value to fill the array with **params: the extra distribution parameters - Returns: a distribution initializer + Returns: + a distribution initializer """ using_np = params.get("use_numpy", False) if using_np: @@ -87,7 +84,8 @@ def uniform(low: float = 0.0, high: float = 1.0, **params: Unpack[ high: upper bound of the uniform distribution (exclusive) **params: the extra distribution parameters - Returns: a distribution initializer + Returns: + a distribution initializer """ using_np = params.get("use_numpy", False) @@ -133,7 +131,8 @@ def gaussian(mean: float = 0.0, std: float = 1.0, **params: Unpack[ std: the standard deviation of the normal distribution **params: the extra distribution parameters - Returns: a distribution initializer + Returns: + a distribution initializer """ using_numpy = params.get("use_numpy", False) @@ -177,7 +176,8 @@ def fan_in_uniform( Args: **params: extra distribution parameters - Returns: a distribution initializer + Returns: + a distribution initializer """ using_numpy = params.get("use_numpy", False) @@ -236,7 +236,8 @@ def fan_in_gaussian( Args: **params: extra distribution parameters - Returns: a distribution initializer + Returns: + a distribution initializer """ using_numpy = params.get("use_numpy", False) diff --git a/tests/components/neurons/spiking/test_quadLIFCell.py b/tests/components/neurons/spiking/test_quadLIFCell.py index 81414b9c..58756dba 100644 --- a/tests/components/neurons/spiking/test_quadLIFCell.py +++ b/tests/components/neurons/spiking/test_quadLIFCell.py @@ -13,10 +13,11 @@ def test_quadLIFCell1(): dkey = random.PRNGKey(1234) dkey, *subkeys = random.split(dkey, 6) dt = 1. # ms + critical_V = 1. # ---- build a simple Poisson cell system ---- with Context(name) as ctx: a = QuadLIFCell( - name="a", n_units=1, tau_m=30., resist_m=1., key=subkeys[0] + name="a", n_units=1, tau_m=30., resist_m=1., critical_V=critical_V, key=subkeys[0] ) # """ From d1c5e773943740bb0d0e312353ed1c23fd6a01eb Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Sat, 15 Nov 2025 18:12:52 -0500 Subject: [PATCH 055/121] removed lava sub-module, and removed monitor/base-monitor legacy components --- ngclearn/components/base_monitor.py | 331 ------------------ ngclearn/components/lava/__init__.py | 11 - ngclearn/components/lava/monitor.py | 32 -- ngclearn/components/lava/neurons/LIFCell.py | 177 ---------- ngclearn/components/lava/neurons/__init__.py | 1 - ngclearn/components/lava/synapses/__init__.py | 3 - .../lava/synapses/hebbianSynapse.py | 159 --------- .../components/lava/synapses/staticSynapse.py | 122 ------- .../lava/synapses/traceSTDPSynapse.py | 181 ---------- ngclearn/components/lava/traces/__init__.py | 1 - ngclearn/components/lava/traces/gatedTrace.py | 69 ---- ngclearn/components/monitor.py | 31 -- 12 files changed, 1118 deletions(-) delete mode 100644 ngclearn/components/base_monitor.py delete mode 100644 ngclearn/components/lava/__init__.py delete mode 100644 ngclearn/components/lava/monitor.py delete mode 100644 ngclearn/components/lava/neurons/LIFCell.py delete mode 100644 ngclearn/components/lava/neurons/__init__.py delete mode 100644 ngclearn/components/lava/synapses/__init__.py delete mode 100644 ngclearn/components/lava/synapses/hebbianSynapse.py delete mode 100755 ngclearn/components/lava/synapses/staticSynapse.py delete mode 100755 ngclearn/components/lava/synapses/traceSTDPSynapse.py delete mode 100755 ngclearn/components/lava/traces/__init__.py delete mode 100755 ngclearn/components/lava/traces/gatedTrace.py delete mode 100644 ngclearn/components/monitor.py diff --git a/ngclearn/components/base_monitor.py b/ngclearn/components/base_monitor.py deleted file mode 100644 index b50435e3..00000000 --- a/ngclearn/components/base_monitor.py +++ /dev/null @@ -1,331 +0,0 @@ -import json - -from ngclearn import Component, Compartment #, transition -from ngclearn import numpy as np -#from ngcsimlib.utils import get_current_path -from ngcsimlib.logger import warn, critical - -import matplotlib.pyplot as plt - - -class Base_Monitor(Component): - """ - An abstract base for monitors for both ngclearn and ngclava. Compartments - wired directly into this component will have their value tracked during - `advance_state` loops automatically. - - Note the monitor only works for compiled methods currently - - - Using default window length: - myMonitor << myComponent.myCompartment - - Using custom window length: - myMonitor.watch(myComponent.myCompartment, customWindowLength) - - To get values out of the monitor either path to the stored value - directly, or pass in a compartment directly. All - paths are the same as their local path variable. - - Using a compartment: - myMonitor.view(myComponent.myCompartment) - - Using a path: - myMonitor.get_store(myComponent.myCompartment.path).value - - There can only be one monitor in existence at a time due to the way it - interacts with resolvers and the compilers - for ngclearn. - - Args: - name: The name of the component. - - default_window_length: The default window length. - """ - auto_resolve = False - - @staticmethod - def build_reset(component): - return Base_Monitor.reset(component) - - @staticmethod - def build_advance_state(component): - return Base_Monitor.record(component) - - @staticmethod - def _record_internal(compartments): - """ - A method to build the method to advance the stored values. - - Args: - compartments: A list of compartments to store values - - Returns: The method to advance the stored values. - - """ - critical( - "build_advance() is not defined on this monitor, use either the " - "monitor found in ngclearn.components or " - "ngclearn.components.lava (If using lava)") - - #@transition(None, True) - @staticmethod - def reset(component): - """ - A method to build the method to reset the stored values. - Args: - component: The component to resolve - - Returns: the reset resolver - """ - output_compartments = [] - compartments = [] - for comp in component.compartments: - output_compartments.append(comp.split("/")[-1] + "*store") - compartments.append(comp.split("/")[-1]) - - @staticmethod - def _reset(**kwargs): - return_vals = [] - for comp in compartments: - current_store = kwargs[comp + "*store"] - return_vals.append(np.zeros(current_store.shape)) - return return_vals if len(compartments) > 1 else return_vals[0] - - # pure func, output compartments, args, params, input compartments - return _reset, output_compartments, [], [], output_compartments - - #@transition(None, True) - @staticmethod - def record(component): - output_compartments = [] - compartments = [] - for comp in component.compartments: - output_compartments.append(comp.split("/")[-1] + "*store") - compartments.append(comp.split("/")[-1]) - - _advance = component._record_internal(compartments) - - return _advance, output_compartments, [], [], compartments + output_compartments - - def __init__(self, name, default_window_length=100, **kwargs): - super().__init__(name, **kwargs) - self.store = {} - self.compartments = [] - self._sources = [] - self.default_window_length = default_window_length - - def __lshift__(self, other): - if isinstance(other, Compartment): - self.watch(other, self.default_window_length) - else: - warn("Only Compartments can be monitored not", type(other)) - - def watch(self, compartment, window_length): - """ - Sets the monitor to watch a specific compartment, for a specified - window length. - - Args: - compartment: the compartment object to monitor - - window_length: the window length - """ - cs, end = self._add_path(compartment.path) - - if hasattr(compartment.value, "dtype"): - dtype = compartment.value.dtype - else: - dtype = type(compartment.value) - - if hasattr(compartment.value, "shape"): - shape = compartment.value.shape - else: - shape = (1,) - new_comp = Compartment(np.zeros(shape, dtype=dtype)) - new_comp_store = Compartment(np.zeros((window_length, *shape), dtype=dtype)) - - comp_key = "*".join(compartment.path.split("/")) - store_comp_key = comp_key + "*store" - - new_comp._setup(self, comp_key) - new_comp_store._setup(self, store_comp_key) - - new_comp << compartment - - cs[end] = new_comp_store - setattr(self, comp_key, new_comp) - setattr(self, store_comp_key, new_comp_store) - self.compartments.append(new_comp.path) - self._sources.append(compartment) - # self._update_resolver() - - def halt(self, compartment): - """ - Stops the monitor from watching a specific compartment. It is important - to note that it does not stop previously compiled methods. It does not - remove it from the stored values, so it can still be viewed. - Args: - compartment: The compartment object to stop watching - """ - if compartment not in self._sources: - return - - comp_key = "*".join(compartment.path.split("/")) - store_comp_key = comp_key + "*store" - - self.compartments.remove(getattr(self, comp_key).path) - self._sources.remove(compartment) - - delattr(self, comp_key) - delattr(self, store_comp_key) - self._update_resolver() - - def halt_all(self): - """ - Stops the monitor from watching all compartments. - """ - for compartment in self._sources: - self.halt(compartment) - - # def _update_resolver(self): - # output_compartments = [] - # compartments = [] - # for comp in self.compartments: - # output_compartments.append(comp.split("/")[-1] + "*store") - # compartments.append(comp.split("/")[-1]) - # - # args = [] - # parameters = [] - # - # add_component_resolver(self.__class__.__name__, "advance_state", - # (self.build_advance(compartments), - # output_compartments)) - # add_resolver_meta(self.__class__.__name__, "advance_state", - # (args, parameters, - # compartments + [o for o in output_compartments], - # False)) - - # add_component_resolver(self.__class__.__name__, "reset", ( - # self.build_reset(compartments), output_compartments)) - # add_resolver_meta(self.__class__.__name__, "reset", - # (args, parameters, [o for o in output_compartments], - # False)) - - def _add_path(self, path): - _path = path.split("/")[1:] - end = _path.pop(-1) - - current_store = self.store - for p in _path: - if p not in current_store.keys(): - current_store[p] = {} - current_store = current_store[p] - - return current_store, end - - def view(self, compartment): - """ - Gets the value associated with the specified compartment - - Args: - compartment: The compartment to extract the stored value of - - Returns: The stored value, None if not monitoring that compartment - - """ - _path = compartment.path.split("/")[1:] - store = self.get_store(_path) - return store.value if store is not None else store - - def get_store(self, path): - current_store = self.store - for p in path: - if p not in current_store.keys(): - return None - current_store = current_store[p] - return current_store - - def save(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".json" - _dict = {"sources": {}, "stores": {}} - for key in self.compartments: - n = key.split("/")[-1] - _dict["sources"][key] = self.__dict__[n].value.shape - _dict["stores"][key + "*store"] = self.__dict__[ - n + "*store"].value.shape - - with open(file_name, "w") as f: - json.dump(_dict, f) - - def load(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".json" - with open(file_name, "r") as f: - vals = json.load(f) - - for comp_path, shape in vals["stores"].items(): - compartment_path = comp_path.split("/")[-1] - new_path = "" - # new_path = get_current_path() + "/" + "/".join( - # compartment_path.split("*")[-3:-1]) - - cs, end = self._add_path(new_path) - - new_comp = Compartment(np.zeros(shape)) - new_comp._setup(self, compartment_path) - - cs[end] = new_comp - setattr(self, compartment_path, new_comp) - - for comp_path, shape in vals['sources'].items(): - compartment_path = comp_path.split("/")[-1] - new_comp = Compartment(np.zeros(shape)) - new_comp._setup(self, compartment_path) - - setattr(self, compartment_path, new_comp) - self.compartments.append(new_comp.path) - - # self._update_resolver() - - def make_plot(self, compartment, ax=None, ylabel=None, xlabel=None, title=None, n=None, plot_func=None): - vals = self.view(compartment) - - if n is None: - n = vals.shape[2] - if title is None: - title = compartment.name.split("/")[0] + " " + compartment.display_name - - if ylabel is None: - _ylabel = compartment.units - elif ylabel: - _ylabel = ylabel - else: - _ylabel = None - - if xlabel is None: - _xlabel = "Time Steps" - elif xlabel: - _xlabel = xlabel - else: - _xlabel = None - - if ax is None: - _ax = plt - _ax.title(title) - if _ylabel: - _ax.ylabel(_ylabel) - if _xlabel: - _ax.xlabel(_xlabel) - else: - _ax = ax - _ax.set_title(title) - if _ylabel: - _ax.set_ylabel(_ylabel) - if _xlabel: - _ax.set_xlabel(_xlabel) - - if plot_func is None: - for k in range(n): - _ax.plot(vals[:, 0, k]) - else: - plot_func(vals[:, :, 0:n], ax=_ax) diff --git a/ngclearn/components/lava/__init__.py b/ngclearn/components/lava/__init__.py deleted file mode 100644 index 962f843a..00000000 --- a/ngclearn/components/lava/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -## lava-compliant neuronal cells -from .neurons.LIFCell import LIFCell -## lava-compliant synapses -from .synapses.staticSynapse import StaticSynapse -from .synapses.traceSTDPSynapse import TraceSTDPSynapse -from .synapses.hebbianSynapse import HebbianSynapse -## Lava-compliant encoders/traces -from .traces.gatedTrace import GatedTrace - -#monitor -from .monitor import Monitor \ No newline at end of file diff --git a/ngclearn/components/lava/monitor.py b/ngclearn/components/lava/monitor.py deleted file mode 100644 index aaabf8f8..00000000 --- a/ngclearn/components/lava/monitor.py +++ /dev/null @@ -1,32 +0,0 @@ -from ngclearn.components.base_monitor import Base_Monitor - - -class Monitor(Base_Monitor): - """ - A numpy implementation of `Base_Monitor`. Designed to be used with all lava compatible ngclearn components - """ - auto_resolve = False - - - @staticmethod - def build_advance(compartments): - @staticmethod - def _advance(**kwargs): - return_vals = [] - for comp in compartments: - new_val = kwargs[comp] - current_store = kwargs[comp + "*store"] - current_store[:-1] = current_store[1:] - current_store[-1] = new_val - return_vals.append(current_store) - return return_vals if len(compartments) > 1 else return_vals[0] - - return _advance - - @staticmethod - def build_advance_state(component): - return super().build_advance_state(component) - - @staticmethod - def build_reset(component): - return super().build_reset(component) diff --git a/ngclearn/components/lava/neurons/LIFCell.py b/ngclearn/components/lava/neurons/LIFCell.py deleted file mode 100644 index e0ba3641..00000000 --- a/ngclearn/components/lava/neurons/LIFCell.py +++ /dev/null @@ -1,177 +0,0 @@ -from ngclearn import numpy as jnp -from ngcsimlib.logger import info, warn -from ngcsimlib.compilers.process import transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngclearn.utils.weight_distribution import initialize_params -from ngcsimlib.logger import info -from ngclearn.utils import tensorstats - -class LIFCell(Component): ## Lava-compliant leaky integrate-and-fire cell - """ - A spiking cell based on (leaky) integrate-and-fire (LIF) neuronal dynamics. - Note that this cell can be readily configured to pure integrate-and-fire - dynamics as needed. Note that dynamics in this Lava-compliant cell are - hard-coded to move according to Euler integration. - - The specific differential equation that characterize this cell - is (for adjusting v, given current j, over time) is: - - | tau_m * dv/dt = gamma_d * (v_rest - v) + j * R - | where R is the membrane resistance and v_rest is the resting potential - | gamma_d is voltage decay -- 1 recovers LIF dynamics and 0 recovers IF dynamics - - | --- Cell Input Compartments: (Takes wired-in signals) --- - | j_exc - excitatory electrical input - | j_inh - inhibitory electrical input - | --- Cell Output Compartments: (These signals are generated) --- - | v - membrane potential/voltage state - | s - emitted binary spikes/action potentials - | rfr - (relative) refractory variable state - | thr_theta - homeostatic/adaptive threshold increment state - - Args: - name: the string name of this cell - - n_units: number of cellular entities (neural population size) - - dt: integration time constant (ms) - - tau_m: cell membrane time constant - - thr_theta_init: initialization kernel for threshold increment variable - - resist_m: membrane resistance value (Default: 1) - - thr: base value for adaptive thresholds that govern short-term - plasticity (in milliVolts, or mV) - - v_rest: membrane resting potential (in mV) - - v_reset: membrane reset potential (in mV) -- upon occurrence of a spike, - a neuronal cell's membrane potential will be set to this value - - v_decay: decay factor applied to voltage leak (Default: 1.); setting this - to 0 mV results in pure integrate-and-fire (IF) dynamics - - tau_theta: homeostatic threshold time constant - - theta_plus: physical increment to be applied to any threshold value if - a spike was emitted - - refract_time: relative refractory period time (ms; Default: 1 ms) - - thr_theta0: (DEPRECATED) initial conditions for voltage threshold - """ - - # Define Functions - def __init__(self, name, n_units, dt, tau_m, thr_theta_init=None, resist_m=1., - thr=-52., v_rest=-65., v_reset=-60., v_decay=1., tau_theta=1e7, - theta_plus=0.05, refract_time=5., thr_theta0=None, **kwargs): - super().__init__(name, **kwargs) - - ## Cell dynamics setup - self.dt = dt - self.tau_m = tau_m ## membrane time constant - self.R_m = resist_m ## resistance value - if kwargs.get("R_m") is not None: - warn("The argument `R_m` being used is deprecated.") - self.Rscale = kwargs.get("R_m") - self.v_rest = v_rest # mV - self.v_reset = v_reset # mV (milli-volts) - self.v_decay = v_decay - ## basic asserts to prevent neuronal dynamics breaking... - assert (self.v_decay * self.dt / self.tau_m) <= 1. - assert self.R_m > 0. - self.tau_theta = tau_theta ## threshold time constant # ms (0 turns off) - self.theta_plus = theta_plus ## threshold increment - self.refract_T = refract_time ## refractory period # ms - self.thr = thr ## (fixed) base value for threshold # mV - self.thr_theta_init = thr_theta_init - self.thr_theta0 = thr_theta0 ## initial jittered adaptive threshold values - - ## Component size setup - self.batch_size = 1 - self.n_units = n_units - - ## Compartment setup - restVals = jnp.zeros((self.batch_size, self.n_units)) - self.j_exc = Compartment(restVals) - self.j_inh = Compartment(restVals) - self.v = Compartment(restVals + self.v_rest) - self.s = Compartment(restVals) - self.rfr = Compartment(restVals + self.refract_T) - self.thr_theta = Compartment(None) - - if thr_theta0 is not None: - warn("The argument `thr_theta0` being used is deprecated.") - self._init(thr_theta0) - else: - if self.thr_theta_init is None: - info(self.name, "is using default threshold variable initializer!") - self.thr_theta_init = {"dist": "constant", "value": 0.} - thr_theta0 = initialize_params(None, self.thr_theta_init, (1, self.n_units)) - self._init(thr_theta0) - - def _init(self, thr_theta0): - self.thr_theta.set(thr_theta0) - - @transition(output_compartments=["v", "s", "rfr", "thr_theta"]) - @staticmethod - def advance_state(dt, tau_m, R_m, v_rest, v_reset, v_decay, refract_T, thr, tau_theta, - theta_plus, j_exc, j_inh, v, s, rfr, thr_theta): - #j = j * (tau_m/dt) ## scale electrical current - j = j_exc - j_inh ## sum the excitatory and inhibitory input channels - mask = (rfr >= refract_T) * 1. #numpy.greater_equal(rfr, refract_T) * 1. - ## update voltage / membrane potential - ### note: the ODE is a bit differently formulated here than usual - dv_dt = (v_rest - v) * v_decay * (dt/tau_m) + ((j * R_m) * mask) - v = v + dv_dt ### hard-coded Euler integration - ## obtain action potentials/spikes - s = (v > (thr + thr_theta)) * 1. #numpy.greater_equal(v, thr + thr_theta) * 1. - ## update refractory variables - rfr = (rfr + dt) * (1. - s) - ## perform hyper-polarization of neuronal cells - v = v * (1. - s) + s * v_reset - ## update adaptive threshold variables - theta_decay = jnp.exp(-dt/tau_theta) - thr_theta = thr_theta * theta_decay + s * theta_plus - ## update time-of-last-spike - #tols = (1. - s) * tols + (s * t) - return v, s, rfr, thr_theta #, tols - - @transition(output_compartments=["j_exc", "j_inh", "v", "s", "rfr"]) - @staticmethod - def reset(batch_size, n_units, v_rest, refract_T): - restVals = jnp.zeros((batch_size, n_units)) - j_exc = restVals #+ 0 - j_inh = restVals #+ 0 - v = restVals + v_rest - s = restVals #+ 0 - rfr = restVals + refract_T - return j_exc, j_inh, v, s, rfr #, tols - - def save(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - jnp.savez(file_name, - threshold_theta=self.thr_theta.value) - - def load(self, directory, seeded=False, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - self._init( data['threshold_theta'] ) - - - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines diff --git a/ngclearn/components/lava/neurons/__init__.py b/ngclearn/components/lava/neurons/__init__.py deleted file mode 100644 index e28ed0f8..00000000 --- a/ngclearn/components/lava/neurons/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .LIFCell import LIFCell diff --git a/ngclearn/components/lava/synapses/__init__.py b/ngclearn/components/lava/synapses/__init__.py deleted file mode 100644 index bd7f9ea3..00000000 --- a/ngclearn/components/lava/synapses/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .staticSynapse import StaticSynapse -from .hebbianSynapse import HebbianSynapse -from .traceSTDPSynapse import TraceSTDPSynapse diff --git a/ngclearn/components/lava/synapses/hebbianSynapse.py b/ngclearn/components/lava/synapses/hebbianSynapse.py deleted file mode 100644 index c06a3792..00000000 --- a/ngclearn/components/lava/synapses/hebbianSynapse.py +++ /dev/null @@ -1,159 +0,0 @@ -from ngclearn import numpy as jnp -from ngcsimlib.logger import info, warn -from ngcsimlib.compilers.process import transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngclearn.utils.weight_distribution import initialize_params -from ngcsimlib.logger import info -from ngclearn.utils import tensorstats - -class HebbianSynapse(Component): ## Lava-compliant Hebbian synapse - """ - A synaptic cable that adjusts its efficacies via a two-factor Hebbian adjustment rule. This is a Lava-compliant - synaptic cable that adjusts with a hard-coded form of (stochastic) gradient ascent. - - | --- Synapse Input Compartments: (Takes wired-in signals) --- - | inputs - input (pre-synaptic) stimulus - | --- Synaptic Plasticity Input Compartments: (Takes in wired-in signals) --- - | pre - pre-synaptic signal to drive first term of Hebbian update - | post - post-synaptic signal to drive 2nd term of Hebbian update - | eta - global learning rate (unidimensional/scalar value) - | --- Synapse Output Compartments: (These signals are generated) --- - | outputs - transformed (post-synaptic) signal - | weights - current value matrix of synaptic efficacies (this is post-update if eta > 0) - - Args: - name: the string name of this cell - - dt: integration time constant (ms) - - resist_scale: a fixed scaling factor to apply to synaptic transform - (Default: 1.), i.e., yields: out = ((W * Rscale) * in) + b - - weight_init: a kernel to drive initialization of this synaptic cable's values; - typically a tuple with 1st element as a string calling the name of - initialization to use - - shape: tuple specifying shape of this synaptic cable (usually a 2-tuple - with number of inputs by number of outputs) - - eta: global learning rate - - w_decay: degree to which (L2) synaptic weight decay is applied to the - computed Hebbian adjustment (Default: 0); note that decay is not - applied to any configured biases - - w_bound: maximum weight to softly bound this cable's value matrix to; if - set to 0, then no synaptic value bounding will be applied - - weights: matrix of synaptic weight values to initialize this synapse - component to - - Rscale: DEPRECATED argument (maps to resist_scale) - """ - - # Define Functions - def __init__(self, name, dt, resist_scale=1., weight_init=None, shape=None, - eta=0., w_decay=0., w_bound=1., weights=None, **kwargs): - super().__init__(name, **kwargs) - - ## synaptic plasticity properties and characteristics - self.weight_init = weight_init - self.shape = shape - self.batch_size = 1 - - self.dt = dt - self.Rscale = resist_scale - if kwargs.get("Rscale") is not None: - warn("The argument `Rscale` being used is deprecated.") - self.Rscale = kwargs.get("Rscale") - self.w_bounds = w_bound - self.w_decay = w_decay ## synaptic decay - self.eta0 = eta - - self.inputs = Compartment(None) - self.outputs = Compartment(None) - self.pre = Compartment(None) - self.post = Compartment(None) - self.weights = Compartment(None) - self.eta = Compartment(jnp.ones((1, 1)) * eta) - - if weights is not None: - warn("The argument `weights` being used is deprecated.") - self._init(weights) - else: - assert self.shape is not None ## if using an init, MUST have shape - if self.weight_init is None: - info(self.name, "is using default weight initializer!") - self.weight_init = {"dist": "uniform", "amin": 0.025, - "amax": 0.8} - weights = initialize_params(None, self.weight_init, self.shape) - self._init(weights) - - def _init(self, weights): - self.rows = weights.shape[0] - self.cols = weights.shape[1] - - ## pre-computed empty zero pads - preVals = jnp.zeros((self.batch_size, self.rows)) - postVals = jnp.zeros((self.batch_size, self.cols)) - ## Compartments - self.inputs.set(preVals) - self.outputs.set(postVals) - self.pre.set(preVals) - self.post.set(postVals) - self.weights.set(weights) - - @transition(output_compartments=["outputs", "weights"]) - @staticmethod - def advance_state(dt, Rscale, w_bounds, w_decay, inputs, weights, - pre, post, eta): - outputs = jnp.matmul(inputs, weights) * Rscale - ######################################################################## - ## Run one step of 2-factor Hebbian adaptation online - dW = jnp.matmul(pre.T, post) - #db = jnp.sum(_post, axis=0, keepdims=True) - ## reformulated bounding flag to be linear algebraic - flag = (w_bounds > 0.) * 1. - dW = (dW * (w_bounds - jnp.abs(weights))) * flag + (dW) * (1. - flag) - ## add small amount of synaptic decay - weights = weights + (dW - weights * w_decay) * eta - weights = jnp.clip(weights, 0., w_bounds) - ######################################################################## - return outputs, weights - - @transition(output_compartments=["inputs", "outputs", "pre", "post", "eta"]) - @staticmethod - def reset(batch_size, rows, cols, eta0): - preVals = jnp.zeros((batch_size, rows)) - postVals = jnp.zeros((batch_size, cols)) - return ( - preVals, # inputs - postVals, # outputs - preVals, # pre - postVals, # post - jnp.ones((1,1)) * eta0 - ) - - def save(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - jnp.savez(file_name, weights=self.weights.value) - - def load(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - self._init( data['weights'] ) - - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines diff --git a/ngclearn/components/lava/synapses/staticSynapse.py b/ngclearn/components/lava/synapses/staticSynapse.py deleted file mode 100755 index 20f39ebe..00000000 --- a/ngclearn/components/lava/synapses/staticSynapse.py +++ /dev/null @@ -1,122 +0,0 @@ -from ngclearn import numpy as jnp -from ngcsimlib.compilers.process import transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngclearn.utils.weight_distribution import initialize_params -from ngcsimlib.logger import info, warn -from ngclearn.components.synapses.hebbian import TraceSTDPSynapse -from ngclearn.utils import tensorstats - -class StaticSynapse(Component): ## Lava-compliant fixed/non-evolvable synapse - """ - A static (dense) synaptic cable; no form of synaptic evolution/adaptation is in-built to this component. This a - Lava-compliant version of the static synapse component from the synapses sub-package of components. - - | --- Synapse Input Compartments: (Takes wired-in signals) --- - | inputs - input (pre-synaptic) stimulus - | --- Synapse Output Compartments: .set()ese signals are generated) --- - | outputs - transformed (post-synaptic) signal - | weights - current value matrix of synaptic efficacies (this is post-update if eta > 0) - - Args: - name: the string name of this cell - - dt: integration time constant (ms) - - weight_init: a kernel to drive initialization of this synaptic cable's values; - typically a tuple with 1st element as a string calling the name of - initialization to use - - shape: tuple specifying shape of this synaptic cable (usually a 2-tuple - with number of inputs by number of outputs) - - resist_scale: a fixed scaling factor to apply to synaptic transform - (Default: 1.), i.e., yields: out = ((W * Rscale) * in) + b - - Rscale: DEPRECATED argument (maps to resist_scale) - - weights: a provided, externally created weight value matrix that will - be used instead of an auto-init call - """ - - # Define Functions - def __init__(self, name, dt, weight_init=None, shape=None, resist_scale=1., - weights=None, **kwargs): - super().__init__(name, **kwargs) - - ## synaptic plasticity properties and characteristics - self.batch_size = 1 - self.dt = dt - self.Rscale = resist_scale - if kwargs.get("Rscale") is not None: - warn("The argument `Rscale` being used is deprecated.") - self.Rscale = kwargs.get("Rscale") - self.shape = shape - self.weight_init = weight_init - - self.inputs = Compartment(None) - self.outputs = Compartment(None) - self.weights = Compartment(None) - - if weights is not None: - warn("The argument `weights` being used is deprecated.") - self._init(weights) - else: - assert self.shape is not None ## if using an init, MUST have shape - if self.weight_init is None: - info(self.name, "is using default weight initializer!") - self.weight_init = {"dist": "uniform", "amin": 0.025, - "amax": 0.8} - weights = initialize_params(None, self.weight_init, self.shape) - self._init(weights) - - def _init(self, weights): - self.rows = weights.shape[0] - self.cols = weights.shape[1] - ## pre-computed empty zero pads - preVals = jnp.zeros((self.batch_size, self.rows)) - postVals = jnp.zeros((self.batch_size, self.cols)) - ## Compartments - self.inputs.set(preVals) - self.outputs.set(postVals) - self.weights.set(weights) - - @transition(output_compartments=["outputs"]) - @staticmethod - def advance_state(dt, Rscale, inputs, weights): - outputs = jnp.matmul(inputs, weights) * Rscale - return outputs - - @transition(output_compartments=["inputs", "outputs"]) - @staticmethod - def reset(batch_size, rows, cols): - preVals = jnp.zeros((batch_size, rows)) - postVals = jnp.zeros((batch_size, cols)) - return ( - preVals, # inputs - postVals, # outputs - ) - - def save(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - jnp.savez(file_name, weights=self.weights.value) - - def load(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - self._init( data['weights'] ) - - - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines diff --git a/ngclearn/components/lava/synapses/traceSTDPSynapse.py b/ngclearn/components/lava/synapses/traceSTDPSynapse.py deleted file mode 100755 index 23a3287d..00000000 --- a/ngclearn/components/lava/synapses/traceSTDPSynapse.py +++ /dev/null @@ -1,181 +0,0 @@ -from ngclearn import numpy as jnp -from ngcsimlib.logger import info, warn -from ngcsimlib.compilers.process import transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngclearn.utils.weight_distribution import initialize_params -from ngcsimlib.logger import info -from ngclearn.utils import tensorstats - -class TraceSTDPSynapse(Component): ## Lava-compliant trace-STDP synapse - """ - A synaptic cable that adjusts its efficacies via trace-based form of spike-timing-dependent plasticity (STDP). - This is a Lava-compliant synaptic cable that adjusts with a hard-coded form of (stochastic) gradient ascent. - - | --- Synapse Input Compartments: (Takes wired-in signals) --- - | inputs - input (pre-synaptic) stimulus - | --- Synaptic Plasticity Input Compartments: (Takes in wired-in signals) --- - | pre - pre-synaptic spike(s) to drive STDP update - | x_pre - pre-synaptic trace value(s) to drive STDP update - | post - post-synaptic spike(s) to drive STDP update - | x_post - post-synaptic trace value(s) to drive STDP update - | eta - global learning rate (unidimensional/scalar value) - | --- Synapse Output Compartments: (These signals are generated) --- - | outputs - transformed (post-synaptic) signal - | weights - current value matrix of synaptic efficacies (this is post-update if eta > 0) - - Args: - name: the string name of this cell - - dt: integration time constant (ms) - - resist_scale: a fixed scaling factor to apply to synaptic transform - (Default: 1.), i.e., yields: out = ((W * Rscale) * in) + b - - weight_init: a kernel to drive initialization of this synaptic cable's values; - typically a tuple with 1st element as a string calling the name of - initialization to use - - shape: tuple specifying shape of this synaptic cable (usually a 2-tuple - with number of inputs by number of outputs) - - Aplus: strength of long-term potentiation (LTP) - - Aminus: strength of long-term depression (LTD) - - eta: global learning rate (default: 1) - - w_decay: degree to which (L2) synaptic weight decay is applied to the - computed Hebbian adjustment (Default: 0); note that decay is not - applied to any configured biases - - w_bound: maximum weight to softly bound this cable's value matrix to; if - set to 0, then no synaptic value bounding will be applied - - preTrace_target: controls degree of pre-synaptic disconnect, i.e., amount of decay - (higher -> lower synaptic values) - - weights: matrix of synaptic weight values to initialize this synapse - component to - - Rscale: DEPRECATED argument (maps to resist_scale) - """ - - # Define Functions - def __init__(self, name, dt, resist_scale=1., weight_init=None, shape=None, - Aplus=0.01, Aminus=0.001, eta=1., w_decay=0., w_bound=1., - preTrace_target=0., weights=None, **kwargs): - super().__init__(name, **kwargs) - - ## synaptic plasticity properties and characteristics - self.weight_init = weight_init - self.shape = shape - self.dt = dt - self.Rscale = resist_scale - if kwargs.get("Rscale") is not None: - warn("The argument `Rscale` being used is deprecated.") - self.Rscale = kwargs.get("Rscale") - self.w_bounds = w_bound - self.w_decay = w_decay ## synaptic decay - self.eta0 = eta - self.Aplus = Aplus - self.Aminus = Aminus - self.x_tar = preTrace_target - - ## Component size setup - self.batch_size = 1 - - self.eta = Compartment(jnp.ones((1, 1)) * eta) - - self.inputs = Compartment(None) - self.outputs = Compartment(None) - self.pre = Compartment(None) ## pre-synaptic spike - self.x_pre = Compartment(None) ## pre-synaptic trace - self.post = Compartment(None) ## post-synaptic spike - self.x_post = Compartment(None) ## post-synaptic trace - self.weights = Compartment(None) - - if weights is not None: - warn("The argument `weights` being used is deprecated.") - self._init(weights) - else: - assert self.shape is not None ## if using an init, MUST have shape - if self.weight_init is None: - info(self.name, "is using default weight initializer!") - self.weight_init = {"dist": "uniform", "amin": 0.025, - "amax": 0.8} - weights = initialize_params(None, self.weight_init, self.shape) - self._init(weights) - - def _init(self, weights): - self.rows = weights.shape[0] - self.cols = weights.shape[1] - ## pre-computed empty zero pads - preVals = jnp.zeros((self.batch_size, self.rows)) - postVals = jnp.zeros((self.batch_size, self.cols)) - ## Compartments - self.inputs.set(preVals) - self.outputs.set(postVals) - self.pre.set(preVals) ## pre-synaptic spike - self.x_pre.set(preVals) ## pre-synaptic trace - self.post.set(postVals) ## post-synaptic spike - self.x_post.set(postVals) ## post-synaptic trace - self.weights.set(weights) - - @transition(output_compartments=["outputs", "weights"]) - @staticmethod - def advance_state(dt, Rscale, Aplus, Aminus, w_bounds, w_decay, x_tar, - inputs, weights, pre, x_pre, post, x_post, eta): - outputs = jnp.matmul(inputs, weights) * Rscale - ######################################################################## - ## Run one step of STDP online - dWpost = jnp.matmul((x_pre - x_tar).T, post * Aplus) - dWpre = -jnp.matmul(pre.T, x_post * Aminus) - dW = dWpost + dWpre - ## reformulated bounding flag to be linear algebraic - flag = (w_bounds > 0.) * 1. - dW = (dW * (w_bounds - jnp.abs(weights))) * flag + (dW) * (1. - flag) - ## physically adjust synapses - weights = weights + (dW - weights * w_decay) * eta - #weights = weights + (dW - weights * w_decay) * dt/tau_w ## ODE format - weights = jnp.clip(weights, 0., w_bounds) - ######################################################################## - return outputs, weights - - @transition(output_compartments=["inputs", "outputs", "pre", "post", "x_pre", "x_post", "eta"]) - @staticmethod - def reset(batch_size, rows, cols, eta0): - preVals = jnp.zeros((batch_size, rows)) - postVals = jnp.zeros((batch_size, cols)) - return ( - preVals, # inputs - postVals, # outputs - preVals, # pre - postVals, # post - preVals, # x_pre - postVals, # x_post - jnp.ones((1, 1)) * eta0 - ) - - def save(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - jnp.savez(file_name, weights=self.weights.value) - - def load(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - self._init( data['weights'] ) - - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines diff --git a/ngclearn/components/lava/traces/__init__.py b/ngclearn/components/lava/traces/__init__.py deleted file mode 100755 index 5dc901bf..00000000 --- a/ngclearn/components/lava/traces/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .gatedTrace import GatedTrace diff --git a/ngclearn/components/lava/traces/gatedTrace.py b/ngclearn/components/lava/traces/gatedTrace.py deleted file mode 100755 index 941fe061..00000000 --- a/ngclearn/components/lava/traces/gatedTrace.py +++ /dev/null @@ -1,69 +0,0 @@ -from ngclearn import numpy as jnp -from ngcsimlib.logger import info, warn -from ngcsimlib.compilers.process import transition -from ngcsimlib.component import Component -from ngcsimlib.compartment import Compartment -from ngclearn.utils.weight_distribution import initialize_params -from ngcsimlib.logger import info -from ngclearn.utils import tensorstats - -class GatedTrace(Component): ## gated/piecewise low-pass filter - """ - A gated/piecewise variable trace (filter). This is a Lava-compliant trace component. - - | --- Cell Input Compartments: (Takes wired-in signals) --- - | inputs - input (takes wired-in external signals) - | --- Cell Output Compartments: (These signals are generated) --- - | trace - traced value signal - - Args: - name: the string name of this operator - - n_units: number of calculating entities or units - - dt: integration time constant (ms) - - tau_tr: trace time constant (in milliseconds, or ms) - """ - - # Define Functions - def __init__(self, name, n_units, dt, tau_tr, **kwargs): - super().__init__(name, **kwargs) - - ## trace control coefficients - self.dt = dt - self.tau_tr = tau_tr ## trace time constant - - ## Layer size setup - self.batch_size = 1 - self.n_units = n_units - - restVals = jnp.zeros((self.batch_size, self.n_units)) - self.inputs = Compartment(restVals) # input compartment - self.trace = Compartment(restVals) - - @transition(output_compartments=["trace"]) - @staticmethod - def advance_state(dt, tau_tr, inputs, trace): - trace = (trace * (1. - dt/tau_tr)) * (1. - inputs) + inputs - return trace - - @transition(output_compartments=["inputs", "trace"]) - @staticmethod - def reset(batch_size, n_units): - restVals = jnp.zeros((batch_size, n_units)) - return restVals, restVals - - def __repr__(self): - comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - maxlen = max(len(c) for c in comps) + 5 - lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - for c in comps: - stats = tensorstats(getattr(self, c).value) - if stats is not None: - line = [f"{k}: {v}" for k, v in stats.items()] - line = ", ".join(line) - else: - line = "None" - lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - return lines diff --git a/ngclearn/components/monitor.py b/ngclearn/components/monitor.py deleted file mode 100644 index d3916f7a..00000000 --- a/ngclearn/components/monitor.py +++ /dev/null @@ -1,31 +0,0 @@ -from ngclearn.components.base_monitor import Base_Monitor -#from ngclearn import transition - -class Monitor(Base_Monitor): - """ - A jax implementation of `Base_Monitor`. Designed to be used with all - non-lava ngclearn components - """ - auto_resolve = False - - @staticmethod - def _record_internal(compartments): - @staticmethod - def _record(**kwargs): - return_vals = [] - for comp in compartments: - new_val = kwargs[comp] - current_store = kwargs[comp + "*store"] - current_store = current_store.at[:-1].set(current_store[1:]) - current_store = current_store.at[-1].set(new_val) - return_vals.append(current_store) - return return_vals if len(compartments) > 1 else return_vals[0] - return _record - - @staticmethod - def build_advance_state(component): - return super().build_advance_state(component) - - @staticmethod - def build_reset(component): - return super().build_reset(component) From a03480a11fbd962b3a4dd83e0d9efd4a94b16ae0 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Sat, 15 Nov 2025 18:15:26 -0500 Subject: [PATCH 056/121] minor cleanup of inits --- ngclearn/components/input_encoders/__init__.py | 1 + ngclearn/components/neurons/__init__.py | 1 + ngclearn/components/neurons/graded/__init__.py | 1 + 3 files changed, 3 insertions(+) diff --git a/ngclearn/components/input_encoders/__init__.py b/ngclearn/components/input_encoders/__init__.py index b779226e..5d14d2ec 100644 --- a/ngclearn/components/input_encoders/__init__.py +++ b/ngclearn/components/input_encoders/__init__.py @@ -2,3 +2,4 @@ from .poissonCell import PoissonCell from .latencyCell import LatencyCell from .phasorCell import PhasorCell + diff --git a/ngclearn/components/neurons/__init__.py b/ngclearn/components/neurons/__init__.py index e7165d7e..f367cd02 100644 --- a/ngclearn/components/neurons/__init__.py +++ b/ngclearn/components/neurons/__init__.py @@ -15,3 +15,4 @@ from .spiking.izhikevichCell import IzhikevichCell from .spiking.hodgkinHuxleyCell import HodgkinHuxleyCell from .spiking.RAFCell import RAFCell + diff --git a/ngclearn/components/neurons/graded/__init__.py b/ngclearn/components/neurons/graded/__init__.py index bde64b39..8d723607 100644 --- a/ngclearn/components/neurons/graded/__init__.py +++ b/ngclearn/components/neurons/graded/__init__.py @@ -4,3 +4,4 @@ from .laplacianErrorCell import LaplacianErrorCell from .bernoulliErrorCell import BernoulliErrorCell from .rewardErrorCell import RewardErrorCell + From 5ec052fcbc79a5cb4ba5b3f963e726f7119b8da8 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Sat, 15 Nov 2025 18:36:51 -0500 Subject: [PATCH 057/121] refactored regression module to be compliant with v3 --- ngclearn/modules/__init__.py | 4 - ngclearn/modules/regression/__init__.py | 5 - ngclearn/modules/regression/elastic_net.py | 54 +---------- ngclearn/modules/regression/lasso.py | 105 ++++++++------------ ngclearn/modules/regression/ridge.py | 107 +++++++++------------ 5 files changed, 85 insertions(+), 190 deletions(-) diff --git a/ngclearn/modules/__init__.py b/ngclearn/modules/__init__.py index b18f84b7..38866e21 100644 --- a/ngclearn/modules/__init__.py +++ b/ngclearn/modules/__init__.py @@ -2,7 +2,3 @@ from .regression.lasso import Iterative_Lasso from .regression.ridge import Iterative_Ridge - - - - diff --git a/ngclearn/modules/regression/__init__.py b/ngclearn/modules/regression/__init__.py index 064d5303..bc45b6b2 100644 --- a/ngclearn/modules/regression/__init__.py +++ b/ngclearn/modules/regression/__init__.py @@ -2,8 +2,3 @@ from .lasso import Iterative_Lasso from .ridge import Iterative_Ridge - - - - - diff --git a/ngclearn/modules/regression/elastic_net.py b/ngclearn/modules/regression/elastic_net.py index faf4fbea..e1b700d3 100644 --- a/ngclearn/modules/regression/elastic_net.py +++ b/ngclearn/modules/regression/elastic_net.py @@ -11,7 +11,7 @@ class Iterative_ElasticNet(): """ A neural circuit implementation of the iterative Elastic Net (L1 and L2) algorithm - using Hebbian learning update rule. + using a Hebbian learning update rule. The circuit implements sparse regression through Hebbian synapses with Elastic Net regularization. @@ -21,8 +21,6 @@ class Iterative_ElasticNet(): | dW_reg = (jnp.sign(W) * l1_ratio) + (W * (1-l1_ratio)/2) | dW/dt = dW + lmbda * dW_reg - - | --- Circuit Components: --- | W - HebbianSynapse for learning regularized dictionary weights | err - GaussianErrorCell for computing prediction errors @@ -104,14 +102,6 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l >> self.W.reset) self.reset = reset - # advance_cmd, advance_args =self.circuit.compile_by_key(self.W, ## execute prediction synapses - # self.err, ## finally, execute error neurons - # compile_key="advance_state") - # evolve_cmd, evolve_args =self.circuit.compile_by_key(self.W, compile_key="evolve") - # reset_cmd, reset_args =self.circuit.compile_by_key(self.err, self.W, compile_key="reset") - # # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # self.dynamic() - def batch_set(self, batch_size): self.W.batch_size = batch_size self.err.batch_size = batch_size @@ -121,33 +111,6 @@ def clamp(self, y_scaled, X): self.W.pre.set(X) self.err.target.set(y_scaled) - # def dynamic(self): ## create dynamic commands forself.circuit - # W, err = self.circuit.get_components("W", "err") - # self.self = W - # self.err = err - # - # @Context.dynamicCommand - # def batch_set(batch_size): - # self.W.batch_size = batch_size - # self.err.batch_size = batch_size - # - # @Context.dynamicCommand - # def clamps(y_scaled, X): - # self.W.inputs.set(X) - # self.W.pre.set(X) - # self.err.target.set(y_scaled) - # - # self.circuit.wrap_and_add_command(jit(self.circuit.evolve), name="evolve") - # self.circuit.wrap_and_add_command(jit(self.circuit.advance_state), name="advance") - # self.circuit.wrap_and_add_command(jit(self.circuit.reset), name="reset") - # - # @scanner - # def _process(compartment_values, args): - # _t, _dt = args - # compartment_values = self.circuit.advance_state(compartment_values, t=_t, dt=_dt) - # return compartment_values, compartment_values[self.W.weights.path] - - def thresholding(self, scale=1.): coef_old = self.coef_ new_coeff = jnp.where(jnp.abs(coef_old) >= self.threshold, coef_old, 0.) @@ -172,18 +135,3 @@ def fit(self, y, X): return self.coef_, self.err.mu.get(), self.err.L.get() - # self.circuit.reset() - # self.circuit.clamps(y_scaled=y, X=X) - # - # for i in range(self.epochs): - # self.circuit._process(jnp.array([[self.dt * i, self.dt] for i in range(self.T)])) - # self.circuit.evolve(t=self.T, dt=self.dt) - # - # self.coef_ = np.array(self.W.weights.value) - # - # return self.coef_, self.err.mu.value, self.err.L.value - - - - - diff --git a/ngclearn/modules/regression/lasso.py b/ngclearn/modules/regression/lasso.py index c0d8c8ef..6db0c57d 100644 --- a/ngclearn/modules/regression/lasso.py +++ b/ngclearn/modules/regression/lasso.py @@ -1,26 +1,19 @@ -import jax -import pandas as pd -from jax import random, jit import numpy as np -from scipy.integrate import solve_ivp -import matplotlib.pyplot as plt -from ngcsimlib.utils import Get_Compartment_Batch -from ngclearn.utils.model_utils import normalize_matrix from ngclearn.utils import weight_distribution as dist -from ngclearn import Context, numpy as jnp -from ngclearn.components import (RateCell, - HebbianSynapse, - GaussianErrorCell, - StaticSynapse) -from ngclearn.utils.model_utils import scanner +from ngclearn import numpy as jnp +from jax import numpy as jnp, random, jit +from ngclearn import Context, MethodProcess +from ngclearn.components.synapses.hebbian.hebbianSynapse import HebbianSynapse +from ngclearn.components.neurons.graded.gaussianErrorCell import GaussianErrorCell +from ngcsimlib.global_state import stateManager class Iterative_Lasso(): """ A neural circuit implementation of the iterative Lasso (L1) algorithm - using Hebbian learning update rule. + using a Hebbian learning update rule. - The circuit implements sparse coding through Hebbian synapses with L1 regularization. + The circuit implements sparse coding-like regression through Hebbian synapses with L1 regularization. The specific differential equation that characterizes this model is adding lmbda * sign(W) to the dW (the gradient of loss/energy function): @@ -89,43 +82,32 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l self.W.batch_size = batch_size self.err.batch_size = batch_size # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - self.err.mu << self.W.outputs - self.W.post << self.err.dmu + self.W.outputs >> self.err.mu + self.err.dmu >> self.W.post # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - advance_cmd, advance_args =self.circuit.compile_by_key(self.W, ## execute prediction synapses - self.err, ## finally, execute error neurons - compile_key="advance_state") - evolve_cmd, evolve_args =self.circuit.compile_by_key(self.W, compile_key="evolve") - reset_cmd, reset_args =self.circuit.compile_by_key(self.err, self.W, compile_key="reset") - # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - self.dynamic() - - def dynamic(self): ## create dynamic commands for self.circuit - W, err = self.circuit.get_components("W", "err") - self.self = W - self.err = err - - @Context.dynamicCommand - def batch_set(batch_size): - self.W.batch_size = batch_size - self.err.batch_size = batch_size - - @Context.dynamicCommand - def clamps(y_scaled, X): - self.W.inputs.set(X) - self.W.pre.set(X) - self.err.target.set(y_scaled) - - self.circuit.wrap_and_add_command(jit(self.circuit.evolve), name="evolve") - self.circuit.wrap_and_add_command(jit(self.circuit.advance_state), name="advance") - self.circuit.wrap_and_add_command(jit(self.circuit.reset), name="reset") - - @scanner - def _process(compartment_values, args): - _t, _dt = args - compartment_values = self.circuit.advance_state(compartment_values, t=_t, dt=_dt) - return compartment_values, compartment_values[self.W.weights.path] - + + advance = (MethodProcess(name="advance_state") + >> self.W.advance_state + >> self.err.advance_state) + self.advance = advance + + evolve = (MethodProcess(name="evolve") + >> self.W.evolve) + self.evolve = evolve + + reset = (MethodProcess(name="reset") + >> self.err.reset + >> self.W.reset) + self.reset = reset + + def batch_set(self, batch_size): + self.W.batch_size = batch_size + self.err.batch_size = batch_size + + def clamp(self, y_scaled, X): + self.W.inputs.set(X) + self.W.pre.set(X) + self.err.target.set(y_scaled) def thresholding(self, scale=2): coef_old = self.coef_ @@ -136,23 +118,16 @@ def thresholding(self, scale=2): return self.coef_, coef_old - def fit(self, y, X): - - self.circuit.reset() - self.circuit.clamps(y_scaled=y, X=X) + self.reset.run() + self.clamp(y_scaled=y, X=X) for i in range(self.epochs): - self.circuit._process(jnp.array([[self.dt * i, self.dt] for i in range(self.T)])) - self.circuit.evolve(t=self.T, dt=self.dt) - - self.coef_ = np.array(self.W.weights.value) - - return self.coef_, self.err.mu.value, self.err.L.value - - - - + inputs = jnp.array(self.advance.pack_rows(self.T, t=lambda x: x, dt=self.dt)) + stateManager.state, outputs = self.advance.scan(inputs) + self.evolve.run(t=self.T, dt=self.dt) + self.coef_ = np.array(self.W.weights.get()) + return self.coef_, self.err.mu.get(), self.err.L.get() diff --git a/ngclearn/modules/regression/ridge.py b/ngclearn/modules/regression/ridge.py index b1698aba..2bcf9593 100644 --- a/ngclearn/modules/regression/ridge.py +++ b/ngclearn/modules/regression/ridge.py @@ -1,21 +1,19 @@ -from jax import random, jit import numpy as np from ngclearn.utils import weight_distribution as dist -from ngclearn import Context, numpy as jnp -from ngclearn.components import (RateCell, - HebbianSynapse, - GaussianErrorCell, - StaticSynapse) -from ngclearn.utils.model_utils import scanner - +from ngclearn import numpy as jnp +from jax import numpy as jnp, random, jit +from ngclearn import Context, MethodProcess +from ngclearn.components.synapses.hebbian.hebbianSynapse import HebbianSynapse +from ngclearn.components.neurons.graded.gaussianErrorCell import GaussianErrorCell +from ngcsimlib.global_state import stateManager class Iterative_Ridge(): """ A neural circuit implementation of the iterative Ridge (L2) algorithm - using Hebbian learning update rule. + using a Hebbian learning update rule. - The circuit implements sparse regression through Hebbian synapses with L2 regularization. + This circuit implements sparse regression through Hebbian synapses with L2 regularization. The specific differential equation that characterizes this model is adding lmbda * W to the dW (the gradient of loss/energy function): @@ -75,54 +73,43 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l feature_dim = dict_dim with Context(self.name) as self.circuit: - self.W = HebbianSynapse("W", shape=(feature_dim, sys_dim), eta=self.lr, - sign_value=-1, weight_init=dist.constant(weight_fill), - prior=('ridge', ridge_lmbda), w_bound=0., - optim_type=optim_type, key=subkeys[0]) + self.W = HebbianSynapse( + "W", shape=(feature_dim, sys_dim), eta=self.lr, sign_value=-1, + weight_init=dist.constant(weight_fill), prior=('ridge', ridge_lmbda), w_bound=0., + optim_type=optim_type, key=subkeys[0] + ) self.err = GaussianErrorCell("err", n_units=sys_dim) # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ self.W.batch_size = batch_size self.err.batch_size = batch_size # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - self.err.mu << self.W.outputs - self.W.post << self.err.dmu + self.W.outputs >> self.err.mu + self.err.dmu >> self.W.post # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - advance_cmd, advance_args =self.circuit.compile_by_key(self.W, ## execute prediction synapses - self.err, ## finally, execute error neurons - compile_key="advance_state") - evolve_cmd, evolve_args =self.circuit.compile_by_key(self.W, compile_key="evolve") - reset_cmd, reset_args =self.circuit.compile_by_key(self.err, self.W, compile_key="reset") - # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - self.dynamic() - - def dynamic(self): ## create dynamic commands forself.circuit - W, err = self.circuit.get_components("W", "err") - self.self = W - self.err = err - - @Context.dynamicCommand - def batch_set(batch_size): - self.W.batch_size = batch_size - self.err.batch_size = batch_size - - @Context.dynamicCommand - def clamps(y_scaled, X): - self.W.inputs.set(X) - self.W.pre.set(X) - self.err.target.set(y_scaled) - - self.circuit.wrap_and_add_command(jit(self.circuit.evolve), name="evolve") - self.circuit.wrap_and_add_command(jit(self.circuit.advance_state), name="advance") - self.circuit.wrap_and_add_command(jit(self.circuit.reset), name="reset") - - - @scanner - def _process(compartment_values, args): - _t, _dt = args - compartment_values = self.circuit.advance_state(compartment_values, t=_t, dt=_dt) - return compartment_values, compartment_values[self.W.weights.path] + advance = (MethodProcess(name="advance_state") + >> self.W.advance_state + >> self.err.advance_state) + self.advance = advance + + evolve = (MethodProcess(name="evolve") + >> self.W.evolve) + self.evolve = evolve + + reset = (MethodProcess(name="reset") + >> self.err.reset + >> self.W.reset) + self.reset = reset + + def batch_set(self, batch_size): + self.W.batch_size = batch_size + self.err.batch_size = batch_size + + def clamp(self, y_scaled, X): + self.W.inputs.set(X) + self.W.pre.set(X) + self.err.target.set(y_scaled) def thresholding(self, scale=2): coef_old = self.coef_ #self.W.weights.value @@ -135,21 +122,15 @@ def thresholding(self, scale=2): def fit(self, y, X): - self.circuit.reset() - self.circuit.clamps(y_scaled=y, X=X) + self.reset.run() + self.clamp(y_scaled=y, X=X) for i in range(self.epochs): - self.circuit._process(jnp.array([[self.dt * i, self.dt] for i in range(self.T)])) - self.circuit.evolve(t=self.T, dt=self.dt) - - self.coef_ = np.array(self.W.weights.value) - - return self.coef_, self.err.mu.value, self.err.L.value - - - - - + inputs = jnp.array(self.advance.pack_rows(self.T, t=lambda x: x, dt=self.dt)) + stateManager.state, outputs = self.advance.scan(inputs) + self.evolve.run(t=self.T, dt=self.dt) + self.coef_ = np.array(self.W.weights.get()) + return self.coef_, self.err.mu.get(), self.err.L.get() From ff2a25afdc555ab1bc942b8f79de890783e9d056 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Sat, 15 Nov 2025 18:38:02 -0500 Subject: [PATCH 058/121] adjusted sphinx-docs w.r.t. new v3 refactoring --- docs/source/ngclearn.components.rst | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/docs/source/ngclearn.components.rst b/docs/source/ngclearn.components.rst index 05c87f7d..c822f3bf 100644 --- a/docs/source/ngclearn.components.rst +++ b/docs/source/ngclearn.components.rst @@ -15,14 +15,6 @@ Subpackages Submodules ---------- -ngclearn.components.base\_monitor module ----------------------------------------- - -.. automodule:: ngclearn.components.base_monitor - :members: - :undoc-members: - :show-inheritance: - ngclearn.components.jaxComponent module --------------------------------------- @@ -31,14 +23,6 @@ ngclearn.components.jaxComponent module :undoc-members: :show-inheritance: -ngclearn.components.monitor module ----------------------------------- - -.. automodule:: ngclearn.components.monitor - :members: - :undoc-members: - :show-inheritance: - Module contents --------------- From 843d93723a22ff1068fe7208185e5f84f82a0ccb Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Sat, 15 Nov 2025 19:30:05 -0500 Subject: [PATCH 059/121] minor revision to double-exp syn pointing, mods to modeling docs --- docs/modeling/neurons.md | 50 ++++++++++++++++++- docs/modeling/synapses.md | 28 ++++++----- docs/ngclearn_papers.md | 6 ++- ngclearn/components/__init__.py | 2 +- ngclearn/components/synapses/__init__.py | 5 +- .../components/synapses/doubleExpSynapse.py | 2 +- 6 files changed, 72 insertions(+), 21 deletions(-) diff --git a/docs/modeling/neurons.md b/docs/modeling/neurons.md index 4babf8f7..36f7c2b1 100644 --- a/docs/modeling/neurons.md +++ b/docs/modeling/neurons.md @@ -86,6 +86,22 @@ and `dmu` is the first derivative with respect to the mean parameter. :noindex: ``` +#### Bernoulli Error Cell + +This cell is (currently) fixed to be a (factorized) multivariate Bernoulli cell. +Concretely, this cell implements compartments/mechanics to facilitate Bernoulli +log likelihood error calculations. + +```{eval-rst} +.. autoclass:: ngclearn.components.BernoulliErrorCell + :noindex: + + .. automethod:: advance_state + :noindex: + .. automethod:: reset + :noindex: +``` + ## Spiking Neurons These neuronal cells exhibit dynamics that involve emission of discrete action @@ -117,10 +133,42 @@ negative pressure on the membrane potential values at `t`). :noindex: ``` +### The IF (Integrate-and-Fire) Cell + +This cell (the simple "integrator") models dynamics over the voltage `v`. Note that `thr` is used as the membrane potential threshold and no adaptive threshold mechanics are implemented for this cell model. +(This cell is primarily a faster, convenience formulation that omits the leak element of the LIF.) + +```{eval-rst} +.. autoclass:: ngclearn.components.IFCell + :noindex: + + .. automethod:: advance_state + :noindex: + .. automethod:: reset + :noindex: +``` + +### The Winner-Take-All (WTAS) Cell + +This cell models dynamics over the voltage `v` as a simple instantaneous +softmax function of the electrical current input, where only a single +spike, which wins the competition across the group of neuronal units +within this component, emits a pulse/spike. + +```{eval-rst} +.. autoclass:: ngclearn.components.WTASCell + :noindex: + + .. automethod:: advance_state + :noindex: + .. automethod:: reset + :noindex: +``` + ### The LIF (Leaky Integrate-and-Fire) Cell This cell (the "leaky integrator") models dynamics over the voltage `v` -and threshold shift `thrTheta` (a homeostatic variable). Note that `thr` +and threshold shift `thr_theta` (a homeostatic variable). Note that `thr` is used as a baseline level for the membrane potential threshold while `thrTheta` is treated as a form of short-term plasticity (full threshold is: `thr + thrTheta(t)`). diff --git a/docs/modeling/synapses.md b/docs/modeling/synapses.md index 470446e9..6b881f0c 100644 --- a/docs/modeling/synapses.md +++ b/docs/modeling/synapses.md @@ -1,17 +1,7 @@ # Synapses -The synapse is a key building block for connecting/wiring together the various -component cells that one would use for characterizing a biomimetic neural system. -These particular objects are meant to perform, per simulated time step, a -specific type of transformation -- such as a linear transform or a -convolution -- utilizing their underlying synaptic parameters. -Most times, a synaptic cable will be represented by a set of matrices (or filters) -that are used to conduct a projection of an input signal (a value presented to a -pre-synaptic/input compartment) resulting in an output signal (a value that -appears within one of its post-synaptic compartments). Notably, a synapse component is -typically associated with a local plasticity rule, e.g., a Hebbian-type -update, that either is triggered online, i.e., at some or all simulation time -steps, or by integrating a differential equation, e.g., via eligibility traces. +The synapse is a key building block for connecting/wiring together the various component cells that one would use for characterizing a biomimetic neural system. These particular objects are meant to perform, per simulated time step, a specific type of transformation -- such as a linear transform or a convolution -- utilizing their underlying synaptic parameters. Most times, a synaptic cable will be represented by a set of matrices (or filters) that are used to conduct a projection of an input signal (a value presented to a pre-synaptic/input compartment) resulting in an output signal (a value that appears within one of its post-synaptic compartments). There are three general groupings of synaptic components in ngc-learn: 1) non-plastic static synapses (only perform fixed transformations of input signals); 2) non-plastic dynamic synapses (perform time-varying, input-dependent transformations on input signals); and 3) plastic synapses that carry out long-term evolution. +Notably, plastic synapse components are typically associated with a local plasticity rule, e.g., a Hebbian-type update, that either is triggered online, i.e., at some or all simulation time steps, or by integrating a differential equation, e.g., via eligibility traces. ## Non-Plastic Synapse Types @@ -74,6 +64,20 @@ This (chemical) synapse performs a linear transform of its input signals. Note t :noindex: ``` +### Double-Exponential Synapse + +This (chemical) synapse performs a linear transform of its input signals. Note that this synapse is "dynamic" in the sense that its efficacies are a function of their pre-synaptic inputs; there is no inherent form of long-term plasticity in this base implementation. Synaptic strength values can be viewed as being filtered/smoothened through a doubleexpoential / difference of two exponentials kernel. + +```{eval-rst} +.. autoclass:: ngclearn.components.DoubleExpSynapse + :noindex: + + .. automethod:: advance_state + :noindex: + .. automethod:: reset + :noindex: +``` + ### Alpha Synapse This (chemical) synapse performs a linear transform of its input signals. Note that this synapse is "dynamic" in the sense that its efficacies are a function of their pre-synaptic inputs; there is no inherent form of long-term plasticity in this base implementation. Synaptic strength values can be viewed as being filtered/smoothened through a kernel that models more realistic rise and fall times of synaptic conductance.. diff --git a/docs/ngclearn_papers.md b/docs/ngclearn_papers.md index 09998a81..3fa9b1df 100644 --- a/docs/ngclearn_papers.md +++ b/docs/ngclearn_papers.md @@ -17,12 +17,14 @@ from data streams." arXiv preprint arXiv:1908.08655 (2019). a hyperdimensional predictive processing cognitive architecture." Proceedings of the Annual Meeting of the Cognitive Science Society (CogSci), Volume 44 (2022). -5. Ororbia, A., and Kelly, M. Alex. "“Learning using a Hyperdimensional Predictive Processing Cognitive -Architecture." 15th International Conference on Artificial General Intelligence (AGI) (2022). +5. Ororbia, A., and Kelly, M. Alex. "Learning using a hyperdimensional predictive processing cognitive +architecture." 15th International Conference on Artificial General Intelligence (AGI) (2022). 6. Ororbia, A., Mali, A., Kifer, D., & Giles, C. L. "Lifelong neural predictive coding: Learning cumulatively online without forgetting." Thirty-sixth Conference on Neural Information Processing Systems (NeurIPS) (2022). +7. Ororbia, A., Friston, K., Rao, Rajesh P. N. "Meta-representational predictive coding: Biomimetic self-supervised learning." arXiv preprint arXiv:2503.21796 (2025). + Note: Please let us know if your work uses ngc-learn so we can update this page to accurately track ngc-learn's use and include your work in the accumulating body of work in predictive processing and/or brain-inspired computational modeling. diff --git a/ngclearn/components/__init__.py b/ngclearn/components/__init__.py index 9e7b5481..6bce427b 100644 --- a/ngclearn/components/__init__.py +++ b/ngclearn/components/__init__.py @@ -39,7 +39,7 @@ from .synapses.hebbian.BCMSynapse import BCMSynapse from .synapses.STPDenseSynapse import STPDenseSynapse from .synapses.exponentialSynapse import ExponentialSynapse -from .synapses.doubleExpSynapse import DoupleExpSynapse +from .synapses.doubleExpSynapse import DoubleExpSynapse from .synapses.alphaSynapse import AlphaSynapse ## point to convolutional component types diff --git a/ngclearn/components/synapses/__init__.py b/ngclearn/components/synapses/__init__.py index 2c9c9f70..95bf3f70 100644 --- a/ngclearn/components/synapses/__init__.py +++ b/ngclearn/components/synapses/__init__.py @@ -1,11 +1,10 @@ from .denseSynapse import DenseSynapse from .staticSynapse import StaticSynapse - ## short-term plasticity components from .STPDenseSynapse import STPDenseSynapse from .exponentialSynapse import ExponentialSynapse -from .doubleExpSynapse import DoupleExpSynapse +from .doubleExpSynapse import DoubleExpSynapse from .alphaSynapse import AlphaSynapse ## dense synaptic components @@ -15,7 +14,6 @@ from .hebbian.eventSTDPSynapse import EventSTDPSynapse from .hebbian.BCMSynapse import BCMSynapse - ## conv/deconv synaptic components from .convolution.convSynapse import ConvSynapse from .convolution.staticConvSynapse import StaticConvSynapse @@ -26,7 +24,6 @@ from .convolution.hebbianDeconvSynapse import HebbianDeconvSynapse from .convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse - ## modulated synaptic components from .modulated.MSTDPETSynapse import MSTDPETSynapse # from .modulated.REINFORCESynapse import REINFORCESynapse diff --git a/ngclearn/components/synapses/doubleExpSynapse.py b/ngclearn/components/synapses/doubleExpSynapse.py index b5a9a3f0..62cce850 100644 --- a/ngclearn/components/synapses/doubleExpSynapse.py +++ b/ngclearn/components/synapses/doubleExpSynapse.py @@ -6,7 +6,7 @@ from ngcsimlib.compartment import Compartment from ngcsimlib.parser import compilable -class DoupleExpSynapse(DenseSynapse): ## dynamic double-exponential synapse cable +class DoubleExpSynapse(DenseSynapse): ## dynamic double-exponential synapse cable """ A dynamic double-exponential synaptic cable; this synapse evolves according to difference of two exponentials synaptic conductance dynamics. From c51b83c002d4fdcef8966ecccff916e63f0a4607 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Mon, 17 Nov 2025 11:19:04 -0500 Subject: [PATCH 060/121] updated adex tutorial doc to v3 --- docs/tutorials/index.rst | 11 - docs/tutorials/lava/hebbian_learning.md | 312 ------------------------ docs/tutorials/lava/introduction.md | 37 --- docs/tutorials/lava/lava_components.md | 29 --- docs/tutorials/lava/lava_context.md | 103 -------- docs/tutorials/lava/monitors.md | 17 -- docs/tutorials/lava/setup.md | 26 -- docs/tutorials/neurocog/adex_cell.md | 33 +-- 8 files changed, 13 insertions(+), 555 deletions(-) delete mode 100644 docs/tutorials/lava/hebbian_learning.md delete mode 100644 docs/tutorials/lava/introduction.md delete mode 100644 docs/tutorials/lava/lava_components.md delete mode 100644 docs/tutorials/lava/lava_context.md delete mode 100644 docs/tutorials/lava/monitors.md delete mode 100644 docs/tutorials/lava/setup.md diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst index f1640834..0efdacbf 100644 --- a/docs/tutorials/index.rst +++ b/docs/tutorials/index.rst @@ -29,14 +29,3 @@ to configuration). foundations/commands foundations/operations foundations/monitors - -.. toctree:: - :maxdepth: 1 - :caption: III. NGC-Lava: Support for Loihi 2 Transfer - - lava/introduction - lava/setup - lava/lava_components - lava/lava_context - lava/hebbian_learning - lava/monitors diff --git a/docs/tutorials/lava/hebbian_learning.md b/docs/tutorials/lava/hebbian_learning.md deleted file mode 100644 index c45a01b6..00000000 --- a/docs/tutorials/lava/hebbian_learning.md +++ /dev/null @@ -1,312 +0,0 @@ -# Training a Spiking Network On-chip - -In this tutorial we will build generate a simple dataset (consisting of binary -patterns of X's, O's, and T's) and train a model in the Loihi simulator -using Hebbian learning in the form of trace-based spike-timing-dependent -plasticity. - -## Setting up ngc-learn - -The first step of this project consist of setting up the configuration file for -ngc-learn. Create a folder in your project directory root called `json_files` -and -then create a `config.json` configuration inside of that folder. - -Now for this project we will not be loading anything dynamically, so we can -simply add: - -```json -{ - "modules": { - "module_path": null - } -} -``` - -The above configuration will skip the dynamic loading of modules, which is -important for Lava-based model transference and simulation. - -Next, in order to run code with the Loihi simulator, the base version of numpy -needs to be used instead of JAX's wrapped numpy (which is what ngc-learn resorts -to by default). To change all of ngc-learn over to using the base version of -numpy, simply add the following to your configuration: - -```json -"packages": { - "use_base_numpy": true -} -``` - -Now your project is configured for ngc-lava and Lava usage and we can move on -to data generation. - -## Generating Data - -For this project we will be using three different patterns to train a simple -biophysical spiking neural network; the data will simply consist of binary -image patterns of either an `X`, `O`, and a `T`. To create the file needed to -generate these patterns, create a Python script named `data_generator.py` in -your project root. Next, we will import `numpy` and `random` and -define the following three generator methods: - -```python -from ngclearn import numpy as np - - -def make_X(size): - X = np.zeros((size, size)) - for i in range(0, size): - X[i, i] = np.random.uniform(0.75, 1) - X[i, size - 1 - i] = np.random.uniform(0.75, 1) - return X - - -def make_O(size): - O = np.zeros((size, size)) - for i in range(0, (size // 2) - 1): - O[1 + i, (size // 2) - 1 - i] = np.random.uniform(0.75, 1) - O[1 + i, (size // 2) + i] = np.random.uniform(0.75, 1) - O[(size // 2) + i, 1 + i] = np.random.uniform(0.75, 1) - O[(size // 2) + i, size - 2 - i] = np.random.uniform(0.75, 1) - return O - - -def make_T(size): - T = np.zeros((size, size)) - T[1, 1:size - 1] = np.random.uniform(0.75, 1, (1, size - 2)) - for i in range(2, size - 1): - T[i, (size // 2) - 1: (size // 2) + 1] = np.random.uniform(0.75, 1, - (1, 2)) - return T -``` - -Each of these methods will create a pattern of the desired size and shape. - -## Building the Model - -Found below is all of the imports that will be needed to run the model we desire -in Lava: - -```python -from ngclava import LavaContext -from ngclearn import numpy as np -from ngclearn.components.lava import LIFCell, GatedTrace, TraceSTDPSynapse, StaticSynapse, Monitor -import ngclearn.utils.viz as viz_utils -import ngclearn.utils.weight_distribution as dist -from data_generator import make_X, make_O, make_T -``` - -To start off building this model, we will define all of the hyperparameters -needed to create the necessary model components: - -```python -# Training Params -epochs = 35 -view_length = 200 -rest_length = 1000 - -# Model Params -n_in = 64 # Input layer size -n_hid = 25 # Hidden layer size -dt = 1. # ms # integration time constant -np.random.seed(42) ## seed the internal numpy calls -``` - -After this we will create the lava context, the components, as well as the -wiring: - -```python -with LavaContext("Model") as model: - z0 = LIFCell("z0", n_units=n_in, thr_theta_init=dist.constant(0.), dt=dt, - tau_m=50., v_decay=0., tau_theta=500., - refract_T=0.) ## IF cell - z1e = LIFCell("z1e", n_units=n_hid, - thr_theta_init=dist.uniform(amin=-2, amax=2.), - dt=dt, tau_m=100., tau_theta=500.) ## excitatory LIF cell - z1i = LIFCell("z1i", n_units=n_hid, - thr_theta_init=dist.uniform(amin=-2, amax=2.), - dt=dt, tau_m=100., thr=-40., v_rest=-60., v_reset=-45., - theta_plus=0.) ## inhibitory LIF cell - - tr0 = GatedTrace("tr0", n_units=n_in, dt=dt, tau_tr=20.) - tr1 = GatedTrace("tr1", n_units=n_hid, dt=dt, tau_tr=20.) - - W1 = TraceSTDPSynapse("W1", weight_init=dist.uniform(amin=0, amax=0.3), - shape=(n_in, n_hid), dt=dt, Aplus=0.011, - Aminus=0.0011, - preTrace_target=0.055) - W1ie = StaticSynapse("W1ie", weight_init=dist.hollow(120.), - shape=(n_hid, n_hid), dt=dt) - W1ei = StaticSynapse("W1ei", weight_init=dist.eye(22.5), - shape=(n_hid, n_hid), dt=dt) - - M = Monitor("M", default_window_length=view_length) - - ## wire z0 to z1e via W1 and z1i to z1e via W1ie - W1.inputs << z0.s - W1ie.inputs << z1i.s - - z1e.j_exc << W1.outputs - z1e.j_inh << W1ie.outputs - - # wire z1e to z1i via W1ie - W1ei.inputs << z1e.s - z1i.j_exc << W1ei.outputs - - # wire cells z0 and z1e to their respective traces - tr0.inputs << z0.s - tr1.inputs << z1e.s - - # wire relevant compartment statistics to synaptic cable W1 (for STDP update) - W1.x_pre << tr0.trace - W1.pre << z0.s - W1.x_post << tr1.trace - W1.post << z1e.s - - # set up monitoring of z1e's spike output - M << z1e.s -``` - -After the components have been set up, we have to "lag out" the synapses that -will cause recurrent (locking) problems when running on the Loihi2. This will -cause each of these synapses to run one time-step behind and fixes many -recurrency -issues (as described [here](lava_context.md)). - -```python - model.set_lag('W1') - model.set_lag('W1ie') - model.set_lag('W1ei') -``` - -Now that the model is all set up, we have to tell the Lava compiler to actually -build all the Lava objects with the following: - -```python - model.rebuild_lava() -``` - -This line will stop the automatic build of components when leaving this -with-block and provides access to all of the Lava components inside of this -with-block. - -Next, we set up two methods, a `clamp` method to set the input data and -`viz` to visualize all of the different receptive fields of our model: - -```python - lz0, lW1 = model.get_lava_components('z0', 'W1') - - - @model.dynamicCommand - def clamp(x): - model.pause() - lz0.j_exc.set(x) - - - @model.dynamicCommand - def viz(): - viz_utils.synapse_plot.visualize([lW1.weights.get()], [(8, 8)], "lava_fields") -``` - -## Running The Model - -Now that everything is set up to build the runtime and start training the model -inside of the Loihi simulator. To set up the runtime we call the following: - -```python - model.set_up_runtime("z0", rest_image=np.zeros((1, 64))) -``` - -This will set up a runtime with `z0` as the root node and also uses a resting -image of all zeros to allow the system to return to its resting state. - -Now the training loop will be as follows: - -```python -with model.runtime: - for i in range(epochs): - print(f"\rStarting Epoch: {i}", end="") - X = np.reshape(make_X(8), (1, 64)) - O = np.reshape(make_O(8), (1, 64)) - T = np.reshape(make_T(8), (1, 64)) - - model.view(X, view_length) - model.rest(rest_length) - - model.view(O, view_length) - model.rest(rest_length) - - model.view(T, view_length) - model.rest(rest_length) - print("\nDone Training") - -``` - -## Evaluating the On-Chip Trained Model - -The code above will work to train the model on a Loihi neuromorphic chip, but, -currently, we do not have a way of viewing how effective the model learned -really is. To set up this evaluation, we can call -the `viz` method defined above to view the receptive fields that our spiking -model has acquired: - -```python - model.viz() -``` - -Running this should produce a set of receptive fields that look like the -following: - -
- -While viewing the receptive fields qualitatively tells us that our spiking -model has trained, we may also want to view the -[raster plots](ngclearn.utils.viz.raster) -- visual depictions of the -underlying spike patterns acquired in the hidden layer of our model -- for each -of our three image patterns (as they are fed into our trained model). To do -this, we will make use of the monitor we defined above in the following manner: - -```python - ## Turning off learning - lW1.eta.set(np.array([0])) - - model.view(np.reshape(make_T(8), (1, 64)), view_length) - model.write_to_ngc() - spikes = M.view(z1e.s) - viz_utils.raster.create_raster_plot(spikes, tag="T", plot_fname="raster_T") - model.rest(rest_length) - print("Done T") - - model.view(np.reshape(make_X(8), (1, 64)), view_length) - model.write_to_ngc() - spikes = M.view(z1e.s) - viz_utils.raster.create_raster_plot(spikes, tag="X", plot_fname="raster_X") - model.rest(rest_length) - print("Done X") - - model.view(np.reshape(make_O(8), (1, 64)), view_length) - model.write_to_ngc() - spikes = M.view(z1e.s) - viz_utils.raster.create_raster_plot(spikes, tag="O", plot_fname="raster_O") - model.rest(rest_length) - print("Done O") -``` - -The above should result in raster plots where the spikes correspond to the -receptive fields of each trained letter pattern. Specifically, you should see -that the top left field is `N0` and the bottom right is `N24`. Your raster plots -should look like the ones below: - -
-
-
- -Finally to save the model to disk, you can call the following: - -```python - model.save_to_json(".", model_name="trained") -``` - -which will save your on-chip trained Loihi model to disk for later use. - - diff --git a/docs/tutorials/lava/introduction.md b/docs/tutorials/lava/introduction.md deleted file mode 100644 index 994e3809..00000000 --- a/docs/tutorials/lava/introduction.md +++ /dev/null @@ -1,37 +0,0 @@ -# Blending ngc-learn and lava-nc - -The subpackage of ngclearn known as ngc-lava is an interfacing layer between -ngclearn's components and contexts and lava-nc's models and processes. In this -package, there is the introduction of the `LavaContext`, a subclass of the default -ngclearn `Context`. This context has all the same functionality as the base -ngclearn context but adds the ability to convert lava compatible components into -their Lava process and model automatically and on-the-fly. This allows for the -development and testing of models inside ngclearn prior to their deployment onto -a Loihi neuromorphic chip without needing to translate between the two models -written across the two different Python libraries. - -## Some Cautionary Notes - -- For the best experience in training models in ngclearn, Python version `>=3.10` - should be used. However, much of lava is written to be used in Python `3.8` and, - because of this, there are some flags and functionality that cannot be used in Lava - components directly. It is for this reason that ngc-learn has several - in-built "lava components", i.e., those in `ngclearn.components.lava` that - are meant to directly interact with ngc-lava; other components (such as those - (`ngclearn.components.neurons` or `ngclearn.components.synapses`) are not likely - to work and, when writing your own custom ngc-lava components, we recommend - that you use those in the `ngclearn.components.lava` subpackage as starting - points to see what design patterns will work with Lava. -- As of right now, all of ngc-lava is built using the Loihi2 configuration and - Loihi1 is not actively supported. Loihi1 might still work but nothing is - guaranteed nor has been tested by the ngc-learn dev team. - -## Table of Contents -1. [Setting up ngc-lava](setup.md): A brief overview of how to set up - ngc-lava -2. [Lava components](lava_components.md): An overview of lava components in ngclearn and - how to make custom ones -3. [Lava Context](lava_context.md): An overview of the Lava context and building - models for Lava -4. [On-Chip Hebbian Learning](hebbian_learning.md): A walkthrough for getting a simple - hebbian learning model setup diff --git a/docs/tutorials/lava/lava_components.md b/docs/tutorials/lava/lava_components.md deleted file mode 100644 index 79411db1..00000000 --- a/docs/tutorials/lava/lava_components.md +++ /dev/null @@ -1,29 +0,0 @@ -# Lava Components - -Inside ngc-learn, there is a wide variety of components with which biophysical -models can be built. Unfortunately, many of those components are not compatible -with Lava and the loihi2. Therefore, ngc-learn supports several in-built -components that are Lava-compliant; many of the components that are compatible -to you can be found in `ngclearn.components.lava`. - -## What Makes an ngc-learn Component Compatible - -For components to be compatible with Lava, there are a few key rules that must -be followed: -- Lava Components can not make use of JAX's random or JAX's `nn` libraries -- Lava Components must import numpy from ngclearn and not JAX (there is a flag - in the configuration file to control JAX's numpy versus base numpy) -- Lava Components cannot take in any runtime arguments to their `advance_state` method -- Lava Components cannot take in any runtime arguments or compartments to their - `reset` method(s) - -## Mapping Methods -- Going from ngc-learn to Lava - -There are two methods that are mapped to their lava processes; these include the -`reset` method and the `advance_state` method. The reset method is just mapped -to the lava components and can be called on them without any issue. The -`advance_state` method is mapped to the `run_spk` method and is called during -the runtime loops in Lava. It is important to note that the methods that are -actually mapped are the pure methods passed into the resolvers that -decorate the ngc-learn `reset` and `advance_state` methods, not the -`reset` and `advance_state` methods themselves. \ No newline at end of file diff --git a/docs/tutorials/lava/lava_context.md b/docs/tutorials/lava/lava_context.md deleted file mode 100644 index cb7888ff..00000000 --- a/docs/tutorials/lava/lava_context.md +++ /dev/null @@ -1,103 +0,0 @@ -# The Lava Context - -The lava context, i.e, the `LavaContext`, serves as the core to ngc-lava as well -as the main workhorse of all of its features. Since it is a subclass of the -default ngc-learn context, we will only be covering the new Lava-specific -features here. - -## Building Lava Components - -The Lava context generally keeps track of two sets of components -- the ngclearn -components and the Lava components. However, due to the nature of the lava -components themselves, they must be built once the model is fixed and cannot be -built on-the-fly. Due to this fact, the building of the lava components must -be triggered before they can be used. Nevertheless, there are a few ways to trigger the -building of the Lava components. It is important to note that only the latest set -of components can be used for methods like clamping and running. This will -affect all dynamically compiled methods. - -### Events that Trigger a Rebuild - -- When a `LavaContext` is first constructed via: `with LavaContext("model") as model:` - leaving the context block will trigger a rebuild -- Calling `with model.updater:` will rebuild the lava components upon leaving the - with-block -- Calling `model.rebuild_lava()` will rebuild the lava components even if it is - still inside a with-block. However, by default, it will stop the with-block - from recompiling upon exiting as doing so would overwrite the previously built - model components. - -### Events That Will Not Trigger a Rebuild - -Simply calling `with model` will not trigger a rebuild upon exiting since this is -where additional dynamic method can be defined as well as reference sub-models -while not triggering a complete rebuild of the Lava components each time. - -## The Runtime - -Inside of Lava, there is an internal runtime that is controlling the simulator -for the loihi2. This runtime must be started in order to act upon Lava components, -such as clamping values to their compartments as well as probing information -about the model. To help simplify this, the `LavaContext` comes with a built-in -runtime manager. To gain access to the ngc-lava runtime manager, first call -`model.set_up_runtime()`. Note that the `set_up_runtime` method takes two -arguments. The first is the root Lava component name to be used to start the -runtime -- this is how Lava knows what component it will need to simulate. The -second argument is the "rest" image -- the "rest" image is used to allow the -dynamical system that is your model to return to its reset state while -receiving no input (this is akin to allowing a biophysical neural system to relax -to its resting potential state). This can be left as `None` and doing so will -skip this functionality. Note that this method does not actually start -the runtime, it just configures everything. It is important to observe that a -clamp method fitting the signature `clamp(x) -> None` needs to be defined in -order to use certain runtime methods as defined below. - -### Runtime Methods - -- `with model.runtime`: The lava runtime will exist for the duration of this - with block. -- `model.start_runtime()`: This starts the runtime without the management of - automatically stopping it later. -- `model.pause()`: Pauses the runtime, allowing for values to be read and set. -- `model.stop()`: Stops the runtime, runtimes can not be restarted once they are - stopped. -- `model.run(t)`: Runs the runtime, for `t` time steps. Will automatically pause - upon completion. -- `model.view(x, t)`: First calls `model.clamp(x)` and then runs the runtime for - `t` steps. Will automatically pause upon completion. -- `model.rest(t)`: First calls `model.clamp(rest_image)` and then runs the - runtime for `t` steps. Will automatically pause upon completion. If a reset - image was not supplied, this runtime method will not be available. - -## Additional Utility Methods - -### Using Lags with: `set_lag(component_name, status=True)` - -In Lava, it is easy to lock your system if there is recurrence in your model. -The Lava context allows for you to temporally "lag" the values emitted by -specific components, delaying their executation with respect to the previous -time-step. - -By default, the process pattern for a mapped Lava component is: -`Receive values -> Process values -> Emit values` - -A lagged Lava component will follow the pattern: -`Emit values -> Receive values -> Process Values` - -Example: -> There is a model that has the wiring pattern of `Z0 -> W1 -> Z1 -> W1` -> Here we can see that in order for Z1 to emit values it relies on the values -> emitted by W1. But W1 also relies on values emitted from Z1. So if we lag -> W1 it will emit last timesteps value at the start of the loop and then wait -> for the new values meaning that the value emitted by W1 will be delayed by a -> timestep, but it will no longer lock Z1 from running. - -### `write_to_ngc()` - -This method is designed to copy the current state of the Lava model into the -ngc-learn model. This will do a one-to-one mapping of all of thecomponents and -their values from Lava to ngclearn. It is important to point out that this must -be done inside of a runtime. This is critical for saving since, in order to save -an on-chip-trained model, it must first be written back to ngc-lava/learn -and then to disk. By default, this is called by `model.save_to_json` if called -inside a runtime. diff --git a/docs/tutorials/lava/monitors.md b/docs/tutorials/lava/monitors.md deleted file mode 100644 index b5d82243..00000000 --- a/docs/tutorials/lava/monitors.md +++ /dev/null @@ -1,17 +0,0 @@ -# Monitors - -While lava does have its own version of monitors, ngclearn offers an -in-built version for convenience. It is -recommended that you use the ngclearn monitors as they have expanded -functionality and are designed to interact with the Lava components well. For an -overview of/details on how monitors work please see -[this](../foundations/monitors.md). The -only difference is that Lava has its own monitor found -in `ngclearn.components.lava`. - -## Sharp Edges and Bits - -- Due to the fact that a Lava component of the monitor must be built, it has to - be defined inside the `LavaContext`. -- To view the values found within the monitor via the `view()` and `get_path()` - methods, `model.write_to_ngc()` must be called to refresh the values. \ No newline at end of file diff --git a/docs/tutorials/lava/setup.md b/docs/tutorials/lava/setup.md deleted file mode 100644 index 542b479e..00000000 --- a/docs/tutorials/lava/setup.md +++ /dev/null @@ -1,26 +0,0 @@ -# Setting Up ngc-lava - -Setting up ngc-lava is fairly straightforward. The only part that takes some -time is the setting up of the lava environment itself. - -## Installation and Setup Steps - -1. To set up and use ngc-lava, first download lava-nc - found [here](https://lava-nc.org/lava/notebooks/in_depth/tutorial01_installing_lava.html). -2. Install ngc-learn via pip `pip install ngc-learn` -3. Clone ngc-lava and add it as a project source - ```bash - git clone https://github.com/NACLab/ngc-lava.git - pip install -e ngc-lava - ``` -4. To compile for lava, Jax must be turned off; to do this, set the flag - `packages/use_base_numpy` to `true` in the ngc-learn - `config.json` file. If you do not have a `config.json` file written, the - script below will make one for you and add the needed Lava configuration flag: - - ```bash - mkdir json_files - touch json_files/config.json - echo "{\n \"packages\": {\n \"use_base_numpy\": true \n }\n}" > json_files/config.json - ``` -5. You are now set up and ready to use ngc-lava. \ No newline at end of file diff --git a/docs/tutorials/neurocog/adex_cell.md b/docs/tutorials/neurocog/adex_cell.md index 4a685488..dcd2e9e6 100755 --- a/docs/tutorials/neurocog/adex_cell.md +++ b/docs/tutorials/neurocog/adex_cell.md @@ -22,9 +22,7 @@ AdEx cell amounts to the following: from jax import numpy as jnp, random, jit import numpy as np -from ngclearn.utils.model_utils import scanner -from ngcsimlib.context import Context -from ngclearn.utils import JaxProcess +from ngclearn import Context, MethodProcess ## import model-specific mechanisms from ngclearn.components.neurons.spiking.adExCell import AdExCell @@ -46,20 +44,15 @@ with Context("Model") as model: intrinsic_mem_thr=-55., v_thr=5., v_rest=-72., v_reset=-75., a=0.1, b=0.75, v0=v0, w0=w0, integration_type="euler", key=subkeys[0] ) - ## create and compile core simulation commands - advance_process = (JaxProcess() + advance_process = (MethodProcess("advance_proc") >> cell.advance_state) - model.wrap_and_add_command(jit(advance_process.pure), name="advance") - - reset_process = (JaxProcess() + reset_process = (MethodProcess("reset_proc") >> cell.reset) - model.wrap_and_add_command(jit(reset_process.pure), name="reset") - - ## set up non-compiled utility commands - @Context.dynamicCommand - def clamp(x): - cell.j.set(x) + +## set up non-compiled utility commands +def clamp(x): + cell.j.set(x) ``` In effect, the AdEx two-dimensional differential equation system [1]-[2] offers @@ -109,19 +102,19 @@ i_app = 19. ## electrical current to inject into AdEx cell data = jnp.asarray([[i_app]], dtype=jnp.float32) time_span = [] -model.reset() +reset_process.run() t = 0. for ts in range(T): x_t = data ## pass in t and dt and run step forward of simulation - model.clamp(x_t) - model.advance(t=t, dt=dt) + clamp(x_t) + advance_process.run(t=t, dt=dt) # t = t + dt ## naively extract simple statistics at time ts and print them to I/O - v = cell.v.value - w = cell.w.value - s = cell.s.value + v = cell.v.get() + w = cell.w.get() + s = cell.s.get() curr_in.append(data) mem_rec.append(v) recov_rec.append(w) From 043b0a8fd3a66407560466f80516febdfc616c15 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Mon, 17 Nov 2025 13:04:45 -0500 Subject: [PATCH 061/121] revised adex and error-cell neurocog tutorials --- docs/tutorials/neurocog/adex_cell.md | 38 +++++++-------------------- docs/tutorials/neurocog/error_cell.md | 35 +++++++++++------------- 2 files changed, 25 insertions(+), 48 deletions(-) diff --git a/docs/tutorials/neurocog/adex_cell.md b/docs/tutorials/neurocog/adex_cell.md index dcd2e9e6..f4ddc79a 100755 --- a/docs/tutorials/neurocog/adex_cell.md +++ b/docs/tutorials/neurocog/adex_cell.md @@ -108,7 +108,7 @@ for ts in range(T): x_t = data ## pass in t and dt and run step forward of simulation clamp(x_t) - advance_process.run(t=t, dt=dt) # + advance_process.run(t=t, dt=dt) # run one step of dynamics t = t + dt ## naively extract simple statistics at time ts and print them to I/O @@ -143,26 +143,27 @@ recov_rec = np.squeeze(np.asarray(recov_rec)) spk_rec = np.squeeze(np.asarray(spk_rec)) # Plot the AdEx cell trajectory -cell_tag = "RS" n_plots = 1 fig, ax = plt.subplots(1, n_plots, figsize=(5*n_plots,5)) ax_ptr = ax -ax_ptr.set(xlabel='Time', ylabel='Voltage (v)', - title="AdEx ({}) Voltage Dynamics".format(cell_tag)) +ax_ptr.set( + xlabel='Time', ylabel='Voltage (v)', title="AdEx Voltage Dynamics" +) v = ax_ptr.plot(time_span, mem_rec, color='C0') ax_ptr.legend([v[0]],['v']) plt.tight_layout() -plt.savefig("{0}".format("adex_v_plot.jpg".format(cell_tag.lower()))) +plt.savefig("{0}".format("adex_v_plot.jpg")) fig, ax = plt.subplots(1, n_plots, figsize=(5*n_plots,5)) ax_ptr = ax -ax_ptr.set(xlabel='Time', ylabel='Recovery (w)', - title="AdEx ({}) Recovery Dynamics".format(cell_tag)) +ax_ptr.set( + xlabel='Time', ylabel='Recovery (w)', title="AdEx Recovery Dynamics" +) w = ax_ptr.plot(time_span, recov_rec, color='C1', alpha=.5) ax_ptr.legend([w[0]],['w']) plt.tight_layout() -plt.savefig("{0}".format("adex_w_plot.jpg".format(cell_tag.lower()))) +plt.savefig("{0}".format("adex_w_plot.jpg")) plt.close() ``` @@ -187,27 +188,6 @@ however, one could configure it to use the midpoint method for integration by setting its argument `integration_type = rk2` in cases where more accuracy in the dynamics is needed (at the cost of additional computational time). -## Optional: Setting Up The Components with a JSON Configuration - -While you are not required to create a JSON configuration file for ngc-learn, -to get rid of the warning that ngc-learn will throw at the start of your -program's execution (indicating that you do not have a configuration set up yet), -all you need to do is create a sub-directory for your JSON configuration -inside of your project code's directory, i.e., `json_files/modules.json`. -Inside the JSON file, you would write the following: - -```json -[ - {"absolute_path": "ngclearn.components", - "attributes": [ - {"name": "AdExCell"}] - }, - {"absolute_path": "ngcsimlib.operations", - "attributes": [ - {"name": "overwrite"}] - } -] -``` ## References diff --git a/docs/tutorials/neurocog/error_cell.md b/docs/tutorials/neurocog/error_cell.md index 04368d5d..b7fbdb11 100644 --- a/docs/tutorials/neurocog/error_cell.md +++ b/docs/tutorials/neurocog/error_cell.md @@ -60,8 +60,8 @@ The code you would write amounts to the below: ```python from jax import numpy as jnp, jit -from ngcsimlib.context import Context -from ngclearn.utils import JaxProcess + +from ngclearn import Context, MethodProcess ## import model-specific mechanisms from ngclearn.components.neurons.graded.gaussianErrorCell import GaussianErrorCell @@ -71,32 +71,29 @@ T = 5 ## number time steps to simulate with Context("Model") as model: cell = GaussianErrorCell("z0", n_units=3) - advance_process = (JaxProcess() + advance_process = (MethodProcess("advance_proc") >> cell.advance_state) - model.wrap_and_add_command(jit(advance_process.pure), name="advance") - - reset_process = (JaxProcess() + reset_process = (MethodProcess("reset_proc") >> cell.reset) - model.wrap_and_add_command(jit(reset_process.pure), name="reset") - - @Context.dynamicCommand - def clamp(x, y): - ## error cells have two key input compartments; a "mu" and a "target" - cell.mu.set(x) - cell.target.set(y) +## set up non-compiled utility commands +def clamp(x, y): + ## error cells have two key input compartments; a "mu" and a "target" + cell.mu.set(x) + cell.target.set(y) + guess = jnp.asarray([[-1., 1., 1.]], jnp.float32) ## the produced guess or prediction answer = jnp.asarray([[1., -1., 1.]], jnp.float32) ## what we wish the guess had been -model.reset() +reset_process.run() for ts in range(T): - model.clamp(guess, answer) - model.advance(t=ts * 1., dt=dt) + clamp(guess, answer) + advance_process.run(t=ts * 1., dt=dt) ## extract compartment values of interest - dmu = cell.dmu.value - dtarget = cell.dtarget.value - loss = cell.L.value + dmu = cell.dmu.get() + dtarget = cell.dtarget.get() + loss = cell.L.get() ## print compartment values to I/O print("{} | dmu: {} dtarget: {} loss: {} ".format(ts, dmu, dtarget, loss)) ``` From 588e3f55f1022e65beab82c3cd696b721d6c5dc6 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Mon, 17 Nov 2025 14:02:01 -0500 Subject: [PATCH 062/121] fixed minor issues in input-encoders, further revisions to docs for v3 --- docs/tutorials/neurocog/hebbian.md | 21 ++++++------ docs/tutorials/neurocog/input_cells.md | 34 ++++++++----------- .../input_encoders/bernoulliCell.py | 16 +-------- .../components/input_encoders/latencyCell.py | 8 ++--- .../components/input_encoders/phasorCell.py | 10 ------ .../components/input_encoders/poissonCell.py | 8 +++-- .../input_encoders/test_poissonCell.py | 2 +- 7 files changed, 36 insertions(+), 63 deletions(-) diff --git a/docs/tutorials/neurocog/hebbian.md b/docs/tutorials/neurocog/hebbian.md index 8e67754c..6c6afbed 100644 --- a/docs/tutorials/neurocog/hebbian.md +++ b/docs/tutorials/neurocog/hebbian.md @@ -21,30 +21,29 @@ Specifically, we will zoom in on two particular code snippets from below: ```python -Wab = HebbianSynapse(name="Wab", shape=(1, 1), eta=1., signVal=-1., - wInit=("constant", 1., None), w_bound=0., key=subkeys[3]) +Wab = HebbianSynapse( + name="Wab", shape=(1, 1), eta=1., signVal=-1., wInit=("constant", 1., None), w_bound=0., key=subkeys[3] +) # wire output compartment (rate-coded output zF) of RateCell `a` to input compartment of HebbianSynapse `Wab` -Wab.inputs << a.zF +a.zF >> Wab.inputs # wire output compartment of HebbianSynapse `Wab` to input compartment (electrical current j) RateCell `b` -b.j << Wab.outputs +Wab.outputs >> b.j # wire output compartment (rate-coded output zF) of RateCell `a` to presynaptic compartment of HebbianSynapse `Wab` -Wab.pre << a.zF +a.zF >> Wab.pre # wire output compartment (rate-coded output zF) of RateCell `b` to postsynaptic compartment of HebbianSynapse `Wab` -Wab.post << b.zF +b.zF >> Wab.post ``` as well as (a bit later in the model construction code): ```python -evolve_process = (JaxProcess() +evolve_process = (MethodProcess() >> a.evolve) -circuit.wrap_and_add_command(jit(evolve_process.pure), name="evolve") -advance_process = (JaxProcess() +advance_process = (MethodProcess() >> a.advance_state) -circuit.wrap_and_add_command(jit(advance_process.pure), name="advance") ``` Notice that beyond wiring component `a`'s values into the synapse `Wab`'s input compartment @@ -54,7 +53,7 @@ post-synaptic compartment `Wab.post`. These compartments are specifically used in `Wab`'s `evolve` call and are not strictly required to be exactly the same as its input and output compartments. Note that, if one wanted `pre` and `post` to be exactly identical to `inputs` and `outputs`, one would simply need -to write `Wab.pre << Wab.inputs` and `Wab.post << Wab.outputs` in place +to write `Wab.inputs >> Wab.pre` and `Wab.outputs >> Wab.post` in place of the pre- and post-synaptic compartment calls above. The above snippets highlight two key aspects of functionality that a synapse diff --git a/docs/tutorials/neurocog/input_cells.md b/docs/tutorials/neurocog/input_cells.md index 1a58adac..c39c6ca7 100644 --- a/docs/tutorials/neurocog/input_cells.md +++ b/docs/tutorials/neurocog/input_cells.md @@ -39,8 +39,7 @@ spike train over $100$ steps in time as follows: ```python from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context -from ngclearn.utils import JaxProcess +from ngclearn import Context, MethodProcess from ngclearn.utils.viz.raster import create_raster_plot ## import model-specific mechanisms @@ -56,27 +55,24 @@ T = 100 ## number time steps to simulate with Context("Model") as model: cell = BernoulliCell("z0", n_units=10, key=subkeys[0]) - advance_process = (JaxProcess() + advance_process = (MethodProcess("advance_proc") >> cell.advance_state) - model.wrap_and_add_command(jit(advance_process.pure), name="advance") - reset_process = (JaxProcess() + reset_process = (MethodProcess("reset_proc") >> cell.reset) - model.wrap_and_add_command(jit(reset_process.pure), name="reset") - - @Context.dynamicCommand - def clamp(x): - cell.inputs.set(x) +def clamp(x): + cell.inputs.set(x) + probs = jnp.asarray([[0.8, 0.2, 0., 0.55, 0.9, 0, 0.15, 0., 0.6, 0.77]], dtype=jnp.float32) spikes = [] -model.reset() +reset_process.run() for ts in range(T): - model.clamp(probs) - model.advance(t=ts * 1., dt=dt) + clamp(probs) + advance_process.run(t=ts * 1., dt=dt) - s_t = cell.outputs.value + s_t = cell.outputs.get() spikes.append(s_t) spikes = jnp.concatenate(spikes, axis=0) create_raster_plot(spikes, plot_fname="input_cell_raster.jpg") @@ -121,7 +117,7 @@ and by replacing the line that has the `BernoulliCell` call with the following line instead: ```python -cell = PoissonCell("z0", n_units=10, max_freq=63.75, key=subkeys[0]) +cell = PoissonCell("z0", n_units=10, target_freq=63.75, key=subkeys[0]) ``` Running the code with the two above small modifications will @@ -149,12 +145,12 @@ mu = 0. probs = jnp.asarray([[1.]],dtype=jnp.float32) for _ in range(n_trials): spikes = [] - model.reset() + reset_process.run() for ts in range(T): - model.clamp(probs) - model.advance(t=ts*1., dt=dt) + clamp(probs) + advance_process.run(t=ts * 1., dt=dt) - s_t = cell.outputs.value + s_t = cell.outputs.get() spikes.append(s_t) count = jnp.sum(jnp.concatenate(spikes, axis=0)) mu += count diff --git a/ngclearn/components/input_encoders/bernoulliCell.py b/ngclearn/components/input_encoders/bernoulliCell.py index 1a5c6dca..52441430 100755 --- a/ngclearn/components/input_encoders/bernoulliCell.py +++ b/ngclearn/components/input_encoders/bernoulliCell.py @@ -26,7 +26,7 @@ class BernoulliCell(JaxComponent): batch_size: batch size dimension of this cell (Default: 1) """ - def __init__(self, name: str, n_units: int, batch_size: int = 1, key: Union[jax.Array, None] = None): + def __init__(self, name: str, n_units: int, batch_size: int = 1, key: Union[jax.Array, None] = None, **kwargs): super().__init__(name=name, key=key) ## Layer Size Setup @@ -80,20 +80,6 @@ def help(cls): ## component help function "hyperparameters": hyperparams} return info - # def __repr__(self): - # comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] - # maxlen = max(len(c) for c in comps) + 5 - # lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" - # for c in comps: - # stats = tensorstats(getattr(self, c).value) - # if stats is not None: - # line = [f"{k}: {v}" for k, v in stats.items()] - # line = ", ".join(line) - # else: - # line = "None" - # lines += f" {f'({c})'.ljust(maxlen)}{line}\n" - # return lines - if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: diff --git a/ngclearn/components/input_encoders/latencyCell.py b/ngclearn/components/input_encoders/latencyCell.py index 82b16e52..30832afe 100755 --- a/ngclearn/components/input_encoders/latencyCell.py +++ b/ngclearn/components/input_encoders/latencyCell.py @@ -144,11 +144,9 @@ class LatencyCell(JaxComponent): """ def __init__( - self, name: str, n_units: int, tau: float = 1., threshold: float = 0.01, - first_spike_time: float = 0., linearize: bool = False, - normalize: bool = False, clip_spikes: bool = False, - num_steps: float = 1., batch_size: int = 1, - key: Union[jax.Array, None] = None + self, name: str, n_units: int, tau: float = 1., threshold: float = 0.01, first_spike_time: float = 0., + linearize: bool = False, normalize: bool = False, clip_spikes: bool = False, num_steps: float = 1., + batch_size: int = 1, key: Union[jax.Array, None] = None, **kwargs ): super().__init__(name=name, key=key) diff --git a/ngclearn/components/input_encoders/phasorCell.py b/ngclearn/components/input_encoders/phasorCell.py index ccfbb15d..0b0a11ad 100755 --- a/ngclearn/components/input_encoders/phasorCell.py +++ b/ngclearn/components/input_encoders/phasorCell.py @@ -137,16 +137,6 @@ def reset(self): self.angles.set(restVals) self.key.set(key) - - def save(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - jnp.savez(file_name, key=self.key.value) - - def load(self, directory, **kwargs): - file_name = directory + "/" + self.name + ".npz" - data = jnp.load(file_name) - self.key.set(data['key']) - @classmethod def help(cls): ## component help function properties = { diff --git a/ngclearn/components/input_encoders/poissonCell.py b/ngclearn/components/input_encoders/poissonCell.py index 47869f1a..de6b2a97 100644 --- a/ngclearn/components/input_encoders/poissonCell.py +++ b/ngclearn/components/input_encoders/poissonCell.py @@ -3,6 +3,7 @@ import jax from typing import Union +from ngcsimlib import deprecate_args from ngcsimlib.parser import compilable from ngcsimlib.compartment import Compartment @@ -29,8 +30,11 @@ class PoissonCell(JaxComponent): batch_size: batch size dimension of this cell (Default: 1) """ - def __init__(self, name: str, n_units: int, target_freq: float = 63.75, batch_size: int = 1, - key: Union[jax.Array, None] = None): + @deprecate_args(max_freq="target_freq") + def __init__( + self, name: str, n_units: int, target_freq: float = 63.75, batch_size: int = 1, + key: Union[jax.Array, None] = None, **kwargs + ): super().__init__(name=name, key=key) ## Constrained Bernoulli meta-parameters diff --git a/tests/components/input_encoders/test_poissonCell.py b/tests/components/input_encoders/test_poissonCell.py index f21f062a..fd29a13b 100644 --- a/tests/components/input_encoders/test_poissonCell.py +++ b/tests/components/input_encoders/test_poissonCell.py @@ -1,7 +1,7 @@ from jax import numpy as jnp, random, jit import numpy as np np.random.seed(42) -from ngclearn.components import PoissonCell +from ngclearn.components.input_encoders.poissonCell import PoissonCell from numpy.testing import assert_array_equal from ngclearn import MethodProcess, Context From 7c56b47c50ab3ad575d89b14c9c2b85a2e953150 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Mon, 17 Nov 2025 14:55:29 -0500 Subject: [PATCH 063/121] revised dyn/chem-syn neurocog doc, cleaned up dynamic syn --- docs/tutorials/neurocog/dynamic_synapses.md | 106 ++++++++---------- .../components/input_encoders/phasorCell.py | 3 +- ngclearn/components/synapses/alphaSynapse.py | 16 --- .../components/synapses/doubleExpSynapse.py | 18 +-- .../components/synapses/exponentialSynapse.py | 16 --- 5 files changed, 52 insertions(+), 107 deletions(-) diff --git a/docs/tutorials/neurocog/dynamic_synapses.md b/docs/tutorials/neurocog/dynamic_synapses.md index bc708264..a921a43e 100644 --- a/docs/tutorials/neurocog/dynamic_synapses.md +++ b/docs/tutorials/neurocog/dynamic_synapses.md @@ -3,7 +3,7 @@ In this lesson, we will study dynamic synapses, or synaptic cable components in ngc-learn that evolve on fast time-scales in response to their pre-synaptic inputs. These types of chemical synapse components are useful for modeling time-varying -conductance which ultimately drives eletrical current input into neuronal units +conductance which ultimately drives electrical current input into neuronal units (such as spiking cells). Here, we will learn how to build three important types of dynamic synapses in ngc-learn -- the exponential, the alpha, and the double-exponential synapse -- and visualize the time-course of their resulting conductances. In addition, we will then @@ -24,17 +24,14 @@ value matrices we might initially employ (as in synapse components such as the Building a dynamic synapse can be done by importing the [exponential synapse](ngclearn.components.synapses.exponentialSynapse), the [double-exponential synapse](ngclearn.components.synapses.doubleExpSynapse), or the [alpha synapse](ngclearn.components.synapses.alphaSynapse) from ngc-learn's in-built components and setting them up within a model context for easy analysis. Go ahead and create a Python script named `probe_synapses.py` to place the code you will write within. -For the first part of this lesson, we will import all three dynamic synpapse models and compare their behavior. +For the first part of this lesson, we will import all three dynamic synapse models and compare their behavior. This can be done as follows (using the meta-parameters we provide in the code block below to ensure reasonable dynamics): ```python from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context -from ngclearn.components import ExponentialSynapse, AlphaSynapse, DoupleExpSynapse - -from ngcsimlib.compilers.process import Process -from ngcsimlib.context import Context -import ngclearn.utils.weight_distribution as dist +from ngclearn import Context, MethodProcess +from ngclearn.components import ExponentialSynapse, AlphaSynapse, DoubleExpSynapse +from ngclearn.utils.distribution_generator import DistributionGenerator dkey = random.PRNGKey(1234) ## creating seeding keys for synapses @@ -46,29 +43,27 @@ T = 8. # ms ## total duration time with Context("dual_syn_system") as ctx: Wexp = ExponentialSynapse( ## exponential dynamic synapse name="Wexp", shape=(1, 1), tau_decay=3., g_syn_bar=1., syn_rest=0., resist_scale=1., - weight_init=dist.constant(value=1.), key=subkeys[0] + weight_init=DistributionGenerator.constant(value=1.), key=subkeys[0] ) Walpha = AlphaSynapse( ## alpha dynamic synapse name="Walpha", shape=(1, 1), tau_decay=1., g_syn_bar=1., syn_rest=0., resist_scale=1., - weight_init=dist.constant(value=1.), key=subkeys[0] + weight_init=DistributionGenerator.constant(value=1.), key=subkeys[0] ) - Wexp2 = DoupleExpSynapse( + Wexp2 = DoubleExpSynapse( name="Wexp2", shape=(1, 1), tau_rise=1., tau_decay=3., g_syn_bar=1., syn_rest=0., resist_scale=1., - weight_init=dist.constant(value=1.), key=subkeys[0] + weight_init=DistributionGenerator.constant(value=1.), key=subkeys[0] ) ## set up basic simulation process calls - advance_process = (Process("advance_proc") + advance_process = (MethodProcess("advance_proc") >> Wexp.advance_state >> Walpha.advance_state >> Wexp2.advance_state) - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> Wexp.reset >> Walpha.reset >> Wexp2.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") ``` where we notice in the above we have instantiated three different kinds of chemical synapse components @@ -90,7 +85,7 @@ $$ $$ where the conductance (for a post-synaptic unit) output of this synapse is driven by a sum over all of its incoming -pre-synaptic spikes; this ODE means that pre-synaptic spikes are filtered via an expoential kernel (i.e., a low-pass filter). +pre-synaptic spikes; this ODE means that pre-synaptic spikes are filtered via an exponential kernel (i.e., a low-pass filter). On the other hand, for the alpha synapse, the dynamics adhere to the following coupled set of ODEs: $$ @@ -100,7 +95,7 @@ $$ where $h_{\text{syn}}(t)$ is an intermediate variable that operates in service of driving the conductance variable $g_{\text{syn}}(t)$ itself. The double-exponential (or difference of exponentials) synapse model looks similar to the alpha synapse except that the -rise and fall/decay of its condutance dynamics are set separately using two different time constants, i.e., $\tau_{\text{rise}}$ and $\tau_{\text{decay}}$, +rise and fall/decay of its conductance dynamics are set separately using two different time constants, i.e., $\tau_{\text{rise}}$ and $\tau_{\text{decay}}$, as follows: $$ @@ -128,7 +123,7 @@ time_span = [] g = [] ga = [] gexp2 = [] -ctx.reset() +reset_process.run() Tsteps = int(T/dt) + 1 for t in range(Tsteps): s_t = jnp.zeros((1, 1)) @@ -136,21 +131,23 @@ for t in range(Tsteps): s_t = jnp.ones((1, 1)) Wexp.inputs.set(s_t) Walpha.inputs.set(s_t) - Wexp.v.set(Wexp.v.value * 0) + Wexp.v.set(Wexp.v.get() * 0) Wexp2.inputs.set(s_t) - Walpha.v.set(Walpha.v.value * 0) - Wexp2.v.set(Wexp2.v.value * 0) - ctx.run(t=t * dt, dt=dt) - - print(f"\r g = {Wexp.g_syn.value} ga = {Walpha.g_syn.value} gexp2 = {Wexp2.g_syn.value}", end="") - g.append(Wexp.g_syn.value) - ga.append(Walpha.g_syn.value) + Walpha.v.set(Walpha.v.get() * 0) + Wexp2.v.set(Wexp2.v.get() * 0) + advance_process.run(t=t * dt, dt=dt) + + print(f"\r g = {Wexp.g_syn.get()} ga = {Walpha.g_syn.get()} gexp2 = {Wexp2.g_syn.get()}", end="") + g.append(Wexp.g_syn.get()) + ga.append(Walpha.g_syn.get()) + gexp2.append(Wexp2.g_syn.get()) time_span.append(t) #* dt) print() g = jnp.squeeze(jnp.concatenate(g, axis=1)) g = g/jnp.amax(g) ga = jnp.squeeze(jnp.concatenate(ga, axis=1)) ga = ga/jnp.amax(ga) +gexp2 = jnp.squeeze(jnp.concatenate(gexp2, axis=1)) gexp2 = gexp2/jnp.amax(gexp2) ``` @@ -195,6 +192,9 @@ ax.grid(which="major") fig.savefig("alpha_syn.jpg") plt.close() +## ---- plot the double-exponential synapse conductance time-course ---- +fig, ax = plt.subplots() + gvals = ax.plot(time_span, gexp2, '-', color='tab:blue') #plt.xticks(time_span, time_labs) ax.set_xticks(time_ticks, time_labs) @@ -207,7 +207,7 @@ plt.close() ``` which should produce and save three plots to disk. You can then compare and contrast the plots of the -expoential, alpha synapse, and double-exponential conductance trajectories: +exponential, alpha synapse, and double-exponential conductance trajectories: ```{eval-rst} .. table:: @@ -222,7 +222,7 @@ expoential, alpha synapse, and double-exponential conductance trajectories: Note that the alpha synapse (right figure) would produce a more realistic fit to recorded synaptic currents (as it attempts to model the rise and fall of current in a less simplified manner) at the cost of extra compute, given it uses two ODEs to -emulate condutance, as opposed to the faster yet less-biophysically-realistic exponential synapse (left figure). +emulate conductance, as opposed to the faster yet less-biophysically-realistic exponential synapse (left figure). ## Excitatory-Inhibitory Driven Dynamics @@ -243,13 +243,10 @@ We will specifically model the excitatory and inhibitory conductance changes usi ```python from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context +from ngclearn import Context, MethodProcess +from ngclearn.operations import Summation from ngclearn.components import ExponentialSynapse, PoissonCell, LIFCell -from ngclearn.operations import summation - -from ngcsimlib.compilers.process import Process -from ngcsimlib.context import Context -import ngclearn.utils.weight_distribution as dist +from ngclearn.utils.distribution_generator import DistributionGenerator ## create seeding keys dkey = random.PRNGKey(1234) @@ -287,39 +284,36 @@ with Context("ei_snn") as ctx: pre_inh = PoissonCell("pre_inh", n_units=n_inh, target_freq=inh_freq, key=subkeys[1]) ## pre-syn inhibitory group Wexc = ExponentialSynapse( ## dynamic synapse between excitatory group and LIF name="Wexc", shape=(n_exc,1), tau_decay=tau_syn_exc, g_syn_bar=g_e_bar, syn_rest=E_rest_exc, resist_scale=1./g_L, - weight_init=dist.constant(value=1.), key=subkeys[2] + weight_init=DistributionGenerator.constant(value=1.), key=subkeys[2] ) Winh = ExponentialSynapse( ## dynamic synapse between inhibitory group and LIF name="Winh", shape=(n_inh, 1), tau_decay=tau_syn_inh, g_syn_bar=g_i_bar, syn_rest=E_rest_inh, resist_scale=1./g_L, - weight_init=dist.constant(value=1.), key=subkeys[2] + weight_init=DistributionGenerator.constant(value=1.), key=subkeys[2] ) post_exc = LIFCell( ## post-syn LIF cell "post_exc", n_units=1, tau_m=tau_m, resist_m=1., thr=v_thr, v_rest=v_rest, conduct_leak=1., v_reset=-75., tau_theta=0., theta_plus=0., refract_time=2., key=subkeys[3] ) - Wexc.inputs << pre_exc.outputs - Winh.inputs << pre_inh.outputs - Wexc.v << post_exc.v ## couple voltage to exc synapse - Winh.v << post_exc.v ## couple voltage to inh synapse - post_exc.j << summation(Wexc.i_syn, Winh.i_syn) ## sum together excitatory & inhibitory pressures + pre_exc.outputs >> Wexc.inputs + pre_inh.outputs >> Winh.inputs + post_exc.v >> Wexc.v ## couple voltage to exc synapse + post_exc.v >> Winh.v ## couple voltage to inh synapse + Summation(Wexc.i_syn, Winh.i_syn) >> post_exc.j ## sum together excitatory & inhibitory pressures - advance_process = (Process("advance_proc") + advance_process = (MethodProcess("advance_proc") >> pre_exc.advance_state >> pre_inh.advance_state >> Wexc.advance_state >> Winh.advance_state >> post_exc.advance_state) - # ctx.wrap_and_add_command(advance_process.pure, name="run") - ctx.wrap_and_add_command(jit(advance_process.pure), name="run") - reset_process = (Process("reset_proc") + reset_process = (MethodProcess("reset_proc") >> pre_exc.reset >> pre_inh.reset >> Wexc.reset >> Winh.reset >> post_exc.reset) - ctx.wrap_and_add_command(jit(reset_process.pure), name="reset") ``` ### Examining the Simple Spiking Circuit's Behavior @@ -331,18 +325,18 @@ volts = [] time_span = [] spikes = [] -ctx.reset() +reset_process.run() pre_exc.inputs.set(jnp.ones((1, n_exc))) pre_inh.inputs.set(jnp.ones((1, n_inh))) -post_exc.v.set(post_exc.v.value * 0 - 65.) ## initial condition for LIF is -65 mV -volts.append(post_exc.v.value) +post_exc.v.set(post_exc.v.get() * 0 - 65.) ## initial condition for LIF is -65 mV +volts.append(post_exc.v.get()) time_span.append(0.) Tsteps = int(T/dt) + 1 for t in range(1, Tsteps): - ctx.run(t=t * dt, dt=dt) - print(f"\r v {post_exc.v.value}", end="") - volts.append(post_exc.v.value) - spikes.append(post_exc.s.value) + advance_process.run(t=t * dt, dt=dt) + print(f"\r v {post_exc.v.get()}", end="") + volts.append(post_exc.v.get()) + spikes.append(post_exc.s.get()) time_span.append(t) #* dt) print() volts = jnp.squeeze(jnp.concatenate(volts, axis=1)) @@ -384,9 +378,7 @@ ax.grid() fig.savefig("ei_circuit_dynamics.jpg") ``` -which should produce a figure depicting dynamics similar to the one below. Black tick -marks indicate post-synaptic pulses whereas the horizontal dashed blue shows the LIF unit's -voltage threshold. +which should produce a figure depicting dynamics similar to the one below. Black tick marks indicate post-synaptic pulses whereas the horizontal dashed blue shows the LIF unit's voltage threshold. ```{eval-rst} diff --git a/ngclearn/components/input_encoders/phasorCell.py b/ngclearn/components/input_encoders/phasorCell.py index 0b0a11ad..594e3b9d 100755 --- a/ngclearn/components/input_encoders/phasorCell.py +++ b/ngclearn/components/input_encoders/phasorCell.py @@ -32,7 +32,8 @@ class PhasorCell(JaxComponent): """ def __init__( - self, name, n_units, target_freq=63.75, batch_size=1, disable_phasor=False, **kwargs): + self, name, n_units, target_freq=63.75, batch_size=1, disable_phasor=False, **kwargs + ): super().__init__(name, **kwargs) ## Phasor meta-parameters diff --git a/ngclearn/components/synapses/alphaSynapse.py b/ngclearn/components/synapses/alphaSynapse.py index fc529b3f..4470af68 100644 --- a/ngclearn/components/synapses/alphaSynapse.py +++ b/ngclearn/components/synapses/alphaSynapse.py @@ -1,6 +1,4 @@ from jax import random, numpy as jnp, jit -from ngclearn.utils.weight_distribution import initialize_params -from ngcsimlib.logger import info from ngclearn.components.synapses import DenseSynapse from ngcsimlib.compartment import Compartment @@ -115,20 +113,6 @@ def reset(self): self.h_syn.set(postVals) self.v.set(postVals) - # def save(self, directory, **kwargs): - # file_name = directory + "/" + self.name + ".npz" - # if self.bias_init != None: - # jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value) - # else: - # jnp.savez(file_name, weights=self.weights.value) - # - # def load(self, directory, **kwargs): - # file_name = directory + "/" + self.name + ".npz" - # data = jnp.load(file_name) - # self.weights.set(data['weights']) - # if "biases" in data.keys(): - # self.biases.set(data['biases']) - @classmethod def help(cls): ## component help function properties = { diff --git a/ngclearn/components/synapses/doubleExpSynapse.py b/ngclearn/components/synapses/doubleExpSynapse.py index 62cce850..ca1fdcdd 100644 --- a/ngclearn/components/synapses/doubleExpSynapse.py +++ b/ngclearn/components/synapses/doubleExpSynapse.py @@ -1,6 +1,4 @@ from jax import random, numpy as jnp, jit -from ngclearn.utils.weight_distribution import initialize_params -from ngcsimlib.logger import info from ngclearn.components.synapses import DenseSynapse from ngcsimlib.compartment import Compartment @@ -85,7 +83,7 @@ def __init__( self.weights.set(self.weights.get() * 0 + 1.) @compilable - def advance_state(self, t, dt): #dt, tau_decay, tau_rise, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, h_syn, v + def advance_state(self, t, dt): s = self.inputs.get() #A = tau_decay/(tau_decay - tau_rise) * jnp.power((tau_rise/tau_decay), tau_rise/(tau_rise - tau_decay)) A = 1. ## FIXME: scale factor to use? @@ -121,20 +119,6 @@ def reset(self): self.h_syn.set(postVals) self.v.set(postVals) - # def save(self, directory, **kwargs): - # file_name = directory + "/" + self.name + ".npz" - # if self.bias_init != None: - # jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value) - # else: - # jnp.savez(file_name, weights=self.weights.value) - # - # def load(self, directory, **kwargs): - # file_name = directory + "/" + self.name + ".npz" - # data = jnp.load(file_name) - # self.weights.set(data['weights']) - # if "biases" in data.keys(): - # self.biases.set(data['biases']) - @classmethod def help(cls): ## component help function properties = { diff --git a/ngclearn/components/synapses/exponentialSynapse.py b/ngclearn/components/synapses/exponentialSynapse.py index d8ba9b5f..e0ee3a6e 100644 --- a/ngclearn/components/synapses/exponentialSynapse.py +++ b/ngclearn/components/synapses/exponentialSynapse.py @@ -1,6 +1,4 @@ from jax import random, numpy as jnp, jit -from ngclearn.utils.weight_distribution import initialize_params -from ngcsimlib.logger import info from ngclearn.components.synapses import DenseSynapse from ngcsimlib.compartment import Compartment @@ -107,20 +105,6 @@ def reset(self): self.g_syn.set(postVals) self.v.set(postVals) - # def save(self, directory, **kwargs): - # file_name = directory + "/" + self.name + ".npz" - # if self.bias_init != None: - # jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value) - # else: - # jnp.savez(file_name, weights=self.weights.value) - # - # def load(self, directory, **kwargs): - # file_name = directory + "/" + self.name + ".npz" - # data = jnp.load(file_name) - # self.weights.set(data['weights']) - # if "biases" in data.keys(): - # self.biases.set(data['biases']) - @classmethod def help(cls): ## component help function properties = { From 9373d8d2e4d6308c4833b66c871bfa4f0f9fbd2f Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Mon, 17 Nov 2025 15:46:09 -0500 Subject: [PATCH 064/121] revised fn and hh-cell neurocog docs, added some refs to distribution generator --- docs/tutorials/neurocog/dynamic_synapses.md | 2 +- .../neurocog/fitzhugh_nagumo_cell.md | 105 ++++-------------- .../tutorials/neurocog/hodgkin_huxley_cell.md | 94 +++++----------- ngclearn/utils/distribution_generator.py | 7 ++ 4 files changed, 57 insertions(+), 151 deletions(-) diff --git a/docs/tutorials/neurocog/dynamic_synapses.md b/docs/tutorials/neurocog/dynamic_synapses.md index a921a43e..d4e9b902 100644 --- a/docs/tutorials/neurocog/dynamic_synapses.md +++ b/docs/tutorials/neurocog/dynamic_synapses.md @@ -22,7 +22,7 @@ value matrices we might initially employ (as in synapse components such as the [DenseSynapse](ngclearn.components.synapses.denseSynapse)). Building a dynamic synapse can be done by importing the [exponential synapse](ngclearn.components.synapses.exponentialSynapse), -the [double-exponential synapse](ngclearn.components.synapses.doubleExpSynapse), or the [alpha synapse](ngclearn.components.synapses.alphaSynapse) from ngc-learn's in-built components and setting them up within a model context for easy analysis. Go ahead and create a Python script named `probe_synapses.py` to place +the [double-exponential synapse](ngclearn.components.synapses.doubleExpSynapse), or the [alpha synapse](ngclearn.components.synapses.alphaSynapse) from ngc-learn's in-built components and setting them up within a model context for easy analysis. Go ahead and create a Python script named `probe_dynamic_synapses.py` to place the code you will write within. For the first part of this lesson, we will import all three dynamic synapse models and compare their behavior. This can be done as follows (using the meta-parameters we provide in the code block below to ensure reasonable dynamics): diff --git a/docs/tutorials/neurocog/fitzhugh_nagumo_cell.md b/docs/tutorials/neurocog/fitzhugh_nagumo_cell.md index 6575fc06..832b6ab1 100644 --- a/docs/tutorials/neurocog/fitzhugh_nagumo_cell.md +++ b/docs/tutorials/neurocog/fitzhugh_nagumo_cell.md @@ -17,8 +17,7 @@ single component system made up of the Fitzhugh-Nagumo (`F-N`) cell. from jax import numpy as jnp, random, jit import numpy as np -from ngcsimlib.context import Context -from ngclearn.utils import JaxProcess +from ngclearn import Context, MethodProcess ## import model-specific mechanisms from ngclearn.components.neurons.spiking.fitzhughNagumoCell import FitzhughNagumoCell @@ -40,42 +39,21 @@ with Context("Model") as model: gamma=gamma, v0=v0, w0=w0, integration_type="euler") ## create and compile core simulation commands - advance_process = (JaxProcess() + advance_process = (MethodProcess("advance") >> cell.advance_state) - model.wrap_and_add_command(jit(advance_process.pure), name="advance") - reset_process = (JaxProcess() + reset_process = (MethodProcess("reset") >> cell.reset) - model.wrap_and_add_command(jit(reset_process.pure), name="reset") - ## set up non-compiled utility commands - @Context.dynamicCommand - def clamp(x): - cell.j.set(x) +## set up non-compiled utility commands +def clamp(x): + cell.j.set(x) ``` -In effect, the FitzHugh–Nagumo `F-N` two-dimensional differential -equation system (developed by [1] and [2]) is -a useful simplification of the more intricate Hodgkin–Huxley (H-H) squid axon -model, attempting to extract some of the benefits of its more detailed modeling -of the spiking cellular activation and deactivation dynamics (specifically -attempting to isolate the properties related to sodium/potassium ion flow -from cellular properties of excitation and propagation). Notably, the `F-N` -cell models membrane potential `v` with a cubic function (which facilitates -self-excitation through positive feedback) in tandem with a recovery variable `w` -that provides a slower form of negative feedback. The linear dynamics that govern -`w` are controlled by (dimensionless) coefficients `alpha` and `beta`, which -control its shift and scale, respectively (another factor `gamma` is introduced -in our implementation, which divides the cubic term in the voltage dynamics, but -generally this can usually be set to either a value of `1` or `3` as in [1]). -The value `tau_w` controls the time constant for the recovery variable (and, -technically, ngc-learn implements `tau_m` to control the membrane potential, -but this is default set to `1` since [1] and [2] typically only use a time -constant for the recovery variable). - -The initial conditions for the voltage (i.e., `v0`) and the recovery (i.e., `w0`) -have been set to particular interesting values above for the demonstration -purposes of this tutorial but, by default, are `0` in the `F-N` cell component. +In effect, the FitzHugh–Nagumo `F-N` two-dimensional differential equation system (developed by [1] and [2]) is a useful simplification of the more intricate Hodgkin–Huxley (H-H) squid axon model, attempting to extract some of the benefits of its more detailed modeling of the spiking cellular activation and deactivation dynamics (specifically attempting to isolate the properties related to sodium/potassium ion flow from cellular properties of excitation and propagation). Notably, the `F-N` cell models membrane potential `v` with a cubic function (which facilitates self-excitation through positive feedback) in tandem with a recovery variable `w` that provides a slower form of negative feedback. The linear dynamics that govern `w` are controlled by (dimensionless) coefficients `alpha` and `beta`, which control its shift and scale, respectively (another factor `gamma` is introduced in our implementation, which divides the cubic term in the voltage dynamics, but generally this can usually be set to either a value of `1` or `3` as in [1]). +The value `tau_w` controls the time constant for the recovery variable (and, technically, ngc-learn implements `tau_m` to control the membrane potential, but this is default set to `1` since [1] and [2] typically only use a time constant for the recovery variable). + +The initial conditions for the voltage (i.e., `v0`) and the recovery (i.e., `w0`) have been set to particular interesting values above for the demonstration purposes of this tutorial but, by default, are `0` in the `F-N` cell component. Formally, the core dynamics of the `F-N` can be written out as follows: @@ -84,24 +62,13 @@ $$ \tau_w \frac{\partial \mathbf{w}_t}{\partial t} &= \mathbf{v}_t + a - b\mathbf{w}_t $$ -where $a$ and $b$ are factors that drive the recovery variable's dynamics -(shift and scaling, respectively), $R$ is the membrane resistance, $\tau_m$ is the -membrane time constant, and $\tau_w$ is the recovery time constant ($g$ is a -dividing constant meant to dampen the effects of the cubic term, but is generally -set to $g = 1$ to adhere to [1] and [2]) +where $a$ and $b$ are factors that drive the recovery variable's dynamics (shift and scaling, respectively), $R$ is the membrane resistance, $\tau_m$ is the membrane time constant, and $\tau_w$ is the recovery time constant ($g$ is a dividing constant meant to dampen the effects of the cubic term, but is generally set to $g = 1$ to adhere to [1] and [2]) ### Simulating a FitzHugh–Nagumo Neuronal Cell -Given that we have a single-cell dynamical system set up as above, we can next -write some code for visualizing how the `F-N` node's membrane potential and -coupled recovery variable evolve with time (specifically over a period of about -`200` milliseconds). We will, much as we did with the leaky integrators in -prior tutorials, inject an electrical current `j` into the `F-N` cell (this time -just a constant current value of `0.23` amperes) and observe how the cell -produces action potentials. -Specifically, we can plot the neuron's voltage `v` and recovery variable `w` -as follows: +Given that we have a single-cell dynamical system set up as above, we can next write some code for visualizing how the `F-N` node's membrane potential and coupled recovery variable evolve with time (specifically over a period of about `200` milliseconds). We will, much as we did with the leaky integrators in prior tutorials, inject an electrical current `j` into the `F-N` cell (this time just a constant current value of `0.23` amperes) and observe how the cell produces action potentials. +Specifically, we can plot the neuron's voltage `v` and recovery variable `w` as follows: ```python curr_in = [] @@ -119,26 +86,26 @@ time_span = np.linspace(0, 200, num=T) dt = time_span[1] - time_span[0] # ~ 0.13342228152101404 ms time_span = [] -model.reset() +reset_process.run() t = 0. for ts in range(T): x_t = data ## pass in t and dt and run step forward of simulation - model.clamp(x_t) - model.advance(t=t, dt=dt) + clamp(x_t) + advance_process.run(t=t, dt=dt) t = t + dt ## naively extract simple statistics at time ts and print them to I/O - v = cell.v.value - w = cell.w.value - s = cell.s.value + v = cell.v.get() + w = cell.w.get() + s = cell.s.get() curr_in.append(data) mem_rec.append(v) recov_rec.append(w) spk_rec.append(s) ## print stats to I/O (overriding previous print-outs to reduce clutter) print("\r {}: s {} ; v {} ; w {}".format(ts, s, v, w), end="") - time_span.append((ts)*dt) + time_span.append(ts * dt) print() import matplotlib #.pyplot as plt @@ -169,38 +136,12 @@ plt.tight_layout() plt.savefig("{0}".format("fncell_plot.jpg")) ``` -You should get a plot that depicts the evolution of the voltage and recovery, -i.e., saved as `fncell_plot.jpg` locally to disk, like the one below: +You should get a plot that depicts the evolution of the voltage and recovery, i.e., saved as `fncell_plot.jpg` locally to disk, like the one below: -A useful note is that the `F-N` above used Euler integration to step through its -dynamics (this is the default/base routine for all cell components in ngc-learn); -however, one could configure it to use the midpoint method for integration -by setting its argument `integration_type = rk2` in cases where more -accuracy in the dynamics is needed (at the cost of additional computational time). - -## Optional: Setting Up The Components with a JSON Configuration - -While you are not required to create a JSON configuration file for ngc-learn, -to get rid of the warning that ngc-learn will throw at the start of your -program's execution (indicating that you do not have a configuration set up yet), -all you need to do is create a sub-directory for your JSON configuration -inside of your project code's directory, i.e., `json_files/modules.json`. -Inside the JSON file, you would write the following: - -```json -[ - {"absolute_path": "ngclearn.components", - "attributes": [ - {"name": "FitzHughNagumoCell"}] - }, - {"absolute_path": "ngcsimlib.operations", - "attributes": [ - {"name": "overwrite"}] - } -] -``` +A useful note is that the `F-N` above used Euler integration to step through its dynamics (this is the default/base routine for all cell components in ngc-learn); however, one could configure it to use the midpoint method for integration by setting its argument `integration_type = rk2` in cases where more accuracy in the dynamics is needed (at the cost of additional computational time). + ## References diff --git a/docs/tutorials/neurocog/hodgkin_huxley_cell.md b/docs/tutorials/neurocog/hodgkin_huxley_cell.md index d47f23bf..38228233 100755 --- a/docs/tutorials/neurocog/hodgkin_huxley_cell.md +++ b/docs/tutorials/neurocog/hodgkin_huxley_cell.md @@ -1,16 +1,12 @@ # Lecture 2E: The Hodgkin-Huxley Cell -In this tutorial, we will study/setup one of the most important biophysical -neuronal models in computational neuroscience -- the Hodgkin-Huxley (H-H) spiking -cell model. +In this tutorial, we will study/setup one of the most important and sophisticated biophysical neuronal models in computational neuroscience -- the Hodgkin-Huxley (H-H) spiking cell model. ## Using and Probing the H-H Cell -Go ahead and make a new folder for this study and create a Python script, -i.e., `run_hhcell.py`, to write your code for this part of the tutorial. +Go ahead and make a new folder for this study and create a Python script, i.e., `run_hhcell.py`, to write your code for this part of the tutorial. -Now let's set up the controller for this lesson's simulation and construct a -single component system made up of an H-H cell. +Now let's set up the controller for this lesson's simulation and construct a single component system made up of an H-H cell. ### Instantiating the H-H Neuronal Cell @@ -22,9 +18,7 @@ H-H cell amounts to the following: from jax import numpy as jnp, random, jit import numpy as np -from ngclearn.utils.model_utils import scanner -from ngcsimlib.context import Context -from ngclearn.utils import JaxProcess +from ngclearn import Context, MethodProcess ## import model-specific mechanisms from ngclearn.components.neurons.spiking.hodgkinHuxleyCell import HodgkinHuxleyCell @@ -52,18 +46,15 @@ with Context("Model") as model: ) ## create and compile core simulation commands - advance_process = (JaxProcess() + advance_process = (MethodProcess("advance") >> cell.advance_state) - model.wrap_and_add_command(jit(advance_process.pure), name="advance") - reset_process = (JaxProcess() + reset_process = (MethodProcess("reset") >> cell.reset) - model.wrap_and_add_command(jit(reset_process.pure), name="reset") - ## set up non-compiled utility commands - @Context.dynamicCommand - def clamp(x): - cell.j.set(x) +## set up non-compiled utility commands +def clamp(x): + cell.j.set(x) ``` Notably, the H-H model is a four-dimensional differential equation system, invented in 1952 @@ -88,15 +79,12 @@ $$ \frac{\partial \mathbf{h}_t}{\partial t} &= \alpha_h(\mathbf{v}_t) * (1 - \mathbf{h}_t) - \beta_h(\mathbf{v}_t) * \mathbf{h}_t $$ -where we observe that the above four-dimensional set of dynamics is composed of nonlinear ODEs. Notice that, in each gate or channel probability ODE, there are two generator functions (each of which is a function of the membrane potential $\mathbf{v}_t$) that produces the necessary dynamic coefficients at time $t$; $\alpha_x(\mathbf{v}_t)$ and $\beta_x(\mathbf{v}_t)$ produce different biopphysical weighting values depending on which channel $x = \{n, m, h\}$ they are related to. +where we observe that the above four-dimensional set of dynamics is composed of nonlinear ODEs. Notice that, in each gate or channel probability ODE, there are two generator functions (each of which is a function of the membrane potential $\mathbf{v}_t$) that produces the necessary dynamic coefficients at time $t$; $\alpha_x(\mathbf{v}_t)$ and $\beta_x(\mathbf{v}_t)$ produce different biophysical weighting values depending on which channel $x = \{n, m, h\}$ they are related to. Note that, in ngc-learn's implementation of the H-H cell model, most of the core coefficients have been generally set according to Hodgkin and Huxley's 1952 work but can be configured by the experimenter to obtain different kinds of behavior/dynamics. ### Simulating the H-H Neuronal Cell -To see how the H-H cell works, we next write some code for visualizing how -the node's membrane potential and core related gates/channels evolve with time -(over a period of about `200` milliseconds). We will inject a square input pulse current -into our H-H cell (specifically into its `j` compartment) and observe how the cell behaves in response. +To see how the H-H cell works, we next write some code for visualizing how the node's membrane potential and core related gates/channels evolve with time (over a period of about `200` milliseconds). We will inject a square input pulse current into our H-H cell (specifically into its `j` compartment) and observe how the cell behaves in response. Specifically, we simulate the injection of this kind of current via the code below: ```python @@ -112,17 +100,17 @@ v = [] n = [] m = [] h = [] -model.reset() +reset_process.run() for ts in range(x_seq.shape[1]): x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t - model.clamp(x_t) - model.run(t=ts * dt, dt=dt) - outs.append(a.s.value) - n.append(cell.n.value[0, 0]) - m.append(cell.m.value[0, 0]) - h.append(cell.h.value[0, 0]) - v.append(cell.v.value[0, 0]) - print(f"\r {ts} v = {cell.v.value}", end="") + clamp(x_t) + advance_process.run(t=ts * dt, dt=dt) + outs.append(cell.s.get()) + n.append(cell.n.get()[0, 0]) + m.append(cell.m.get()[0, 0]) + h.append(cell.h.get()[0, 0]) + v.append(cell.v.get()[0, 0]) + print(f"\r {ts} v = {cell.v.get()}", end="") time_span.append(ts*dt) outs = jnp.concatenate(outs, axis=1) v = jnp.array(v) @@ -130,8 +118,7 @@ time_span = jnp.array(time_span) outs = jnp.squeeze(outs) ``` -and we can plot the dynamics of the neuron's voltage `v` and its three gate/channel -variables, `h`, `m`, and `n`, with the following: +and we can plot the dynamics of the neuron's voltage `v` and its three gate/channel variables, `h`, `m`, and `n`, with the following: ```python import matplotlib.pyplot as plt @@ -161,9 +148,7 @@ plt.savefig("{0}".format("hh_plot.jpg")) plt.close() ``` -You should get a compound plot that depict the evolution of the H-H cell's voltage -and channel/gate variables, i.e., saved as `hh_plot.jpg` locally to -disk, like the one below: +You should get a compound plot that depict the evolution of the H-H cell's voltage and channel/gate variables, i.e., saved as `hh_plot.jpg` locally to disk, like the one below: ```{eval-rst} .. table:: @@ -176,38 +161,11 @@ disk, like the one below: +--------------------------------------------------------+ ``` -A useful note is that the H-H cell above used Euler integration to step through its -dynamics (this is the default/base routine for all cell components in ngc-learn). -However, one could configure the cell to use the midpoint method for integration -by setting its argument `integration_type = rk2` or the Runge-Kutta fourth-order -routine via `integration_type=rk4` for cases where, at the cost of increased -compute time, more accurate dynamics are possible. - -## Optional: Setting Up The Components with a JSON Configuration - -While you are not required to create a JSON configuration file for ngc-learn, -to get rid of the warning that ngc-learn will throw at the start of your -program's execution (indicating that you do not have a configuration set up yet), -all you need to do is create a sub-directory for your JSON configuration -inside of your project code's directory, i.e., `json_files/modules.json`. -Inside the JSON file, you would write the following: - -```json -[ - {"absolute_path": "ngclearn.components", - "attributes": [ - {"name": "HodgkinHuxleyCell"}] - }, - {"absolute_path": "ngcsimlib.operations", - "attributes": [ - {"name": "overwrite"}] - } -] -``` +A useful note is that the H-H cell above used Euler integration to step through its dynamics (this is the default/base routine for all cell components in ngc-learn). +However, one could configure the cell to use the midpoint method for integration by setting its argument `integration_type = rk2` or the Runge-Kutta fourth-order routine via `integration_type=rk4` for cases where, at the cost of increased compute time, more accurate dynamics are possible. + ## References -[1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description -of membrane current and its application to conduction and excitation in nerve." -The Journal of physiology 117.4 (1952): 500. +[1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of membrane current and its application to conduction and excitation in nerve." The Journal of physiology 117.4 (1952): 500. diff --git a/ngclearn/utils/distribution_generator.py b/ngclearn/utils/distribution_generator.py index 15845334..7d86d012 100644 --- a/ngclearn/utils/distribution_generator.py +++ b/ngclearn/utils/distribution_generator.py @@ -173,6 +173,10 @@ def fan_in_uniform( The values are sampled from a uniform distribution in the range [-limit, limit], where limit = sqrt(1 / fan_in), and fan_in is inferred from the shape. + | Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural + | networks." Proceedings of the thirteenth international conference on artificial intelligence and statistics. + | JMLR Workshop and Conference Proceedings, 2010. + Args: **params: extra distribution parameters @@ -233,6 +237,9 @@ def fan_in_gaussian( The values are sampled from a normal distribution with mean 0 and stddev = sqrt(1 / fan_in), where fan_in is inferred from the shape. + | He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet + | classification." Proceedings of the IEEE international conference on computer vision. 2015. + Args: **params: extra distribution parameters From 08c788a87b7a2ccc7a9c9909f5b03f7334c72f65 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Mon, 17 Nov 2025 15:57:11 -0500 Subject: [PATCH 065/121] revised integration and izh-cell neurocog docs --- docs/tutorials/neurocog/integration.md | 69 ++----------- docs/tutorials/neurocog/izhikevich_cell.md | 112 +++++---------------- 2 files changed, 32 insertions(+), 149 deletions(-) diff --git a/docs/tutorials/neurocog/integration.md b/docs/tutorials/neurocog/integration.md index a42dea7c..95320c6f 100644 --- a/docs/tutorials/neurocog/integration.md +++ b/docs/tutorials/neurocog/integration.md @@ -1,32 +1,14 @@ # Numerical Integration -In constructing one's own biophysical models, particularly those of phenomena -that change with time, ngc-learn offers useful flexible tools for numerical -integration that facilitate an easier time in constructing your own components -that play well with the library's simulation backend. Knowing how things work -beyond Euler integration -- the base/default form of integration often employed -by ngc-learn -- might be useful for constructing and simulating dynamics more -accurately (often at the cost of additional computational time). +In constructing one's own biophysical models, particularly those of phenomena that change with time, ngc-learn offers useful flexible tools for numerical integration that facilitate an easier time in constructing your own components that play well with the library's simulation backend. Knowing how things work beyond Euler integration -- the base/default form of integration often employed by ngc-learn -- might be useful for constructing and simulating dynamics more accurately (often at the cost of additional computational time). ## Euler Integration -Euler integration is very simple (and fast) way of using the ordinary differential -equations you typically define for the cellular dynamics of various components -in ngc-learn (which typically get called in any component's `AdvanceState()` -command). +Euler integration is very simple (and fast) way of using the ordinary differential equations you typically define for the cellular dynamics of various components in ngc-learn (which typically get called in any component's `advance_state()` command). -While utilizing the numerical integrator will depend on your component's design -and the (biophysical) elements you wish to model, let's observe ngc-learn's -base backend utilities (its integration backend `ngclearn.utils.diffeq`) in -the context of numerically integrating a simple -differential equation; specifically the autonomous (linear) ordinary differential equation (ODE): -$\frac{\partial y(t)}{\partial t} = y(t)$. The analytic -solution to this equation is also simple -- it is $y(t) = e^{t}$. +While utilizing the numerical integrator will depend on your component's design and the (biophysical) elements you wish to model, let's observe ngc-learn's base backend utilities (its integration backend `ngclearn.utils.diffeq`) in the context of numerically integrating a simple differential equation; specifically the autonomous (linear) ordinary differential equation (ODE): $\frac{\partial y(t)}{\partial t} = y(t)$. The analytic solution to this equation is also simple -- it is $y(t) = e^{t}$. -If you have defined your differential equation $\frac{\partial y(t)}{\partial t}$ -in a rather simple format[^1], you can write the following code to examine how -Euler integration approximates the analytical solution (in this example, we -examine just two different step sizes, i.e., `dt = 0.1` and `dt = 0.09`) +If you have defined your differential equation $\frac{\partial y(t)}{\partial t}$ in a rather simple format[^1], you can write the following code to examine how Euler integration approximates the analytical solution (in this example, we examine just two different step sizes, i.e., `dt = 0.1` and `dt = 0.09`) ```python from jax import numpy as jnp, random, jit, nn @@ -89,41 +71,13 @@ which should yield you a plot like the one below: -Notice how the integration constant `dt` (or $\Delta t$) chosen affects the approximation of ngc-learn's -Euler integrator and typically, when constructing your biophysical models, you -will need to think about this constant in the context of your simulation time-scale -and what you intend to model. Note that, in many biophysical component cells, -you will have an integration time constant of some form, i.e., a $\tau$, that you -can control, allowing you to fix your `dt` to your simulated time-scale -(say to a value like `dt = 1` millisecond) while tuning/altering your -time constant $\tau$ (since the differential equation will be weighted -by $\frac{\Delta t}{\tau}$). +Notice how the integration constant `dt` (or $\Delta t$) chosen affects the approximation of ngc-learn's Euler integrator and typically, when constructing your biophysical models, you will need to think about this constant in the context of your simulation time-scale and what you intend to model. Note that, in many biophysical component cells, you will have an integration time constant of some form, i.e., a $\tau$, that you can control, allowing you to fix your `dt` to your simulated time-scale (say to a value like `dt = 1` millisecond) while tuning/altering your time constant $\tau$ (since the differential equation will be weighted by $\frac{\Delta t}{\tau}$). ## Higher-Order Forms of (Explicit) Integration -Notably, ngc-learn has built-in several forms of (explicit) numerical integration beyond -the Euler method, such as a second order Runge-Kutta (RK-2) method (also known as -the midpoint method) and 4th-order Runge-Kutta (RK-4) method or an error-predictor method such as Heun's method -(also known as the trapezoid method). These forms of integration might be useful particularly -if a cell or plastic synaptic component you might be writing follows dynamics -that are more nonlinear or biophysically complex (requiring a higher degree -of simulation accuracy). For instance, ngc-learn's in-built cell components, -particularly those of higher biophysical complexity -- like the -[Izhikevich cell](ngclearn.components.neurons.spiking.izhikevichCell) or the -[FitzhughNagumo cell](ngclearn.components.neurons.spiking.fitzhughNagumoCell) -- -contain argument flags for switching their simulation steps to use RK-2. - -To illustrate the value of higher-order numerical integration methods, let us -examine a simple polynomial equation (thus nonlinear) that is further -non-autonomous, i.e., it is a function of the time variable $t$ itself. A -possible set of dynamics in this case might be: -$\frac{\partial y(t)}{\partial t} = -2 t^3 + 12 t^2 - 20 t + 8.5$ which -has the analytic solution $y(t) = -(1/2) t^4 + 4 t^3 - 10 t^2 + 8.5 t + C$ ( -where we will set $C = 1$). You can write code like below, importing from -`ngclearn.utils.diffeq.ode_utils` the Euler routine (`step_euler`), -the RK-2 routine (`step_rk2`), the RK-4 routine (`step_rk4`), and Heun's method (`step_heun`), and compare -how these methods approximate the nonlinear dynamics inherent to our -constructed $\frac{\partial y(t)}{\partial t}$ ODE below: +Notably, ngc-learn has built-in several forms of (explicit) numerical integration beyond the Euler method, such as a second order Runge-Kutta (RK-2) method (also known as the midpoint method) and 4th-order Runge-Kutta (RK-4) method or an error-predictor method such as Heun's method (also known as the trapezoid method). These forms of integration might be useful particularly if a cell or plastic synaptic component you might be writing follows dynamics that are more nonlinear or biophysically complex (requiring a higher degree of simulation accuracy). For instance, ngc-learn's in-built cell components, particularly those of higher biophysical complexity -- like the [Izhikevich cell](ngclearn.components.neurons.spiking.izhikevichCell) or the [FitzhughNagumo cell](ngclearn.components.neurons.spiking.fitzhughNagumoCell) -- contain argument flags for switching their simulation steps to use RK-2. + +To illustrate the value of higher-order numerical integration methods, let us examine a simple polynomial equation (thus nonlinear) that is further non-autonomous, i.e., it is a function of the time variable $t$ itself. A possible set of dynamics in this case might be: $\frac{\partial y(t)}{\partial t} = -2 t^3 + 12 t^2 - 20 t + 8.5$ which has the analytic solution $y(t) = -(1/2) t^4 + 4 t^3 - 10 t^2 + 8.5 t + C$ (where we will set $C = 1$). You can write code like below, importing from `ngclearn.utils.diffeq.ode_utils` the Euler routine (`step_euler`), the RK-2 routine (`step_rk2`), the RK-4 routine (`step_rk4`), and Heun's method (`step_heun`), and compare how these methods approximate the nonlinear dynamics inherent to our constructed $\frac{\partial y(t)}{\partial t}$ ODE below: ```python from jax import numpy as jnp, random, jit, nn @@ -194,12 +148,7 @@ which should yield you a plot like the one below: -As you might observe, RK-4 give the best approximation of the solution. In addition, -when the integration step size is held constant, Euler integration -does quite poorly over just a few steps while RK-2 and Heun's method do much better -at approximating the analytical equation. In the end, the type of numerical integration method employed can -matter depending on the ODE(s) you use in modeling, particularly if you seek higher accuracy -for more nonlinear dynamics like in our example above. +As you might observe, RK-4 give the best approximation of the solution. In addition, when the integration step size is held constant, Euler integration does quite poorly over just a few steps while RK-2 and Heun's method do much better at approximating the analytical equation. In the end, the type of numerical integration method employed can matter depending on the ODE(s) you use in modeling, particularly if you seek higher accuracy for more nonlinear dynamics like in our example above. [^1]: The format expected by ngc-learn's backend is that the differential equation provides a functional API/form like so: for instance `dy/dt = diff_eqn(t, y(t), params)`, diff --git a/docs/tutorials/neurocog/izhikevich_cell.md b/docs/tutorials/neurocog/izhikevich_cell.md index 6d1449a6..f6b20f60 100644 --- a/docs/tutorials/neurocog/izhikevich_cell.md +++ b/docs/tutorials/neurocog/izhikevich_cell.md @@ -19,8 +19,7 @@ single component system made up of the Izhikevich (`IZH`) cell. from jax import numpy as jnp, random, jit import numpy as np -from ngcsimlib.context import Context -from ngclearn.utils import JaxProcess +from ngclearn import Context, MethodProcess ## import model-specific mechanisms from ngclearn.components.neurons.spiking.izhikevichCell import IzhikevichCell @@ -44,38 +43,19 @@ with Context("Model") as model: integration_type="euler", v0=v0, w0=w0, key=subkeys[0]) ## create and compile core simulation commands - advance_process = (JaxProcess() + advance_process = (MethodProcess("advance") >> cell.advance_state) - model.wrap_and_add_command(jit(advance_process.pure), name="advance") - reset_process = (JaxProcess() + reset_process = (MethodProcess("reset") >> cell.reset) - model.wrap_and_add_command(jit(reset_process.pure), name="reset") - ## set up non-compiled utility commands - @Context.dynamicCommand - def clamp(x): - cell.j.set(x) +## set up non-compiled utility commands +def clamp(x): + cell.j.set(x) ``` -The Izhikevich `IZH`, much like the FitzHugh–Nagumo cell covered in -[a different lesson](../neurocog/fitzhugh_nagumo_cell.md), is a two-dimensional -differential equation system (developed in [1]) that attempts to (approximately) -model spiking cellular activation and deactivation dynamics. Notably, the `IZH` -cell models membrane potential `v` (using a squared term) jointly with a -recovery variable `w` (which is meant to provide a slower form of negative feedback). -In his model, Izhikevich introduced four important control factors/coefficients, -the choices of values for each changes the behavior of the neuronal model and -thus recovering dynamics of different classes of neurons found in the brain. -Several of these control factors have been renamed and/or mapped to more -explicit descriptors in ngc-learn (for example, Izhikevich's original factor -`a` has been mapped to `a = 1/tau_w` allowing the user to define the time -constant for the recovery variable much in the same manner as the -FitzHugh–Nagumo cell). Also like the FitzHugh–Nagumo cell, the Izhikevich model -contains configurable initial conditions for its voltage (i.e., `v0`) and -recovery values (i.e., `w0`), which we see have been set to interesting values -for the purposes of this lesson (these are actually the default values of -the Izhikevich component, i.e., `v0=-65` and `w0=-14`). +The Izhikevich `IZH`, much like the FitzHugh–Nagumo cell covered in [a different lesson](../neurocog/fitzhugh_nagumo_cell.md), is a two-dimensional differential equation system (developed in [1]) that attempts to (approximately) model spiking cellular activation and deactivation dynamics. Notably, the `IZH` cell models membrane potential `v` (using a squared term) jointly with a recovery variable `w` (which is meant to provide a slower form of negative feedback). +In his model, Izhikevich introduced four important control factors/coefficients, the choices of values for each will change the behavior of the neuronal model and thus recovering dynamics of different classes of neurons found in the brain. Several of these control factors have been renamed and/or mapped to more explicit descriptors in ngc-learn (for example, Izhikevich's original factor `a` has been mapped to `a = 1/tau_w` allowing the user to define the time constant for the recovery variable much in the same manner as the FitzHugh–Nagumo cell). Also like the FitzHugh–Nagumo cell, the Izhikevich model contains configurable initial conditions for its voltage (i.e., `v0`) and recovery values (i.e., `w0`), which we see have been set to interesting values for the purposes of this lesson (these are actually the default values of the Izhikevich component, i.e., `v0=-65` and `w0=-14`). Formally, the core dynamics of the `IZH` can be written out as follows: @@ -84,22 +64,12 @@ $$ \tau_w \frac{\partial \mathbf{w}_t}{\partial t} &= b \mathbf{v}_t - \mathbf{w}_t $$ -where $b$ is the coupling factor, $R$ is the membrane resistance, $\tau_m$ is the -membrane time constant, and $\tau_w$ is the recovery time constant (technically, -$\tau_m = 1$, $R = 1$, and $\tau_w = 1/a$ to get to the perspective originally -put forth in [1]). +where $b$ is the coupling factor, $R$ is the membrane resistance, $\tau_m$ is the membrane time constant, and $\tau_w$ is the recovery time constant (technically, $\tau_m = 1$, $R = 1$, and $\tau_w = 1/a$ to get to the perspective originally put forth in [1]). ### Simulating a Izhikevich Neuronal Cell -Given the single-cell dynamical system we set up above, we finally write -some code that uses and visualizes the flow of the `IZH` cell's membrane -potential and coupled recovery variable (specifically over a period of about -`200` milliseconds). We will, much as we did with the leaky integrators in -prior tutorials, inject an electrical current `j` into the `IZH` cell -- this -time with a constant current value of `10` amperes -- and observe how the cell -produces action potentials. -Specifically, we can plot the `IZH` neuron's voltage `v` and recovery variable `w` -in the following manner: +Given the single-cell dynamical system we set up above, we finally write some code that uses and visualizes the flow of the `IZH` cell's membrane potential and coupled recovery variable (specifically over a period of about `200` milliseconds). We will, much as we did with the leaky integrators in prior tutorials, inject an electrical current `j` into the `IZH` cell -- this time with a constant current value of `10` amperes -- and observe how the cell produces action potentials. +Specifically, we can plot the `IZH` neuron's voltage `v` and recovery variable `w` in the following manner: ```python curr_in = [] @@ -114,19 +84,19 @@ i_app = 10. # 0.23 ## electrical current to inject into F-N cell data = jnp.asarray([[i_app]], dtype=jnp.float32) time_span = [] -model.reset() +reset_process.run() t = 0. for ts in range(T): x_t = data ## pass in t and dt and run step forward of simulation - model.clamp(x_t) - model.advance(t=t, dt=dt) + clamp(x_t) + advance_process.run(t=t, dt=dt) t = t + dt ## naively extract simple statistics at time ts and print them to I/O - v = cell.v.value - w = cell.w.value - s = cell.s.value + v = cell.v.get() + w = cell.w.get() + s = cell.s.get() curr_in.append(data) mem_rec.append(v) recov_rec.append(w) @@ -164,22 +134,12 @@ plt.tight_layout() plt.savefig("{0}".format("izhcell_plot.jpg")) ``` -You should get a plot that depicts the evolution of the voltage and recovery of -the Izhikevich cell, i.e., saved as `izhcell_plot.jpg` locally to disk, much -like the one below: +You should get a plot that depicts the evolution of the voltage and recovery of the Izhikevich cell, i.e., saved as `izhcell_plot.jpg` locally to disk, much like the one below: -The plot above, which you can modify slightly yourself to include the neuronal -type tag "RS" like we do, actually depicts the dynamics for a specific type of spiking -neuron called the "regular spiking" (RS) neuron (also the default configuration -for ngc-learn's neuronal cell implementation), which is only one of several -kinds of neurons you can emulate with Izhikevich's dynamics implemented in -ngc-learn. Try modifying the exposed Izhikevich cell hyper-parameters above -and setting them to particular values (such as those noted in the -component's documentation) to recreate other possible neuron types. For -example, to obtain a "fast spiking" (FS) neuronal cell, all you would need to -do is modify the recovery variable's time constant like so: +The plot above, which you can modify slightly yourself to include the neuronal type tag "RS" like we do, actually depicts the dynamics for a specific type of spiking neuron called the "regular spiking" (RS) neuron (also the default configuration for ngc-learn's neuronal cell implementation), which is only one of several kinds of neurons you can emulate with Izhikevich's dynamics implemented in +ngc-learn. Try modifying the exposed Izhikevich cell hyper-parameters above and setting them to particular values (such as those noted in the component's documentation) to recreate other possible neuron types. For example, to obtain a "fast spiking" (FS) neuronal cell, all you would need to do is modify the recovery variable's time constant like so: ```python ## FS cell configuration values @@ -189,15 +149,11 @@ w_reset = 8. ## ngc-learn default coupling_factor = 0.2 ## ngc-learn default ``` -to obtain a voltage/recovery dynamics plot like so (if you also modify the -plot title of the plotting code accordingly): +to obtain a voltage/recovery dynamics plot like so (if you also modify the plot title of the plotting code accordingly): -Three other well-known classes of neural behaviors are possible to easily simulate -under the following hyper-parameter configurations (which produce the array -of three plots similar to those shown near the bottom of this lesson), -by simplifying modifying hyper-parameters according to the following: +Three other well-known classes of neural behaviors are possible to easily simulate under the following hyper-parameter configurations (which produce the array of three plots similar to those shown near the bottom of this lesson), by simplifying modifying hyper-parameters according to the following: 1. Chattering (CH) neurons: ```python @@ -222,8 +178,7 @@ w_reset = 2. coupling_factor = 0.25 ``` -The above three hyper-parameter settings produce, from top-to-bottom, the -plots shown below (from left-to-right): +The above three hyper-parameter settings produce, from top-to-bottom, the plots shown below (from left-to-right): ```{eval-rst} @@ -237,27 +192,6 @@ plots shown below (from left-to-right): +-------------------------------------------------------+-------------------------------------------------------+--------------------------------------------------------+ ``` -## Optional: Setting Up The Components with a JSON Configuration - -While you are not required to create a JSON configuration file for ngc-learn, -to get rid of the warning that ngc-learn will throw at the start of your -program's execution (indicating that you do not have a configuration set up yet), -all you need to do is create a sub-directory for your JSON configuration -inside of your project code's directory, i.e., `json_files/modules.json`. -Inside the JSON file, you would write the following: - -```json -[ - {"absolute_path": "ngclearn.components", - "attributes": [ - {"name": "IzhikevichCell"}] - }, - {"absolute_path": "ngcsimlib.operations", - "attributes": [ - {"name": "overwrite"}] - } -] -``` ## References From 13cfbd476cbf26be60698f00ca760ed037f161e6 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Mon, 17 Nov 2025 20:33:43 -0500 Subject: [PATCH 066/121] revised izh-cell, cleaned-up fn-cell, and revised lif neurocog docs --- .../neurocog/fitzhugh_nagumo_cell.md | 5 +- docs/tutorials/neurocog/izhikevich_cell.md | 12 ++- docs/tutorials/neurocog/lif.md | 97 ++++++------------- 3 files changed, 37 insertions(+), 77 deletions(-) diff --git a/docs/tutorials/neurocog/fitzhugh_nagumo_cell.md b/docs/tutorials/neurocog/fitzhugh_nagumo_cell.md index 832b6ab1..5faed24b 100644 --- a/docs/tutorials/neurocog/fitzhugh_nagumo_cell.md +++ b/docs/tutorials/neurocog/fitzhugh_nagumo_cell.md @@ -35,8 +35,9 @@ w0 = -0.16983366 ## initial recovery value (for reset condition) ## create simple system with only one F-N cell with Context("Model") as model: - cell = FitzhughNagumoCell("z0", n_units=1, tau_w=tau_w, alpha=alpha, beta=beta, - gamma=gamma, v0=v0, w0=w0, integration_type="euler") + cell = FitzhughNagumoCell( + "z0", n_units=1, tau_w=tau_w, alpha=alpha, beta=beta, gamma=gamma, v0=v0, w0=w0, integration_type="euler" + ) ## create and compile core simulation commands advance_process = (MethodProcess("advance") diff --git a/docs/tutorials/neurocog/izhikevich_cell.md b/docs/tutorials/neurocog/izhikevich_cell.md index f6b20f60..bdbdc742 100644 --- a/docs/tutorials/neurocog/izhikevich_cell.md +++ b/docs/tutorials/neurocog/izhikevich_cell.md @@ -38,9 +38,10 @@ coupling_factor = 0.2 ## create simple system with only one Izh Cell with Context("Model") as model: - cell = IzhikevichCell("z0", n_units=1, tau_w=tau_w, v_reset=v_reset, - w_reset=w_reset, coupling_factor=coupling_factor, - integration_type="euler", v0=v0, w0=w0, key=subkeys[0]) + cell = IzhikevichCell( + "z0", n_units=1, tau_w=tau_w, v_reset=v_reset, w_reset=w_reset, coupling_factor=coupling_factor, + integration_type="euler", v0=v0, w0=w0, key=subkeys[0] + ) ## create and compile core simulation commands advance_process = (MethodProcess("advance") @@ -123,8 +124,9 @@ n_plots = 1 fig, ax = plt.subplots(1, n_plots, figsize=(5*n_plots,5)) ax_ptr = ax -ax_ptr.set(xlabel='Time', ylabel='Voltage (v), Recovery (w)', - title="Izhikevich (RS) Voltage/Recovery Dynamics") +ax_ptr.set( + xlabel='Time', ylabel='Voltage (v), Recovery (w)', title=f"Izhikevich ({cell_tag}) Voltage/Recovery Dynamics" +) v = ax_ptr.plot(time_span, mem_rec, color='C0') w = ax_ptr.plot(time_span, recov_rec, color='C1', alpha=.5) diff --git a/docs/tutorials/neurocog/lif.md b/docs/tutorials/neurocog/lif.md index 48485da9..82a08030 100755 --- a/docs/tutorials/neurocog/lif.md +++ b/docs/tutorials/neurocog/lif.md @@ -1,31 +1,15 @@ # Lecture 2B: The Leaky Integrate-and-Fire Cell -The leaky integrate-and-fire (LIF) cell component in ngc-learn is a stepping -stone towards working with more biophysical intricate cell components when crafting -your neuronal circuit models. This -[cell](ngclearn.components.neurons.spiking.LIFCell) is markedly different from the -[simplified LIF](ngclearn.components.neurons.spiking.sLIFCell) in both its -implemented dynamics as well as what modeling routines that it offers, including -the fact that it does not offer implicit fixed lateral inhibition like the -`SLIF` does (one would need to explicitly model the lateral inhibition as a -separate population of `LIF` cells, as we do in the -[Diehl and Cook model museum spiking network](../../museum/snn_dc.md)). Furthermore, -using this neuronal cell is a useful transition to using the more complicated and -biophysically more accurate neuronal models such as the -[adaptive exponential integrator cell](ngclearn.components.neurons.spiking.adExCell) -or the -[Izhikevich cell](ngclearn.components.neurons.spiking.izhikevichCell). +The leaky integrate-and-fire (LIF) cell component in ngc-learn is a stepping stone towards working with more biophysical intricate cell components when crafting your neuronal circuit models. This [cell](ngclearn.components.neurons.spiking.LIFCell) is markedly different from the [simplified LIF](ngclearn.components.neurons.spiking.sLIFCell) in both its implemented dynamics as well as what modeling routines that it offers, including the fact that it does not offer implicit fixed lateral inhibition like the `SLIF` does (one would need to explicitly model the lateral inhibition as a separate population of `LIF` cells, as we do in the [Diehl and Cook model museum spiking network](../../museum/snn_dc.md)). Furthermore, using this neuronal cell is a useful transition to using the more complicated and biophysically more accurate neuronal models such as the [adaptive exponential integrator cell](ngclearn.components.neurons.spiking.adExCell) or the [Izhikevich cell](ngclearn.components.neurons.spiking.izhikevichCell). ## Instantiating the LIF Neuronal Cell -To implement a single-component dynamical system made up of a single LIF -cell, you would write code akin to the following: +To implement a single-component dynamical system made up of a single LIF cell, you would write code akin to the following: ```python from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context -from ngclearn.utils import JaxProcess +from ngclearn import Context, MethodProcess ## import model-specific mechanisms from ngclearn.components.neurons.spiking.LIFCell import LIFCell from ngclearn.utils.viz.spike_plot import plot_spiking_neuron @@ -42,35 +26,27 @@ tau_m = 100. ## create simple system with only one AdEx with Context("Model") as model: - cell = LIFCell("z0", n_units=1, tau_m=tau_m, resist_m=tau_m/dt, thr=V_thr, - v_rest=V_rest, v_reset=-60., tau_theta=300., theta_plus=0.05, - refract_time=2., key=subkeys[0]) + cell = LIFCell( + "z0", n_units=1, tau_m=tau_m, resist_m=tau_m/dt, thr=V_thr, v_rest=V_rest, v_reset=-60., tau_theta=300., + theta_plus=0.05, refract_time=2., key=subkeys[0] + ) ## create and compile core simulation commands - advance_process = (JaxProcess() + advance_process = (MethodProcess("advance") >> cell.advance_state) - model.wrap_and_add_command(jit(advance_process.pure), name="advance") - reset_process = (JaxProcess() + reset_process = (MethodProcess("reset") >> cell.reset) - model.wrap_and_add_command(jit(reset_process.pure), name="reset") - ## set up non-compiled utility commands - @Context.dynamicCommand - def clamp(x): - cell.j.set(x) +## set up non-compiled utility commands +def clamp(x): + cell.j.set(x) ``` ## Simulating the LIF on Stepped Constant Electrical Current -Given our single-LIF dynamical system above, let us write some code to use -our `LIF` node and visualize the resultant spiking pattern super-imposed -over its membrane (voltage) potential by feeding -into it a step current, where the electrical current `j` starts at $0$ then -switches to $0.3$ at $t = 10$ ms (much as we did for the `SLIF` component -in the previous lesson). We craft the simulation portion of our code like so: - +Given our single-LIF dynamical system above, let us write some code to use our `LIF` node and visualize the resultant spiking pattern super-imposed over its membrane (voltage) potential by feeding into it a step current, where the electrical current `j` starts at $0$ then switches to $0.3$ at $t = 10$ ms (much as we did for the `SLIF` component in the previous lesson). We craft the simulation portion of our code like so: ```python # create a synthetic electrical step current @@ -80,14 +56,14 @@ curr_in = [] mem_rec = [] spk_rec = [] -model.reset() +reset_process.run() for ts in range(current.shape[1]): j_t = jnp.expand_dims(current[0,ts], axis=0) ## get data at time ts - model.clamp(j_t) - model.advance(t=ts*1., dt=dt) + clamp(j_t) + advance_process.run(t=ts*1., dt=dt) ## naively extract simple statistics at time ts and print them to I/O - v = cell.v.value - s = cell.s.value + v = cell.v.get() + s = cell.s.get() curr_in.append(j_t) mem_rec.append(v) spk_rec.append(s) @@ -95,43 +71,24 @@ for ts in range(current.shape[1]): print() ``` -Then, we can plot the input current, the neuron's voltage `v`, and its output -spikes as follows: +Then, we can plot the input current, the neuron's voltage `v`, and its output spikes as follows: ```python import numpy as np curr_in = np.squeeze(np.asarray(curr_in)) mem_rec = np.squeeze(np.asarray(mem_rec)) spk_rec = np.squeeze(np.asarray(spk_rec)) -plot_spiking_neuron(curr_in, mem_rec, spk_rec, None, dt, thr_line=V_thr, min_mem_val=V_rest-1., - max_mem_val=V_thr+2., spike_loc=V_thr, spike_spr=0.5, title="LIF-Node: Constant Electrical Input", fname="lif_plot.jpg") +plot_spiking_neuron( + curr_in, mem_rec, spk_rec, None, dt, thr_line=V_thr, min_mem_val=V_rest-1., max_mem_val=V_thr+2., spike_loc=V_thr, + spike_spr=0.5, title="LIF-Node: Constant Electrical Input", fname="lif_plot.jpg" +) ``` which should produce the following plot (saved to disk): -As we might observe, the LIF operates very differently from the SLIF, notably -that its dynamics live in the different space of values (one aspect of the -SLIF is that its dynamics are effectively normalized/configured to live -a non-negative membrane potential number space), specifically values that -are a bit better aligned with those observed in experimental neuroscience. -While more biophysically more accurate, the `LIF` typically involves consideration -of multiple additional hyper-parameters/simulation coefficients, including -the resting membrane potential value `v_rest` and the reset membrane value -`v_reset` (upon occurrence of a spike/emitted action potential); the `SLIF`, -in contrast, assumed a `v_reset = v_reset = 0.`. Note that the `LIF`'s -`tau_theta` and `theta_plus` coefficients govern its particular adaptive threshold, -which is a particular increment variable (one per cell in the `LIF` component) -that gets adjusted according to its own dynamics and added to the fixed constant -threshold `thr`, i.e., the threshold that a cell's membrane potential must -exceed for a spike to be emitted. - -The `LIF` cell component is particularly useful when more flexibility is required/ -desired in setting up neuronal dynamics, particularly when attempting to match -various mathematical models that have been proposed in computational neuroscience. -This benefit comes at the greater cost of additional tuning and experimental planning, -whereas the `SLIF` can be a useful go-to initial spiking cell for building certain spiking -models such as those proposed in machine intelligence research (we demonstrate -one such use-case in the context of the -[feedback alignment-trained spiking network](../../museum/snn_bfa.md) that we offer in the model museum). +As we might observe, the LIF operates very differently from the SLIF, notably that its dynamics live in the different space of values (one aspect of the SLIF is that its dynamics are effectively normalized/configured to live a non-negative membrane potential number space), specifically values that are a bit better aligned with those observed in experimental neuroscience. While more biophysically more accurate, the `LIF` typically involves consideration of multiple additional hyper-parameters/simulation coefficients, including +the resting membrane potential value `v_rest` and the reset membrane value `v_reset` (upon occurrence of a spike/emitted action potential); the `SLIF`, in contrast, assumed a `v_reset = v_reset = 0.`. Note that the `LIF`'s `tau_theta` and `theta_plus` coefficients govern its particular adaptive threshold, which is a particular increment variable (one per cell in the `LIF` component) that gets adjusted according to its own dynamics and added to the fixed constant threshold `thr`, i.e., the threshold that a cell's membrane potential must exceed for a spike to be emitted. + +The `LIF` cell component is particularly useful when more flexibility is required/desired in setting up neuronal dynamics, particularly when attempting to match various mathematical models that have been proposed in computational neuroscience. This benefit comes at the greater cost of additional tuning and experimental planning, whereas the `SLIF` can be a useful go-to initial spiking cell for building certain spiking models such as those proposed in machine intelligence research (we demonstrate one such use-case in the context of the [feedback alignment-trained spiking network](../../museum/snn_bfa.md) that we offer in the model museum). From f52c86aa4520963c4937f5d4faeaf84d99ede97d Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Tue, 18 Nov 2025 12:26:08 -0500 Subject: [PATCH 067/121] revised metrics/plotting neurocog docs --- docs/tutorials/neurocog/metrics.md | 42 +++++++---------------------- docs/tutorials/neurocog/plotting.md | 24 ++++------------- 2 files changed, 14 insertions(+), 52 deletions(-) diff --git a/docs/tutorials/neurocog/metrics.md b/docs/tutorials/neurocog/metrics.md index aea77da6..e872e23c 100644 --- a/docs/tutorials/neurocog/metrics.md +++ b/docs/tutorials/neurocog/metrics.md @@ -1,26 +1,11 @@ # Metrics and Measurement Functions -Inside of `ngclearn.utils.metric_utils`, ngc-learn offers metrics and measurement -utility functions that can be quite useful when building neurocognitive models using -ngc-learn's node-and-cables system for specific tasks. While this utilities -sub-module will not always contain every possible function you might need, -given that measurements are often dependent on the task the experimenter wants -to conduct, there are several commonly-used ones drawn from machine intelligence -and computational neuroscience that are (jit-i-fied) in-built to ngc-learn you -can readily use. -In this small lesson, we will briefly examine two examples of importing such -functions and examine what they do. +Inside of `ngclearn.utils.metric_utils`, ngc-learn offers metrics and measurement utility functions that can be quite useful when building neurocognitive models using ngc-learn's node-and-cables system for specific tasks. While this utilities sub-module will not always contain every possible function you might need, given that measurements are often dependent on the task the experimenter wants to conduct, there are several commonly-used ones drawn from machine intelligence and computational neuroscience that are (jit-i-fied) in-built to ngc-learn you can readily use. +In this small lesson, we will briefly examine two examples of importing such functions and examine what they do. ## Measuring Task-Level Quantities -For many tasks that you might be interested in, a useful measurement -is the performance of the model in some supervised learning context. For example, -you might want to measure a model's accuracy on a classification task. To do so, -assuming we have some model outputs extracted from a model that you have constructed -elsewhere -- say a matrix of scores `Y_scores` -- and a target set of predictions -that you are testing against -- such as `Y_labels` (in one-hot binary encoded form ) --- then you can write some code to compute the accuracy, mean squared error (MSE), -and categorical log likelihood (Cat-NLL), like so: +For many tasks that you might be interested in, a useful measurement is the performance of the model in some supervised learning context. For example, you might want to measure a model's accuracy on a classification task. To do so, assuming we have some model outputs extracted from a model that you have constructed elsewhere -- say a matrix of scores `Y_scores` -- and a target set of predictions that you are testing against -- such as `Y_labels` (in one-hot binary encoded form ) -- then you can write some code to compute the accuracy, mean squared error (MSE), and categorical log likelihood (Cat-NLL), like so: ```python from jax import numpy as jnp @@ -55,24 +40,18 @@ and you should obtain the following in I/O like so: > Cat-NLL = 4.003 ``` -Notice that we imported the utility function `softmax` from -`ngclearn.utils.model_utils` to convert our raw theoretical model scores to -probability values so that using `measure_CatNLL()` makes sense (as this -assumes the model scores are normalized probability values). +Notice that we imported the utility function `softmax` from `ngclearn.utils.model_utils` to convert our raw theoretical model scores to +probability values so that using `measure_CatNLL()` makes sense (as this assumes the model scores are normalized probability values). ## Measuring Some Model Statistics -In some cases, you might be interested in measuring certain statistics -related to aspects of a model that you construct. For example, you might have -collected a (binary) spike train produced by one of the internal neuronal layers -of your ngc-learn-simulated spiking neural network and want to compute the -firing rates and Fano factors associated with each neuron. Doing so with -ngc-learn utility functions would entail writing something like: +In some cases, you might be interested in measuring certain statistics related to properties of a model that you construct. For example, you might have collected a (binary) spike train produced by one of the internal neuronal layers of your ngc-learn-simulated spiking neural network and want to compute the firing rates and Fano factors associated with each neuron. Doing so with ngc-learn utility functions would entail writing something like: ```python from jax import numpy as jnp from ngclearn.utils.metric_utils import measure_fanoFactor, measure_firingRate +## let's create a fake synthetic spike train for 3 neurons (one per column) spikes = jnp.asarray([[0., 0., 0.], [0., 0., 1.], [0., 1., 0.], @@ -92,6 +71,7 @@ spikes = jnp.asarray([[0., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=jnp.float32) +## measure the firing rates and Fano factors of the 3 neurons fr = measure_firingRate(spikes, preserve_batch=True) fano = measure_fanoFactor(spikes, preserve_batch=True) @@ -106,8 +86,4 @@ which should result in the following to be printed to I/O: > Fano Factor = [[0.8888888 0.77777773 0.55555546]] ``` -The Fano factor is a useful secondary statistic for characterizing the -variable of a neuronal spike train -- as we see in the measurement above, -the first and second neurons have a higher Fano factor (given they are -more irregular in their spiking patterns) whereas the third neuron is far more -regular in its spiking pattern and thus has a lower Fano factor. +The Fano factor is a useful secondary statistic for characterizing the variability of a neuronal spike train -- as we see in the measurement above, the first and second neurons have a higher Fano factor (given they are more irregular in their spiking patterns) whereas the third neuron is far more regular in its spiking pattern and thus has a lower Fano factor. diff --git a/docs/tutorials/neurocog/plotting.md b/docs/tutorials/neurocog/plotting.md index b105b06f..f48845b2 100644 --- a/docs/tutorials/neurocog/plotting.md +++ b/docs/tutorials/neurocog/plotting.md @@ -1,24 +1,12 @@ # Plotting and Visualization -While writing one's own custom task-specific matplotlib visualization code -might be needed for specific experimental setups, there are several useful tools -already in-built to ngc-learn, organized under the package sub-directory -`ngclearn.utils.viz`, including utilities for generating raster plots and -synaptic receptive field views (useful for biophysical models such as spiking -neural networks) as well as t-SNE plots of model latent codes. While the other -lesson/tutorials demonstrate some of these useful routines (e.g., raster plots -for spiking neuronal cells), in this small lesson, we will demonstrate how to -produce a t-SNE plot using ngc-learn's in-built tool. +While writing one's own custom task-specific matplotlib visualization code might be needed for specific experimental setups, there are several useful tools already in-built to ngc-learn, organized under the package sub-directory `ngclearn.utils.viz`, including utilities for generating raster plots and synaptic receptive field views (useful for biophysical models such as spiking neural networks) as well as t-SNE plots of model latent codes. While the other lesson/tutorials demonstrate some of these useful routines (e.g., raster plots for spiking neuronal cells), in this small lesson, we will demonstrate how to produce a t-SNE plot using ngc-learn's in-built tool. ## Generating a t-SNE Plot -Let's say you have a labeled five-dimensional (5D) dataset -- which we will -synthesize artificially in this lesson from an "unobserved" trio of multivariate -Gaussians -- and wanted to visualize these "model outputs" and their -corresponding labels in 2D via ngc-learn's in-built t-SNE. +Let's say you have a labeled five-dimensional (5D) dataset -- which we will artificially synthesize in this lesson from an "unobserved" trio of multivariate Gaussians -- and that you wanted to visualize these "model outputs" and their corresponding labels in 2D via ngc-learn's in-built t-SNE. -The following bit of Python code will do this for you (including the artificial -data generator): +The following bit of Python code will do this for you (including setting up the data generator): ```python from jax import numpy as jnp, random @@ -26,7 +14,7 @@ from ngclearn.utils.viz.dim_reduce import extract_tsne_latents, plot_latents dkey = random.PRNGKey(1234) -def gen_data(dkey, N): ## artificial data generator (or proxy model) +def gen_data(dkey, N): ## data generator (or proxy stochastic data generating process) mu1 = jnp.asarray([[2.1, 3.2, 0.6, -4., -2.]]) cov1 = jnp.eye(5) * 0.78 mu2 = jnp.asarray([[-1.8, 0.2, -0.1, 1.99, 1.56]]) @@ -59,6 +47,4 @@ which should produce a plot, i.e., `codes.jpg`, similar to the one below: -In this example scenario, we see that we can successfully map the 5D model output -data to a plottable 2D space, facilitating some level of downstream qualitative -interpretation of the model. +In this example scenario, we see that we can successfully map the 5D model output data to a plottable 2D space, facilitating some level of downstream qualitative interpretation of the model. From 276cb8906e8e0070b1d2a3f36bc3ab13604ba553 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Tue, 18 Nov 2025 17:39:04 -0500 Subject: [PATCH 068/121] revised mod/reward-stdp neurocog doc --- docs/tutorials/neurocog/mod_stdp.md | 272 +++++++++------------------- 1 file changed, 84 insertions(+), 188 deletions(-) diff --git a/docs/tutorials/neurocog/mod_stdp.md b/docs/tutorials/neurocog/mod_stdp.md index 3a76de37..18809649 100755 --- a/docs/tutorials/neurocog/mod_stdp.md +++ b/docs/tutorials/neurocog/mod_stdp.md @@ -1,52 +1,28 @@ # Lecture 4D: Reward-Modulated Spike-Timing-Dependent Plasticity -In this lesson, we will build on the notions of spike-timing-dependent -plasticity (STDP), covered [earlier here](../neurocog/stdp.md), to construct -an important form of biological credit assignment in spiking neural networks -known as reward-modulated STDP (sometimes abbreviated to R-STDP). Specifically, -we will simulate and plot the underlying plasticity dynamics associated with -this form of change in synaptic efficacy, specifically studying two in-built -schemes of STDP: modulated STDP (MSTDP) and modulated STDP with eligibility -traces (MSTDP-ET). +In this lesson, we will build on the notions of spike-timing-dependent plasticity (STDP), covered [earlier here](../neurocog/stdp.md), to construct an important form of biological credit assignment in spiking neural networks known as reward-modulated STDP (sometimes abbreviated to R-STDP). Specifically, we will simulate and plot the underlying plasticity dynamics associated with this form of change in synaptic efficacy, specifically studying two in-built schemes of STDP: modulated STDP (MSTDP) and modulated STDP with eligibility traces (MSTDP-ET). ## Probing Modulated STDP and Eligibility Traces -Go ahead and make a new folder for this study and create a Python script, -i.e., `run_reward_stdp.py`, to write your code for this part of the tutorial. +Go ahead and make a new folder for this study and create a Python script, i.e., `run_reward_stdp.py`, to write your code for this part of the tutorial. -Much as we did in the STDP lesson, we will build a 3-component dynamical system --- two spiking neurons (represented by traces) that are connected with a single -synapse -- but, this time, we will simulate three variations of this system in -parallel. Each one of these variants will evolve its single synapse according -to a different condition of STDP: +Much as we did in the STDP lesson, we will build a 3-component dynamical system -- two spiking neurons (represented by traces) that are connected with a single synapse -- but, this time, we will simulate three variations of this system in parallel. Each one of these variants will evolve its single synapse according to a different condition of STDP: 1. the first one will change its synapse's strength in accordance with trace-based STDP; 2. the second one will change its synapse's strength via modulated STDP (MSTDP); and, 3. the third and final one will change its synapse's strength via modulated STDP equipped with an eligibility trace (MSTDP-ET). -The second and third model above will make use of ngc-learn's in-built -[MSTDPETSynapse](ngclearn.components.synapses.modulated.MSTDPETSynapse), which -is an STDP cable component that sub-classes the `TraceSTDPSynapse` cable component -and will offer the additional machinery we will need to carry out modulated -forms of STDP. -All three of these variant STDP-evolved systems will make use of the same set -of variable traces (the `VarTrace` object introduced in the previous STDP lesson), -and we will control the spike trains by providing a specific set of pre-synaptic -spike times and a corresponding set of post-synaptic spike times ( both in -milliseconds). Furthermore, we will insert a convenience cell in-built -to ngc-learn called the `RewardErrorCell`, which is generally use to produce -what is known in neuroscience literature as "reward prediction error" (RPE). - -Writing the above three parallel single synapse systems, including meta-parameters -and the required compiled simulation and dynamic commands, can be done as follows: +The second and third model above will make use of ngc-learn's in-built [MSTDPETSynapse](ngclearn.components.synapses.modulated.MSTDPETSynapse), which is an STDP cable component that sub-classes the `TraceSTDPSynapse` cable component and will offer the additional machinery we will need to carry out modulated forms of STDP. +All three of these variant STDP-evolved systems will make use of the same set of variable traces (the `VarTrace` object introduced in the previous STDP lesson), and we will control the spike trains by providing a specific set of pre-synaptic spike times and a corresponding set of post-synaptic spike times (both in milliseconds). Furthermore, we will insert a convenience cell in-built to ngc-learn called the `RewardErrorCell`, which is generally use to produce what is known in neuroscience literature as "reward prediction error" (RPE). + +Writing the above three parallel single synapse systems, including meta-parameters and the required compiled simulation and dynamic commands, can be done as follows: ```python from jax import numpy as jnp, random, jit from ngcsimlib.context import Context -from ngclearn.utils import JaxProcess +from ngclearn import Context, MethodProcess ## import model-specific mechanisms -from ngclearn.components import (TraceSTDPSynapse, MSTDPETSynapse, - RewardErrorCell, VarTrace) -import ngclearn.utils.weight_distribution as dist +from ngclearn.components import (TraceSTDPSynapse, MSTDPETSynapse, RewardErrorCell, VarTrace) +from ngclearn.utils.distribution_generator import DistributionGenerator ## create seeding keys (JAX-style) dkey = random.PRNGKey(231) @@ -55,142 +31,95 @@ dkey, *subkeys = random.split(dkey, 2) dt = 1. # ms # integration time constant T_max = 200 ## number time steps to simulate tau_pre = tau_post = 20. # ms -tau_elg = 25. +tau_elg = 25. # ms ## eligibility trace time constant Aplus = Aminus = 1. ## in ngc-learn, Aplus/Aminus are magnitudes (signs are handled internally) gamma = 0.2 gamma_0 = 0.2/tau_elg with Context("Model") as model: W_stdp = TraceSTDPSynapse( ## reward-STDP (RSTDP) - "W1_stdp", shape=(1, 1), eta=gamma, A_plus=Aplus, A_minus=Aminus, - weight_init=dist.constant(value=0.2), key=subkeys[0]) + "W1_stdp", shape=(1, 1), eta=gamma, A_plus=Aplus, A_minus=Aminus, + weight_init=DistributionGenerator.constant(value=0.2), key=subkeys[0] + ) W_mstdp = MSTDPETSynapse( ## reward-STDP (RSTDP) - "W1_rstdp", shape=(1, 1), eta=gamma, A_plus=Aplus, A_minus=Aminus, - tau_elg=0., weight_init=dist.constant(value=0.2), key=subkeys[0]) + "W1_rstdp", shape=(1, 1), eta=gamma, A_plus=Aplus, A_minus=Aminus, tau_elg=0., + weight_init=DistributionGenerator.constant(value=0.2), key=subkeys[0] + ) W_mstdpet = MSTDPETSynapse( ## reward-STDP w/ eligibility traces - "W_mstdpet", shape=(1, 1), eta=gamma_0, A_plus=Aplus, A_minus=Aminus, - tau_elg=tau_elg, weight_init=dist.constant(value=0.2), key=subkeys[0]) + "W_mstdpet", shape=(1, 1), eta=gamma_0, A_plus=Aplus, A_minus=Aminus, tau_elg=tau_elg, + weight_init=DistributionGenerator.constant(value=0.2), key=subkeys[0] + ) ## set up pre- and -post synaptic trace variables tr0 = VarTrace("tr0", n_units=1, tau_tr=tau_pre, a_delta=Aplus) tr1 = VarTrace("tr1", n_units=1, tau_tr=tau_post, a_delta=Aminus) rpe = RewardErrorCell("r", n_units=1, alpha=0.) - evolve_process = (JaxProcess() + evolve_process = (MethodProcess("evolve") >> W_stdp.evolve >> W_mstdp.evolve >> W_mstdpet.evolve) - model.wrap_and_add_command(jit(evolve_process.pure), name="evolve") - advance_process = (JaxProcess() + advance_process = (MethodProcess("advance") >> tr0.advance_state >> tr1.advance_state >> rpe.advance_state >> W_stdp.advance_state >> W_mstdp.advance_state >> W_mstdpet.advance_state) - model.wrap_and_add_command(jit(advance_process.pure), name="advance") - reset_process = (JaxProcess() + reset_process = (MethodProcess("reset") >> W_stdp.reset >> W_mstdp.reset >> W_mstdpet.reset >> rpe.reset >> tr0.reset - >> tr1.reset - ) - model.wrap_and_add_command(jit(reset_process.pure), name="reset") - - @Context.dynamicCommand - def clamp_spikes(f_j, f_i): - tr0.inputs.set(f_j) - tr1.inputs.set(f_i) - - @Context.dynamicCommand - def clamp_stdp_stats(f_j, f_i, trace_j, trace_i): - W_stdp.preSpike.set(f_j) - W_stdp.postSpike.set(f_i) - W_stdp.preTrace.set(trace_j) - W_stdp.postTrace.set(trace_i) - - @Context.dynamicCommand - def clamp_mstdp_stats(f_j, f_i, trace_j, trace_i, reward): - W_mstdp.preSpike.set(f_j) - W_mstdp.postSpike.set(f_i) - W_mstdp.preTrace.set(trace_j) - W_mstdp.postTrace.set(trace_i) - W_mstdp.modulator.set(reward) - - @Context.dynamicCommand - def clamp_mstdpet_stats(f_j, f_i, trace_j, trace_i, reward): - W_mstdpet.preSpike.set(f_j) - W_mstdpet.postSpike.set(f_i) - W_mstdpet.preTrace.set(trace_j) - W_mstdpet.postTrace.set(trace_i) - W_mstdpet.modulator.set(reward) + >> tr1.reset) + +## set up some utility functions for the model context +def clamp_spikes(f_j, f_i): + tr0.inputs.set(f_j) + tr1.inputs.set(f_i) + +def clamp_stdp_stats(f_j, f_i, trace_j, trace_i): + W_stdp.preSpike.set(f_j) + W_stdp.postSpike.set(f_i) + W_stdp.preTrace.set(trace_j) + W_stdp.postTrace.set(trace_i) + +def clamp_mstdp_stats(f_j, f_i, trace_j, trace_i, reward): + W_mstdp.preSpike.set(f_j) + W_mstdp.postSpike.set(f_i) + W_mstdp.preTrace.set(trace_j) + W_mstdp.postTrace.set(trace_i) + W_mstdp.modulator.set(reward) + +def clamp_mstdpet_stats(f_j, f_i, trace_j, trace_i, reward): + W_mstdpet.preSpike.set(f_j) + W_mstdpet.postSpike.set(f_i) + W_mstdpet.preTrace.set(trace_j) + W_mstdpet.postTrace.set(trace_i) + W_mstdpet.modulator.set(reward) ``` -Given our three parallel models constructed above, we ready to write some code -to use our simulation setup. Before we do, however, notice that we have -configured the simulation -time `T_max` to be `200` milliseconds (ms), the integration time constant -`dt` to be `1` ms, and the time constant for both our pre-synaptic and -post-synaptic spiking neuron traces to be `20` ms. Two final points to notice -about the models we have constructed above are: -1. the RPE cell `rpe` has been configured to only output a given clamped - reward signal via `alpha = 0`; this, according to the internal design of the - `RewardErrorCell`, just effectively shuts off the moving average prediciton - of reward signals that the cell encounters over time (in most practical - cases, you will not want this set to zero as we are often interested in - the difference between a reward prediction and a target reward value); -2. for the third model, the MSTDP-ET model, we have configured an eligibility - trace to be used by setting the eligibility time constant `tau_elg` to be - be non-zero, i.e., it was set to `25` ms. An eligibility trace, in the context - STDP/Hebbian synaptic updates, simply another set of dynamics (i.e., another - ordinary differential equation) that we maintain as STDP synaptic updates - are computed. - -With respect to the second point made about eligibility traces, formally, we note -that under MSTDP-ET, instead of computing a trace-based STDP update at -each and every single time step `t` and updating the synapses immediately, -we first aggregate each STDP into another variable (the eligibility) according -to the following ODE: +Given our three parallel models constructed above, we ready to write some code to use our simulation setup. Before we do, however, notice that we have configured the simulation time `T_max` to be `200` milliseconds (ms), the integration time constant `dt` to be `1` ms, and the time constant for both our pre-synaptic and post-synaptic spiking neuron traces to be `20` ms. Two final points to notice about the models that we have constructed above are: +1. the RPE cell `rpe` has been configured to only output a given clamped reward signal via `alpha = 0`; this, according to the internal design of the `RewardErrorCell`, just effectively shuts off the moving average prediction of reward signals that the cell encounters over time (in most practical cases, you will not want this set to zero as we are often interested in the difference between a reward prediction and a target reward value); +2. for the third model, the MSTDP-ET model, we have configured an eligibility trace to be used by setting the eligibility time constant `tau_elg` to be non-zero, i.e., it was set to `25` ms. An eligibility trace, in the context STDP/Hebbian synaptic updates, simply another set of dynamics (i.e., another ordinary differential equation) that we maintain as STDP synaptic updates are computed. + +With respect to the second point made above about eligibility traces, formally, we note that: under MSTDP-ET, instead of computing a trace-based STDP update at each and every single time step `t` and updating the synapses immediately, we first aggregate each STDP update into another variable (the "eligibility") according to the following ODE: $$ -\tau_{elg} \frac{\partial \mathbf{E}_{ij}}{\partial t} = -\mathbf{E}_{ij} + -\beta \frac{\partial \mathbf{W}_{ij}}{\partial t} +\tau_{elg} \frac{\partial \mathbf{E}_{ij}}{\partial t} = -\mathbf{E}_{ij} + \beta \frac{\partial \mathbf{W}_{ij}}{\partial t} $$ -where $i$ denotes the index of the post-synpatic spiking neuron (which emits -a spike we label as $f_i$) and $j$ denotes the index of the pre-synaptic -spiking neuron (which emits a spike we label as $f_j$), $\mathbf{W}_{ij}$ is -the synapse that connects neuron $j$ to $i$, $\mathbf{E}_{ij}$ is the eligibility -trace we maintain for synapse $\mathbf{W}_{ij}$, and $\beta$ is control factor -(typically set to one) for scaling the magnitude of the STDP update's effect. -Finally, note that $\frac{\partial \mathbf{W}_{ij}}{\partial t}$ is the actual -synaptic update produced by our trace-based STDP at time $t$. - -Given the idea of the eligibility trace above, and how our RPE cell has been -configured, we can write down simply what kind of synaptic update ] -$\Delta \mathbf{W}_{ij}(t)$ that each of -our three dynamical systems will yield once we simulate them. -1. Trace-based STDP will produce an update to the synapse according to the combined - products of a paired pre-synaptic trace and post-synaptic spike (long-term - potentiation) and a paired pre-synaptic spike and post-synaptic trace - (long-term depression), i.e, - $\Delta \mathbf{W}_{ij}(t) = \gamma \frac{\partial \mathbf{W}_{ij}}{\partial t}$; -2. MSTDP -- the second/middle model with the `MSTDPETSynapse` with its - `tau_elg = 0` -- will produce a modulated update to the synapse at each - time step as follows: - $\Delta \mathbf{W}_{ij}(t) = \gamma r(t) \frac{\partial \mathbf{W}_{ij}}{\partial t}$; -3. MSTDP-ET -- the third and final model that uses an eligiblity trace -- will - produce a modulated update to the synapse at each time step via: - $\Delta \mathbf{W}_{ij}(t) = \gamma r(t) \mathbf{E}_{ij}(t)$. -Note that $r(t)$ is the reward administered at each time step `t` and $\gamma$ -is just an additional dampening factor to control how much of the STDP update -is applied at each time step (i.e., a global learning rate). - -Armed with our knowledge of the plasticity dynamics above, we next write down -what we want our model simulations to do: +where $i$ denotes the index of the post-synaptic spiking neuron (which emits a spike we label as $f_i$) and $j$ denotes the index of the pre-synaptic spiking neuron (which emits a spike we label as $f_j$), $\mathbf{W}_{ij}$ is the synapse that connects neuron $j$ to $i$, $\mathbf{E}_{ij}$ is the eligibility trace we maintain for synapse $\mathbf{W}_{ij}$, and $\beta$ is control factor (typically set to one) for scaling the magnitude of the STDP update's effect. +Finally, note that $\frac{\partial \mathbf{W}_{ij}}{\partial t}$ is the actual synaptic update produced by our trace-based STDP at time $t$. + +Given the idea of the eligibility trace explained above, as well as how our RPE cell has been configured, we can write down simply what kind of synaptic update $\Delta \mathbf{W}_{ij}(t)$ that each of our three dynamical systems will yield once we simulate them. +1. Trace-based STDP will produce an update to the synapse according to the combined products of a paired pre-synaptic trace and post-synaptic spike (long-term potentiation) and a paired pre-synaptic spike and post-synaptic trace (long-term depression), i.e, $\Delta \mathbf{W}_{ij}(t) = \gamma \frac{\partial \mathbf{W}_{ij}}{\partial t}$; +2. MSTDP -- the second/middle model with the `MSTDPETSynapse` with its `tau_elg = 0` -- will produce a modulated update to the synapse at each time step as follows: $\Delta \mathbf{W}_{ij}(t) = \gamma r(t) \frac{\partial \mathbf{W}_{ij}}{\partial t}$; +3. MSTDP-ET -- the third and final model that uses an eligibility trace -- will produce a modulated update to the synapse at each time step via: $\Delta \mathbf{W}_{ij}(t) = \gamma r(t) \mathbf{E}_{ij}(t)$. Note that $r(t)$ is the reward administered at each time step `t` and $\gamma$ is just an additional dampening factor to control how much of the STDP update is applied at each time step (i.e., a global learning rate). + +Armed with our knowledge of the plasticity dynamics above, we next write down what we want our model simulations to do: ```python # synthetic spike times of pre and post synaptic neurons @@ -208,7 +137,7 @@ elg_vals = [] W_stdp_vals = [] W_mstdp_vals = [] W_mstdpet_vals = [] -model.reset() +reset_process.run() for i in range(T_max): f_j = jnp.zeros((1, 1)) ## pre-syn spike if (i * dt) in spike_times_pre: @@ -222,33 +151,29 @@ for i in range(T_max): reward = -reward rpe.reward.set(reward) - model.clamp_spikes(f_j, f_i) ## clamp pre/post spikes to traces - model.advance(t=i * dt, dt=dt) + clamp_spikes(f_j, f_i) ## clamp pre/post spikes to traces + advance_process.run(t=i * dt, dt=dt) - model.clamp_stdp_stats(f_j, f_i, tr0.trace.value, tr1.trace.value) - model.clamp_mstdp_stats( - f_j, f_i, tr0.trace.value, tr1.trace.value, rpe.reward.value) - model.clamp_mstdpet_stats( - f_j, f_i, tr0.trace.value, tr1.trace.value, rpe.reward.value) - model.evolve(t=i * dt, dt=dt) + clamp_stdp_stats(f_j, f_i, tr0.trace.get(), tr1.trace.get()) + clamp_mstdp_stats(f_j, f_i, tr0.trace.get(), tr1.trace.get(), rpe.reward.get()) + clamp_mstdpet_stats(f_j, f_i, tr0.trace.get(), tr1.trace.get(), rpe.reward.get()) + evolve_process.run(t=i * dt, dt=dt) ## record statistics for plotting pre_spikes.append(jnp.squeeze(f_j)) post_spikes.append(jnp.squeeze(f_i)) r_vals.append(jnp.squeeze(reward)) - tr0_vals.append(jnp.squeeze(tr0.trace.value)) - tr1_vals.append(jnp.squeeze(-tr1.trace.value)) - dWstdp_vals.append(jnp.squeeze(W_stdp.dWeights.value)) - elg_vals.append(jnp.squeeze(W_mstdpet.eligibility.value)) - W_stdp_vals.append(jnp.squeeze(W_stdp.weights.value)) - W_mstdp_vals.append(jnp.squeeze(W_mstdp.weights.value)) - W_mstdpet_vals.append(jnp.squeeze(W_mstdpet.weights.value)) + tr0_vals.append(jnp.squeeze(tr0.trace.get())) + tr1_vals.append(jnp.squeeze(-tr1.trace.get())) + dWstdp_vals.append(jnp.squeeze(W_stdp.dWeights.get())) + elg_vals.append(jnp.squeeze(W_mstdpet.eligibility.get())) + W_stdp_vals.append(jnp.squeeze(W_stdp.weights.get())) + W_mstdp_vals.append(jnp.squeeze(W_mstdp.weights.get())) + W_mstdpet_vals.append(jnp.squeeze(W_mstdpet.weights.get())) t_vals.append(i * dt) ``` -which will run all three models simultaneously for `200` simulated milliseconds -and collect statistics of interest. We may then finally make several plots of what happens under each STDP mode -(reproducing some key results in [1]. First, we will plot the resulting synaptic magnitude over time, like so: +which will run all three of our models simultaneously for `200` simulated milliseconds and collect statistics of interest. We may then finally make several plots of what happens under each STDP mode (reproducing some key results in [1]. First, we will plot the resulting synaptic magnitude over time, like so: ```python import matplotlib.pyplot as plt @@ -282,23 +207,13 @@ ax3.grid() fig1.savefig("modstdp_syn_dynamics.jpg") ``` -which should produce a plot like the one below: +which should produce a plot like the one below: -Notice, first, that the middle plot for MSTDP (the red curve in the middle plot) -essentially mimics the update produced by STDP for the first `100` ms and then -flips (becomes a mirror image) of the STDP trajectory; this is due to the fact -that, as you can see in the code you wrote earlier for the spike train simulation, -the reward signal changes sign after `100` ms and since MSTDP is effectively -the product of the reward and an STDP synaptic update the sign of the synaptic -change will flip as well. Finally, notice that the MSTDP-ET yields a -smoothened change in synaptic efficacy (the blue curve in the bottom plot); -this is due to the eligibility trace leakily integrating the STDP updates -over time (and ultimately multiplying the trace by the reward at time `t`). +Notice, first, that the middle plot for MSTDP (the red curve in the middle plot) essentially mimics the update produced by STDP for the first `100` ms and then flips (becomes a mirror image) of the STDP trajectory; this is due to the fact that, as you can see in the code you wrote earlier for the spike train simulation, the reward signal changes sign after `100` ms and since MSTDP is effectively the product of the reward and an STDP synaptic update the sign of the synaptic change will flip as well. Finally, notice that the MSTDP-ET yields a smoothened change in synaptic efficacy (the blue curve in the bottom plot); this is due to the eligibility trace leakily integrating the STDP updates over time (and ultimately multiplying the trace by the reward at time `t`). -We will then plot the dynamics of important compartments the drive the operation -of the various STDP models with the following code block: +We will then plot the dynamics of important compartments the drive the operation of the various STDP models with the following code block: ```python ## create STDP synaptic dynamics plots (figure 1) @@ -352,27 +267,8 @@ which should yield the following component dynamics plot: -This plot usefully breaks down the plasticity dynamics of all three STDP models -into the core component dynamics. The top two plots illustrate the emissions -of spikes over time (the pre-synaptic spike plot followed by the post-synaptic -spike plot) while underneath these -- the third plot -- is a visualization -of these spikes respective traces multiplied by their corresponding sign -that is used in STDP, i.e., the blue pre-synaptic curve is positive as it -represents synaptic potentiation over time (pre occurs before post) while the -orange post-synaptic curve is negative as it represents synaptic depression -over time (post occurs after pre). The teal curve in the fourth plot -illustrates what kind of updates that typical trace-based STDP would produce, -in the absence of a reward signal, whereas the yellow-ish/goldenrod curve -underneath shows the eligibilty trace that smoothens out the more pulse-like -adjustments that STDP yields. In the very bottom plot, we see the red piecewise -function that characterizes our reward signal -- for the first `100` ms it -is simply one whereas for the last `200` ms it is negative one. In general, -one will not likely have access to a clean dense reward in most control -problems, i.e., the reward signal is typically sparse, which will mean that -modulated STDP updates will only occur when the signal is non-zero; this is -the advantage that MSTDP-ET offers over MSTDP as the synaptic change -dynamics persist (yet decay) in between reward presentation times and thus -MSTDP-ET will be more effective in cases when the reward signal is delayed. +This plot usefully breaks down the plasticity dynamics of all three STDP models into the core component dynamics. The top two plots illustrate the emissions of spikes over time (the pre-synaptic spike plot followed by the post-synaptic spike plot) while underneath these -- the third plot -- is a visualization of these spikes respective traces multiplied by their corresponding sign +that is used in STDP, i.e., the blue pre-synaptic curve is positive as it represents synaptic potentiation over time (pre occurs before post) while the orange post-synaptic curve is negative as it represents synaptic depression over time (post occurs after pre). The teal curve in the fourth plot illustrates what kind of updates that typical trace-based STDP would produce, in the absence of a reward signal, whereas the yellow-ish/goldenrod curve underneath shows the eligibility trace that smoothens out the more pulse-like adjustments that STDP yields. In the very bottom plot, we see the red piecewise function that characterizes our reward signal -- for the first `100` ms it is simply one whereas for the last `200` ms it is negative one. In general, one will not likely have access to a clean dense reward in most control problems, i.e., the reward signal is typically sparse, which will mean that modulated STDP updates will only occur when the signal is non-zero; this is the advantage that MSTDP-ET offers over MSTDP as the synaptic change dynamics persist (yet decay) in between reward presentation times and, thus, MSTDP-ET will be more effective in cases when the reward signal is delayed. ## References From ab6c716771176ba3d068d8b8984949035bfda07f Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 19 Nov 2025 17:28:29 -0500 Subject: [PATCH 069/121] revised stp-syn neurocog doc and updated stp-syn to use proper initializer --- docs/tutorials/neurocog/mod_stdp.md | 2 +- docs/tutorials/neurocog/rate_cell.md | 97 ++++------ .../neurocog/short_term_plasticity.md | 171 +++++------------- .../components/synapses/STPDenseSynapse.py | 7 +- 4 files changed, 90 insertions(+), 187 deletions(-) diff --git a/docs/tutorials/neurocog/mod_stdp.md b/docs/tutorials/neurocog/mod_stdp.md index 18809649..f705f935 100755 --- a/docs/tutorials/neurocog/mod_stdp.md +++ b/docs/tutorials/neurocog/mod_stdp.md @@ -18,7 +18,7 @@ Writing the above three parallel single synapse systems, including meta-paramete ```python from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context + from ngclearn import Context, MethodProcess ## import model-specific mechanisms from ngclearn.components import (TraceSTDPSynapse, MSTDPETSynapse, RewardErrorCell, VarTrace) diff --git a/docs/tutorials/neurocog/rate_cell.md b/docs/tutorials/neurocog/rate_cell.md index 56afb789..f6554116 100644 --- a/docs/tutorials/neurocog/rate_cell.md +++ b/docs/tutorials/neurocog/rate_cell.md @@ -1,7 +1,6 @@ # Lecture 3A: The Rate Cell Model -Graded neurons are one of the main classes/collections of cell components in ngc-learn. These specifically offer cell models that operate under real-valued dynamics -- in other words, they do not spike or use discrete pulse-like values in their operation. These are useful for building biophysical systems that evolve under continuous, time-varying dynamics, e.g., continuous-time recurrent neural networks, various kinds of predictive coding circuit models, as well as for continuous components in discrete systems, e.g. electrical -current differential equations in spiking networks. +Graded neurons are one of the main classes/collections of cell components in ngc-learn. These specifically offer cell models that operate under real-valued dynamics -- in other words, they do not spike or use discrete pulse-like values in their operation. These are useful for building biophysical systems that evolve under continuous, time-varying dynamics, e.g., continuous-time recurrent neural networks, various kinds of predictive coding circuit models, as well as for continuous components in discrete systems, e.g. electrical current differential equations in spiking networks. In this tutorial, we will study one of ngc-learn's workhorse in-built graded cell components, the rate cell ([RateCell](ngclearn.components.neurons.graded.rateCell)). @@ -9,15 +8,12 @@ In this tutorial, we will study one of ngc-learn's workhorse in-built graded cel ### Instantiating the Rate Cell -Let's go ahead and set up the controller for this lesson's simulation, -where we will a dynamical system with only a single component, -specifically the rate-cell (RateCell). Let's start with the file's header -(or import statements): +Let's go ahead and set up the controller for this lesson's simulation, where we will a dynamical system with only a single component, specifically the rate-cell (RateCell). Let's start with the file's header (or import statements): ```python from jax import numpy as jnp, random, jit -from ngclearn.utils import JaxProcess -from ngcsimlib.context import Context + +from ngclearn import Context, MethodProcess ## import model-specific elements from ngclearn.components.neurons.graded.rateCell import RateCell ``` @@ -36,91 +32,67 @@ gamma = 1. with Context("Model") as model: ## model/simulation definition ## instantiate components (like cells) - cell = RateCell("z0", n_units=1, tau_m=tau_m, act_fx=act_fx, - prior=("gaussian", gamma), integration_type="euler", key=subkeys[0]) + cell = RateCell( + "z0", n_units=1, tau_m=tau_m, act_fx=act_fx, prior=("gaussian", gamma), integration_type="euler", + key=subkeys[0] + ) ## instantiate desired core commands that drive the simulation - advance_process = (JaxProcess() + advance_process = (MethodProcess("advance") >> cell.advance_state) - model.wrap_and_add_command(jit(advance_process.pure), name="advance") - - reset_process = (JaxProcess() + reset_process = (MethodProcess("reset") >> cell.reset) - model.wrap_and_add_command(jit(reset_process.pure), name="reset") - ## instantiate some non-jitted dynamic utility commands - @Context.dynamicCommand - def clamp(x): - cell.j.set(x) +## instantiate utility commands +def clamp(x): + cell.j.set(x) ``` -A notable argument to the rate-cell, beyond some of its differential equation -constants (`tau_m` and `gamma`), is its activation function choice (default is -the `identity`), which we have chosen to be a discrete pulse emitting function -known as the `unit_threshold` (which outputs a value of one for any input that -exceeds the threshold of one and zero for anything else). +A notable argument to the rate-cell, beyond some of its differential equation constants (`tau_m` and `gamma`), is its activation function choice (default is the `identity`), which we have chosen to be a discrete pulse emitting function known as the `unit_threshold` (which outputs a value of one for any input that exceeds the threshold of one and zero for anything else). -Mathematically, under the hood, a rate-cell evolves according to the -ordinary differential equation (ODE): +Mathematically, under the hood, a rate-cell evolves according to the ordinary differential equation (ODE): $$ \tau_m \frac{\partial \mathbf{z}}{\partial t} = -\gamma \text{prior}\big(\mathbf{z}\big) + (\mathbf{x} + \mathbf{x}_{td}) $$ -where $\mathbf{x}$ is external input signal and $\mathbf{x}_{td}$ (default -value is zero) is an optional additional input pressure signal (`td` stands for "top-down", -its name motivated by predictive coding literature). -A good way to understand this equation is in the context of two examples: -1. in a biophysically more realistic spiking network, $\mathbf{x}$ is the -total electrical input into the cell from multiple injections produced -by transmission across synapses ($\mathbf{x}_{td} = 0$)) and the $\text{prior}$ -is set to `gaussian` ($\gamma = 1$), yielding the equation -$\tau_m \frac{\partial \mathbf{z}}{\partial t} = -\mathbf{z} + \mathbf{x}$ for -a simple model of synaptic conductance, and -2. in a predictive coding circuit, $\mathbf{x}$ is the sum of input projections -(or messages) passed from a "lower" layer/group of neurons while $\mathbf{x}_{td}$ -is set to be the sum of (top-down) pressures produced by an "upper" layer/group -such as the value of a pair of nearby error neurons multiplied by $-1$.[^1] In -this example, $0 \leq \gamma \leq 1$ and $\text{prior}$ could be set to one -of any kind of kurtotic distribution to induce a soft form of sparsity in -the dynamics, e.g., such as "cauchy" for the Cauchy distribution. +where $\mathbf{x}$ is external input signal and $\mathbf{x}_{td}$ (default value is zero) is an optional additional input pressure signal (`td` stands for "top-down", its name motivated by predictive coding literature). +A good way to understand this equation is in the context of two examples: +1. in a biophysically more realistic spiking network, $\mathbf{x}$ is the total electrical input into the cell from multiple injections produced by transmission across synapses ($\mathbf{x}_{td} = 0$)) and the $\text{prior}$ is set to `gaussian` ($\gamma = 1$), yielding the equation $\tau_m \frac{\partial \mathbf{z}}{\partial t} = -\mathbf{z} + \mathbf{x}$ for a simple model of synaptic conductance, and +2. in a predictive coding circuit, $\mathbf{x}$ is the sum of input projections (or messages) passed from a "lower" layer/group of neurons while $\mathbf{x}_{td}$ is set to be the sum of (top-down) pressures produced by an "upper" layer/group such as the value of a pair of nearby error neurons multiplied by $-1$.[^1] In this example, $0 \leq \gamma \leq 1$ and $\text{prior}$ could be set to one of any kind of kurtotic distribution to induce a soft form of sparsity in the dynamics, e.g., such as "cauchy" for the Cauchy distribution. ### Simulating a Rate Cell -Given our single rate-cell dynamical system above, let us write some code to use -our `Rate` node and visualize its dynamics by feeding -into it a pulse current (a piecewise input function that is an alternating -sequence of intervals of where nothing is input and others where a non-zero -value is input) for a small period of time (`dt * T = 1 * 210` ms). Specifically, -we can plot the input current, the neuron's linear rate activity `z` and its -nonlinear activity `phi(z)` as follows: +Given our single rate-cell dynamical system above, let us write some code to use our `Rate` node and visualize its dynamics by feeding into it a pulse current (a piecewise input function that is an alternating sequence of intervals of where nothing is input and others where a non-zero value is input) for a small period of time (`dt * T = 1 * 210` ms). Specifically, we can plot the input current, the neuron's linear rate activity `z` and its nonlinear activity `phi(z)` as follows: ```python # create a synthetic electrical pulse current -current = jnp.concatenate((jnp.zeros((1,10)), - jnp.ones((1,50)) * 1.006, - jnp.zeros((1,50)), - jnp.ones((1,50)) * 1.006, - jnp.zeros((1,50))), axis=1) +current = jnp.concatenate( + (jnp.zeros((1,10)), + jnp.ones((1,50)) * 1.006, + jnp.zeros((1,50)), + jnp.ones((1,50)) * 1.006, + jnp.zeros((1,50))), axis=1 +) lin_out = [] nonlin_out = [] t_values = [] -model.reset() +reset_process.run() t = 0. for ts in range(current.shape[1]): j_t = jnp.expand_dims(current[0,ts], axis=0) ## get data at time ts - model.clamp(j_t) - model.advance(t=ts*1., dt=dt) + clamp(j_t) + advance_process.run(t=ts*1., dt=dt) t_values.append(t) - t += dt + t += dt ## advance time forward by dt milliseconds ## naively extract simple statistics at time ts and print them to I/O - linear_z = cell.z.value - nonlinear_z = cell.zF.value + linear_z = cell.z.get() + nonlinear_z = cell.zF.get() lin_out.append(linear_z) nonlin_out.append(nonlinear_z) print("\r {}: s {} ; v {}".format(ts, linear_z, nonlinear_z), end="") @@ -148,10 +120,11 @@ ax.grid() fig.savefig("rate_cell_integration.jpg") ``` -which should yield a dynamics plot similar to the one below: +which should yield a dynamics plot similar to the one below: + [^1]: [Error neurons](ngclearn.components.neurons.graded.gaussianErrorCell) produce this kind of "top-down" value, which is technically the first derivative diff --git a/docs/tutorials/neurocog/short_term_plasticity.md b/docs/tutorials/neurocog/short_term_plasticity.md index b225f3c5..c669c5bb 100755 --- a/docs/tutorials/neurocog/short_term_plasticity.md +++ b/docs/tutorials/neurocog/short_term_plasticity.md @@ -1,69 +1,28 @@ # Lecture 4E: Short-Term Plasticity -In this lesson, we will study how short-term plasticity (STP) [1] dynamics --- where synaptic efficacy is cast in terms of the history of presynaptic activity -- -using ngc-learn's in-built `STPDenseSynapse`. -Specifically, we will study how a dynamic synapse may be constructed and -examine what short-term depression (STD) and short-term facilitation -(STF) dominated configurations of an STP synapse look like. +In this lesson, we will study how short-term plasticity (STP) [1] dynamics -- where synaptic efficacy is cast in terms of the history of presynaptic activity -- using ngc-learn's in-built `STPDenseSynapse`. Specifically, we will study how a dynamic synapse may be constructed and examine what short-term depression (STD) and short-term facilitation (STF) dominated configurations of an STP synapse look like. ## Probing Short-Term Plasticity -Go ahead and make a new folder for this study and create a Python script, -i.e., `run_shortterm_plasticity.py`, to write your code for this part of the -tutorial. +Go ahead and make a new folder for this study and create a Python script, i.e., `run_shortterm_plasticity.py`, to write your code for this part of the tutorial. -We will write a 3-component dynamical system that connects a Poisson input -encoding cell to a leaky integrate-and-fire (LIF) cell via a single dynamic -synapse that evolves according to STP. We will first write our -simulation of this dynamic synapse from the perspective of STF-dominated -dynamics, plotting out the results under two different Poisson spike trains -with different spiking frequencies. Then, we will modify our simulation -to emulate dynamics from an STD-dominated perspective. +We will write a 3-component dynamical system that connects a Poisson input encoding cell to a leaky integrate-and-fire (LIF) cell via a single dynamic synapse that evolves according to STP. We will first write our simulation of this dynamic synapse from the perspective of STF-dominated dynamics, plotting out the results under two different Poisson spike trains with different spiking frequencies. Then, we will modify our simulation to emulate dynamics from an STD-dominated perspective. ### Starting with Facilitation-Dominated Dynamics -One experimental goal with using a "dynamic synapse" [1] is often to computationally -model the fact that synaptic efficacy (strength/conductance magnitude) is -not a fixed quantity -- even in cases where long-term adaptation/learning is -absent -- and instead a time-varying property that depends on a fixed -quantity of biophysical resources. Specifically, biological neuronal networks, -synaptic signaling (or communication of information across synaptic connection -pathways) consumes some quantity of neurotransmitters -- STF results from an -influx of calcium into an axon terminal of a pre-synaptic neuron (after -emission of a spike pulse) whereas STD occurs after a depletion of -neurotransmitters that is consumed by the act of synaptic signaling at the axon -terminal of a pre-synaptic neuron. Studies of cortical neuronal regions have -empirically found that some areas are STD-dominated, STF-dominated, or exhibit -some mixture of the two. - -Ultimately, the above means that, in the context of spiking cells, when a -pre-synaptic neuron emits a pulse, this act will affect the relative magnitude -of the synapse's efficacy. In some cases, this will result in an increase -(facilitation) and, in others, this will result in a decrease (depression) -that lasts over a short period of time (several hundreds to thousands of -milliseconds in many instances). -As a result of considering synapses to have a dynamic nature to them, both over -short and long time-scales, plasticity can now be thought of as a stimulus and -resource-dependent quantity, reflecting an important biophysical aspect that -affects how neuronal systems adapt and generalize given different kinds of -sensory stimuli. - -Writing our STP dynamic synapse can be done by importing -[STPDenseSynapse](ngclearn.components.synapses.STPDenseSynapse) -from ngc-learn's in-built components and using it to wire the output -spike compartment of the `PoissonCell` to the input electrical current -compartment of the `LIFCell`. This can be done as follows (using the -meta-parameters we provide in the code block below to ensure -STF-dominated dynamics): +One experimental goal with using a "dynamic synapse" [1] is often to computationally model the fact that synaptic efficacy (strength/conductance magnitude) is not a fixed quantity -- even in cases where long-term adaptation/learning is absent -- and instead a time-varying property that depends on a fixed quantity of biophysical resources. Specifically, biological neuronal networks, synaptic signaling (or communication of information across synaptic connection pathways) consumes some quantity of neurotransmitters -- STF results from an influx of calcium into an axon terminal of a pre-synaptic neuron (after emission of a spike pulse) whereas STD occurs after a depletion of neurotransmitters that is consumed by the act of synaptic signaling at the axon terminal of a pre-synaptic neuron. Studies of cortical neuronal regions have empirically found that some areas are STD-dominated, STF-dominated, or exhibit some mixture of the two. + +Ultimately, the above means that, in the context of spiking cells, when a pre-synaptic neuron emits a pulse, this act will affect the relative magnitude of the synapse's efficacy. In some cases, this will result in an increase (facilitation) and, in others, this will result in a decrease (depression) that lasts over a short period of time (several hundreds to thousands of milliseconds in many instances). As a result of considering synapses to have a dynamic nature to them, both over short and long time-scales, plasticity can now be thought of as a stimulus and resource-dependent quantity, reflecting an important biophysical aspect that affects how neuronal systems adapt and generalize given different kinds of sensory stimuli. + +Writing our STP dynamic synapse can be done by importing [STPDenseSynapse](ngclearn.components.synapses.STPDenseSynapse) from ngc-learn's in-built components and using it to wire the output spike compartment of the `PoissonCell` to the input electrical current compartment of the `LIFCell`. This can be done as follows (using the meta-parameters we provide in the code block below to ensure STF-dominated dynamics): ```python from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context -from ngclearn.utils import JaxProcess + +from ngclearn import Context, MethodProcess ## import model-specific mechanisms from ngclearn.components import PoissonCell, STPDenseSynapse, LIFCell -import ngclearn.utils.weight_distribution as dist +from ngclearn.utils.distribution_generator import DistributionGenerator ## create seeding keys (JAX-style) dkey = random.PRNGKey(231) @@ -88,49 +47,40 @@ plot_fname = "{}Hz_stp_{}.jpg".format(firing_rate_e, tag) with Context("Model") as model: W = STPDenseSynapse( - "W", shape=(1, 1), weight_init=dist.constant(value=2.5), - resources_init=dist.constant(value=Rval), tau_f=tau_f, tau_d=tau_d, - key=subkeys[0] + "W", shape=(1, 1), weight_init=DistributionGenerator.constant(value=2.5), + resources_init=DistributionGenerator.constant(value=Rval), tau_f=tau_f, tau_d=tau_d, key=subkeys[0] ) z0 = PoissonCell("z0", n_units=1, target_freq=firing_rate_e, key=subkeys[0]) z1 = LIFCell( - "z1", n_units=1, tau_m=tau_m, resist_m=(tau_m / dt) * R_m, v_rest=-60., - v_reset=-70., thr=-50., tau_theta=0., theta_plus=0., refract_time=0. + "z1", n_units=1, tau_m=tau_m, resist_m=(tau_m / dt) * R_m, v_rest=-60., v_reset=-70., thr=-50., tau_theta=0., + theta_plus=0., refract_time=0. ) - W.inputs << z0.outputs ## z0 -> W - z1.j << W.outputs ## W -> z1 + z0.outputs >> W.inputs ## z0 -> W + W.outputs >> z1.j ## W -> z1 - advance_process = (JaxProcess() + advance_process = (MethodProcess("advance") >> z0.advance_state >> W.advance_state >> z1.advance_state) - model.wrap_and_add_command(jit(advance_process.pure), name="advance") - reset_process = (JaxProcess() + reset_process = (MethodProcess("reset") >> z0.reset >> z1.reset >> W.reset) - model.wrap_and_add_command(jit(reset_process.pure), name="reset") - @Context.dynamicCommand - def clamp(obs): - z0.inputs.set(obs) +## set up some utility functions for the model context +def clamp(obs): + z0.inputs.set(obs) ``` -Notice that the `STPDenseSynapse` has two important time constants to configure; -`tau_f` ($\tau_f$), the facilitation time constant, and `tau_d` ($\tau_d$), the -depression time constant. In effect, it is these two constants that you will -want to set to obtain different desired behavior from this in-built dynamic -synapse: +Notice that the `STPDenseSynapse` has two important time constants to configure; `tau_f` ($\tau_f$), the facilitation time constant, and `tau_d` ($\tau_d$), the depression time constant. In effect, it is these two constants that you will want to set to obtain different desired behavior from this in-built dynamic synapse: 1. setting $\tau_f > \tau_d$ will result in STF-dominated behavior; whereas 2. setting $\tau_f < \tau_d$ will produce STD-dominated behavior. -Note that setting $\tau_d = 0$ will result in short-term depression being turned off -completely (and $\tau_f = 0$ disables STF). +Note that setting $\tau_d = 0$ will result in short-term depression being turned off completely (and $\tau_f = 0$ disables STF). -Formally, given the time constants above the dynamics of the `STPDenseSynapse` -operate according to the following coupled ordinary differential equations (ODEs): +Formally, given the time constants above the dynamics of the `STPDenseSynapse` operate according to the following coupled ordinary differential equations (ODEs): $$ \tau_f \frac{\partial u_j(t)}{\partial t} &= -u_j(t) + N_R \big(1 - u_j(t)\big) s_j(t) \\ @@ -144,25 +94,11 @@ W^{dyn}_{ij}(t + \Delta t) = \Big( W^{max}_{ij} u_j(t + \Delta t) x_j(t) s_j(t) + W^{dyn}_{ij} (1 - s_j(t)) $$ -where $N_R$ represents an increment produced by a pre-synaptic spike $\mathbf{s}_j(t)$ -(and in essence, the neurotransmitter resources available to yield facilitation), -$W^{max}_{ij}$ denotes the absolute synaptic efficacy (or maximum response -amplitude of this synapse in the case of a complete release of all -neurotransmitters; $x_j(t) = u_j(t) = 1$) of the connection between pre-synaptic -neuron $j$ and post-synaptic neuron $i$, and $W^{dyn}_{ij}(t)$ is the value -of the dynamic synapse's efficacy at time `t`. -$\mathbf{x}_j$ is a variable (which lies in the range of $[0,1]$) that indicates -the fraction of (neurotransmitter) resources available after a depletion of the -neurotransmitter resource pool. $\mathbf{u}_j$, on the hand, -represents the neurotransmitter "release probability", or the fraction of available -resources ready for the dynamic synapse's use. +where $N_R$ represents an increment produced by a pre-synaptic spike $\mathbf{s}_j(t)$ (and in essence, the neurotransmitter resources available to yield facilitation), $W^{max}_{ij}$ denotes the absolute synaptic efficacy (or maximum response amplitude of this synapse in the case of a complete release of all neurotransmitters; $x_j(t) = u_j(t) = 1$) of the connection between pre-synaptic neuron $j$ and post-synaptic neuron $i$, and $W^{dyn}_{ij}(t)$ is the value of the dynamic synapse's efficacy at time `t`. $\mathbf{x}_j$ is a variable (which lies in the range of $[0,1]$) that indicates the fraction of (neurotransmitter) resources available after a depletion of the neurotransmitter resource pool. $\mathbf{u}_j$, on the hand, represents the neurotransmitter "release probability", or the fraction of available resources ready for the dynamic synapse's use. ### Simulating and Visualizing STF -Now that we understand the basics of how an ngc-learn STP synapse works, we can next -try it out on a simple pre-synaptic Poisson spike train. Writing out the -simulated input Poisson spike train and our STP model's processing of this -data can be done as follows: +Now that we understand the basics of how an ngc-learn STP synapse works, we can next try it out on a simple pre-synaptic Poisson spike train. Writing out the simulated input Poisson spike train and our STP model's processing of this data can be done as follows: ```python t_vals = [] @@ -170,26 +106,27 @@ u_vals = [] x_vals = [] W_vals = [] num_z1_spikes = 0. -model.reset() +reset_process.run() obs = jnp.asarray([[1.]]) ts = 1. ptr = 0 # spike time pointer for i in range(T_max): - model.clamp(obs) - model.advance(t=dt * ts, dt=dt) - u = jnp.squeeze(W.u.value) - x = jnp.squeeze(W.x.value) - Wexc = jnp.squeeze(W.Wdyn.value) - s0 = jnp.squeeze(W.inputs.value) - s1 = jnp.squeeze(z1.s.value) + clamp(obs) + advance_process.run(t=dt * ts, dt=dt) + u = jnp.squeeze(W.u.get()) + x = jnp.squeeze(W.x.get()) + Wexc = jnp.squeeze(W.Wdyn.get()) + s0 = jnp.squeeze(W.inputs.get()) + s1 = jnp.squeeze(z1.s.get()) num_z1_spikes = s1 + num_z1_spikes u_vals.append(u) x_vals.append(x) W_vals.append(Wexc) t_vals.append(ts) - print("{}| u: {} x: {} W: {} pre: {} post {}".format(ts, u, x, Wexc, s0, s1)) + print("\r{}| u: {} x: {} W: {} pre: {} post {}".format(ts, u, x, Wexc, s0, s1), end="") ts += dt ptr += 1 +print() print("Number of z1 spikes = ",num_z1_spikes) u_vals = jnp.squeeze(jnp.asarray(u_vals)) @@ -197,8 +134,7 @@ x_vals = jnp.squeeze(jnp.asarray(x_vals)) t_vals = jnp.squeeze(jnp.asarray(t_vals)) ``` -We may then plot out the result of the STF-dominated dynamics we -simulate above with the following code: +We may then plot out the result of the STF-dominated dynamics we simulate above with the following code: ```python import matplotlib.pyplot as plt @@ -235,13 +171,13 @@ ax2.grid() fig1.savefig(plot_fname) ``` -Under the `2` Hertz Poisson spike train set up above, the plotting -code should produce (and save to disk) the following: +Under the `2` Hertz Poisson spike train set up above, the plotting code should produce (and save to disk) the following: -Note that, if you change the frequency of the input Poisson spike train to `20` -Hertz instead, like so: +where we also observe that about `3` spikes/pulses were emitted by the post-synaptic neuron over the course of this +simulation. +Note that, if you change the frequency of the input Poisson spike train to `20` Hertz instead, like so: ```python firing_rate_e = 20 ## Hz (of Poisson input train) @@ -251,16 +187,13 @@ and re-run your simulation script, you should obtain the following: -Notice that increasing the frequency in which the pre-synaptic spikes occur -results in more volatile dynamics with respect to the effective synaptic -efficacy over time. +where we further observe that about `68` spikes/pulses were emitted by the post-synaptic neuron over the course of this +simulation. +In general, notice that increasing the frequency in which the pre-synaptic spikes occur results in more volatile dynamics with respect to the effective synaptic efficacy over time. ### Depression-Dominated Dynamics -With your code above, it's simple to reconfigure the model to emulate -the opposite of STF dominated dynamics, i.e., short-term depression (STD) -dominated dynamics. -Modify your meta-parameter values like so: +With the code you have written code above, it's simple to reconfigure the model to emulate the opposite of STF dominated dynamics, i.e., short-term depression (STD) dominated dynamics. To do so, you will need to modify your meta-parameter values like so: ```python firing_rate_e = 2 ## Hz (of Poisson input train) @@ -274,17 +207,13 @@ and re-run your script to obtain an output akin to the following: -Now, modify your meta-parameters one last time to use a higher-frequency -input spike train, i.e., `firing_rate_e = 20 ## Hz`, to obtain a plot similar -to the one below: +which, after running the script, will print out that the post-synaptic neuron spiked about `3` times. Now, modify your meta-parameters one last time to use a higher-frequency input spike train, i.e., `firing_rate_e = 20 ## Hz`, to obtain a plot similar to the one below: -You have now successfully simulated a dynamic synapse in ngc-learn across -several different Poisson input train frequencies under both STF and -STD-dominated regimes. In more complex biophysical models, it could prove useful -to consider combining STP with other forms of long-term experience-dependent -forms of synaptic adaptation, such as spike-timing-dependent plasticity. +where the script will further print out that the post-synaptic neuron spiked only a single time. + +You have now successfully simulated a dynamic synapse in ngc-learn across several different Poisson input train frequencies under both STF- and STD-dominated regimes. In more complex biophysical models, it could prove useful to consider combining STP with other forms of long-term experience-dependent forms of synaptic adaptation, such as [spike-timing-dependent plasticity](stdp.md). ## References diff --git a/ngclearn/components/synapses/STPDenseSynapse.py b/ngclearn/components/synapses/STPDenseSynapse.py index c523ec74..ff3aed4a 100755 --- a/ngclearn/components/synapses/STPDenseSynapse.py +++ b/ngclearn/components/synapses/STPDenseSynapse.py @@ -1,7 +1,7 @@ from jax import random, numpy as jnp, jit -from ngclearn.utils.weight_distribution import initialize_params from ngcsimlib.logger import info +from ngclearn.utils.distribution_generator import DistributionGenerator from ngclearn.components.synapses import DenseSynapse from ngcsimlib.compartment import Compartment from ngcsimlib.parser import compilable @@ -72,9 +72,10 @@ def __init__( self.Wdyn = Compartment(self.weights.get() * 0) ## dynamic synapse values if self.resources_init is None: info(self.name, "is using default resources value initializer!") - self.resources_init = {"dist": "uniform", "amin": 0.125, "amax": 0.175} # 0.15 + #self.resources_init = {"dist": "uniform", "amin": 0.125, "amax": 0.175} # 0.15 + self.resources_init = DistributionGenerator.uniform(low=0.125, high=0.175) self.resources = Compartment( - initialize_params(subkeys[2], self.resources_init, shape) + self.resources_init(shape, subkeys[2]) #initialize_params(subkeys[2], self.resources_init, shape) ) ## matrix U - synaptic resources matrix @compilable From f332546e40f9991acab7ec3911e57b1ec46d062c Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 19 Nov 2025 17:57:30 -0500 Subject: [PATCH 070/121] revised elements of utils to comply with docs --- .../ngclearn.utils.feature_dictionaries.rst | 21 +++++++++++++++++++ docs/source/ngclearn.utils.masks.rst | 21 +++++++++++++++++++ docs/source/ngclearn.utils.rst | 2 ++ ngclearn/utils/diffeq/ode_utils.py | 13 ++++++------ ngclearn/utils/diffeq/odes.py | 13 ++++++++++++ .../utils/feature_dictionaries/__init__.py | 0 .../feature_dictionaries/polynomialLibrary.py | 13 +++++------- ngclearn/utils/masks/__init__.py | 0 ngclearn/utils/patch.py | 12 +++++------ ngclearn/utils/patch_utils.py | 16 ++++++++------ 10 files changed, 85 insertions(+), 26 deletions(-) create mode 100644 docs/source/ngclearn.utils.feature_dictionaries.rst create mode 100644 docs/source/ngclearn.utils.masks.rst create mode 100644 ngclearn/utils/feature_dictionaries/__init__.py create mode 100644 ngclearn/utils/masks/__init__.py diff --git a/docs/source/ngclearn.utils.feature_dictionaries.rst b/docs/source/ngclearn.utils.feature_dictionaries.rst new file mode 100644 index 00000000..bc9daa64 --- /dev/null +++ b/docs/source/ngclearn.utils.feature_dictionaries.rst @@ -0,0 +1,21 @@ +ngclearn.utils.feature\_dictionaries package +============================================ + +Submodules +---------- + +ngclearn.utils.feature\_dictionaries.polynomialLibrary module +------------------------------------------------------------- + +.. automodule:: ngclearn.utils.feature_dictionaries.polynomialLibrary + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: ngclearn.utils.feature_dictionaries + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/ngclearn.utils.masks.rst b/docs/source/ngclearn.utils.masks.rst new file mode 100644 index 00000000..17721150 --- /dev/null +++ b/docs/source/ngclearn.utils.masks.rst @@ -0,0 +1,21 @@ +ngclearn.utils.masks package +============================ + +Submodules +---------- + +ngclearn.utils.masks.multiblock2d module +---------------------------------------- + +.. automodule:: ngclearn.utils.masks.multiblock2d + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: ngclearn.utils.masks + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/ngclearn.utils.rst b/docs/source/ngclearn.utils.rst index 82816960..ac63b363 100644 --- a/docs/source/ngclearn.utils.rst +++ b/docs/source/ngclearn.utils.rst @@ -10,6 +10,8 @@ Subpackages ngclearn.utils.analysis ngclearn.utils.density ngclearn.utils.diffeq + ngclearn.utils.feature_dictionaries + ngclearn.utils.masks ngclearn.utils.optim ngclearn.utils.viz diff --git a/ngclearn/utils/diffeq/ode_utils.py b/ngclearn/utils/diffeq/ode_utils.py index 30ddb2d4..55a70ace 100755 --- a/ngclearn/utils/diffeq/ode_utils.py +++ b/ngclearn/utils/diffeq/ode_utils.py @@ -1,12 +1,13 @@ """ Routines and co-routines for ngc-learn's differential equation integration backend. -Currently supported back-end forms of integration in ngc-learn include: -0) Euler integration (RK-1); -1) Midpoint method (RK-2); -2) Heun's method (error-corrector RK-2); -3) Ralston's method (error-corrector RK-2); -4) 4th-order Runge-Kutta method (RK-4); +| Currently supported back-end forms of integration in ngc-learn include: +| 0) Euler integration (RK-1); +| 1) Midpoint method (RK-2); +| 2) Heun's method (error-corrector RK-2); +| 3) Ralston's method (error-corrector RK-2); +| 4) 4th-order Runge-Kutta method (RK-4); + """ from jax import numpy as jnp, random, jit #, nn diff --git a/ngclearn/utils/diffeq/odes.py b/ngclearn/utils/diffeq/odes.py index b37e2408..733f1082 100644 --- a/ngclearn/utils/diffeq/odes.py +++ b/ngclearn/utils/diffeq/odes.py @@ -1,3 +1,16 @@ +""" +In-built dynamical systems built on differential equations. Note that these systems are designed such that they +directly operzte with ngc-learn's ODE integration backend. + +| Currently in-built dynamical systems include: +| 0) A continuous linear 2D system; +| 1) A continuous cubic 2D system; +| 2) A Lorenz attractor system; +| 3) A continuous linear 3D system; +| 4) A continuous oscillator system. + +""" + import jax.numpy as jnp def linear_2D(t, x, params): diff --git a/ngclearn/utils/feature_dictionaries/__init__.py b/ngclearn/utils/feature_dictionaries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ngclearn/utils/feature_dictionaries/polynomialLibrary.py b/ngclearn/utils/feature_dictionaries/polynomialLibrary.py index f50686ef..8dea87eb 100644 --- a/ngclearn/utils/feature_dictionaries/polynomialLibrary.py +++ b/ngclearn/utils/feature_dictionaries/polynomialLibrary.py @@ -1,19 +1,16 @@ -import jax.numpy as jnp -from jax import jit, random -import jax.numpy as jnp +from jax import jit, random, numpy as jnp from typing import List, Tuple, Union from dataclasses import dataclass - - @dataclass class PolynomialLibrary: """ A class for creating polynomial feature libraries in 1D, 2D, or 3D. - Attributes: - poly_order (int): Maximum order of polynomial terms - include_bias (bool): Whether to include the bias term in the output + Args: + poly_order (int): Maximum order of polynomial terms (Attribute) + + include_bias (bool): Whether to include the bias term in the output (Attribute) """ poly_order: int = None diff --git a/ngclearn/utils/masks/__init__.py b/ngclearn/utils/masks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ngclearn/utils/patch.py b/ngclearn/utils/patch.py index c5dd14f2..94ad9af6 100644 --- a/ngclearn/utils/patch.py +++ b/ngclearn/utils/patch.py @@ -93,9 +93,9 @@ def target(self, img: jnp.ndarray): plt.show() - -gen = PatchGenerator(patch_width=5, patch_height=5, horizontal_alignment='center', horizontal_stride=1) - -test_img = jnp.zeros((32, 32)) - -gen.target(test_img) +## testing code +# gen = PatchGenerator(patch_width=5, patch_height=5, horizontal_alignment='center', horizontal_stride=1) +# +# test_img = jnp.zeros((32, 32)) +# +# gen.target(test_img) diff --git a/ngclearn/utils/patch_utils.py b/ngclearn/utils/patch_utils.py index f3116e84..a558e82e 100755 --- a/ngclearn/utils/patch_utils.py +++ b/ngclearn/utils/patch_utils.py @@ -39,17 +39,19 @@ class Create_Patches: Args: img: jax array of size (H, W) - patched: (height_patch, width_patch) - overlap: (height_overlap, width_overlap) - add_frame: increases the img size by (height_patch - height_overlap, width_patch - width_overlap) - create_patches: creates small patches out of the image based on the provided attributes. + patch_shape: (height_patch, width_patch) + + overlap_shape: (height_overlap, width_overlap) Returns: - jnp.array: Array containing the patches - shape: (num_patches, patch_height, patch_width) + jnp.array: Array containing the patches, shape: (num_patches, patch_height, patch_width) """ + #patched: (height_patch, width_patch) + #overlap: (height_overlap, width_overlap) + #add_frame: increases the img size by (height_patch - height_overlap, width_patch - width_overlap) + #create_patches: creates small patches out of the image based on the provided attributes. def __init__(self, img, patch_shape, overlap_shape): self.img = img @@ -90,6 +92,8 @@ def create_patches(self, add_frame=False, center=True): Keyword Args: add_frame: If true the function will add zero frames (increase the dimension) to the image + center: + Returns: jnp.array: Array containing the patches shape: (num_patches, patch_height, patch_width) From 80685ef9a36159e4bf744ef305b28fde84ad9608 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 19 Nov 2025 19:02:40 -0500 Subject: [PATCH 071/121] revised stdp neurocog doc to v3 --- .../neurocog/simple_leaky_integrator.md | 205 +++++------------- docs/tutorials/neurocog/stdp.md | 152 ++++--------- ngclearn/utils/model_utils.py | 16 +- 3 files changed, 109 insertions(+), 264 deletions(-) diff --git a/docs/tutorials/neurocog/simple_leaky_integrator.md b/docs/tutorials/neurocog/simple_leaky_integrator.md index 73068a73..1aa6643e 100644 --- a/docs/tutorials/neurocog/simple_leaky_integrator.md +++ b/docs/tutorials/neurocog/simple_leaky_integrator.md @@ -1,24 +1,18 @@ # Lecture 2A: The Simplified Leaky Integrator Cell -In this tutorial, we will study one of ngc-learn's (simplest) in-built leaky -integrator components, the simplified leaky integrate-and-fire (SLIF). +In this tutorial, we will study one of ngc-learn's (simplest) in-built neuronal cell components, the simplified leaky integrate-and-fire (SLIF). ## Creating and Using a Leaky Integrator ### Instantiating the Leaky Integrate-and-Fire Cell -With our JSON configuration in place, go ahead and create a Python script, -i.e., `run_slif.py`, to write your code for this part of the tutorial. - -Now let's go ahead and set up the controller for this lesson's simulation, -where we will a dynamical system with only a single component, -specifically the simplified LIF (sLIF), like so: +Start by creating a Python script, i.e., `run_slif.py`, to write your code for this part of the tutorial. +Now let's go ahead and set up the controller/model-context for this lesson's simulation, where we will a dynamical system with only a single component, specifically the simplified LIF (sLIF). Write code to do this like so: ```python from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context -from ngclearn.utils import JaxProcess +from ngclearn import Context, MethodProcess ## import model-specific mechanisms from ngclearn.components.neurons.spiking.sLIFCell import SLIFCell from ngclearn.utils.viz.spike_plot import plot_spiking_neuron @@ -36,57 +30,28 @@ tau_m = R_m * C ## membrane time constant ## create simple system with only one sLIF with Context("Model") as model: - cell = SLIFCell("z0", n_units=1, tau_m=tau_m, resist_m=R_m, thr=V_thr, - refract_time=ref_T, key=subkeys[0]) + cell = SLIFCell("z0", n_units=1, tau_m=tau_m, resist_m=R_m, thr=V_thr, refract_time=ref_T, key=subkeys[0]) ## set up core commands that drive the simulation - advance_process = (JaxProcess() + advance_process = (MethodProcess("advance") >> cell.advance_state) - model.wrap_and_add_command(jit(advance_process.pure), name="advance") - reset_process = (JaxProcess() + reset_process = (MethodProcess("reset") >> cell.reset) - model.wrap_and_add_command(jit(reset_process.pure), name="reset") - ## set up non-compiled utility commands - @Context.dynamicCommand - def clamp(x): - cell.j.set(x) +## set up non-compiled utility commands +def clamp(x): + cell.j.set(x) ``` -This node has quite a few compartments and constants but only a handful are important -for understanding how this model governs spiking/firing rates during -a controller's simulation window. Specifically, in this lesson, we will focus on -its electrical current `j` (formally labeled here as $\mathbf{j}_t$), -its voltage `v` (formally labeled: $\mathbf{v}_t$), its spike emission -(or action potential) `s` (formally $\mathbf{s}_t$), and its refractory -variable/marker (formally $\mathbf{r}_t$). The subscript $t$ indicates -that this compartment variable takes on a certain value at a certain time step -$t$ and we will refer to the ngc-learn controller's integration time constant, -the amount of time we move forward by, as $\Delta t$. The constants or -hyper-parameters we will be most interested in are the cell's membrane resistance -`R_m` (formally $R$ with its capacitance $C$ implied), its membrane time -constant `tau_m` (formally $\tau_m$), its refractory period time -`refract_T` (formally $T_{ref}$), and its voltage threshold `v_thr` -(formally $V_thr$). (There are other constants inherent to the -sLIF, but these are sufficient for this exercise.) - -Later on, towards the end of this tutorial, we provide some theoretical -exposition and explanation of the above constants/compartments -(see `On the Dynamics of Leaky Integrators`); for practical -purposes we will now move on to using your `sLIF` node in a simple simulation -to illustrate some of its dynamics. +This node has quite a few compartments and constants but only a handful are important for understanding how this model governs spiking/firing rates during a controller's simulation window. Specifically, in this lesson, we will focus on its electrical current `j` (formally labeled here as $\mathbf{j}_t$), its voltage `v` (formally labeled: $\mathbf{v}_t$), its spike emission (or action potential) `s` (formally $\mathbf{s}_t$), and its refractory variable/marker (formally $\mathbf{r}_t$). The subscript $t$ indicates that this compartment variable takes on a certain value at a certain time step $t$ and we will refer to the ngc-learn controller's integration time constant, the amount of time we move forward by, as $\Delta t$. The constants or hyper-parameters we will be most interested in are the cell's membrane resistance `R_m` (formally $R$ with its capacitance $C$ implied), its membrane time constant `tau_m` (formally $\tau_m$), its refractory period time `refract_T` (formally $T_{ref}$), and its voltage threshold `v_thr` (formally $V_thr$). (There are other constants inherent to the sLIF, but these are sufficient for this exercise.) + +Later on, towards the end of this tutorial, we provide some theoretical exposition and explanation of the above constants/compartments +(see `On the Dynamics of Leaky Integrators`); for practical purposes, we will now move on to using your `sLIF` node in a simple simulation to illustrate some of its dynamics. ### Simulating a Leaky Integrator - -Given our single-cell dynamical system above, let us write some code to use -our `sLIF` node and visualize its spiking pattern by feeding -into it a step current, where the electrical current `j` starts at $0$ then -switches to $0.3$ at $t = 10$ (ms). Specifically, we can plot the input current, -the neuron's voltage `v`, and its output spikes as follows: +Given our single-cell dynamical system above, let us write some code to use our `sLIF` node and visualize its spiking pattern by feeding into it a step current, where the electrical current `j` starts at $0$ then switches to $0.3$ at $t = 10$ (ms). Specifically, we can plot the input current, the neuron's voltage `v`, and its output spikes as follows: ```python # create a synthetic electrical step current @@ -96,81 +61,65 @@ curr_in = [] mem_rec = [] spk_rec = [] -model.reset() +reset_process.run() for ts in range(current.shape[1]): j_t = jnp.expand_dims(current[0,ts], axis=0) ## get data at time ts - model.clamp(j_t) - model.advance(t=ts*1., dt=dt) + clamp(j_t) + advance_process.run(t=ts*1., dt=dt) ## naively extract simple statistics at time ts and print them to I/O - v = cell.v.value - s = cell.s.value + v = cell.v.get() + s = cell.s.get() curr_in.append(j_t) mem_rec.append(v) spk_rec.append(s) - print(" {}: s {} ; v {}".format(ts, s, v), end="") + print(f"\r{ts}: s {s} ; v {v}", end="") print() import numpy as np curr_in = np.squeeze(np.asarray(curr_in)) mem_rec = np.squeeze(np.asarray(mem_rec)) spk_rec = np.squeeze(np.asarray(spk_rec)) -plot_spiking_neuron(curr_in, mem_rec, spk_rec, None, dt, thr_line=V_thr, min_mem_val=0., - max_mem_val=1.3, title="SLIF-Node: Constant Electrical Input", - fname="lif_plot.jpg") +plot_spiking_neuron( + curr_in, mem_rec, spk_rec, None, dt, thr_line=V_thr, min_mem_val=0., max_mem_val=1.3, + title="SLIF-Node: Constant Electrical Input", fname="lif_plot.jpg" +) ``` -which produces the following plot (saved as `lif_plot.jpg` locally to disk): +which produces the following plot (saved as `lif_plot.jpg` locally to disk): -where we see that, given a build-up over time in the neuron's membrane potential -(since the current is constant and non-zero after $10$ ms), a spike is emitted -once the value of the membrane potential exceeds the threshold (indicated by -the dashed horizontal line in the middle plot) $V_{thr} = 1$. -Notice that if we play with the value of `ref_T` (the refactory period $T_{ref}$) -and change it to something like `ref_T = 10 * dt` (ten times the integration time -constant), we get the following neuronal dynamics plot: +where we see that, given a build-up over time in the neuron's membrane potential (since the current is constant and non-zero after $10$ ms), a spike is emitted once the value of the membrane potential exceeds the threshold (indicated by the dashed horizontal line in the middle plot) $V_{thr} = 1$. Notice that if we play with the value of `ref_T` (the refactory period $T_{ref}$) and change it to something like `ref_T = 10 * dt` (ten times the integration time constant), we get the following neuronal dynamics plot: -where we see that after the LIF neuron fires, it remains stuck at its resting -potential for a period of $0.01$ ms (the short flat periods in the red curve -starting after the first spike). +where we see that after the LIF neuron fires, it remains stuck at its resting potential for a period of $0.01$ ms (the short flat periods in the red curve starting after the first spike). ## On the Dynamics of Leaky Integrators -Now let us unpack this component by first defining the relevant compartments: +Now let us unpack this component by first defining the relevant compartments: -+ $\mathbf{j}_t$: the current electrical current of the neurons within this node - (note that this current could be the summation of multiple step/pointwise - current sources or be the current sample of an electrical current, itself - modeled by a differential equation); ++ $\mathbf{j}_t$: the current electrical current of the neurons within this node (note that this current could be the summation of multiple step/pointwise current sources or be the current sample of an electrical current, itself modeled by a differential equation); + $\mathbf{v}_t$: the current membrane potential of the neurons within this node; -+ $\mathbf{s}_t$: the current recording/reading of any spikes produced by this - node's neurons; -+ $\mathbf{r}_t$: the current value of the absolute refractory variables - this - accumulates with time (and forces neurons to rest) ++ $\mathbf{s}_t$: the current recording/reading of any spikes produced by this node's neurons; ++ $\mathbf{r}_t$: the current value of the absolute refractory variables - this accumulates with time (and forces neurons to rest) and finally the constants: -+ $V_{thr}$: threshold that a neuron's membrane potential must overcome before - a spike is transmitted; ++ $V_{thr}$: threshold that a neuron's membrane potential must overcome before a spike is transmitted; + $\Delta t$: the integration time constant, on the order of milliseconds (ms); + $R$: the neural (cell) membrane resistance, on the order of mega Ohms ($M \Omega$); + $C$: the neural (cell) membrane capacitance, on the order of picofarads ($pF$); -+ $\tau_{m}$: membrane potential time constant (also $\tau_{m} = R * C$ - - resistance times capacitance); ++ $\tau_{m}$: membrane potential time constant (also $\tau_{m} = R * C$, or resistance times capacitance); + $T_{ref}$: the length of a neuron's absolute refractory period. -With above defined, we can now explicitly lay out the underlying (linear) ordinary -differential equation that the `sLIF` evolves according to: +With above defined, we can now explicitly lay out the underlying (linear) ordinary differential equation that the `sLIF` evolves according to: $$ \tau_m \frac{\partial \mathbf{v}_t}{\partial t} = (-\mathbf{v}_t + R \mathbf{j}_t), \; \mbox{where, } \tau_m = R C $$ -and with some simple mathematical manipulations (leveraging the method of finite differences), -we can derive the Euler integrator employed by the `sLIF` as follows: +and with some simple mathematical manipulations (leveraging the method of finite differences), we can derive the Euler integrator employed by the `sLIF` as follows: $$ \tau_m \frac{\partial \mathbf{v}_t}{\partial t} &= (-\mathbf{v}_t + R \mathbf{j}_t) \\ @@ -178,32 +127,24 @@ $$ \mathbf{v}_{t + \Delta t} &= \mathbf{v}_t + (-\mathbf{v}_t + R \mathbf{j}_t) \frac{\Delta t}{\tau_m } $$ -where we see that above integration tells us that the membrane potential of this node varies -over time as a function of the sum of its input electrical current $\mathbf{j}_t$ -(multiplied by the cell membrane resistance) and a leak (or decay) $-\mathbf{v}_t$ -modulated by the integration time constant divided by the membrane time constant. -The `sLIF` allows you to control the value of $\tau_m$ directly (hence why we -calculated $\tau_m$ externally via our chosen $R$ and $C$; other neuronal cells -allow you to change $\tau_m$ via $R$ and $C$). - - - +in this walkthrough.) +--> -In effect, given the above, every time the `sLIF`'s `.advanceState()` function is -called within a simulation controller (`Controller()`), the above Euler integration of -the membrane potential differential equation is happening each time step. Knowing this, -the last item required to understand ngc-learn's `sLIF` node's computation is -related to its spike $\mathbf{s}_t$. The spike reading is computed simply by -comparing the current membrane potential $\mathbf{v}_t$ to the constant threshold -defined by $V_{thr}$ according to the following piecewise function: +In effect, given the above, every time the `sLIF`'s `.advanceState()` function is called within a simulation controller (`Controller()`), the above Euler integration of the membrane potential differential equation is happening each time step. Knowing this, the last item required to understand ngc-learn's `sLIF` node's computation is related to its spike $\mathbf{s}_t$. The spike reading is computed simply by comparing the current membrane potential $\mathbf{v}_t$ to the constant threshold defined by $V_{thr}$ according to the following piecewise function: $$ \mathbf{s}_{t, i} = \begin{cases} @@ -212,52 +153,8 @@ $$ \end{cases} $$ -where we see that if the $i$th neuron's membrane potential exceeds the threshold -$V_{thr}$, then a voltage spike is emitted. After a spike is emitted, the $i$th -neuron within the node needs to be reset to its resting potential and this is done -with the final compartment that we mentioned, i.e., the refractory -variable $\mathbf{r}_t$. -The refractory variable $\mathbf{r}_t$ is important for hyperpolarizing the -$i$th neuron back to its resting potential (establishing a critical reset mechanism --- otherwise, the neuron would fire out of control after overcoming its -threshold) and reducing the amount of spikes generated over time. This reduction -is one of the key factors behind the power efficiency of biological neuronal systems. -Another aspect of ngc-learn's refractory variable is the temporal length of the reset itself, -which is controlled by the $T_{ref}$ (`T_ref`) constant -- this yields what is known as the -absolute refractory period, or the interval of time at which a second action potential -absolutely cannot be initiated. If $T_{ref}$ is set to be greater than -zero, then the $i$th neuron that fires will be forced to remain at its resting -potential of zero for the duration of this refractory period. - -Note that the reason the `sLIF` contains simplified in its name is that its -internal dynamics and parameterization have been drastically simplified in -comparison to ngc-learn's more standard `LIF` component. Furthermore, the -`sLIF` operates assuming a resting membrane potential of `0` (milliVolts) whereas, -for more intricate leaky integrator models, the resting potential is often -negative, requiring a different and more careful setting of hyper-parameters -(such as the voltage threshold). Nevertheless, although `sLIF` is a simpler -model, it can be used as a rational first step for crafting very useful spiking -neural networks and offers other aspects of functionality not used in this tutorial, -such as adaptive threshold functionality and fast approximate lateral inhibition/recurrence. - -## Optional: Setting Up The Components with a JSON Configuration - -While you are not required to create a JSON configuration file for ngc-learn, -to get rid of the warning that ngc-learn will throw at the start of your -program's execution (indicating that you do not have a configuration set up yet), -all you need to do is create a sub-directory for your JSON configuration -inside of your project code's directory, i.e., `json_files/modules.json`. -Inside the JSON file, you would write the following: - -```json -[ - {"absolute_path": "ngclearn.components", - "attributes": [ - {"name": "SLIFCell"}] - }, - {"absolute_path": "ngcsimlib.operations", - "attributes": [ - {"name": "overwrite"}] - } -] -``` +where we see that if the $i$th neuron's membrane potential exceeds the threshold $V_{thr}$, then a voltage spike is emitted. After a spike is emitted, the $i$th neuron within the node needs to be reset to its resting potential and this is done with the final compartment that we mentioned, i.e., the refractory variable $\mathbf{r}_t$. +The refractory variable $\mathbf{r}_t$ is important for hyperpolarizing the $i$th neuron back to its resting potential (establishing a critical reset mechanism -- otherwise, the neuron would fire out of control after overcoming its threshold) and reducing the amount of spikes generated over time. This reduction is one of the key factors behind the power efficiency of biological neuronal systems. Another aspect of ngc-learn's refractory variable is the temporal length of the reset itself, which is controlled by the $T_{ref}$ (`T_ref`) constant -- this yields what is known as the absolute refractory period, or the interval of time at which a second action potential absolutely cannot be initiated. If $T_{ref}$ is set to be greater than zero, then the $i$th neuron that fires will be forced to remain at its resting potential of zero for the duration of this refractory period. + +Note that the reason the `sLIF` contains simplified in its name is that its internal dynamics and parameterization have been drastically simplified in comparison to ngc-learn's more standard `LIF` component. Furthermore, the `sLIF` operates assuming a resting membrane potential of `0` (milliVolts) whereas, for more intricate leaky integrator models, the resting potential is often negative, requiring a different and more careful setting of hyper-parameters (such as the voltage threshold). Nevertheless, although `sLIF` is a simpler model, it can be used as a rational first step for crafting very useful spiking neural networks and offers other aspects of functionality not used in this tutorial, such as adaptive threshold functionality and fast approximate lateral inhibition/recurrence. + diff --git a/docs/tutorials/neurocog/stdp.md b/docs/tutorials/neurocog/stdp.md index b8e889a0..ea65cc41 100755 --- a/docs/tutorials/neurocog/stdp.md +++ b/docs/tutorials/neurocog/stdp.md @@ -1,40 +1,23 @@ # Lecture 4C: Spike-Timing-Dependent Plasticity -In the context of spiking neuronal networks, one of the most important forms -of adaptation that is often simulated is that of spike-timing-dependent -plasticity (STDP). In this lesson, we will setup and use one -of ngc-learn's standard in-built STDP-based components, visualizing the -changes in synaptic efficacy that it produces in the context of -pre-synaptic and post-synaptic variable traces. +In the context of spiking neuronal networks, one of the most important forms of adaptation that is often simulated is that of spike-timing-dependent plasticity (STDP). In this lesson, we will setup and use one of ngc-learn's standard in-built STDP-based components, visualizing the changes in synaptic efficacy that it produces in the context of pre-synaptic and post-synaptic variable traces. ## Probing Spike-Timing-Dependent Plasticity -Go ahead and make a new folder for this study and create a Python script, -i.e., `run_trstdp.py`, to write your code for this part of the tutorial. +Go ahead and make a new folder for this study and create a Python script, i.e., `run_trstdp.py`, to write your code for this part of the tutorial. -Now let's set up the model for this lesson's simulation and construct a -3-component system made up of two variable traces (`VarTrace`) connected by -one single synapse that is capable of producing changes in connection strength -in accordance with STDP, specifically with a form of the update rule known -as [trace-based STDP](ngclearn.components.synapses.hebbian.traceSTDPSynapse). -Note that the trace components do not really do -anything meaningful unless they receive some input and we will provide -carefully controlled input spike values in order to control their behavior -so as to see how STDP responds to the relative temporal ordering of a pre- and -post-synaptic spike, where the time of spikes is approximated by the -corresponding pre- and post-synaptic traces (which decay exponentially with time -in the absence of input). +Now let's set up the model for this lesson's simulation and construct a 3-component system made up of two variable traces (`VarTrace`) connected by one single synapse that is capable of producing changes in connection strength in accordance with STDP, specifically with a form of the update rule known as [trace-based STDP](ngclearn.components.synapses.hebbian.traceSTDPSynapse). Note that the trace components do not really do anything meaningful unless they receive some input. Therefore, we will provide carefully controlled input spike values in order to control their behavior in order to see how STDP responds to the relative temporal ordering of a pre- and post-synaptic spike, where the timing of the spikes is approximated by the corresponding pre- and post-synaptic traces (which decay exponentially with time in the absence of input). -Writing the above 3-component system can be in the following manner: +Writing the above 3-component system can be done in the following manner: ```python from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context -from ngclearn.utils import JaxProcess + +from ngclearn import Context, MethodProcess ## import model-specific mechanisms from ngclearn.components.other.varTrace import VarTrace from ngclearn.components.synapses.hebbian.traceSTDPSynapse import TraceSTDPSynapse -import ngclearn.utils.weight_distribution as dist +from ngclearn.utils.distribution_generator import DistributionGenerator ## create seeding keys (JAX-style) dkey = random.PRNGKey(231) @@ -46,55 +29,43 @@ T_max = 100 ## number time steps to simulate with Context("Model") as model: tr0 = VarTrace("tr0", n_units=1, tau_tr=8., a_delta=1.) tr1 = VarTrace("tr1", n_units=1, tau_tr=8., a_delta=1.) - W = TraceSTDPSynapse("W1", shape=(1, 1), eta=0., A_plus=1., A_minus=0.8, - weight_init=dist.uniform(0.0, 0.3), key=subkeys[0]) + W = TraceSTDPSynapse( + "W1", shape=(1, 1), eta=0., A_plus=1., A_minus=0.8, + weight_init=DistributionGenerator.uniform(low=0.0, high=0.3), key=subkeys[0] + ) # wire only relevant compartments to synaptic cable W for demo purposes - W.preTrace << tr0.trace - # self.W1.preSpike << self.z0.outputs ## we disable this as we will manually - ## insert a binary value (for a spike) - W.postTrace << tr1.trace - # self.W1.postSpike << self.z1e.s ## we disable this as we will manually - ## insert a binary value (for a spike) - - evolve_process = (JaxProcess() + tr0.trace >> W.preTrace + # self.z0.outputs >> self.W1.preSpike ## we disable this as we will manually + ## insert a binary value (for a spike) in this tutorial + tr1.trace >> W.postTrace + # self.z1e.s >> self.W1.postSpike ## we disable this as we will manually + ## insert a binary value (for a spike) in this tutorial + + evolve_synapse = (MethodProcess("evolve") >> W.evolve) - model.wrap_and_add_command(jit(evolve_process.pure), name="evolve") - - advance_process = (JaxProcess() - >> tr0.advance_state - >> tr1.advance_state) - model.wrap_and_add_command(jit(advance_process.pure), name="advance_traces") - reset_process = (JaxProcess() - >> tr0.reset - >> tr1.reset - >> W.reset) - model.wrap_and_add_command(jit(reset_process.pure), name="reset") + advance_traces = (MethodProcess("advance") + >> tr0.advance_state + >> tr1.advance_state + >> W.advance_state) + reset = (MethodProcess("reset") + >> tr0.reset + >> tr1.reset + >> W.reset) - @Context.dynamicCommand - def clamp_synapse(pre_spk, post_spk): - W.preSpike.set(pre_spk) - W.postSpike.set(post_spk) +## set up some utility functions for the model context +def clamp_synapse(pre_spk, post_spk): + W.preSpike.set(pre_spk) + W.postSpike.set(post_spk) - - @Context.dynamicCommand - def clamp_traces(pre_spk, post_spk): - tr0.inputs.set(pre_spk) - tr1.inputs.set(post_spk) +def clamp_traces(pre_spk, post_spk): + tr0.inputs.set(pre_spk) + tr1.inputs.set(post_spk) ``` -With our carefully constructed STDP-adapted model above, we can then simulate -the changes to synaptic efficacy that it would produce as a function of -the distance between and order between a pre- and a post-synaptic binary spike. -Notice that in the above model, we have set the global learning rate `eta` to -zero, which will prevent the `TraceSTDPSynapse` from actually adjusting -its internal matrix of synaptic weight values using the updates produced by -STDP -- this means our synapses are held fixed throughout this particular -demonstration. Our goal is to produce an approximation of the theoretical synaptic -strength adjustment curve dictated by STDP; this can be done using the -code below: +With our carefully constructed STDP-adapted model above, we can then simulate the changes to synaptic efficacy that it would produce as a function of the distance between and order between a pre- and a post-synaptic binary spike. Notice that in the above model, we have set the global learning rate `eta` to zero, which will prevent the `TraceSTDPSynapse` from actually adjusting its internal matrix of synaptic weight values using the updates produced by STDP -- this means our synapses are held fixed throughout this particular demonstration. Our goal is to produce an approximation of the theoretical synaptic strength adjustment curve dictated by STDP; this can be done using the code below: ```python t_values = [] @@ -118,16 +89,13 @@ for i in range(T_max+1): _pre_trig = jnp.ones((1,1)) _post_trig = jnp.zeros((1,1)) ts = 0. - model.clamp_traces(pre_spk, post_spk) - model.advance_traces(t=dt * i, dt=dt) + clamp_traces(pre_spk, post_spk) + advance_traces.run(t=dt * i, dt=dt) ## get STDP update - W.preSpike.set(_pre_trig) - W.postSpike.set(_post_trig) - W.preTrace.set(tr0.trace.value) - W.postTrace.set(tr1.trace.value) - model.evolve(t=dt * i, dt=dt) - dW = W.dWeights.value + clamp_synapse(_pre_trig, _post_trig) + evolve_synapse.run(t=dt * i, dt=dt) + dW = W.dWeights.get() dW_vals.append(dW) if i >= int(T_max/2): t_values.append(ts) @@ -165,40 +133,12 @@ which should produce a plot similar to the one in the left-hand side below: +------------------------------------------------------------+----------------------------------------------------------------+ ``` -where we have provided a marked-up image of the STDP experimental data produced -and visualized in the classical work done by Bi and Poo in 1998 [1]. -We remark that our approximate STDP synaptic change curve does not perfectly -match/fit that of [1] perfectly by any means but does capture the -general trend and form of the long-term potentiation arc (the roughly -negative exponential curve to the right-hand side of zero) and the long-term -depression curve (the flipped exponential-like function to the left-hand -side of zero). Ultimately, a synaptic component like the `TraceSTDPSynapse` -can be quite useful for constructing spiking neural network architectures -that learn in a biologically-plausible fashion since this rule, as seen by the -above simulation usage, solely depends on information that is locally -available at the pre-synaptic neuron (its spike and a single trace that -tracks its temporal spiking history) and the post-synaptic neuron -(its own spike as well as a trace that tracks its spike history). Notably, -traced-based STDP can be an effective way of adapting the synapses of -biophysically more accurate computational models, such as those that balance -excitatory and inhibitory pressures produced by laterally-wired populations of -leaky integrator neurons, e.g., the -[Diehl and Cook spiking architecture](../../museum/snn_dc) we study in the model -museum in more detail. - -### Other Forms of Spike-Timing-Dependent Plasticity -Finally, beyond trace-based STDP, there are other types of STDP in-built to -ngc-learn, such as event-driven post-synaptic STDP -([eventSTDPSynapse](ngclearn.components.synapses.hebbian.eventSTDPSynapse)), that -you can experiment with and use in your model building and simulation projects. -You can learn more about these in the ngc-learn -[modeling API](../../modeling/components.md). -Beyond this, the ngc-learn dev team is always busy behind the scenes -constructing more standard computational neuroscience building blocks and -synaptic plasticity rules; so keep an eye out for future incoming developments! +Notice that, for the above visual, we have also provided a marked-up image of the STDP experimental data produced and visualized in the classical work done by Bi and Poo in 1998 [1]. We remark that our approximate STDP synaptic change curve does not perfectly match/fit that of [1] perfectly by any means; however, it does capture the general trend and form of the long-term potentiation arc (the roughly negative exponential curve to the right-hand side of zero) and the long-term depression curve (the flipped exponential-like function to the left-hand side of zero). Ultimately, a synaptic component like the `TraceSTDPSynapse` can be quite useful for constructing spiking neural network architectures that learn in a biologically-plausible fashion given that this rule, as seen by the above simulation usage, solely depends on information that is locally available at the pre-synaptic neuron (its spike and a single trace that tracks its temporal spiking history) and the post-synaptic neuron (its own spike as well as a trace that tracks its spike history). Notably, traced-based STDP can be an effective way of adapting the synapses of biophysically more accurate computational models, such as those that balance excitatory and inhibitory pressures produced by laterally-wired populations of leaky integrator neurons, e.g., the [Diehl and Cook spiking architecture](../../museum/snn_dc) that we study in more detail within the context of a model museum exhibit. + +### Other Forms of Spike-Timing-Dependent Plasticity +Finally, beyond trace-based STDP, there are other types of STDP in-built to ngc-learn, such as event-driven post-synaptic STDP ([eventSTDPSynapse](ngclearn.components.synapses.hebbian.eventSTDPSynapse)), which you can experiment with and use in your model building and simulation projects. You can learn more about these and other related biologically-plausible learning rules in the ngc-learn [modeling API](../../modeling/components.md) (specifically in the "Synapses" subsection page). +Beyond this, the ngc-learn dev team is always busy behind the scenes constructing more standard computational neuroscience building blocks and synaptic plasticity rules; so keep an eye out for future incoming developments! ## References -[1] Bi, Guo-qiang, and Mu-ming Poo. "Synaptic modifications in cultured -hippocampal neurons: dependence on spike timing, synaptic strength, and -postsynaptic cell type." Journal of neuroscience 18.24 (1998). +[1] Bi, Guo-qiang, and Mu-ming Poo. "Synaptic modifications in cultured hippocampal neurons: dependence on spike timing, synaptic strength, and postsynaptic cell type." Journal of neuroscience 18.24 (1998). diff --git a/ngclearn/utils/model_utils.py b/ngclearn/utils/model_utils.py index 64e28f21..2d0c0d03 100755 --- a/ngclearn/utils/model_utils.py +++ b/ngclearn/utils/model_utils.py @@ -298,8 +298,14 @@ def d_relu(x): @jit def telu(x): """ - Proposed by Fernandez and Mali 24, https://arxiv.org/abs/2412.20269 and https://arxiv.org/abs/2402.02790 - TeLU activation: f(x) = x * tanh(e^x) + The hyperbolic tangent exponential linear (TeLU) function: + + | f(x) = x * tanh(e^x) + + This was proposed by Fernandez and Mali 24 in: + + | https://arxiv.org/abs/2412.20269 and in, + | https://arxiv.org/abs/2402.02790 Args: x: input (tensor) value @@ -312,8 +318,10 @@ def telu(x): @jit def d_telu(x): """ - - Derivative of TeLU: f'(x) = tanh(e^x) + x * e^x * (1 - tanh^2(e^x)) + Derivative of the hyperbolic tangent exponential linear (TeLU) function. + Effectively, this is formally: + + | f'(x) = tanh(e^x) + x * e^x * (1 - tanh^2(e^x)) Args: x: input (tensor) value From 95176894d846bf7a639acdb579864ca49a3735f5 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 19 Nov 2025 19:09:38 -0500 Subject: [PATCH 072/121] revised traces neurocog tutorial to v3 --- docs/tutorials/neurocog/traces.md | 79 +++++++++---------------------- 1 file changed, 23 insertions(+), 56 deletions(-) diff --git a/docs/tutorials/neurocog/traces.md b/docs/tutorials/neurocog/traces.md index 4d338f65..da038db7 100755 --- a/docs/tutorials/neurocog/traces.md +++ b/docs/tutorials/neurocog/traces.md @@ -1,29 +1,17 @@ # Lecture 1B: Trace Variables and Filtering -Traces represent one very important component tool in ngc-learn as these are -often, in biophysical model simulations, used to produce real-valued -representations of often discrete-valued patterns, e.g., spike vectors within -a spike train, that can facilitate mechanisms such as online biological credit -assignment. In this lesson, we will observe how one of ngc-learn's core -trace components -- the `VarTrace` -- operates. +Traces represent one very important component tool in ngc-learn as these are often, in biophysical model simulations, used to produce real-valued representations of often discrete-valued patterns, e.g., spike vectors within a spike train, that can facilitate mechanisms such as online biological credit assignment. In this lesson, we will observe how one of ngc-learn's core trace components -- the `VarTrace` -- operates. ## Setting Up a Variable Trace for a Poisson Spike Train -To observe the value of a variable trace, we will pair it to another in-built -ngc-component; the `PoissonCell`, which will be configured to emit spikes -approximately at `63.75` Hertz (yielding a fairly sparse spike train). This means -we will construct a two-component dynamical system, where the input -compartment `outputs` of the `PoissonCell` will be wired directly into the -`inputs` compartment of the `VarTrace`. Note that a `VarTrace` has an `inputs` -compartment -- which is where raw signals typically go into -- and a `trace` -output compartment -- which is where filtered signal values/by-products are emitted from. +To observe the value of a variable trace, we will pair it to another in-built ngc-component; the `PoissonCell`, which will be configured to emit spikes approximately at `63.75` Hertz (yielding a fairly sparse spike train). This means we will construct a two-component dynamical system, where the input compartment `outputs` of the `PoissonCell` will be wired directly into the `inputs` compartment of the `VarTrace`. Note that a `VarTrace` has an `inputs` compartment -- which is where raw signals typically go into -- and a `trace` output compartment -- which is where filtered signal values/by-products are emitted from. The code below will instantiate the paired Poisson cell and corresponding variable trace: ```python from jax import numpy as jnp, random, jit -from ngclearn.utils import JaxProcess -from ngcsimlib.context import Context + +from ngclearn import Context, MethodProcess ## import model-specific mechanisms from ngclearn.components.input_encoders.poissonCell import PoissonCell from ngclearn.components.other.varTrace import VarTrace @@ -37,30 +25,24 @@ with Context("Model") as model: trace = VarTrace("tr0", n_units=1, tau_tr=30., a_delta=0.5) ## wire up cell z0 to trace tr0 - trace.inputs << cell.outputs + cell.outputs >> trace.inputs - advance_process = (JaxProcess() + advance_process = (MethodProcess("advance") >> cell.advance_state >> trace.advance_state) - model.wrap_and_add_command(jit(advance_process.pure), name="advance") - reset_process = (JaxProcess() + reset_process = (MethodProcess("reset") >> cell.reset >> trace.reset) - model.wrap_and_add_command(jit(reset_process.pure), name="reset") - - @Context.dynamicCommand - def clamp(x): - cell.inputs.set(x) +## set up some utility functions for the model context +def clamp(x): + cell.inputs.set(x) ``` ## Running the Paired Cell-Trace System -We can then run the above two-component dynamical system by injecting a fixed -(valid) probability value into the Poisson input encoder and then record the -resulting spikes and trace values. We will do this for `T = 200` milliseconds (ms) -with the code below: +We can then run the above two-component dynamical system by injecting a fixed (valid) probability value into the Poisson input encoder and then record the resulting spikes and trace values. We will do this for `T = 200` milliseconds (ms) with the code below: ```python dt = 1. # ms # integration time constant @@ -70,22 +52,21 @@ probs = jnp.asarray([[0.35]],dtype=jnp.float32) time_span = [] spikes = [] traceVals = [] -model.reset() +reset_process.run() for ts in range(T): - model.clamp(probs) - model.advance(t=ts*1., dt=dt) + clamp(probs) + advance_process.run(t=ts*1., dt=dt) - print("{} {}".format(cell.outputs.value, trace.trace.value), end="") - spikes.append( cell.outputs.value ) - traceVals.append( trace.trace.value ) + print(f"\r{cell.outputs.get()} {trace.trace.get()}", end="") + spikes.append( cell.outputs.get() ) + traceVals.append( trace.trace.get() ) time_span.append(ts * dt) print() spikes = jnp.concatenate(spikes,axis=0) traceVals = jnp.concatenate(traceVals,axis=0) ``` -We can plot the above simulation's trace outputs with the discrete spikes -super-imposed at their times of occurrence with the code below: +We can plot the above simulation's trace outputs with the discrete spikes super-imposed at their times of occurrence with the code below: ```python import matplotlib #.pyplot as plt @@ -100,8 +81,7 @@ stat = jnp.where(spikes > 0.) indx = (stat[0] * 1. - 1.).tolist() spk = ax.vlines(x=indx, ymin=0.985, ymax=1.05, colors='black', ls='-', lw=5) -ax.set(xlabel='Time (ms)', ylabel='Trace Output', - title='Variable Trace of Poisson Spikes') +ax.set(xlabel='Time (ms)', ylabel='Trace Output', title='Variable Trace of Poisson Spikes') #ax.legend([zTr[0],spk[0]],['z','phi(z)']) ax.grid() fig.savefig("poisson_trace.jpg") @@ -111,29 +91,16 @@ to get the following output saved to disk: -Notice that every time a spike is produced by the Poisson encoding cell, the trace -increments by `0.5` -- the result of the `a_delta` hyper-parameter we set when -crafting the model and simulation object -- and then exponentially decays in -the absence of a spike (with the time constant of `tau_tr = 30` milliseconds). +Notice that every time a spike is produced by the Poisson encoding cell, the trace increments by `0.5` -- the result of the `a_delta` hyper-parameter we set when crafting the model and simulation object -- and then exponentially decays in the absence of a spike (with the time constant of `tau_tr = 30` milliseconds). -The variable trace can be further configured to filter signals in different ways -if desired; specifically by manipulating its `decay_type` and `a_delta` arguments. -Notably, if a piecewise-gated variable trace is desired (a very common choice -in some neuronal circuit models), then all one would have to do is set `a_delta = 0`, -yielding the following line in the model creation code earlier in this tutorial: +The variable trace can be further configured to filter signals in different ways if desired; specifically by manipulating its `decay_type` and `a_delta` arguments. Notably, if a piecewise-gated variable trace is desired (a very common choice in some neuronal circuit models), then all one would have to do is set `a_delta = 0`, yielding the following line in the model creation code earlier in this tutorial: ```python trace = VarTrace("tr0", n_units=1, tau_tr=30., a_delta=0., decay_type="exp") ``` -Running the same code from before but with the above alteration would yield the -plot below: +Running the same code from before but with the above alteration would yield the plot below: -Notice that, this time, when a spike is emitted from the Poisson cell, the trace -is "clamped" to the value of one and then exponentially decays. Such a trace -configuration is useful if one requires the maintained trace to never increase -beyond a value of one, preventing divergence or run-away values if a spike train -is particularly dense and yielding friendlier values for biological learning -rules. +Notice that, this time, when a spike is emitted from the Poisson cell, the trace is "clamped" to the value of one and then exponentially decays. Such a trace configuration is useful if one requires the maintained trace to never increase beyond a value of one, preventing divergence or run-away values if a spike train is particularly dense and yielding friendlier values for biological learning rules. From 51c2650569c4581e0bdf583327e5441dfeeb7b03 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 19 Nov 2025 19:53:46 -0500 Subject: [PATCH 073/121] cleaned up utils.optim and wrote compliant NAG optim --- docs/source/ngclearn.utils.optim.rst | 8 +++ ngclearn/utils/optim/adam.py | 19 ++----- ngclearn/utils/optim/nag.py | 84 ++++++++++++++++++++++++++++ ngclearn/utils/optim/optim_utils.py | 5 +- ngclearn/utils/optim/sgd.py | 4 +- 5 files changed, 104 insertions(+), 16 deletions(-) create mode 100644 ngclearn/utils/optim/nag.py diff --git a/docs/source/ngclearn.utils.optim.rst b/docs/source/ngclearn.utils.optim.rst index 20145f98..547b6209 100644 --- a/docs/source/ngclearn.utils.optim.rst +++ b/docs/source/ngclearn.utils.optim.rst @@ -12,6 +12,14 @@ ngclearn.utils.optim.adam module :undoc-members: :show-inheritance: +ngclearn.utils.optim.nag module +------------------------------- + +.. automodule:: ngclearn.utils.optim.nag + :members: + :undoc-members: + :show-inheritance: + ngclearn.utils.optim.optim\_utils module ---------------------------------------- diff --git a/ngclearn/utils/optim/adam.py b/ngclearn/utils/optim/adam.py index 8fa9cba6..5186f035 100644 --- a/ngclearn/utils/optim/adam.py +++ b/ngclearn/utils/optim/adam.py @@ -1,16 +1,11 @@ # %% -# from ngcsimlib.component import Component -# from ngcsimlib.compartment import Compartment -# from ngcsimlib.resolver import resolver - import numpy as np from jax import jit, numpy as jnp, random, nn, lax from functools import partial -import time -def step_update(param, update, g1, g2, lr, beta1, beta2, time, eps): +def step_update(param, update, g1, g2, lr, beta1, beta2, time_step, eps): """ Runs one step of Adam over a set of parameters given updates. The dynamics for any set of parameters is as follows: @@ -39,17 +34,17 @@ def step_update(param, update, g1, g2, lr, beta1, beta2, time, eps): beta2: 2nd moment control factor - time: current time t or iteration step/call to this Adam update + time_step: current time t or iteration step/call to this Adam update eps: numberical stability coefficient (for calculating final update) Returns: - adjusted parameter tensor (same shape as "param") + adjusted parameter tensor (same shape as "param"), adjusted g1, adjusted g2 """ _g1 = beta1 * g1 + (1. - beta1) * update _g2 = beta2 * g2 + (1. - beta2) * jnp.square(update) - g1_unb = _g1 / (1. - jnp.power(beta1, time)) - g2_unb = _g2 / (1. - jnp.power(beta2, time)) + g1_unb = _g1 / (1. - jnp.power(beta1, time_step)) + g2_unb = _g2 / (1. - jnp.power(beta2, time_step)) _param = param - lr * g1_unb/(jnp.sqrt(g2_unb) + eps) return _param, _g1, _g2 @@ -83,9 +78,7 @@ def adam_step(opt_params, theta, updates, eta=0.001, beta1=0.9, beta2=0.999, eps new_g1 = [] new_g2 = [] for i in range(len(theta)): - px_i, g1_i, g2_i = step_update(theta[i], updates[i], g1[i], - g2[i], eta, beta1, - beta2, time_step, eps) + px_i, g1_i, g2_i = step_update(theta[i], updates[i], g1[i], g2[i], eta, beta1, beta2, time_step, eps) new_theta.append(px_i) new_g1.append(g1_i) new_g2.append(g2_i) diff --git a/ngclearn/utils/optim/nag.py b/ngclearn/utils/optim/nag.py new file mode 100644 index 00000000..4335b4df --- /dev/null +++ b/ngclearn/utils/optim/nag.py @@ -0,0 +1,84 @@ +# %% + +import numpy as np +from jax import jit, numpy as jnp, random, nn, lax +from functools import partial +import time + + +def step_update(param, update, phi_old, lr, mu, time_step): + """ + Runs one step of Nesterov's accelerated gradient (NAG) over a set of parameters given updates. + The dynamics for any set of parameters is as follows: + + | phi = param - update * lr + | param = phi + (phi - phi_previous) * mu, where mu = 0 iff t <= 1 (first iteration) + + Args: + param: parameter tensor to change/adjust + + update: update tensor to be applied to parameter tensor (must be same + shape as "param") + + phi_old: previous friction/momentum parameter + + lr: global step size value to be applied to updates to parameters + + mu: friction/momentum control factor + + time_step: current time t or iteration step/call to this NAG update + + Returns: + adjusted parameter tensor (same shape as "param"), adjusted momentum/friction variable + """ + phi = param - update * lr ## do a phantom gradient adjustment step + _param = phi + (phi - phi_old) * (mu * (time_step > 1.)) ## NAG-step + _phi_old = phi + return _param, _phi_old + +@jit +def nag_step(opt_params, theta, updates, eta=0.01, mu=0.9): ## apply adjustment to theta + """ + Implements Nesterov's accelerated gradient (NAG) algorithm as a decoupled update rule given adjustments produced + by a credit assignment algorithm/process. + + Args: + opt_params: (ArrayLike) parameters of the optimization algorithm + + theta: (ArrayLike) the weights of neural network + + updates: (ArrayLike) the updates of neural network + + eta: (float, optional) step size coefficient for NAG update (Default: 0.001) + + mu: (float, optional) friction/momentum control factor. (Default: 0.9) + + Returns: + ArrayLike: opt_params. New opt params, ArrayLike: theta. The updated weights + """ + phi, time_step = opt_params + time_step = time_step + 1 + new_theta = [] + new_phi = [] + for i in range(len(theta)): + px_i, phi_i = step_update(theta[i], updates[i], phi[i], eta, mu, time_step) + new_theta.append(px_i) + new_phi.append(phi_i) + return (new_phi, time_step), new_theta + +@jit +def nag_init(theta): + time_step = jnp.asarray(0.0) + phi = [jnp.zeros(theta[i].shape) for i in range(len(theta))] + return phi, time_step + +if __name__ == '__main__': + weights = [jnp.asarray([3.0, 3.0]), jnp.asarray([3.0, 3.0])] + updates = [jnp.asarray([3.0, 3.0]), jnp.asarray([3.0, 3.0])] + opt_params = nag_init(weights) + opt_params, theta = nag_step(opt_params, weights, updates) + print(f"opt_params: {opt_params}, theta: {theta}") + weights = theta + print("##################") + opt_params, theta = nag_step(opt_params, weights, updates) + print(f"opt_params: {opt_params}, theta: {theta}") diff --git a/ngclearn/utils/optim/optim_utils.py b/ngclearn/utils/optim/optim_utils.py index f02de676..e521c07b 100755 --- a/ngclearn/utils/optim/optim_utils.py +++ b/ngclearn/utils/optim/optim_utils.py @@ -1,17 +1,20 @@ import functools from .sgd import sgd_step, sgd_init +from .nag import nag_step, nag_init from .adam import adam_step, adam_init def get_opt_init_fn(opt='adam'): return { 'adam': adam_init, + 'nag': nag_init, 'sgd': sgd_init }[opt] def get_opt_step_fn(opt='adam', **kwargs): - # **kwargs here is the hyper parameters you want to pass in the optimization function + ## **kwargs here is the hyper-parameters you want to pass in the optimization function return { 'adam': functools.partial(adam_step, **kwargs), + 'nag': functools.partial(nag_step, **kwargs), 'sgd': functools.partial(sgd_step, **kwargs), }[opt] diff --git a/ngclearn/utils/optim/sgd.py b/ngclearn/utils/optim/sgd.py index e0c38e64..0e8aecac 100755 --- a/ngclearn/utils/optim/sgd.py +++ b/ngclearn/utils/optim/sgd.py @@ -15,7 +15,8 @@ def step_update(param, update, lr): @jit def sgd_step(opt_params, theta, updates, eta=0.001): ## apply adjustment to theta - """Return a params update + """ + Returns updated parameters in accordance to a stochastic gradient descent (SGD) recipe Args: opt_params: (ArrayLike) parameters of the optimization algorithm @@ -42,7 +43,6 @@ def sgd_step(opt_params, theta, updates, eta=0.001): ## apply adjustment to thet def sgd_init(theta): return jnp.asarray(0.0) - if __name__ == '__main__': opt_params, theta = sgd_step((2.0), [1.0, 1.0], [3.0, 4.0], 3e-2) print(f"opt_params: {opt_params}, theta: {theta}") From 05e0a7d16ceb3bd7a4276d2d068f16f77dcb0dc1 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 19 Nov 2025 19:57:25 -0500 Subject: [PATCH 074/121] cleaned up utils.optim and wrote compliant NAG optim --- ngclearn/utils/optim/adam.py | 6 +++--- ngclearn/utils/optim/nag.py | 6 +++--- ngclearn/utils/optim/sgd.py | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ngclearn/utils/optim/adam.py b/ngclearn/utils/optim/adam.py index 5186f035..4fb5c87a 100644 --- a/ngclearn/utils/optim/adam.py +++ b/ngclearn/utils/optim/adam.py @@ -5,7 +5,7 @@ from functools import partial -def step_update(param, update, g1, g2, lr, beta1, beta2, time_step, eps): +def step_update(param, update, g1, g2, eta, beta1, beta2, time_step, eps): """ Runs one step of Adam over a set of parameters given updates. The dynamics for any set of parameters is as follows: @@ -28,7 +28,7 @@ def step_update(param, update, g1, g2, lr, beta1, beta2, time_step, eps): g2: second moment factor/correction factor to use in parameter update (must be same shape as "update") - lr: global step size value to be applied to updates to parameters + eta: global step size value to be applied to updates to parameters beta1: 1st moment control factor @@ -45,7 +45,7 @@ def step_update(param, update, g1, g2, lr, beta1, beta2, time_step, eps): _g2 = beta2 * g2 + (1. - beta2) * jnp.square(update) g1_unb = _g1 / (1. - jnp.power(beta1, time_step)) g2_unb = _g2 / (1. - jnp.power(beta2, time_step)) - _param = param - lr * g1_unb/(jnp.sqrt(g2_unb) + eps) + _param = param - eta * g1_unb/(jnp.sqrt(g2_unb) + eps) return _param, _g1, _g2 @jit diff --git a/ngclearn/utils/optim/nag.py b/ngclearn/utils/optim/nag.py index 4335b4df..045be116 100644 --- a/ngclearn/utils/optim/nag.py +++ b/ngclearn/utils/optim/nag.py @@ -6,7 +6,7 @@ import time -def step_update(param, update, phi_old, lr, mu, time_step): +def step_update(param, update, phi_old, eta, mu, time_step): """ Runs one step of Nesterov's accelerated gradient (NAG) over a set of parameters given updates. The dynamics for any set of parameters is as follows: @@ -22,7 +22,7 @@ def step_update(param, update, phi_old, lr, mu, time_step): phi_old: previous friction/momentum parameter - lr: global step size value to be applied to updates to parameters + eta: global step size value to be applied to updates to parameters mu: friction/momentum control factor @@ -31,7 +31,7 @@ def step_update(param, update, phi_old, lr, mu, time_step): Returns: adjusted parameter tensor (same shape as "param"), adjusted momentum/friction variable """ - phi = param - update * lr ## do a phantom gradient adjustment step + phi = param - update * eta ## do a phantom gradient adjustment step _param = phi + (phi - phi_old) * (mu * (time_step > 1.)) ## NAG-step _phi_old = phi return _param, _phi_old diff --git a/ngclearn/utils/optim/sgd.py b/ngclearn/utils/optim/sgd.py index 0e8aecac..dfde125c 100755 --- a/ngclearn/utils/optim/sgd.py +++ b/ngclearn/utils/optim/sgd.py @@ -1,16 +1,16 @@ from jax import jit, numpy as jnp -def step_update(param, update, lr): +def step_update(param, update, eta): """ Runs one step of SGD over a set of parameters given updates. Args: - lr: global step size to apply when adjusting parameters + eta: global step size to apply when adjusting parameters Returns: adjusted parameter tensor (same shape as "param") """ - _param = param - lr * update + _param = param - update * eta return _param @jit From c1a21ceb046757f51a33b178672d4448f2232eb3 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Thu, 20 Nov 2025 13:42:47 -0500 Subject: [PATCH 075/121] cleanup of components, added leaky-noise-cell, minor edits --- ngclearn/__init__.py | 1 + ngclearn/components/__init__.py | 1 + .../input_encoders/bernoulliCell.py | 4 +- .../components/input_encoders/latencyCell.py | 4 +- .../components/input_encoders/phasorCell.py | 4 +- .../components/input_encoders/poissonCell.py | 4 +- ngclearn/components/neurons/__init__.py | 1 + .../components/neurons/graded/__init__.py | 3 +- .../neurons/graded/bernoulliErrorCell.py | 6 +- .../neurons/graded/gaussianErrorCell.py | 7 +- .../neurons/graded/laplacianErrorCell.py | 7 +- .../neurons/graded/leakyNoiseCell.py | 151 ++++++++++++++++++ .../components/neurons/graded/rateCell.py | 17 +- .../neurons/graded/rewardErrorCell.py | 7 +- ngclearn/components/neurons/spiking/IFCell.py | 4 +- .../components/neurons/spiking/LIFCell.py | 7 +- .../components/neurons/spiking/RAFCell.py | 4 +- .../components/neurons/spiking/WTASCell.py | 6 +- .../components/neurons/spiking/adExCell.py | 4 +- .../neurons/spiking/fitzhughNagumoCell.py | 5 +- .../neurons/spiking/hodgkinHuxleyCell.py | 4 +- .../neurons/spiking/izhikevichCell.py | 4 +- .../components/neurons/spiking/quadLIFCell.py | 4 +- .../components/neurons/spiking/sLIFCell.py | 7 +- ngclearn/components/other/expKernel.py | 7 +- ngclearn/components/other/varTrace.py | 7 +- .../components/synapses/STPDenseSynapse.py | 4 +- ngclearn/components/synapses/alphaSynapse.py | 4 +- .../synapses/convolution/convSynapse.py | 5 +- .../synapses/convolution/deconvSynapse.py | 5 +- .../convolution/hebbianConvSynapse.py | 7 +- .../convolution/hebbianDeconvSynapse.py | 7 +- .../convolution/traceSTDPConvSynapse.py | 7 +- .../convolution/traceSTDPDeconvSynapse.py | 7 +- ngclearn/components/synapses/denseSynapse.py | 4 +- .../components/synapses/doubleExpSynapse.py | 4 +- .../components/synapses/exponentialSynapse.py | 4 +- .../components/synapses/hebbian/BCMSynapse.py | 4 +- .../synapses/hebbian/eventSTDPSynapse.py | 5 +- .../synapses/hebbian/expSTDPSynapse.py | 5 +- .../synapses/hebbian/hebbianSynapse.py | 6 +- .../synapses/hebbian/traceSTDPSynapse.py | 5 +- .../synapses/modulated/MSTDPETSynapse.py | 5 +- .../synapses/patched/hebbianPatchedSynapse.py | 4 +- .../synapses/patched/patchedSynapse.py | 7 +- .../neurons/graded/test_RateCell.py | 11 +- .../synapses/hebbian/test_hebbianSynapse.py | 7 +- 47 files changed, 252 insertions(+), 145 deletions(-) create mode 100755 ngclearn/components/neurons/graded/leakyNoiseCell.py diff --git a/ngclearn/__init__.py b/ngclearn/__init__.py index fa1d4030..9f1ece8a 100644 --- a/ngclearn/__init__.py +++ b/ngclearn/__init__.py @@ -32,6 +32,7 @@ from ngcsimlib.context import Context, ContextObjectTypes from ngcsimlib import Component from ngcsimlib.compartment import Compartment +from ngcsimlib.parser import compilable from ngcsimlib import logger diff --git a/ngclearn/components/__init__.py b/ngclearn/components/__init__.py index 6bce427b..d8c4dc67 100644 --- a/ngclearn/components/__init__.py +++ b/ngclearn/components/__init__.py @@ -2,6 +2,7 @@ ## point to rate-coded cell component types from .neurons.graded.rateCell import RateCell +from .neurons.graded.leakyNoiseCell import LeakyNoiseCell from .neurons.graded.gaussianErrorCell import GaussianErrorCell from .neurons.graded.laplacianErrorCell import LaplacianErrorCell from .neurons.graded.bernoulliErrorCell import BernoulliErrorCell diff --git a/ngclearn/components/input_encoders/bernoulliCell.py b/ngclearn/components/input_encoders/bernoulliCell.py index 52441430..4f18f9d8 100755 --- a/ngclearn/components/input_encoders/bernoulliCell.py +++ b/ngclearn/components/input_encoders/bernoulliCell.py @@ -1,7 +1,7 @@ from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, random -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment import jax from typing import Union diff --git a/ngclearn/components/input_encoders/latencyCell.py b/ngclearn/components/input_encoders/latencyCell.py index 30832afe..3550f086 100755 --- a/ngclearn/components/input_encoders/latencyCell.py +++ b/ngclearn/components/input_encoders/latencyCell.py @@ -6,8 +6,8 @@ from ngclearn.utils.model_utils import clamp_min, clamp_max -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment @partial(jit, static_argnums=[5]) def _calc_spike_times_linear(data, tau, thr, first_spk_t, num_steps=1., diff --git a/ngclearn/components/input_encoders/phasorCell.py b/ngclearn/components/input_encoders/phasorCell.py index 594e3b9d..ada0ddc8 100755 --- a/ngclearn/components/input_encoders/phasorCell.py +++ b/ngclearn/components/input_encoders/phasorCell.py @@ -4,8 +4,8 @@ from typing import Union from ngcsimlib.logger import info, warn -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment class PhasorCell(JaxComponent): """ diff --git a/ngclearn/components/input_encoders/poissonCell.py b/ngclearn/components/input_encoders/poissonCell.py index de6b2a97..810776ab 100644 --- a/ngclearn/components/input_encoders/poissonCell.py +++ b/ngclearn/components/input_encoders/poissonCell.py @@ -4,8 +4,8 @@ from typing import Union from ngcsimlib import deprecate_args -from ngcsimlib.parser import compilable -from ngcsimlib.compartment import Compartment +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment class PoissonCell(JaxComponent): """ diff --git a/ngclearn/components/neurons/__init__.py b/ngclearn/components/neurons/__init__.py index f367cd02..564577cd 100644 --- a/ngclearn/components/neurons/__init__.py +++ b/ngclearn/components/neurons/__init__.py @@ -1,5 +1,6 @@ ## point to rate-coded cell componet types from .graded.rateCell import RateCell +from .graded.leakyNoiseCell import LeakyNoiseCell from .graded.gaussianErrorCell import GaussianErrorCell from .graded.laplacianErrorCell import LaplacianErrorCell from .graded.bernoulliErrorCell import BernoulliErrorCell diff --git a/ngclearn/components/neurons/graded/__init__.py b/ngclearn/components/neurons/graded/__init__.py index 8d723607..2974d91f 100644 --- a/ngclearn/components/neurons/graded/__init__.py +++ b/ngclearn/components/neurons/graded/__init__.py @@ -1,5 +1,6 @@ -## point to rate-coded cell componet types +## point to rate-coded cell component types from .rateCell import RateCell +from .leakyNoiseCell import LeakyNoiseCell from .gaussianErrorCell import GaussianErrorCell from .laplacianErrorCell import LaplacianErrorCell from .bernoulliErrorCell import BernoulliErrorCell diff --git a/ngclearn/components/neurons/graded/bernoulliErrorCell.py b/ngclearn/components/neurons/graded/bernoulliErrorCell.py index f6666015..978aa1ce 100755 --- a/ngclearn/components/neurons/graded/bernoulliErrorCell.py +++ b/ngclearn/components/neurons/graded/bernoulliErrorCell.py @@ -2,12 +2,10 @@ from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, jit -from ngclearn.utils import tensorstats from ngclearn.utils.model_utils import sigmoid, d_sigmoid -from ngcsimlib.logger import info -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment class BernoulliErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell """ diff --git a/ngclearn/components/neurons/graded/gaussianErrorCell.py b/ngclearn/components/neurons/graded/gaussianErrorCell.py index d757800c..776dad46 100755 --- a/ngclearn/components/neurons/graded/gaussianErrorCell.py +++ b/ngclearn/components/neurons/graded/gaussianErrorCell.py @@ -2,11 +2,8 @@ from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, jit -from ngclearn.utils import tensorstats - -from ngcsimlib.logger import info -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell """ diff --git a/ngclearn/components/neurons/graded/laplacianErrorCell.py b/ngclearn/components/neurons/graded/laplacianErrorCell.py index 56bd5c12..afa833cd 100755 --- a/ngclearn/components/neurons/graded/laplacianErrorCell.py +++ b/ngclearn/components/neurons/graded/laplacianErrorCell.py @@ -2,11 +2,8 @@ from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, jit -from ngclearn.utils import tensorstats - -from ngcsimlib.logger import info -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment class LaplacianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell """ diff --git a/ngclearn/components/neurons/graded/leakyNoiseCell.py b/ngclearn/components/neurons/graded/leakyNoiseCell.py new file mode 100755 index 00000000..9b6f2ebc --- /dev/null +++ b/ngclearn/components/neurons/graded/leakyNoiseCell.py @@ -0,0 +1,151 @@ +from jax import numpy as jnp, random, jit +from ngcsimlib.logger import info +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment +from ngclearn.components.jaxComponent import JaxComponent +from ngclearn.utils.model_utils import create_function +from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2, step_rk4 + +def _dfz_fn(z, j_input, j_recurrent, eps, tau_x, sigma_rec): ## raw dynamics ODE + dz_dt = -z + (j_recurrent + j_input) + jnp.sqrt(2. * tau_x * (sigma_rec) ^ 2) * eps + return dz_dt * (1. / tau_x) + +class LeakyNoiseCell(JaxComponent): ## Real-valued, leaky noise cell + """ + A non-spiking cell driven by the gradient dynamics entailed by a continuous-time noisy, leaky recurrent state. + + Reference: https://pmc.ncbi.nlm.nih.gov/articles/PMC4771709/ + + The specific differential equation that characterizes this cell is (for adjusting x) is: + + | tau_x * dx/dt = -x + j_rec + j_in + sqrt(2 alpha (sigma_rec)^2) * eps + | where j_in is the set of incoming input signals + | and j_rec is the set of recurrent input signals + | and eps is a sample of unit Gaussian noise, i.e., eps ~ N(0, 1) + + | --- Cell Input Compartments: --- + | j_input - input (bottom-up) electrical/stimulus current (takes in external signals) + | j_recurrent - recurrent electrical/stimulus pressure + | --- Cell State Compartments --- + | x - noisy rate activity / current value of state + | --- Cell Output Compartments: --- + | r - post-rectified activity, i.e., fx(x) = relu(x) + + Args: + name: the string name of this cell + + n_units: number of cellular entities (neural population size) + + tau_x: state membrane time constant (milliseconds) + + act_fx: rectification function (Default: "relu") + + output_scale: factor to multiply output of nonlinearity of this cell by (Default: 1.) + + integration_type: type of integration to use for this cell's dynamics; + current supported forms include "euler" (Euler/RK-1 integration) and "midpoint" or "rk2" + (midpoint method/RK-2 integration) (Default: "euler") + + :Note: setting the integration type to the midpoint method will increase the accuracy of the estimate of + the cell's evolution at an increase in computational cost (and simulation time) + + sigma_rec: noise scaling factor / standard deviation (Default: 1) + """ + + # Define Functions + def __init__( + self, name, n_units, tau_x, act_fx="relu", integration_type="euler", batch_size=1, sigma_rec=1., + shape=None, **kwargs + ): + super().__init__(name, **kwargs) + + + self.tau_x = tau_x + self.sigma_rec = sigma_rec ## a "resistance" scaling factor + + ## integration properties + self.integrationType = integration_type + self.intgFlag = get_integrator_code(self.integrationType) + + ## Layer size setup + _shape = (batch_size, n_units) ## default shape is 2D/matrix + if shape is None: + shape = (n_units,) ## we set shape to be equal to n_units if nothing provided + else: + _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor + self.shape = shape + self.n_units = n_units + self.batch_size = batch_size + + self.fx, self.dfx = create_function(fun_name=act_fx) + + # compartments (state of the cell & parameters will be updated through stateless calls) + restVals = jnp.zeros(_shape) + self.j_input = Compartment(restVals, display_name="Input Stimulus Current", units="mA") # electrical current + self.j_recurrent = Compartment(restVals, display_name="Recurrent Stimulus Current", units="mA") # electrical current + self.x = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity + self.r = Compartment(restVals, display_name="Rectified Rate Activity") # rectified output + + @compilable + def advance_state(self, t, dt): #dt, fx, tau_x, sigma_rec, intgFlag, key, j_input, j_recurrent, x): + key, skey = random.split(self.key.get(), 2) + ### run a step of integration over neuronal dynamics + eps = random.normal(skey[0], shape=self.x.get().shape) ## sample of unit distributional noise + #x = _run_cell(dt, self.j_input.get(), self.j_recurrent.get(), self.x.get(), eps, self.tau_x, self.sigma_rec, integType=self.intgFlag) + + _step_fns = { + 0: step_euler, + 1: step_rk2, + 2: step_rk4, + } + _step_fn = _step_fns.get(self.intgFlag, step_euler) + params = (self.j_input.get(), self.j_recurrent.get(), eps, self.tau_x, self.sigma_rec) + _, x = _step_fn(0., self.x.get(), _dfz_fn, dt, params) ## update state activation dynamics + r = self.fx(x) ## calculate rectified / post-activation function value(s) + + self.key.set(key) + self.x.set(x) + self.r.set(r) + + @compilable + def reset(self): + _shape = (self.batch_size, self.shape[0]) + if len(self.shape) > 1: + _shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2]) + restVals = jnp.zeros(_shape) + self.j_input.set(restVals) + self.j_recurrent.set(restVals) + self.x.set(restVals) + self.r.set(restVals) + + @classmethod + def help(cls): ## component help function + properties = { + "cell_type": "LeakyNoiseCell - evolves neurons according to continuous-time noisy/leaky dynamics " + } + compartment_props = { + "inputs": + {"j_input": "External input stimulus value(s)", + "j_recurrent": "Recurrent/prior-state stimulus value(s)"}, + "states": + {"x": "Update to continuous noisy, leaky dynamics; value at time t"}, + "outputs": + {"r": "A linear rectifier applied to rate-coded dynamics; f(z)"}, + } + hyperparams = { + "n_units": "Number of neuronal cells to model in this layer", + "batch_size": "Batch size dimension of this component", + "tau_x": "State time constant", + "sigma_rec": "The non-zero degree/scale of noise to inject into this neuron" + } + info = {cls.__name__: properties, + "compartments": compartment_props, + "dynamics": "tau_x * dz/dt = -z + j_input + j_recurrent + noise, where noise ~N(0, sigma_rec)", + "hyperparameters": hyperparams} + return info + +if __name__ == '__main__': + from ngcsimlib.context import Context + with Context("Bar") as bar: + X = LeakyNoiseCell("X", 9, 0.03) + print(X) diff --git a/ngclearn/components/neurons/graded/rateCell.py b/ngclearn/components/neurons/graded/rateCell.py index a03401fc..ddf18f06 100755 --- a/ngclearn/components/neurons/graded/rateCell.py +++ b/ngclearn/components/neurons/graded/rateCell.py @@ -1,18 +1,16 @@ # %% from jax import numpy as jnp, random, jit -from functools import partial -from ngclearn.utils import tensorstats -# from ngclearn import resolver, Component, Compartment -from ngcsimlib.compartment import Compartment + +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment from ngclearn.components.jaxComponent import JaxComponent from ngclearn.utils.model_utils import create_function, threshold_soft, \ threshold_cauchy from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ step_euler, step_rk2, step_rk4 - from ngcsimlib.logger import info -from ngcsimlib.parser import compilable + def _dfz_internal_laplace(z, j, j_td, tau_m, leak_gamma): ## raw dynamics z_leak = jnp.sign(z) ## d/dx of Laplace is signum @@ -228,9 +226,10 @@ def advance_state(self, dt): dfx_val = self.dfx(z) j = _modulate(j, dfx_val) j = j * self.resist_scale - tmp_z = _run_cell(dt, j, j_td, z, - self.tau_m, leak_gamma=self.priorLeakRate, - integType=self.intgFlag, priorType=self.priorType) + tmp_z = _run_cell( + dt, j, j_td, z, self.tau_m, leak_gamma=self.priorLeakRate, integType=self.intgFlag, + priorType=self.priorType + ) ## apply optional thresholding sub-dynamics if self.thresholdType == "soft_threshold": tmp_z = threshold_soft(tmp_z, self.thr_lmbda) diff --git a/ngclearn/components/neurons/graded/rewardErrorCell.py b/ngclearn/components/neurons/graded/rewardErrorCell.py index 479b5c74..91a8056d 100755 --- a/ngclearn/components/neurons/graded/rewardErrorCell.py +++ b/ngclearn/components/neurons/graded/rewardErrorCell.py @@ -2,11 +2,8 @@ from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, jit -from ngclearn.utils import tensorstats - -from ngcsimlib.logger import info -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment class RewardErrorCell(JaxComponent): ## Reward prediction error cell """ diff --git a/ngclearn/components/neurons/spiking/IFCell.py b/ngclearn/components/neurons/spiking/IFCell.py index ec87053a..640d9995 100755 --- a/ngclearn/components/neurons/spiking/IFCell.py +++ b/ngclearn/components/neurons/spiking/IFCell.py @@ -7,8 +7,8 @@ triangular_estimator, straight_through_estimator) -from ngcsimlib.parser import compilable -from ngcsimlib.compartment import Compartment +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment @jit diff --git a/ngclearn/components/neurons/spiking/LIFCell.py b/ngclearn/components/neurons/spiking/LIFCell.py index 850f24a8..6fedf559 100644 --- a/ngclearn/components/neurons/spiking/LIFCell.py +++ b/ngclearn/components/neurons/spiking/LIFCell.py @@ -6,8 +6,8 @@ triangular_estimator, straight_through_estimator) -from ngcsimlib.parser import compilable -from ngcsimlib.compartment import Compartment +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment def _dfv(t, v, params): ## voltage dynamics wrapper j, rfr, tau_m, refract_T, v_rest, g_L = params @@ -186,8 +186,7 @@ def advance_state(self, dt, t): m_switch = (jnp.sum(s) > 0.).astype(jnp.float32) ## TODO: not batch-able rS = s * random.uniform(skey, s.shape) - rS = nn.one_hot(jnp.argmax(rS, axis=1), num_classes=s.shape[1], - dtype=jnp.float32) + rS = nn.one_hot(jnp.argmax(rS, axis=1), num_classes=s.shape[1], dtype=jnp.float32) s = s * (1. - m_switch) + rS * m_switch self.key.set(key) diff --git a/ngclearn/components/neurons/spiking/RAFCell.py b/ngclearn/components/neurons/spiking/RAFCell.py index 102a97c4..6c2bdc5d 100755 --- a/ngclearn/components/neurons/spiking/RAFCell.py +++ b/ngclearn/components/neurons/spiking/RAFCell.py @@ -5,8 +5,8 @@ from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ step_euler, step_rk2 -from ngcsimlib.parser import compilable -from ngcsimlib.compartment import Compartment +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment ######################################################################################################################## ## RAF dynamics (multi-dimensional ODEs) diff --git a/ngclearn/components/neurons/spiking/WTASCell.py b/ngclearn/components/neurons/spiking/WTASCell.py index 1d8f0a0e..b4602c74 100755 --- a/ngclearn/components/neurons/spiking/WTASCell.py +++ b/ngclearn/components/neurons/spiking/WTASCell.py @@ -2,10 +2,8 @@ from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, random, jit, nn from ngcsimlib import deprecate_args -from ngcsimlib.logger import info, warn - -from ngcsimlib.parser import compilable -from ngcsimlib.compartment import Compartment +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment from ngclearn.utils.model_utils import softmax diff --git a/ngclearn/components/neurons/spiking/adExCell.py b/ngclearn/components/neurons/spiking/adExCell.py index ef05d2c2..1e55b55d 100755 --- a/ngclearn/components/neurons/spiking/adExCell.py +++ b/ngclearn/components/neurons/spiking/adExCell.py @@ -4,8 +4,8 @@ from ngcsimlib.logger import info, warn from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2 -from ngcsimlib.parser import compilable -from ngcsimlib.compartment import Compartment +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment @jit def _dfv_internal(j, v, w, tau_m, v_rest, sharpV, vT, R_m): ## raw voltage dynamics diff --git a/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py b/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py index d666a2bf..9fe7f603 100755 --- a/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py +++ b/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py @@ -5,9 +5,8 @@ from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ step_euler, step_rk2 -from ngcsimlib.parser import compilable -from ngcsimlib.compartment import Compartment - +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment @jit def _dfv_internal(j, v, w, a, b, g, tau_m): ## raw voltage dynamics diff --git a/ngclearn/components/neurons/spiking/hodgkinHuxleyCell.py b/ngclearn/components/neurons/spiking/hodgkinHuxleyCell.py index 3ee00ca5..87ec823b 100644 --- a/ngclearn/components/neurons/spiking/hodgkinHuxleyCell.py +++ b/ngclearn/components/neurons/spiking/hodgkinHuxleyCell.py @@ -4,8 +4,8 @@ from ngcsimlib.logger import info, warn from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2, step_rk4 -from ngcsimlib.parser import compilable -from ngcsimlib.compartment import Compartment +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment def _calc_biophysical_constants(v): ## computes H-H biophysical constants (which are functions of voltage v) diff --git a/ngclearn/components/neurons/spiking/izhikevichCell.py b/ngclearn/components/neurons/spiking/izhikevichCell.py index 07d89fc0..b94c3402 100755 --- a/ngclearn/components/neurons/spiking/izhikevichCell.py +++ b/ngclearn/components/neurons/spiking/izhikevichCell.py @@ -4,8 +4,8 @@ from ngcsimlib.logger import info, warn from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2 -from ngcsimlib.parser import compilable -from ngcsimlib.compartment import Compartment +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment @jit def _dfv_internal(j, v, w, b, tau_m): ## raw voltage dynamics diff --git a/ngclearn/components/neurons/spiking/quadLIFCell.py b/ngclearn/components/neurons/spiking/quadLIFCell.py index b8b93982..6d7c95b6 100755 --- a/ngclearn/components/neurons/spiking/quadLIFCell.py +++ b/ngclearn/components/neurons/spiking/quadLIFCell.py @@ -8,8 +8,8 @@ # triangular_estimator, # straight_through_estimator) -from ngcsimlib.parser import compilable -from ngcsimlib.compartment import Compartment +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment from ngclearn.components.neurons.spiking.LIFCell import LIFCell diff --git a/ngclearn/components/neurons/spiking/sLIFCell.py b/ngclearn/components/neurons/spiking/sLIFCell.py index c644e6a2..6b0c6fd8 100644 --- a/ngclearn/components/neurons/spiking/sLIFCell.py +++ b/ngclearn/components/neurons/spiking/sLIFCell.py @@ -3,14 +3,11 @@ from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, random, jit from functools import partial -from ngclearn.utils import tensorstats -from ngcsimlib.logger import info, warn from ngclearn.utils.diffeq.ode_utils import step_euler from ngclearn.utils.surrogate_fx import secant_lif_estimator -from ngcsimlib.logger import info -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment @jit def _dfv_internal(j, v, rfr, tau_m, refract_T): ## raw voltage dynamics diff --git a/ngclearn/components/other/expKernel.py b/ngclearn/components/other/expKernel.py index a7b25f6a..7c99049f 100644 --- a/ngclearn/components/other/expKernel.py +++ b/ngclearn/components/other/expKernel.py @@ -2,11 +2,8 @@ from jax import numpy as jnp, random, jit from functools import partial from ngclearn.utils import tensorstats -from ngcsimlib import deprecate_args - -from ngcsimlib.logger import info, warn -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment @partial(jit, static_argnums=[5,6]) def _apply_kernel(tf_curr, s, t, tau_w, win_len, krn_start, krn_end): diff --git a/ngclearn/components/other/varTrace.py b/ngclearn/components/other/varTrace.py index f1ddc2bc..d4de9f47 100644 --- a/ngclearn/components/other/varTrace.py +++ b/ngclearn/components/other/varTrace.py @@ -3,11 +3,8 @@ from ngclearn.components.jaxComponent import JaxComponent from jax import numpy as jnp, random, jit from functools import partial -from ngclearn.utils import tensorstats -from ngcsimlib.parser import compilable -from ngcsimlib.logger import info, warn - -from ngcsimlib.compartment import Compartment +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment @partial(jit, static_argnums=[4]) def _run_varfilter(dt, x, x_tr, decayFactor, gamma_tr, a_delta=0.): diff --git a/ngclearn/components/synapses/STPDenseSynapse.py b/ngclearn/components/synapses/STPDenseSynapse.py index ff3aed4a..31cf7c67 100755 --- a/ngclearn/components/synapses/STPDenseSynapse.py +++ b/ngclearn/components/synapses/STPDenseSynapse.py @@ -2,9 +2,9 @@ from ngcsimlib.logger import info from ngclearn.utils.distribution_generator import DistributionGenerator +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment from ngclearn.components.synapses import DenseSynapse -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable class STPDenseSynapse(DenseSynapse): ## short-term plastic synaptic cable """ diff --git a/ngclearn/components/synapses/alphaSynapse.py b/ngclearn/components/synapses/alphaSynapse.py index 4470af68..cbdbb8c8 100644 --- a/ngclearn/components/synapses/alphaSynapse.py +++ b/ngclearn/components/synapses/alphaSynapse.py @@ -1,8 +1,8 @@ from jax import random, numpy as jnp, jit +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment from ngclearn.components.synapses import DenseSynapse -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable class AlphaSynapse(DenseSynapse): ## dynamic alpha synapse cable """ diff --git a/ngclearn/components/synapses/convolution/convSynapse.py b/ngclearn/components/synapses/convolution/convSynapse.py index 62b6ee3a..02186493 100755 --- a/ngclearn/components/synapses/convolution/convSynapse.py +++ b/ngclearn/components/synapses/convolution/convSynapse.py @@ -1,7 +1,6 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable -from ngclearn.utils.weight_distribution import initialize_params +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment from ngcsimlib.logger import info import ngclearn.utils.weight_distribution as dist from ngclearn.components.synapses.convolution.ngcconv import conv2d diff --git a/ngclearn/components/synapses/convolution/deconvSynapse.py b/ngclearn/components/synapses/convolution/deconvSynapse.py index 32f1dfc8..8355494d 100755 --- a/ngclearn/components/synapses/convolution/deconvSynapse.py +++ b/ngclearn/components/synapses/convolution/deconvSynapse.py @@ -1,7 +1,6 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable -from ngclearn.utils.weight_distribution import initialize_params +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment from ngcsimlib.logger import info import ngclearn.utils.weight_distribution as dist from ngclearn.components.synapses.convolution.ngcconv import deconv2d diff --git a/ngclearn/components/synapses/convolution/hebbianConvSynapse.py b/ngclearn/components/synapses/convolution/hebbianConvSynapse.py index 4b45d2ce..a66242a4 100755 --- a/ngclearn/components/synapses/convolution/hebbianConvSynapse.py +++ b/ngclearn/components/synapses/convolution/hebbianConvSynapse.py @@ -1,9 +1,6 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable -from ngclearn.utils.weight_distribution import initialize_params -import ngclearn.utils.weight_distribution as dist - +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment from ngclearn.components.synapses.convolution.convSynapse import ConvSynapse from ngclearn.components.synapses.convolution.ngcconv import (_conv_same_transpose_padding, diff --git a/ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py b/ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py index 35def788..d3317728 100755 --- a/ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py +++ b/ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py @@ -1,9 +1,6 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable -from ngclearn.utils.weight_distribution import initialize_params -import ngclearn.utils.weight_distribution as dist - +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment from ngclearn.components.synapses.convolution.deconvSynapse import DeconvSynapse from ngclearn.components.synapses.convolution.ngcconv import (deconv2d, _calc_dX_deconv, diff --git a/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py b/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py index c9b4e5f2..86aa33c4 100755 --- a/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py +++ b/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py @@ -1,9 +1,6 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable -from ngclearn.utils.weight_distribution import initialize_params -import ngclearn.utils.weight_distribution as dist - +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment from ngclearn.components.synapses.convolution.convSynapse import ConvSynapse from ngclearn.components.synapses.convolution.ngcconv import (_conv_same_transpose_padding, diff --git a/ngclearn/components/synapses/convolution/traceSTDPDeconvSynapse.py b/ngclearn/components/synapses/convolution/traceSTDPDeconvSynapse.py index a6286cd8..a894213e 100755 --- a/ngclearn/components/synapses/convolution/traceSTDPDeconvSynapse.py +++ b/ngclearn/components/synapses/convolution/traceSTDPDeconvSynapse.py @@ -1,9 +1,6 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable -from ngclearn.utils.weight_distribution import initialize_params -import ngclearn.utils.weight_distribution as dist - +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment from ngclearn.components.synapses.convolution.deconvSynapse import DeconvSynapse from ngclearn.components.synapses.convolution.ngcconv import (deconv2d, _calc_dX_deconv, diff --git a/ngclearn/components/synapses/denseSynapse.py b/ngclearn/components/synapses/denseSynapse.py index 99996e65..db656b7a 100755 --- a/ngclearn/components/synapses/denseSynapse.py +++ b/ngclearn/components/synapses/denseSynapse.py @@ -3,8 +3,8 @@ from ngclearn.utils.distribution_generator import DistributionGenerator from ngcsimlib.logger import info -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment class DenseSynapse(JaxComponent): ## base dense synaptic cable """ diff --git a/ngclearn/components/synapses/doubleExpSynapse.py b/ngclearn/components/synapses/doubleExpSynapse.py index ca1fdcdd..91a05d60 100644 --- a/ngclearn/components/synapses/doubleExpSynapse.py +++ b/ngclearn/components/synapses/doubleExpSynapse.py @@ -1,8 +1,8 @@ from jax import random, numpy as jnp, jit +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment from ngclearn.components.synapses import DenseSynapse -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable class DoubleExpSynapse(DenseSynapse): ## dynamic double-exponential synapse cable """ diff --git a/ngclearn/components/synapses/exponentialSynapse.py b/ngclearn/components/synapses/exponentialSynapse.py index e0ee3a6e..dc20c362 100644 --- a/ngclearn/components/synapses/exponentialSynapse.py +++ b/ngclearn/components/synapses/exponentialSynapse.py @@ -1,8 +1,8 @@ from jax import random, numpy as jnp, jit +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment from ngclearn.components.synapses import DenseSynapse -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable class ExponentialSynapse(DenseSynapse): ## dynamic exponential synapse cable """ diff --git a/ngclearn/components/synapses/hebbian/BCMSynapse.py b/ngclearn/components/synapses/hebbian/BCMSynapse.py index c31bba12..ff669a07 100755 --- a/ngclearn/components/synapses/hebbian/BCMSynapse.py +++ b/ngclearn/components/synapses/hebbian/BCMSynapse.py @@ -1,6 +1,6 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment from ngclearn.components.synapses.denseSynapse import DenseSynapse diff --git a/ngclearn/components/synapses/hebbian/eventSTDPSynapse.py b/ngclearn/components/synapses/hebbian/eventSTDPSynapse.py index 265445e4..826b9ff9 100755 --- a/ngclearn/components/synapses/hebbian/eventSTDPSynapse.py +++ b/ngclearn/components/synapses/hebbian/eventSTDPSynapse.py @@ -1,7 +1,6 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable - +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment from ngclearn.components.synapses.denseSynapse import DenseSynapse class EventSTDPSynapse(DenseSynapse): # event-driven, post-synaptic STDP diff --git a/ngclearn/components/synapses/hebbian/expSTDPSynapse.py b/ngclearn/components/synapses/hebbian/expSTDPSynapse.py index 74312c6f..bb481512 100644 --- a/ngclearn/components/synapses/hebbian/expSTDPSynapse.py +++ b/ngclearn/components/synapses/hebbian/expSTDPSynapse.py @@ -1,7 +1,6 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable - +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment from ngclearn.components.synapses.denseSynapse import DenseSynapse class ExpSTDPSynapse(DenseSynapse): diff --git a/ngclearn/components/synapses/hebbian/hebbianSynapse.py b/ngclearn/components/synapses/hebbian/hebbianSynapse.py index ff3b796e..bd9f0024 100644 --- a/ngclearn/components/synapses/hebbian/hebbianSynapse.py +++ b/ngclearn/components/synapses/hebbian/hebbianSynapse.py @@ -4,10 +4,8 @@ from functools import partial from ngclearn.utils.optim import get_opt_init_fn, get_opt_step_fn -from ngcsimlib.logger import info -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable - +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment from ngclearn.components.synapses import DenseSynapse from ngclearn.utils import tensorstats from ngcsimlib import deprecate_args diff --git a/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py b/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py index 59098ed8..1c7ac3ab 100755 --- a/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py +++ b/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py @@ -1,7 +1,6 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable - +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment from ngclearn.components.synapses.denseSynapse import DenseSynapse diff --git a/ngclearn/components/synapses/modulated/MSTDPETSynapse.py b/ngclearn/components/synapses/modulated/MSTDPETSynapse.py index bbd7dae3..150ebc9b 100755 --- a/ngclearn/components/synapses/modulated/MSTDPETSynapse.py +++ b/ngclearn/components/synapses/modulated/MSTDPETSynapse.py @@ -1,7 +1,6 @@ from jax import random, numpy as jnp, jit -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable -from ngclearn.utils.weight_distribution import initialize_params +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment from ngclearn.components.synapses.hebbian import TraceSTDPSynapse diff --git a/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py b/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py index 64aabd92..1ef3c94f 100644 --- a/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py +++ b/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py @@ -6,8 +6,8 @@ from ngclearn.utils.optim import get_opt_init_fn, get_opt_step_fn from ngcsimlib.logger import info -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment from ngclearn.components.synapses.patched import PatchedSynapse from ngclearn.utils import tensorstats diff --git a/ngclearn/components/synapses/patched/patchedSynapse.py b/ngclearn/components/synapses/patched/patchedSynapse.py index 3ea00475..7fae13f7 100644 --- a/ngclearn/components/synapses/patched/patchedSynapse.py +++ b/ngclearn/components/synapses/patched/patchedSynapse.py @@ -7,11 +7,8 @@ from ngclearn.utils.weight_distribution import initialize_params from ngcsimlib.logger import info -from ngcsimlib.compartment import Compartment -from ngcsimlib.parser import compilable - -import math - +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment def create_multi_patch_synapses(key, shape, n_sub_models, sub_stride, weight_init): sub_shape = (shape[0] // n_sub_models, shape[1] // n_sub_models) diff --git a/tests/components/neurons/graded/test_RateCell.py b/tests/components/neurons/graded/test_RateCell.py index 95260c95..ecd1ce9a 100644 --- a/tests/components/neurons/graded/test_RateCell.py +++ b/tests/components/neurons/graded/test_RateCell.py @@ -3,7 +3,7 @@ from jax import numpy as jnp, random, jit import numpy as np np.random.seed(42) -from ngclearn.components import RateCell +from ngclearn.components.neurons.graded.rateCell import RateCell from numpy.testing import assert_array_equal from ngclearn import Context, MethodProcess @@ -23,8 +23,8 @@ def test_RateCell1(): advance_process = (MethodProcess("advance_proc") >> a.advance_state) reset_process = (MethodProcess("reset_proc") >> a.reset) - def clamp(x): - a.j.set(x) + def clamp(x): + a.j.set(x) ## input spike train x_seq = jnp.ones((1, 10)) @@ -35,12 +35,13 @@ def clamp(x): reset_process.run() for ts in range(x_seq.shape[1]): x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t - ctx.clamp(x_t) + clamp(x_t) advance_process.run(t=ts * 1., dt=dt) - outs.append(a.z.value) + outs.append(a.z.get()) outs = jnp.concatenate(outs, axis=1) # print(outs) ## output should equal input # assert_array_equal(outs, y_seq, tol=1e-3) np.testing.assert_allclose(outs, y_seq, atol=1e-3) +#test_RateCell1() diff --git a/tests/components/synapses/hebbian/test_hebbianSynapse.py b/tests/components/synapses/hebbian/test_hebbianSynapse.py index 1b39ff5a..ba5dc463 100644 --- a/tests/components/synapses/hebbian/test_hebbianSynapse.py +++ b/tests/components/synapses/hebbian/test_hebbianSynapse.py @@ -4,7 +4,7 @@ import numpy as np np.random.seed(42) -from ngclearn.components import HebbianSynapse +from ngclearn.components.synapses.hebbian.hebbianSynapse import HebbianSynapse from numpy.testing import assert_array_equal from ngclearn import Context, MethodProcess @@ -57,12 +57,11 @@ def clamp_post(x): advance_process.run(t=1. * dt, dt=dt) evolve_process.run(t=1. * dt, dt=dt) - print(a.weights.get()) + #print(a.weights.get()) # Basic assertions to check learning dynamics assert a.weights.get().shape == (10, 5), "" assert a.weights.get()[0, 0] == 0.5, "" -test_hebbianSynapse() - +#test_hebbianSynapse() From 78e58da8145ef5ffa870128d6f67fdfd5592ecbc Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Thu, 20 Nov 2025 14:22:38 -0500 Subject: [PATCH 076/121] revised leaky-noise-cell, wrote its unit test, test-passed --- .../neurons/graded/leakyNoiseCell.py | 26 ++++++---- .../neurons/graded/test_leakyNoiseCell.py | 47 +++++++++++++++++++ 2 files changed, 63 insertions(+), 10 deletions(-) create mode 100644 tests/components/neurons/graded/test_leakyNoiseCell.py diff --git a/ngclearn/components/neurons/graded/leakyNoiseCell.py b/ngclearn/components/neurons/graded/leakyNoiseCell.py index 9b6f2ebc..85c4cd03 100755 --- a/ngclearn/components/neurons/graded/leakyNoiseCell.py +++ b/ngclearn/components/neurons/graded/leakyNoiseCell.py @@ -6,10 +6,14 @@ from ngclearn.utils.model_utils import create_function from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2, step_rk4 -def _dfz_fn(z, j_input, j_recurrent, eps, tau_x, sigma_rec): ## raw dynamics ODE - dz_dt = -z + (j_recurrent + j_input) + jnp.sqrt(2. * tau_x * (sigma_rec) ^ 2) * eps +def _dfz_fn(z, j_input, j_recurrent, eps, tau_x, sigma_rec, leak_scale): ## raw dynamics ODE + dz_dt = -(z * leak_scale) + (j_recurrent + j_input) + jnp.sqrt(2. * tau_x * jnp.square(sigma_rec)) * eps return dz_dt * (1. / tau_x) +def _dfz(t, z, params): ## raw dynamics ODE wrapper + j_input, j_recurrent, eps, tau_x, sigma_rec, leak_scale = params + return _dfz_fn(z, j_input, j_recurrent, eps, tau_x, sigma_rec, leak_scale) + class LeakyNoiseCell(JaxComponent): ## Real-valued, leaky noise cell """ A non-spiking cell driven by the gradient dynamics entailed by a continuous-time noisy, leaky recurrent state. @@ -55,13 +59,14 @@ class LeakyNoiseCell(JaxComponent): ## Real-valued, leaky noise cell # Define Functions def __init__( self, name, n_units, tau_x, act_fx="relu", integration_type="euler", batch_size=1, sigma_rec=1., - shape=None, **kwargs + leak_scale=1., shape=None, **kwargs ): super().__init__(name, **kwargs) self.tau_x = tau_x self.sigma_rec = sigma_rec ## a "resistance" scaling factor + self.leak_scale = leak_scale ## the leak scaling factor (most appropriate default is 1) ## integration properties self.integrationType = integration_type @@ -87,22 +92,23 @@ def __init__( self.r = Compartment(restVals, display_name="Rectified Rate Activity") # rectified output @compilable - def advance_state(self, t, dt): #dt, fx, tau_x, sigma_rec, intgFlag, key, j_input, j_recurrent, x): - key, skey = random.split(self.key.get(), 2) + def advance_state(self, t, dt): ### run a step of integration over neuronal dynamics - eps = random.normal(skey[0], shape=self.x.get().shape) ## sample of unit distributional noise - #x = _run_cell(dt, self.j_input.get(), self.j_recurrent.get(), self.x.get(), eps, self.tau_x, self.sigma_rec, integType=self.intgFlag) + key, skey = random.split(self.key.get(), 2) + eps = random.normal(skey, shape=self.x.get().shape) ## sample of unit distributional noise + #x = _run_cell(dt, self.j_input.get(), self.j_recurrent.get(), self.x.get(), eps, self.tau_x, self.sigma_rec, integType=self.intgFlag) _step_fns = { 0: step_euler, 1: step_rk2, 2: step_rk4, } - _step_fn = _step_fns.get(self.intgFlag, step_euler) - params = (self.j_input.get(), self.j_recurrent.get(), eps, self.tau_x, self.sigma_rec) - _, x = _step_fn(0., self.x.get(), _dfz_fn, dt, params) ## update state activation dynamics + _step_fn = _step_fns[self.intgFlag] #_step_fns.get(self.intgFlag, step_euler) + params = (self.j_input.get(), self.j_recurrent.get(), eps, self.tau_x, self.sigma_rec, self.leak_scale) + _, x = _step_fn(0., self.x.get(), _dfz, dt, params) ## update state activation dynamics r = self.fx(x) ## calculate rectified / post-activation function value(s) + ## set compartments to next state values in accordance with dynamics self.key.set(key) self.x.set(x) self.r.set(r) diff --git a/tests/components/neurons/graded/test_leakyNoiseCell.py b/tests/components/neurons/graded/test_leakyNoiseCell.py new file mode 100644 index 00000000..096c4f68 --- /dev/null +++ b/tests/components/neurons/graded/test_leakyNoiseCell.py @@ -0,0 +1,47 @@ +# %% + +from jax import numpy as jnp, random, jit +import numpy as np +np.random.seed(42) +from ngclearn.components.neurons.graded.leakyNoiseCell import LeakyNoiseCell +from numpy.testing import assert_array_equal + +from ngclearn import Context, MethodProcess + + +def test_LeakyNoiseCell1(): + name = "leaky_noise_ctx" + dkey = random.PRNGKey(42) + dkey, *subkeys = random.split(dkey, 100) + dt = 1. # ms + with Context(name) as ctx: + a = LeakyNoiseCell( + name="a", n_units=1, tau_x=50., act_fx="identity", integration_type="euler", batch_size=1, sigma_rec=0., + leak_scale=0. + ) + advance_process = (MethodProcess("advance_proc") >> a.advance_state) + reset_process = (MethodProcess("reset_proc") >> a.reset) + + def clamp(x): + a.j_input.set(x) + + ## input spike train + x_seq = jnp.ones((1, 10)) + ## desired output/epsp pulses + y_seq = jnp.asarray([[0.02, 0.04, 0.06, 0.08, 0.09999999999999999, 0.11999999999999998, 0.13999999999999999, 0.15999999999999998, 0.17999999999999998, 0.19999999999999998]], dtype=jnp.float32) + + outs = [] + reset_process.run() + for ts in range(x_seq.shape[1]): + x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t + clamp(x_t) + advance_process.run(t=ts * 1., dt=dt) + outs.append(a.x.get()) + outs = jnp.concatenate(outs, axis=1) + # print(outs) + # print(y_seq) + ## output should approximately equal input + # assert_array_equal(outs, y_seq, tol=1e-3) + np.testing.assert_allclose(outs, y_seq, atol=1e-3) + +#test_LeakyNoiseCell1() From 88ce190a0f2b701614bef70a4043ce546389f013 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Thu, 20 Nov 2025 18:29:17 -0500 Subject: [PATCH 077/121] some revisions/updates to toc/pointer/general tutorial docs --- docs/installation.md | 57 ++++++++++------------------------- docs/ngclearn_papers.md | 27 +++++------------ docs/tutorials/foundations.md | 18 +++-------- docs/tutorials/index.rst | 6 +--- docs/tutorials/intro.md | 34 ++++----------------- docs/tutorials/theory.md | 47 +++-------------------------- 6 files changed, 40 insertions(+), 149 deletions(-) diff --git a/docs/installation.md b/docs/installation.md index 03bbe8a2..2c482c43 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -1,65 +1,41 @@ # Installation -**ngc-learn** officially supports Linux on Python 3. It can be run with or -without a GPU. +**ngc-learn** officially supports Linux on Python 3. It can be run with or without a GPU. -Setup: ngc-learn, -in its entirety (including its supporting utilities), -requires that you ensure that you have installed the following base dependencies in -your system. Note that this library was developed and tested on Ubuntu 22.04 (and earlier versions on 18.04/20.04). -Specifically, ngc-learn requires: +Setup: NGC-Learn, in its entirety (including its supporting utilities), requires that you ensure that you have installed the following base dependencies in your system. Note that this library was developed and tested on Ubuntu 22.04 (with much earlier versions on Ubuntu 18.04/20.04). +Specifically, NGC-Learn requires: * Python (>=3.10) -* ngcsimlib (>=1.0.0), (official page) +* ngcsimlib (>=2.0.0), (official page) * NumPy (>=1.22.0) * SciPy (>=1.7.0) * JAX (>= 0.4.28; and jaxlib>=0.4.28) * Matplotlib (>=3.8.0), (for `ngclearn.utils.viz`) * Scikit-learn (>=1.6.1), (for `ngclearn.utils.patch_utils` and `ngclearn.utils.density`) -Note that the above requirements are taken care of if one installs ngc-learn -through either `pip`. One can either install the CPU version of ngc-learn (if no JAX is -pre-installed or only the CPU version of JAX is installed currently) via +Note that the above requirements are taken care of if one installs NGC-Learn through either `pip`. One can either install the CPU version of NGC-Learn (if no JAX is pre-installed or only the CPU version of JAX is currently installed) via: ```console $ pip install ngclearn ``` -or install the GPU version of ngc-learn by first installing the -CUDA 12 -version of JAX before running the above pip command. +or install the GPU version of NGC-Learn by first installing the CUDA 12 version of JAX before running the above pip command. -Alternatively, one may locally, step-by-step (see below), install and setup -ngc-learn from source after pulling from the repo. +Alternatively, one may locally, step-by-step (see below), install and setup NGC-Learn from source after pulling from the repo. -Note that installing the official pip package without any form of JAX installed -on your system will default to downloading the CPU version of ngc-learn (see -below for installing the GPU version). +Note that installing the official pip package without any form of JAX installed on your system will default to downloading the CPU version of NGC-Learn (see below for installing the GPU version). ## Install from Source -0. Install ngc-sim-lib first (as an editable install); visit the repo -https://github.com/NACLab/ngc-sim-lib for details. +1. Install NGC-Sim-Lib first (as an editable install); visit the repo https://github.com/NACLab/ngc-sim-lib for details. -1. Clone the ngc-learn repository: +2. Clone the NGC-Learn repository: ```console $ git clone https://github.com/NACLab/ngc-learn.git $ cd ngc-learn ``` -2. (Optional; only for GPU version) Install JAX for either CUDA 12 , depending - on your system setup. Follow the - installation instructions - on the official JAX page to properly install the CUDA 11 or 12 version. +3. (Optional; only for GPU version) Install JAX for either CUDA 12 , depending on your system setup. Follow the installation instructions on the official JAX page to properly install the CUDA 11 or 12 version. - - -3. Install the ngc-learn package via: +4. Install the NGC-Learn package via: ```console $ pip install . ``` @@ -68,22 +44,21 @@ or, to install as an editable install for development, run: $ pip install -e . ``` -If the installation was successful, you should see the following if you test -it against your Python interpreter, i.e., run the $ python command -and complete the following sequence of steps as depicted in the screenshot below:
- +If the installation was successful, you should see the following if you test it against your Python interpreter, i.e., run the $ python command and complete the following sequence of steps as depicted in the screenshot below:
```console Python 3.11.4 (main, MONTH DAY YEAR, TIME) [GCC XX.X.X] on linux Type "help", "copyright", "credits" or "license" for more information. >>> import ngclearn >>> ngclearn.__version__ -'2.0.2' +'3.0.0' ``` + diff --git a/docs/ngclearn_papers.md b/docs/ngclearn_papers.md index 3fa9b1df..b89d005f 100644 --- a/docs/ngclearn_papers.md +++ b/docs/ngclearn_papers.md @@ -1,30 +1,19 @@ # List of Papers/Publications -The following is a list of current papers that use ngc-learn (this list will be -actively updated as we discover others that use ngc-learn): +The following is a list of current papers that use ngc-learn (this list will be actively updated as we discover others that use ngc-learn): -1. Ororbia, A., and Kifer, D. The neural coding framework for learning -generative models. Nature Communications 13, 2064 (2022). +1. Ororbia, A., and Kifer, D. The neural coding framework for learning generative models. Nature Communications 13, 2064 (2022). -2. Ororbia, A., and Mali, A. Backprop-free reinforcement learning with active -neural generative coding. Proceedings of the AAAI Conference on Artificial -intelligence (2022). +2. Ororbia, A., and Mali, A. Backprop-free reinforcement learning with active neural generative coding. Proceedings of the AAAI Conference on Artificial intelligence (2022). -3. Ororbia, A. "Spiking neural predictive coding for continual learning -from data streams." arXiv preprint arXiv:1908.08655 (2019). +3. Ororbia, A. "Spiking neural predictive coding for continual learning from data streams." Neurocomputing 544: 126292 (2022). -4. Ororbia, A, and Kelly, M. Alex. "CogNGen: constructing the kernel of -a hyperdimensional predictive processing cognitive architecture." -Proceedings of the Annual Meeting of the Cognitive Science Society (CogSci), Volume 44 (2022). +4. Ororbia, A, and Kelly, M. Alex. "CogNGen: constructing the kernel of a hyperdimensional predictive processing cognitive architecture." Proceedings of the Annual Meeting of the Cognitive Science Society (CogSci), Volume 44 (2022). -5. Ororbia, A., and Kelly, M. Alex. "Learning using a hyperdimensional predictive processing cognitive -architecture." 15th International Conference on Artificial General Intelligence (AGI) (2022). +5. Ororbia, A., and Kelly, M. Alex. "Learning using a hyperdimensional predictive processing cognitive architecture." 15th International Conference on Artificial General Intelligence (AGI) (2022). -6. Ororbia, A., Mali, A., Kifer, D., & Giles, C. L. "Lifelong neural predictive coding: Learning cumulatively online without -forgetting." Thirty-sixth Conference on Neural Information Processing Systems (NeurIPS) (2022). +6. Ororbia, A., Mali, A., Kifer, D., & Giles, C. L. "Lifelong neural predictive coding: Learning cumulatively online without forgetting." Thirty-sixth Conference on Neural Information Processing Systems (NeurIPS) (2022). 7. Ororbia, A., Friston, K., Rao, Rajesh P. N. "Meta-representational predictive coding: Biomimetic self-supervised learning." arXiv preprint arXiv:2503.21796 (2025). -Note: Please let us know if your work uses ngc-learn so we can update this page to accurately track -ngc-learn's use and include your work in the accumulating body of work in predictive processing -and/or brain-inspired computational modeling. +Note: Please let us know if your work uses ngc-learn so we can update this page to accurately track ngc-learn's use and include your work in the accumulating body of work in predictive processing and/or brain-inspired computational modeling. diff --git a/docs/tutorials/foundations.md b/docs/tutorials/foundations.md index 15887da3..822dc8f8 100644 --- a/docs/tutorials/foundations.md +++ b/docs/tutorials/foundations.md @@ -1,18 +1,8 @@ # Foundational Elements -In this set of tutorials/walkthroughs, we go through the some of the core elements -and mechanisms underlying ngc-learn in order understand how its simulation -scheme (and the nodes-and-cables system) works and to help in writing your -own custom elements. +In this set of tutorials/walkthroughs, we go through some of the core elements and mechanisms underlying NGC-Learn in order understand how its simulation scheme (and the nodes-and-cables system) works and to help in writing your own custom elements. The foundational walkthroughs are organized as follows: -1. [Using Model Contexts](../tutorials/foundations/contexts.md): This lesson goes - the fundamentals of the primary simulation construct you need to set up models, the - (simulation) context. -2. [Understanding Commands](../tutorials/foundations/commands.md): This lesson will - walk you through the basics of a command -- an essential part of building a - simulation controller in ngc-learn and ngcsimlib -- and offer some useful - points for designing new ones. -3. [Operations](../tutorials/foundations/operations.md): Here, the basics - of bundle rules, a commonly use mechanism for crafting complex biophysical - systems, will be presented. +1. [Using Model Contexts](../tutorials/foundations/contexts.md): This lesson goes the fundamentals of the primary simulation construct you need to set up models, the (simulation) context. +2. [Understanding Commands](../tutorials/foundations/commands.md): This lesson will walk you through the basics of a command -- an essential part of building a simulation controller in ngc-learn and ngcsimlib -- and offer some useful points for designing new ones. +3. [Operations](../tutorials/foundations/operations.md): Here, the basics of bundle rules, a commonly-used mechanism for crafting complex biophysical systems, will be presented. diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst index 0efdacbf..430924cf 100644 --- a/docs/tutorials/index.rst +++ b/docs/tutorials/index.rst @@ -5,11 +5,7 @@ Tutorial Contents ================= -Lessons/tutorials go through the very basics of constructing a dynamical system in -ngc-learn, core elements and tools of neurocognitive modeling using ngc-learn's -in-built components and simulation tools, and finally providing foundational insights -into how ngc-learn and its backend, ngc-sim-lib, work (particularly with respect -to configuration). +Lessons/tutorials go through the very basics of constructing a dynamical system in NGC-Learn, core elements and tools of neurocognitive modeling using NGC-Learn's in-built components and simulation tools, and finally providing foundational insights into how NGC-Learn and its backend, NGC-Sim-Lib, work (particularly with respect to configuration). .. toctree:: :maxdepth: 1 diff --git a/docs/tutorials/intro.md b/docs/tutorials/intro.md index 5651a453..b32cd184 100755 --- a/docs/tutorials/intro.md +++ b/docs/tutorials/intro.md @@ -1,37 +1,15 @@ # Introduction -ngc-learn is a general-purpose library for modeling biomimetic/neuro-mimetic -complex systems. While the library is designed to provide flexibility on the -experimenter/designer side -- allowing one to design their own dynamics and -evolutionary processes -- at its foundation are a few standard components, the -basic modeling nodes for simulating some common biophysical systems computationally, -that useful to know in getting started and quickly building some classical/historical -models. If you are interested in knowing some of the neurophysiological theory -behind ngc-learn's design philosophy, [this section](../tutorials/theory) might -be of interest. +NGC-Learn is a general-purpose library for modeling complex dynamical systems, particularly those that are useful for computational neuroscience, neuroscience-motivated artificial intelligence (NeuroAI), and brain-inspired computing. + +While the library is designed to provide flexibility on the experimenter/designer side -- allowing one to develop their own dynamics and evolutionary processes -- at its foundation are a few standard components. These are basic modeling nodes for simulating some common biophysical systems computationally, which are useful to know when getting started and for quickly building some classical/historical models. If you are interested in knowing some of the neurophysiological theory behind NGC-Learn's design philosophy, [this section](../tutorials/theory) might be of interest. -Specifically, to make best use of ngc-learn, it is important to get the -hang of its "nodes-and-cables system" (as it was historically referred to) in -order to build simulation objects. This set of tutorials will walk through, -step-by-step, the key aspects of the library you need to know so you can build -and run simulations of computational biophysical models. In addition, we -provide walkthroughs of some of the central mechanisms underlying -ngcsimlib, the simulation -dependency library that drives ngc-learn; these are particularly useful for not -only understanding why and how things are done by ngc-learn's simulation -backend but also for those who want to design new, custom extensions of ngc-learn -either for their own research or to contribute to the development of the main library. +Specifically, to make best use of NGC-Learn, it is important to get the hang of its "nodes-and-cables system" (the historical name for its backend engine) in order to build simulation objects. This set of tutorials will walk you through, step-by-step, the key aspects of the library that you will need to know so that you can build +and run simulations of computational biophysical models. In addition, we provide walkthroughs of some of the central mechanisms underlying NGC-Sim-Lib, the simulation dependency library that drives NGC-Learn; these lessons are particularly useful for not only understanding why and how things are done by NGC-Learn's simulation backend engine but also for those who want to design new, custom extensions of NGC-Learn either for their own research or to help contribute to the development of the main library. ## Organization of Tutorials -The core tutorials and lessons for using ngc-learn can be found [here, in the -tutorial table of contents](../tutorials/index.rst) and go through: the basic -configuration and use of ngc-learn and ngc-sim-lib to construct simulations -of dynamical systems, the essentials of neurocognitive modeling (such as -building and analyzing neuronal dynamics and synaptic plasticity), as well -as the coverage of some key foundational ideas/tools worth knowing about -ngc-learn (and its backend, ngc-sim-lib) particularly to facilitate easier -debugging, experimental configuration, and advanced model tools like `bundle rules`. +The core tutorials and lessons for using NGC-Learn can be found [here, in the tutorial table of contents](../tutorials/index.rst) which essentially go through: the basic configuration and use of NGC-Learn and NGC-Sim-Lib to construct simulations of dynamical systems, the essentials of neurocognitive modeling (such as building and analyzing models of neuronal dynamics and synaptic plasticity), as well as the coverage of some key foundational ideas/tools worth knowing about NGC-Learn (and its backend, NGC-Sim-Lib) particularly to facilitate easier debugging, experimental configuration, and advanced modeling tools. +In NGC-Learn, it is possible to construct other forms of learning from the very base learning/plasticity components already in-built into the base library. Notably, a class of learning and inference systems that adapt through a process known as contrastive Hebbian learning (CHL) can be constructed and simulated with ngc-learn. + +In this walkthrough, we will design a simple Harmonium, also known as the restricted Boltzmann machine (RBM). We will specifically focus on learning its synaptic connections with an algorithmic recipe known +as contrastive divergence (CD), which can be considered to be a stochastic form of CHL. After going through this exhibit, you will: + +1. Learn how to construct an `NGCGraph` that emulates the structure of an RBM and adapt the NGC settling process to calculate approximate synaptic weight gradients in accordance to contrastive divergence. +2. Simulate fantasized image samples using the block Gibbs sampler implicitly defined by the negative phase graph. + +Note that the folders of interest to this walkthrough are: ++ `ngc-museum/exhibits/harmonium/`: this contains the necessary simulation scripts (which can be found [here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/harmonium)); ++ `ngc-museum/data/mnist/`: this contains the zipped copy of the MNIST digit image arrays + +## On the Harmonium Probabilistic Graphical Model + +A harmonium is a generative model implemented as a stochastic, two-layer neural system (a type of probabilistic graphic model; PGM) that attempts to learn a probability distribution over sensory input $\mathbf{x}$, i.e., the goal of a harmonium is to learn $p(\mathbf{x})$, the underlying probability/likelihood of a given (training) dataset. Fundamentally, the approach to estimating $p(\mathbf{x})$ that carried out by a harmonium is through the optimization of an energy function $E(\mathbf{x})$ (a concept motivated by statistical mechanics), where the system searches for an internal configuration, i.e., the values of its synapses, that assigns low energy (values) to sample patterns that come from the true data distribution $p(\mathbf{x})$ and high energy (values) to patterns that do not (or those that do not come from the training dataset). + +```{eval-rst} +.. table:: + :align: center + + +-----------------------------------------------------------------+ + | .. image:: ../images/museum/harmonium/rbm_arch.jpg | + | :scale: 65% | + | :align: center | + +-----------------------------------------------------------------+ +``` + +The most common, simplest harmonium is one where input nodes (one per dimension of the data observation space) are modeled as binary/Boolean sensors -- or "visible units" $\mathbf{z}^0$ (observed variables) that are clamped to actual data patterns -- connected to a layer of (stochastic) binary latent feature detectors -- or "hidden units" $\mathbf{z}^1$ (unobserved or latent variables). Notably, the synaptic connections between the latent and visible units are symmetric. Furthermore, as a result of a key restriction imposed on the harmonium's network structure, i.e., no lateral connections between the neurons/units within $\mathbf{z}^0$ as well as those within $\mathbf{z}^1$, computing the latent and visible states is as straightforward as the following: + +$$ +p(\mathbf{z}^1 | \mathbf{z}^0) &= sigmoid(\mathbf{W} \cdot \mathbf{z}^0 + \mathbf{b}), +\; \mathbf{z}^1 \sim p(\mathbf{z}^1 | \mathbf{z}^0) \\ +p(\mathbf{z}^0 | \mathbf{z}^1) &= sigmoid(\mathbf{W}^T \cdot \mathbf{z}^1 + \mathbf{c}), +\; \mathbf{z}^0 \sim p(\mathbf{z}^0 | \mathbf{z}^1) +$$ + +where $\mathbf{c}$ is the visible bias vector, $\mathbf{b}$ is the latent bias vector, +and $\mathbf{W}$ is the synaptic weight matrix that connects $\mathbf{z}^0$ to $\mathbf{z}^1$ (and its transpose $\mathbf{W}^T$ is used to make predictions of the input itself). Note that $\cdot$ means matrix/vector multiplication and $\sim$ denotes that we would sample from a probability (vector). In the above harmonium's case, samples will be drawn treating the conditionals such as $p(\mathbf{z}^1 | \mathbf{z}^0)$ as multivariate Bernoulli distributions. +$\mathbf{z}^0$ would typically be clamped/set to the actual sensory input data $\mathbf{x}$. + +The energy function of the harmonium's joint configuration $(\mathbf{z}^0,\mathbf{z}^1)$ (similar to that of a Hopfield network) is specified as follows: + +$$ +E(\mathbf{z}^0,\mathbf{z}^1) = -\sum_i \mathbf{c}_i \mathbf{z}^0_i - +\sum_j \mathbf{b}_j \mathbf{z}^1_j - \sum_i \sum_j \mathbf{z}^0_i \mathbf{W}_{ij} \mathbf{z}^1_j . +$$ + +Notice that, in the equation above, we sum over vector dimension indices, e.g., $\mathbf{z}^0_i$ retrieves the $i$th scalar element of (vector) $\mathbf{z}^0$ while $\mathbf{W}_{ij}$ retrieves the scalar element at position $(i,j)$ within matrix $\mathbf{W}$. With this energy function, one can write out the probability that a harmonium PGM assigns to a data point as: + +$$ +p(\mathbf{z}^0 = \mathbf{x}) = \frac{1}{Z} \exp( -E(\mathbf{z}^0,\mathbf{z}^1) ) +$$ + +where $Z$ is the normalizing constant (or, in statistical mechanics, the partition function) needed to obtain proper probability values[^1]. +When one works through the derivation of the gradient of the log probability $\log p(\mathbf{x})$ with respect to the synapses such as $\mathbf{W}$, they get a (contrastive) Hebbian-like update rule as follows: + +$$ +\Delta \mathbf{W} = <\mathbf{z}^0_i \mathbf{z}^1_j>_{data} - <\mathbf{z}^0_i \mathbf{z}^1_j>_{model} +$$ + +where the angle brackets $< >$ tell us that we need to take the expectation of the values within the brackets under a certain distribution (such as the data distribution denoted by the subscript $data$). The above rule can also be considered to be a stochastic form of a general recipe known as contrastive Hebbian learning (CHL) [4]. + +Technically, to compute the update above, obtaining the first term +$<\mathbf{z}^0_i \mathbf{z}^1_j>_{data}$ is easy since we only need to take the product of a data point and its corresponding latent state under the harmonium. However, obtaining the second term $<\mathbf{z}^0_i \mathbf{z}^1_j>_{model}$ is very costly, since we would need to +initialize the value of $\mathbf{z}^0$ to a random initial state and then run a (block) Gibbs sampler for many iterations to accurately approximate the second term. Fortunately, it was shown in work such as [3], that learning a harmonium is still possible by replacing the term $<\mathbf{z}^0_i \mathbf{z}^1_j>_{model}$ with $<\mathbf{z}^0_i \mathbf{z}^1_j>_{recon}$, which is simply computed by using the +first term's latent state $\mathbf{z}^1$ to reconstruct the input and then using this reconstruction once more in order to obtain its corresponding binary latent state. This is known as "contrastive divergence" (CD-1), and, although this approximation has been shown to not actual follow the gradient of any known objective function, it works well in practice when learning a harmonium-based generative model. Finally, the vectorized form of the CD-1 update is: + +$$ +\Delta \mathbf{W} = \Big[ (\mathbf{z}^0_{pos})^T \cdot \mathbf{z}^1_{pos} \Big] - \Big[ (\mathbf{z}^0_{neg})^T \cdot \mathbf{z}^1_{neg} \Big] +$$ + +where the first term (in brackets) is labeled as the "positive phase" (or the positive, data-dependent statistics -- where $\mathbf{z}^0_{pos}$ denotes the positive phase sample of $\mathbf{z}^0$) while the second term is labeled as the "negative phase" (or the negative, data-independent statistics -- where $\mathbf{z}^0_{neg}$ denotes the negative phase sample of $\mathbf{z}^0$). Note that simpler rules of a similar form can be worked out for the latent/visible bias vectors as well. + +In NGC-Learn, to simulate the above harmonium PGM and its CD-1 update, we will model the positive and negative phases as simulated co-models, each responsible for producing the relevant statistics that we will require in order to adjust synapses. Additionally, we will find that we can further re-purpose the created co-models to construct a block Gibbs sampler for confabulating "fantasized" +data patterns from a harmonium that has been fit to data. + + +## Boltzmann Machines: Positive and Negative Co-Models + +We begin by first specifying the structure of the harmonium system that we would like to simulate. In NGC shorthand, the above positive and negative phase graphs would simply be (under one complete generative model): + +``` +z0 -(z0-z1)-> z1 +z1 -(z1-z0) -> z0 +Note: z1-z0 = (z0-z1)^T (transpose-tied synapses) +``` + +In order to construct the desired harmonium, particularly the structure needed to implement CD-1, we will need to break up the model into its key "phases", i.e., a positive phase and a negative phase. We will model each phase as its own simulated nodes-and-cables structure within one single model context, allowing us to craft a general approach that permits a CD-based learning. Notably, we will use the negative-phase co-model to emulate the crucial MCMC sampling step to synthesize data from the trained RBM. + +Building the positive phase of our harmonium can be done as follows: + +```python +with Context("Circuit") as self.circuit: + ## set up positive-phase graph + self.z0 = BernoulliStochasticCell("z0", n_units=obs_dim, is_stoch=False) + self.z1 = BernoulliStochasticCell("z1", n_units=hid_dim, key=subkeys[0]) + + self.W1 = HebbianSynapse( + "W1", shape=(obs_dim, hid_dim), eta=0., weight_init=dist.gaussian(mean=0., std=sigma), + bias_init=dist.constant(value=0.), w_bound=0., optim_type="sgd", sign_value=1., key=subkeys[1] + ) + ## wire up z0 to z1 via synaptic project W1 + self.z0.s >> self.W1.inputs + self.W1.outputs >> self.z1.inputs +``` + +To gather the rest of the statistics that we require, we will need to build the negative phase of our model (which is responsible for "dreaming up" or "confabulating" data samples from its internal model of the world). Constructing the negative-phase co-model, under the same model `Context` above can be done as follows: + +```python + ## set up negative-phase graph + self.z0neg = BernoulliStochasticCell("z0neg", n_units=obs_dim, key=subkeys[3]) + self.z1neg = BernoulliStochasticCell("z1neg", n_units=hid_dim, key=subkeys[4]) + + self.E1 = DenseSynapse( ## E1 = W1.T + "E1", shape=(hid_dim, obs_dim), weight_init=dist.gaussian(mean=0., std=sigma), + bias_init=dist.constant(value=0.), resist_scale=1., key=subkeys[2] + ) + self.E1.weights.set(self.W1.weights.get().T) + self.V1 = HebbianSynapse( ## V1 = W1 + "V1", shape=(obs_dim, hid_dim), eta=0., weight_init=dist.gaussian(mean=0., std=sigma), + bias_init=None, w_bound=0., optim_type="sgd", sign_value=1., key=subkeys[1] + ) + self.V1.weights.set(self.W1.weights.get()) + self.V1.biases.set(self.W1.biases.get()) + + ## wire up z1 to z0(neg) via E1=(W1)^T, and z0(neg) to z1(neg) via V1=W1 + self.z1.s >> self.E1.inputs + self.E1.outputs >> self.z0neg.inputs + self.z0neg.p >> self.V1.inputs ## drive hiddens by probs of visibles + self.V1.outputs >> self.z1neg.inputs +``` + +The above chunk of code effectively sets up the propagation of information from the latent neurons `z1` back down to `z0` (obtaining the negative phase values of `z0`, i.e., `z0neg`) and then the propagation of the reconstructed values back up to `z1` one last time (obtaining the negative phase values of `z1`, i.e., `z0neg`). + +To build a CHL-based form of plasticity, allowing us to build the CD-1 learning process, we will then need to wire up a set of 2-factor Hebbian rules like so: + +```python + ## set up contrastive Hebbian learning rule (pos-stats - neg-stats) + self.z0.s >> self.W1.pre ## positive-phase pre-synaptic term + self.z1.p >> self.W1.post ## positive-phase post-synaptic term + self.z0neg.p >> self.V1.pre ## negative-phase pre-synaptic term + self.z1neg.p >> self.V1.post ## negative-phase pre-synaptic term +``` + +the results of these two Hebbian rules are then used in an exhibit-specific function (`_update_via_CHL()`) written in the [`Harmonium` class](https://github.com/NACLab/ngc-museum/blob/v3/exhibits/harmonium/harmonium.py). +While we observe that our "negative phase" co-model allows us to emulate the CD learning recipe[^2], technically, the negative phase of a harmonium should be run for a very high value of steps (approaching infinity) in order to obtain a proper sample from the PGM's equilibrium/steady state distribution. However, this would be extremely costly to simulate and, as early studies [3] observed, often only a few or even a single step of this Markov chain proved to work quite well, approximating the contrastive divergence objective (the learning algorithm's namesake) instead of direct maximum likelihood. + +Note that the full code, containing the snippets above, can be found in the Model Museum `Harmonium` model structure class. One could further generalize our CD-1 framework to variations, such as "persistent" CD (where we, instead of running `z1` back down through `E1` synapses, we inject random noise instead (to sample the harmonium's latent prior), or even an algorithm known as parallel tempering, where we would maintain multiple co-models and draw samples from all of them to obtain negative-phase statistics. + +Finally, within the `Harmonium` class, we have written a routine for drawing samples from the model directly, i.e., we implement a block Gibbs sampler in order synthesize data from the RBM's current set of parameters. + +## Using the Harmonium to Dream Up Handwritten Digits + +We finally take the harmonium that we have constructed above and fit it to some MNIST digits. Specifically, we will leverage the [Harmonium](https://github.com/NACLab/ngc-museum/blob/v3/exhibits/harmonium/harmonium.py), model in the Model Museum since it implements all of the above core mechanisms (and more) internally. In the script `sim_harmonium.py`, you will find a general training that will fit our harmonium to the MNIST database (unzip the file `mnist.zip` in the `ngc-museum/exhibits/data/` directory if you have not already) by cycling through it several times, saving the final +(best) resulting to disk within the `exp/` sub-directory. Go ahead and execute the training process as follows: + +```console +$ python sim_harmonium.py +``` + +which will fit/adapt your harmonium to MNIST. This should produce per-training iteration output, printed to I/O, similar to the following: + +```console +--- Initial RBM Synaptic Stats --- +W1: min -0.0494 ; max 0.0445 mu -0.0000 ; norm 4.4734 +b1: min -4.0000 ; max -4.0000 mu -4.0000 ; norm 64.0000 +c0: min -11.6114 ; max 0.0635 mu -3.8398 ; norm 135.2238 +-1| Test: E(X) = 99.8526 err(X) = 54.3889 +0| Test: E(X) = 116.6596 err(X) = 46.8236; Train: E(X) = 112.0452 err(X) = 52.7418 +1| Test: E(X) = 89.5413 err(X) = 36.8690; Train: E(X) = 102.4642 err(X) = 41.3630 +2| Test: E(X) = 75.7558 err(X) = 31.8582; Train: E(X) = 82.9692 err(X) = 34.5511 +3| Test: E(X) = 66.6632 err(X) = 28.6253; Train: E(X) = 72.1229 err(X) = 30.4615 +4| Test: E(X) = 60.8256 err(X) = 26.2317; Train: E(X) = 64.3613 err(X) = 27.6882 +5| Test: E(X) = 55.5070 err(X) = 24.3207; Train: E(X) = 58.9254 err(X) = 25.5485 +6| Test: E(X) = 51.7455 err(X) = 22.8012; Train: E(X) = 54.4092 err(X) = 23.8361 +7| Test: E(X) = 49.4866 err(X) = 21.6163; Train: E(X) = 51.1574 err(X) = 22.4523 +8| Test: E(X) = 46.2826 err(X) = 20.5934; Train: E(X) = 48.2617 err(X) = 21.3355 +9| Test: E(X) = 43.8611 err(X) = 19.7679; Train: E(X) = 46.0239 err(X) = 20.4297 +10| Test: E(X) = 42.2886 err(X) = 19.0672; Train: E(X) = 44.3544 err(X) = 19.6835 +11| Test: E(X) = 41.7468 err(X) = 18.4881; Train: E(X) = 42.9321 err(X) = 19.0372 +... + +... +91| Test: E(X) = 65.5179 err(X) = 11.0443; Train: E(X) = 65.0850 err(X) = 10.9832 +92| Test: E(X) = 65.4790 err(X) = 11.0118; Train: E(X) = 64.8345 err(X) = 10.9820 +93| Test: E(X) = 65.9917 err(X) = 11.0013; Train: E(X) = 64.4392 err(X) = 10.9586 +94| Test: E(X) = 64.0737 err(X) = 10.9874; Train: E(X) = 64.2096 err(X) = 10.9312 +95| Test: E(X) = 64.0479 err(X) = 10.9906; Train: E(X) = 63.8461 err(X) = 10.9274 +96| Test: E(X) = 63.5719 err(X) = 10.9712; Train: E(X) = 63.3354 err(X) = 10.8940 +97| Test: E(X) = 64.1757 err(X) = 10.9589; Train: E(X) = 62.8447 err(X) = 10.8960 +98| Test: E(X) = 63.8886 err(X) = 10.9563; Train: E(X) = 62.6391 err(X) = 10.8727 +99| Test: E(X) = 62.2265 err(X) = 10.9347; Train: E(X) = 62.3147 err(X) = 10.8671 +--- Final RBM Synaptic Stats --- +W1: min -1.8648 ; max 1.3757 mu -0.0012 ; norm 70.6230 +b1: min -7.5815 ; max 0.2337 mu -2.3395 ; norm 53.3993 +c0: min -11.6316 ; max -2.4227 mu -5.3259 ; norm 161.5646 +``` + +You will find, after the training script has finished executing, several outputs in the `exp/filters/` model sub-directory that is created for you. Concretely, you will find a grid-plot of the (first `100` of the) harmonium's acquired filters (or "receptive fields"), much as we did for the sparse coding exhibit, that will look similar to the following: + + + +Interestingly enough, we see that our harmonium/RBM has extracted what appears to be rough stroke features, which is what it uses when sampling its binary latent feature detectors to compose final synthesized image patterns (each binary feature detector serves as Boolean function that emits a decision of `1` if the filter is to be used and a `0` if not). In particular, we remark notice that the filters that our harmonium has acquired are a bit more prominent due to the fact our exhibit employs some weight decay (specifically, Gaussian/L2 decay -- with intensity `l2_lambda=0.01` -- to the `W1` synaptic matrix of our RBM). +Weight decay of this form is particularly useful to not only mitigate against the harmonium overfitting to its training data but also to ensure that the Markov chain inherent to its negative-phase mixes more effectively [5] (which ensures better-quality samples from the block Gibbs sampler, which we will use next). + +Finally, you will also find in the `exp/filters/` model sub-folder another grid-plot containing some (about `100`) of the RBM's reconstructions of held-out development data. This plot should look similar to the one below: + + + +### Sampling the Harmonium + +Once the training process has completed, you can then run the following to sample from trained model using block Gibbs sampling: + +```console +$ python sample_harmonium.py +``` + +which will take your trained harmonium's negative-phase co-model and use it to synthesize some digit patterns. You should see inside the `exp/samples/` sub-directory three sample-image grids (i.e., `samples_0.jpg`, `samples_1.jpg`, and `samples_2.jpg`) similar to what is shown below: + +```{eval-rst} +.. image:: ../images/museum/harmonium/samples_0.jpg + :width: 30% +.. image:: ../images/museum/harmonium/samples_1.jpg + :width: 30% +.. image:: ../images/museum/harmonium/samples_2.jpg + :width: 30% +``` + +Furthermore, you will see three corresponding GIFs that have been generated for you that visualize how each of the three simulated sampling Markov chains change with time (i.e., these are the files: `markov_chain_0.gif`, `markov_chain_1.gif`, and `markov_chain_2.gif`). + + + +It is important to understand that the three grids of samples shown above come from particular points in the block Gibbs sampling process. +(Note that one reads these sample grid plots left-column to right-column, and top-row to bottom-row; this way of reading the plot follows the ordering of samples extracted from the specific Markov chain sequence.) +Note that, although each chain is run for many total steps, the `sample_harmonium.py` script "thins" out each Markov chain by only pulling out a fantasized pattern every `20` steps (further "burning" in each chain before collecting samples). Each chain is merely initialized with random Bernoulli noise. Note that higher-quality samples can be obtained if one modifies the earlier harmonium to learn with persistent CD or parallel tempering. + +### Final Notes + +The harmonium that we have built in this exhibit is a classical Bernoulli harmonium/RBM, which is a neural PGM that assumes that the input data features are binary in nature. If one wants to model data that is continuous/real-valued, then the harmonium model above would need to be extended to utilize visible units that follow a continuous distribution; for instance, if one modeled a multivariate Gaussian distribution, this would yield a Gaussian restricted Boltzmann machine (GRBM). + + +## References +[1] Smolensky, P. "Information Processing in Dynamical Systems: Foundations of Harmony Theory" (Chapter 6). Parallel distributed processing: explorations in the microstructure of cognition 1 (1986).
+[2] Geoffrey Hinton. Products of Experts. International conference on artificial neural networks (1999).
+[3] Hinton, Geoffrey E. "Training products of experts by maximizing contrastive likelihood." Technical Report, Gatsby computational neuroscience unit (1999).
+[4] Movellan, Javier R. "Contrastive Hebbian learning in the continuous Hopfield model." Connectionist models. Morgan Kaufmann, 1991. 10-17.
+[5] Hinton, Geoffrey E. "A practical guide to training restricted Boltzmann machines." Neural networks: Tricks of the trade. Springer, Berlin, Heidelberg, 2012. 599-619. + + +[^1]: In fact, it is intractable to compute the partition function $Z$ for any reasonably-sized harmonium; fortunately, we will not need to calculate $Z$ in order to learn and sample from a Harmonium. +[^2]: In general, CD-1 means contrastive divergence where the negative phase is only run for one single step, i.e., `K=1`. The more general form of CD is known as CD-K, the K-step CD algorithm where `K > 1`. (Sometimes, CD-1 is just referred to as just "CD".) diff --git a/docs/museum/index.rst b/docs/museum/index.rst index 25f75dfc..ff62017f 100644 --- a/docs/museum/index.rst +++ b/docs/museum/index.rst @@ -17,5 +17,6 @@ relevant, referenced publicly available ngc-learn simulation code. sparse_coding snn_dc snn_bfa + harmonium sindy rl_snn diff --git a/docs/museum/snn_dc.md b/docs/museum/snn_dc.md index c42c1a45..b7a5af9e 100755 --- a/docs/museum/snn_dc.md +++ b/docs/museum/snn_dc.md @@ -313,24 +313,8 @@ neuroscience 9 (2015): 99. [^1]: Note that the `LIFCell` is not the same as ngc-learn's -[sLIFCell](ngclearn.components.neurons.spiking.sLIFCell), which is a particular -cell that simplifies the spiking dynamics greatly and is not meant to operate -in the negative milliVolt range like the `LIFCell` does. -[^2]: While both forms of modeling electrical current are easily doable in - ngc-learn, the `DC_SNN` exhibit model opts for the second approach for simplicity - and additional simulation speed. -[^3]: Trace components have also been used in the `DC_SNN` exhibit model, specifically -those built with the [variable trace](ngclearn.components.other.varTrace) component. -Note that the variable trace effectively applies a low-pass filter iteratively -to the spikes produced by a spike train. -[^4]: In the NAC group's -experience, observing the mean and Frobenius norm of synaptic values can be a -useful starting point for determining unhealthy behavior or some degenerate cases -in the context of spiking neural network credit assignment. -[^5]: To load in the exact synaptic efficacies we obtained to get the images -above, you can unzip the folder `dcsnn_syn.zip`, which contains all of the -model's numpy array values, and simply copy all of the compressed numpy arrays -into your `exp/snn_stdp/custom/` folder, which is where ngc-learn/ngc-sim-lib -look for pre-trained value arrays when loading in a previously constructed model. -Once you do this, running `analyze_dcsnn.py` with the same arguments as above -should produce plots/images much like those in this walkthrough. +[sLIFCell](ngclearn.components.neurons.spiking.sLIFCell), which is a particular cell that simplifies the spiking dynamics greatly and is not meant to operate in the negative milliVolt range like the `LIFCell` does. +[^2]: While both forms of modeling electrical current are easily doable in NGC-Learn, the `DC_SNN` exhibit model opts for the second approach for simplicity and additional simulation speed. +[^3]: Trace components have also been used in the `DC_SNN` exhibit model, specifically those built with the [variable trace](ngclearn.components.other.varTrace) component. Note that the variable trace effectively applies a low-pass filter iteratively to the spikes produced by a spike train. +[^4]: In the NAC group's experience, observing the mean and Frobenius norm of synaptic values can be a useful starting point for determining unhealthy behavior or some degenerate cases in the context of spiking neural network credit assignment. +[^5]: To load in the exact synaptic efficacies we obtained to get the images above, you can unzip the folder `dcsnn_syn.zip`, which contains all of the model's numpy array values, and simply copy all of the compressed numpy arrays into your `exp/snn_stdp/custom/` folder, which is where ngc-learn/ngc-sim-lib look for pre-trained value arrays when loading in a previously constructed model. Once you do this, running `analyze_dcsnn.py` with the same arguments as above should produce plots/images much like those in this walkthrough. From 7026c8c69ee7c1627e1129d8eb6ad8df5a8c2d5e Mon Sep 17 00:00:00 2001 From: Will Gebhardt Date: Wed, 3 Dec 2025 08:12:53 -0500 Subject: [PATCH 104/121] Update __init__.py Added the config/logging back to the init --- ngclearn/__init__.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/ngclearn/__init__.py b/ngclearn/__init__.py index fa1d4030..64e7d7d7 100644 --- a/ngclearn/__init__.py +++ b/ngclearn/__init__.py @@ -32,24 +32,9 @@ from ngcsimlib.context import Context, ContextObjectTypes from ngcsimlib import Component from ngcsimlib.compartment import Compartment - -from ngcsimlib import logger +from ngcsimlib import logger, configure if not Path(argv[0]).name == "sphinx-build" or Path(argv[0]).name == "build.py": if "readthedocs" not in argv[0]: ## prevent readthedocs execution of preload - # configure() - # logger.init_logging() - # from ngcsimlib.configManager import get_config - # pkg_config = get_config("packages") - # if pkg_config is not None: - # use_base_numpy = pkg_config.get("use_base_numpy", False) - # if use_base_numpy: - # import numpy as numpy - # else: - # from jax import numpy - # else: - # from jax import numpy - # - # - # preload_modules() - a = 2 + configure() + logger.init_logging() From 81fbf4d36c5ea4cfeae234d6dc1659d5ff14fef5 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 3 Dec 2025 11:24:30 -0500 Subject: [PATCH 105/121] placed pointer to rao-ballard1999 exhibit; updates to docs --- docs/museum/index.rst | 16 ++++++++++------ docs/museum/model_museum.md | 17 +++-------------- docs/museum/pc_rao_ballard1999.md | 12 ++++++++++++ 3 files changed, 25 insertions(+), 20 deletions(-) create mode 100644 docs/museum/pc_rao_ballard1999.md diff --git a/docs/museum/index.rst b/docs/museum/index.rst index ff62017f..8b0ce9d3 100644 --- a/docs/museum/index.rst +++ b/docs/museum/index.rst @@ -5,18 +5,22 @@ Model Exhibits ============== -Models are presented in ngc-learn's model museum in the form of "exhibits", -which are effectively model-specific walkthroughs and analyses, based on the -relevant, referenced publicly available ngc-learn simulation code. +Models are presented in ngc-learn's model museum in the form of "exhibits", which are effectively model-specific walkthroughs and analyses, based on the relevant, referenced publicly available ngc-learn simulation code. .. toctree:: :maxdepth: 1 - :caption: Neuromimetic Models + :caption: Neuroscience Models - pcn_discrim sparse_coding + pc_rao_ballard1999 snn_dc + rl_snn + +.. toctree:: + :maxdepth: 1 + :caption: NeuroAI / Neuro-mimetic Models + + pcn_discrim snn_bfa harmonium sindy - rl_snn diff --git a/docs/museum/model_museum.md b/docs/museum/model_museum.md index bd154524..c03cb2d8 100644 --- a/docs/museum/model_museum.md +++ b/docs/museum/model_museum.md @@ -1,20 +1,9 @@ # The Model Museum -There is an ever-growing galaxy of neurobiological models and credit assignment -processes [1, 2]. One of ngc-learn's aims, in the spirit of scientific -reproducibility, is to capture a snapshot of as many of these -biomimetic/neuro-mimetic models as possible, in the form of a digital "museum". -This museum is further designed with the notion of exhibits and exhibitors, -aiding to facilitate credit assignment and respectful citation to the ideas and -the work of those that have helped to lay the foundations for the progress -observed today. Recently, we have separated out the model museum into its own -particular maintained repository called -[ngc-museum](https://github.com/NACLab/ngc-museum), where you can find and -access/run historical models and agents built with ngc-learn to perform -different experimental tasks. +There is an ever-growing galaxy of neurobiological models and credit assignment processes [1, 2]. One of ngc-learn's aims, in the spirit of scientific reproducibility, is to capture a snapshot of as many of these biomimetic/neuro-mimetic models as possible, in the form of a digital "museum". +This museum is further designed with the notion of exhibits and exhibitors, aiding to facilitate credit assignment and respectful citation to the ideas and the work of those that have helped to lay the foundations for the progress observed today. Recently, we have separated out the model museum into its own particular maintained repository called [ngc-museum](https://github.com/NACLab/ngc-museum), where you can find and access/run historical models and agents built with ngc-learn to perform different experimental tasks. -Please refer to the [table of contents](../museum/index.rst) for walkthroughs on -using and running various historical models in the museum. +Please refer to the [table of contents](../museum/index.rst) for walkthroughs and guidance on using and running various historical model exhibits in the museum. ## References [1] Ororbia, Alexander G. "Brain-inspired machine intelligence: A survey diff --git a/docs/museum/pc_rao_ballard1999.md b/docs/museum/pc_rao_ballard1999.md new file mode 100644 index 00000000..c6de5a76 --- /dev/null +++ b/docs/museum/pc_rao_ballard1999.md @@ -0,0 +1,12 @@ +# Hierarchical Predictive Coding (Rao & Ballard) + +In this exhibit, we create, simulate, and visualize the +internally acquired receptive fields of the predictive coding model originally proposed in (Rao & Ballard, 1999) [1]. + +The model code for this +exhibit can be found +[here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/pc_recon). + + +## References +[1] Rao, Rajesh PN, and Dana H. Ballard. "Predictive coding in the visual cortex: a functional interpretation of some extra-classical receptive-field effects." Nature neuroscience 2.1 (1999): 79-87. \ No newline at end of file From 09d0375a5ff2824f39848cd967e02b16aa950d5f Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 3 Dec 2025 13:35:54 -0500 Subject: [PATCH 106/121] updates to docs/revisions --- docs/index.rst | 3 +- docs/museum/sparse_coding.md | 266 +++++++++++++---------------------- docs/ngclearn_papers.md | 6 +- docs/ngclearn_talks.md | 13 ++ 4 files changed, 114 insertions(+), 174 deletions(-) create mode 100644 docs/ngclearn_talks.md diff --git a/docs/index.rst b/docs/index.rst index 55af9aaa..7ab63e25 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -49,9 +49,10 @@ Welcome to ngc-learn's documentation! .. toctree:: :maxdepth: 1 - :caption: Papers that use NGC-Learn + :caption: NGC-Learn Papers & Media ngclearn_papers + ngclearn_talks Indices and tables ================== diff --git a/docs/museum/sparse_coding.md b/docs/museum/sparse_coding.md index d36dbf50..929fd4c0 100755 --- a/docs/museum/sparse_coding.md +++ b/docs/museum/sparse_coding.md @@ -1,86 +1,59 @@ # Sparse Coding and Iterative Thresholding -In this exhibit, we create, simulate, and visualize the -internally acquired filters/atoms of variants of a sparse coding system based -on the classical model proposed by (Olshausen & Field, 1996) [1]. +In this exhibit, we create, simulate, and visualize the internally acquired filters/atoms of variants of a sparse coding system based on the classical model proposed by (Olshausen & Field, 1996) [1]. After going through this demonstration, you will: -1. Learn how to build a 2-layer sparse coding model of natural image patterns, -using the original dataset used in [1]. -2. Visualize the acquired filters of the learned dictionary models and examine -the results of imposing a kurtotic prior as well as a thresholding function -over latent codes. +1. Learn how to build a 2-layer sparse coding model of natural image patterns, using the original dataset used in [1]. +2. Visualize the acquired filters of the learned dictionary models and examine the results of imposing a kurtotic prior as well as a thresholding function over latent codes. -The model code for this -exhibit can be found -[here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/olshausen_sc). +The model code for this exhibit can be found [here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/olshausen_sc). -Note: You will need to unzip the data arrays in `exhibits/data/natural_scenes.zip` -to the folder `exhibits/data/` to work through this exhibit. +Note: You will need to unzip the data arrays in `exhibits/data/natural_scenes.zip` to the folder `exhibits/data/` to work through this exhibit. ## On Dictionary Learning -Dictionary learning poses a very interesting question for statistical learning: -can we extract "feature detectors" from a given database (or collection of patterns) -such that only a few of these detectors play a role in reconstructing any given, -original pattern/data point? -The aim of dictionary learning is to acquire or learn a matrix, also called the -"dictionary", which is meant to contain "atoms" or basic elements inside this dictionary -(such as simple fundamental features such as the basic strokes/curves/edges -that compose handwritten digits or characters). Several atoms (or rows of this -matrix) inside the dictionary can then be linearly combined to reconstruct a -given input signal or pattern. A sparse dictionary model is able to reconstruct -input patterns with as few of these atoms as possible. Typical sparse dictionary -or coding models work with an over-complete spanning set, or, in other words, -a latent dimensionality (which one could think of as the number of neurons -in a single latent state node of an ngc-learn system) that is greater than the -dimensionality of the input itself. - -From a neurobiological standpoint, sparse coding emulates a fundamental property -of neural populations -- the activities among a neural population are sparse where, -within a period of time, the number of total active neurons (those that are firing) -is smaller than the total number of neurons in the population itself. When sensory -inputs are encoded within this population, different subsets (which might overlap) of -neurons activate to represent different inputs (one way to view this is that they -"fight" or compete for the right to activate in response to different stimuli). -Classically, it was shown in [1] that a sparse coding model trained on natural -image patches learned within its dictionary non-orthogonal filters that resembled -receptive fields of simple-cells (found in the visual cortex). +Dictionary learning poses a very interesting question for statistical learning: can we extract "feature detectors" from a given database (or collection of patterns) such that only a few of these detectors play a role in reconstructing any given, original pattern/data point? +The aim of dictionary learning is to acquire or learn a matrix, also called the "dictionary", which is meant to contain "atoms" or basic elements inside this dictionary (such as simple fundamental features such as the basic strokes/curves/edges that compose handwritten digits or characters). Several atoms (or rows of this matrix) inside the dictionary can then be linearly combined to reconstruct a given input signal or pattern. A sparse dictionary model is able to reconstruct input patterns with as few of these atoms as possible. Typical sparse dictionary or coding models work with an over-complete spanning set, or, in other words, a latent dimensionality (which one could think of as the number of neurons in a single latent state node of an ngc-learn system) that is greater than the dimensionality of the input itself. + +From a neurobiological standpoint, sparse coding emulates a fundamental property of neural populations -- the activities among a neural population are sparse where, within a period of time, the number of total active neurons (those that are firing) is smaller than the total number of neurons in the population itself. When sensory inputs are encoded within this population, different subsets (which might overlap) of neurons activate to represent different inputs (one way to view this is that they "fight" or compete for the right to activate in response to different stimuli). +Classically, it was shown in [1] that a sparse coding model trained on natural image patches learned within its dictionary non-orthogonal filters that resembled receptive fields of simple-cells (found in the visual cortex). ## Constructing a Sparse Coding System -To build a sparse coding model, we can manually craft a model using ngc-learn's -nodes-and-cables system. First, we specify the underlying generative model we -aim to emulate. Formally, we seek to optimize a set of latent codes according -to the following differential equation: +To build a sparse coding model, we can manually craft a model using ngc-learn's nodes-and-cables system. First, we specify the underlying generative model we aim to emulate. Formally, we seek to optimize a set of latent codes according to the following differential equation: $$ \tau_m \frac{\partial \mathbf{z}_t}{\partial t} = \big(\mathbf{W}^T \cdot \mathbf{e}(t) \big) + \lambda \Omega\big(\mathbf{z}(t)\big) $$ -where $\tau_m$ is the latent code time constant and the error neurons $\mathbf{e}(t)$ -at the sensory input layer made at time $t$ are specified as: +where the above is also referred to as the E-step (since the optimization carried out for most sparse coding models is done within the framework of expectation-maximization -- E-step refers to updates to the latent variables whereas M-step refers to updates to synaptic/dictionary parameters) and $\tau_m$ is the latent code time constant and the error neurons $\mathbf{e}(t)$ at the sensory input layer made at time $t$ are specified as: $$ \mathbf{e}(t) = -\big(\mathbf{o}_t - (\mathbf{W} \cdot \mathbf{z}(t)) \big) $$ -where we see that we aim to learn a two-layer generative system that specifically -imposes a prior distribution `p(z)` over the latent feature detectors (via the -constraint function $ \Omega\big(\mathbf{z}(t)\big) $ ) that we hope -to extract in node `z`. Note that this two-layer model (or single latent-variable layer -model) could either be the linear generative model from [1] or one similar to the -model learned through ISTA [2] if a (soft) thresholding function is used instead. +where we see that we aim to learn a two-layer generative system that specifically imposes a prior distribution `p(z)` over the latent feature detectors (via the constraint function $ \Omega\big(\mathbf{z}(t)\big) $ ) that we hope to extract in node `z`. Note that this two-layer model (or single latent-variable layer model) could either be the linear generative model from [1] or one similar to the model learned through ISTA [2] if a (soft) thresholding function is used instead. + +Furthermore, the synaptic weight updates (the M-step) to our sparse coding model generally adhere to the following differential equation: + +$$ +\tau_m \frac{\partial \mathbf{W}}{\partial t} = -\mathbf{W} + \big(\mathbf{e}(t) \cdot (\mathbf{z}(t))^T \big) +$$ -Constructing the above system for (Olshausen & Field, 1996) is done, much -like we do in the `SparseCoding` agent constructor in the model museum exhibit -code, as follows: +Constructing the above system for (Olshausen & Field, 1996) is done, much like we do in the `SparseCoding` agent constructor in the model museum exhibit code, as follows: ```python -from ngcsimlib.context import Context -from ngclearn.components import GaussianErrorCell as ErrorCell, RateCell, HebbianSynapse, StaticSynapse +from ngclearn.utils.io_utils import makedir +from ngclearn.utils.viz.synapse_plot import visualize +from jax import numpy as jnp, random, jit +from ngclearn import Context, MethodProcess, JointProcess +from ngclearn.components.neurons.graded.rateCell import RateCell +from ngclearn.components.synapses.denseSynapse import DenseSynapse +from ngclearn.components.synapses.hebbian.hebbianSynapse import HebbianSynapse +from ngclearn.components.neurons.graded.gaussianErrorCell import GaussianErrorCell as ErrorCell from ngclearn.utils.model_utils import normalize_matrix +from ngclearn.utils.distribution_generator import DistributionGenerator as dist in_dim = # ... dimension of patch data ... hid_dim = # ... number of atoms in the dictionary matrix @@ -88,174 +61,127 @@ dt = 1. # ms T = 300 # ms # (OR) number of E-steps to take during inference # ---- build a sparse coding linear generative model with a Cauchy prior ---- with Context("Circuit") as circuit: - z1 = RateCell("z1", n_units=hid_dim, tau_m=20., act_fx="identity", - prior=("cauchy", 0.14), integration_type="euler") + z1 = RateCell( + "z1", n_units=hid_dim, tau_m=20, act_fx="identity", prior=("cauchy", 0.14), integration_type="euler" + ) e0 = ErrorCell("e0", n_units=in_dim) - W1 = HebbianSynapse("W1", shape=(hid_dim, in_dim), - eta=1e-2, wInit=("fan_in_gaussian", 0., 1.), - bInit=None, w_bound=0., optim_type="sgd", signVal=-1.) - E1 = StaticSynapse("E1", shape=(in_dim, hid_dim), - wInit=("uniform", -0.2, 0.2), Rscale=1.) + W1 = HebbianSynapse( + "W1", shape=(hid_dim, in_dim), eta=1e-2, weight_init=dist.fan_in_gaussian(), bias_init=None, w_bound=0., optim_type="sgd", sign_value=-1. + ) + E1 = DenseSynapse( ## E1 = (W1)^T + "E1", shape=(in_dim, hid_dim), weight_init=dist.uniform(-0.2, 0.2), + resist_scale=1. + ) + E1.weights.set(W1.weights.get().T) + ## wire z1.zF to e0.mu via W1 - W1.inputs << z1.zF - e0.mu << W1.outputs - ## wire e0.dmu to z1.j - E1.inputs << e0.dmu - z1.j << E1.outputs - ## Setup W1 for its 2-factor Hebbian update - W1.pre << z1.zF - W1.post << e0.dmu - - reset_cmd, reset_args = circuit.compile_by_key( - W1, E1, z1, e0, - compile_key="reset") - advance_cmd, advance_args = circuit.compile_by_key( - W1, E1, z1, e0, - compile_key="advance_state") - evolve_cmd, evolve_args = circuit.compile_by_key(W1, compile_key="evolve") + z1.zF >> W1.inputs + W1.outputs >> e0.mu + ## wire e0.dmu back up to z1.j via E1 (for E-step) + e0.dmu >> E1.inputs + E1.outputs >> z1.j + + ## Setup W1 for its 2-factor Hebbian update (for M-step) + z1.zF >> W1.pre + e0.dmu >> W1.post + + ## Inference process + advance = (MethodProcess(name="advance") + >> W1.advance_state + >> E1.advance_state + >> z1.advance_state + >> e0.advance_state) + ## Reset-to-baseline process + reset = (MethodProcess(name="reset") + >> W1.reset + >> E1.reset + >> z1.reset + >> e0.reset) + ## Learning process + evolve = (MethodProcess(name="evolve") + >> W1.evolve) ``` -Notice that, in our model `circuit`, we have taken care to set the `.param_axis` -variable to be equal to `1` -- this will, whenever we call `apply_constraints()`, -tell the NGC system to normalize the Euclidean norm of the columns -of the dictionary matrix to be equal to a value of one. This is a particularly -important constraint to apply to sparse coding models as this prevents the -trivial solution of simply growing out -the magnitude of the dictionary synapses to solve the underlying constrained -optimization problem (and, in general, constraining the rows or -columns of generative models helps to facilitate a more stable training process). -This norm constraint is configured in the agent constructor's dynamic -compile function, specifically in the snippet below: +There is one important co-routine we also need to make sure we include for our sparse coding `circuit` that needs to happen after each update to the synapses -- synaptic weight normalization. Specifically, we want to normalize the Euclidean norm of the columns of the dictionary matrix to be equal to a value of one. + +This is a particularly important constraint to apply to sparse coding models as this prevents the trivial solution of simply growing out the magnitude of the dictionary synapses to solve the underlying constrained optimization problem (and, in general, constraining the rows or columns of generative models helps to facilitate a more stable training process). This norm constraint can be simply written as below: ```python -@Context.dynamicCommand def norm(): - W1.weights.set(normalize_matrix(W1.weights.value, 1., order=2, axis=1)) + W1.weights.set(normalize_matrix(W1.weights.get(), 1., order=2, axis=1)) ``` -To build the version of our model (the ISTA model) using a thresholding function, -instead of using a factorial prior over the latents, we can write the following: +To build the version of our model (the ISTA model) using a thresholding function, instead of using a factorial prior over the latents, we can write the following: ```python # ---- build a sparse coding generative model w/ a thresholding function ---- with Context("Circuit") as circuit: - z1 = RateCell("z1", n_units=hid_dim, tau_m=20., act_fx="identity", - threshold=("soft_threshold", 5e-3), integration_type="euler") + z1 = RateCell( + "z1", n_units=hid_dim, tau_m=20, act_fx="identity", threshold=("soft_threshold", 5e-3), integration_type="euler" + ) e0 = ErrorCell("e0", n_units=in_dim) - W1 = HebbianSynapse("W1", shape=(hid_dim, in_dim), - eta=1e-2, wInit=("fan_in_gaussian", 0., 1.), - bInit=None, w_bound=0., optim_type="sgd", signVal=-1.) - E1 = StaticSynapse("E1", shape=(in_dim, hid_dim), - wInit=("uniform", -0.2, 0.2), Rscale=1.) + W1 = HebbianSynapse( + "W1", shape=(hid_dim, in_dim), eta=1e-2, weight_init=dist.fan_in_gaussian(), bias_init=None, w_bound=0., optim_type="sgd", sign_value=-1. + ) + E1 = DenseSynapse( + "E1", shape=(in_dim, hid_dim), weight_init=dist.uniform(-0.2, 0.2), + resist_scale=1. + ) + E1.weights.set(W1.weights.get().T) ## ...rest of the code is the same as the Cauchy prior model... ``` -Note that the above two models are built and configured for you in the -Model Museum, in the `museum/exhibits/olshausen_sc/sparse_coding.py` -agent constructor, which internally implements the model contexts depicted above -as well as the necessary task-specific functions needed to reproduce the -correct experimental setup (these get compiled in the constructor's -`dynamic()` method. For both the Cauchy prior model of [1] -and the iterative thresholding model of [2], we track, in the -training script `train_patch_sc.py`, various dictionary synaptic -statistics and a measurement of the model reconstruction loss. The -reconstruction loss is a key part of the objective that both models -optimize, i.e., both SC models effectively optimize an -energy function that is a sum of its reconstruction error of its sensory -input and the sparsity of its single latent state layer `z1`). +Note that the above two models are built and configured for you in the Model Museum, in the `museum/exhibits/olshausen_sc/sparse_coding.py` agent constructor, which internally implements the model contexts depicted above as well as the necessary task-specific functions needed to reproduce the correct experimental setup (these get compiled in the constructor's `dynamic()` method. For both the Cauchy prior model of [1] and the iterative thresholding model of [2], we track, in the training script `train_patch_sc.py`, various dictionary synaptic statistics and a measurement of the model reconstruction loss. The reconstruction loss is a key part of the objective that both models optimize, i.e., both SC models effectively optimize an energy function that is a sum of its reconstruction error of its sensory input and the sparsity of its single latent state layer `z1`). ## Learning Latent Feature Detectors -We will now simulate the learning of feature detectors using the two -sparse coding models specified above. The code provided in -`train_patch_sc.py` will execute a simulation of the above -two models on the natural images found in `exhibits/data/natural_scenes.zip`), -which is a dataset composed of several images of the American Northwest. - -First, navigate to the `exhibits/` directory to access the example/demonstration -code and further enter the `exhibits/data/` sub-folder. Unzip the file -`natural_scenes.zip` to create one more sub-folder that contains two numpy arrays, -the first labeled `natural_scenes/raw_dataX.npy` and another labeled as -`natural_scenes/dataX.npy`. The first one contains the original, `512 x 512` raw pixel -image arrays (flattened) while the second contains the pre-processed, whitened/normalized -(and flattened) image data arrays (these are the pre-processed image patterns used -in [1]). You will, in this demonstration, only be working with `natural_scenes/dataX.npy`. +We will now simulate the learning of feature detectors using the two sparse coding models specified above. The code provided in `train_patch_sc.py` will execute a simulation of the above two models on the natural images found in `exhibits/data/natural_scenes.zip`), which is a dataset composed of several images of the American Northwest. + +First, navigate to the `exhibits/` directory to access the example/demonstration code and further enter the `exhibits/data/` sub-folder. Unzip the file `natural_scenes.zip` to create one more sub-folder that contains two numpy arrays, the first labeled `natural_scenes/raw_dataX.npy` and another labeled as `natural_scenes/dataX.npy`. The first one contains the original, `512 x 512` raw pixel image arrays (flattened) while the second contains the pre-processed, whitened/normalized (and flattened) image data arrays (these are the pre-processed image patterns used in [1]). You will, in this demonstration, only be working with `natural_scenes/dataX.npy`. Two (raw) images sampled from the original dataset (`raw_dataX.npy`) are shown below: | | | |---|---| | ![](../images/museum/data_img1.png) | ![](../images/museum/data_img2.png) | -With the data unpacked and ready, we can now run the training process in -the model exhibit by either executing its Python simulation script like so: +With the data unpacked and ready, we can now run the training process in the model exhibit by either executing its Python simulation script like so: ```console $ python train_patch_sc.py --dataX="$DATA_DIR/dataX.npy" \ --n_iter=200 --model_type="sc_cauchy" ``` -or simply running the convenience Bash script `$ ./sim.sh` (which cleans up the model -experimental output folder each time you call the training script in order -to reduce memory clutter on your system). Running either the Python or Bash -script will then train a sparse coding model with a Cauchy prior on `16 x 16` -pixel patches from the natural image dataset in [1].[^1] After the simulation -terminates, i.e., once `200` iterations/passes through the data have been made, -you will notice in the `exp/filters/` sub-directory a visual plot -of your trained model's filters which should look like the one below: +or simply running the convenience Bash script `$ ./sim.sh` (which cleans up the model experimental output folder each time you call the training script in order to reduce memory clutter on your system). Running either the Python or Bash script will then train a sparse coding model with a Cauchy prior on `16 x 16` pixel patches from the natural image dataset in [1].[^1] After the simulation terminates, i.e., once `200` iterations/passes through the data have been made, you will notice in the `exp/filters/` sub-directory a visual plot of your trained model's filters which should look like the one below: -If you modify either the Bash script or Python script call to use -with a different model argument like so: +If you modify either the Bash script or Python script call to use with a different model argument like so: ```console $ python train_patch_sc.py --dataX="$DATA_DIR/dataX.npy" \ --n_iter=200 --model_type="sc_ista" ``` -you will now train your sparse coding using a latent soft-thresholding function -(emulating ISTA). After this simulated training process ends, you should see -in your `exp/filters/` sub-directory a filter plot like the one below: +you will now train your sparse coding using a latent soft-thresholding function (emulating ISTA). After this simulated training process ends, you should see in your `exp/filters/` sub-directory a filter plot like the one below: -The filter plots, notably, visually indicate that the dictionary atoms in both -sparse coding systems learned to function as edge detectors, each tuned to -a particular position, orientation, and frequency. These learned feature detectors, -as discussed in [1], appear to behave similar to the primary visual area (V1) -neurons of the cerebral cortex in the brain. In the end, even though the edge -detectors learned by both our models qualitatively appear to be similar, -we should note that the latent codes (when inferring them given sensory input) -for the model that used the thresholding function will ultimately sparser -(given the direct clamping to zero values it imposes mathematically). -Furthermore, the filters for the model with thresholding appear to smoother -and with fewer occurrences of less-than-useful slots than the Cauchy model -(or filters that did not appear to extract any particularly interpretable -features). +The filter plots, notably, visually indicate that the dictionary atoms in both sparse coding systems learned to function as edge detectors, each tuned to a particular position, orientation, and frequency. These learned feature detectors, as discussed in [1], appear to behave similar to the primary visual area (V1) neurons of the cerebral cortex in the brain. In the end, even though the edge detectors learned by both our models qualitatively appear to be similar, we should note that the latent codes (when inferring them given sensory input) for the model that used the thresholding function will ultimately sparser (given the direct clamping to zero values it imposes mathematically). +Furthermore, the filters for the model with thresholding appear to smoother and with fewer occurrences of less-than-useful slots than the Cauchy model (or filters that did not appear to extract any particularly interpretable features). ### Computing Hardware Note: -This tutorial was tested and run on an `Ubuntu 22.04.2 LTS` operating system -using an `NVIDIA GeForce RTX 2070` GPU with `CUDA Version: 12.1` -(`Driver Version: 530.41.03`). Note that the times reported in any tutorial -screenshot/console snippets were produced on this system. +This tutorial was tested and run on an `Ubuntu 22.04.2 LTS` operating system using an `NVIDIA GeForce RTX 2070` GPU with `CUDA Version: 12.1` (`Driver Version: 530.41.03`). Note that the times reported in any tutorial screenshot/console snippets were produced on this system. ## References [1] Olshausen, B., Field, D. Emergence of simple-cell receptive field properties diff --git a/docs/ngclearn_papers.md b/docs/ngclearn_papers.md index b89d005f..04460db5 100644 --- a/docs/ngclearn_papers.md +++ b/docs/ngclearn_papers.md @@ -1,6 +1,6 @@ -# List of Papers/Publications +# List of Papers and Publications -The following is a list of current papers that use ngc-learn (this list will be actively updated as we discover others that use ngc-learn): +The following is a list of current papers that use NGC-Learn (this list will be actively updated as we discover others that use NGC-Learn): 1. Ororbia, A., and Kifer, D. The neural coding framework for learning generative models. Nature Communications 13, 2064 (2022). @@ -16,4 +16,4 @@ The following is a list of current papers that use ngc-learn (this list will be 7. Ororbia, A., Friston, K., Rao, Rajesh P. N. "Meta-representational predictive coding: Biomimetic self-supervised learning." arXiv preprint arXiv:2503.21796 (2025). -Note: Please let us know if your work uses ngc-learn so we can update this page to accurately track ngc-learn's use and include your work in the accumulating body of work in predictive processing and/or brain-inspired computational modeling. +Note: Please let us know if your work uses NGC-Learn so we can update this page to accurately track NGC-Learn's use and include your work in the space of computational neuroscience, NeuroAI, and/or brain-inspired computational modeling. diff --git a/docs/ngclearn_talks.md b/docs/ngclearn_talks.md new file mode 100644 index 00000000..421da729 --- /dev/null +++ b/docs/ngclearn_talks.md @@ -0,0 +1,13 @@ +# Talks and Media Related to NGC-Learn + +The following is a list of talks and any media related to NGC-Learn: + +1. "NGC-Learn V3: A Fast, Modular, Computational Neuroscience Library". William Gebhardt (NAC Lab). Link (Youtube Video) (2025) + +2. "An Introduction to NGC-Learn: A Computational Neuroscience Library". William Gebhardt (NAC Lab). Link (Youtube Video) (2024)[^1] + +Keep an eye out on the [NAC Lab Youtube channel](https://www.youtube.com/@TheNACLab/featured) for additional future videos related to tutorials, educational material, and research related to NGC-Learn. + +Note: Please let us know if you give any tutorials/talks on work related that makes use of NGC-Learn so we can update this page to make these useful educational materials available to a wider audience. + +[^1]: Note that this talk is related to NGC-Learn (v2)/NGC-Sim-Lib (v1). \ No newline at end of file From 091ee742361ea5d0362df9579e5964f58014164b Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 3 Dec 2025 16:16:26 -0500 Subject: [PATCH 107/121] removed flag from bernoulli/latency-cells for now; minor edit to doc --- docs/installation.md | 2 +- .../components/input_encoders/bernoulliCell.py | 4 ++-- .../components/input_encoders/latencyCell.py | 18 +++++++++--------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/installation.md b/docs/installation.md index 2c482c43..75cf5c21 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -2,7 +2,7 @@ **ngc-learn** officially supports Linux on Python 3. It can be run with or without a GPU. -Setup: NGC-Learn, in its entirety (including its supporting utilities), requires that you ensure that you have installed the following base dependencies in your system. Note that this library was developed and tested on Ubuntu 22.04 (with much earlier versions on Ubuntu 18.04/20.04). +Setup: NGC-Learn, in its entirety (including its supporting utility sub-packages), requires that you ensure that you have installed the following base dependencies in your system. Note that this library was developed and tested on Ubuntu 22.04 (with much earlier versions on Ubuntu 18.04/20.04). Specifically, NGC-Learn requires: * Python (>=3.10) * ngcsimlib (>=2.0.0), (official page) diff --git a/ngclearn/components/input_encoders/bernoulliCell.py b/ngclearn/components/input_encoders/bernoulliCell.py index 4f18f9d8..87b965fc 100755 --- a/ngclearn/components/input_encoders/bernoulliCell.py +++ b/ngclearn/components/input_encoders/bernoulliCell.py @@ -30,8 +30,8 @@ def __init__(self, name: str, n_units: int, batch_size: int = 1, key: Union[jax. super().__init__(name=name, key=key) ## Layer Size Setup - self.batch_size = Compartment(batch_size, fixed=True) - self.n_units = Compartment(n_units, fixed=True) + self.batch_size = Compartment(batch_size) + self.n_units = Compartment(n_units) restVals = jnp.zeros((batch_size, n_units)) self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment diff --git a/ngclearn/components/input_encoders/latencyCell.py b/ngclearn/components/input_encoders/latencyCell.py index 3550f086..c0708e3d 100755 --- a/ngclearn/components/input_encoders/latencyCell.py +++ b/ngclearn/components/input_encoders/latencyCell.py @@ -151,18 +151,18 @@ def __init__( super().__init__(name=name, key=key) ## latency meta-parameters - self.first_spike_time = Compartment(first_spike_time, fixed=True) - self.tau = Compartment(tau, fixed=True) - self.threshold = Compartment(threshold, fixed=True) - self.linearize = Compartment(linearize, fixed=True) - self.clip_spikes = Compartment(clip_spikes, fixed=True) + self.first_spike_time = Compartment(first_spike_time) + self.tau = Compartment(tau) + self.threshold = Compartment(threshold) + self.linearize = Compartment(linearize) + self.clip_spikes = Compartment(clip_spikes) ## normalize latency code s.t. final spike(s) occur w/in num_steps - self.normalize = Compartment(normalize, fixed=True) - self.num_steps = Compartment(num_steps, fixed=True) + self.normalize = Compartment(normalize) + self.num_steps = Compartment(num_steps) ## Layer Size Setup - self.batch_size = Compartment(batch_size, fixed=True) - self.n_units = Compartment(n_units, fixed=True) + self.batch_size = Compartment(batch_size) + self.n_units = Compartment(n_units) ## Compartment setup restVals = jnp.zeros((batch_size, n_units)) From 633a63deb87a222df4f749bc126cdc674dc16675 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 3 Dec 2025 16:26:14 -0500 Subject: [PATCH 108/121] updates to theory doc --- docs/tutorials/theory.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/tutorials/theory.md b/docs/tutorials/theory.md index a39133cf..1622a6bf 100755 --- a/docs/tutorials/theory.md +++ b/docs/tutorials/theory.md @@ -1,11 +1,15 @@ # Theory and Design Motivation ## Cable Theory and Neural Compartments -At its core, part of NGC-Learn's internal design is inspired by (neural) cable theory , where neuronal units, which are arranged in complex connectivity structures, are viewed as performing dendritic calculations (of varying complexity). In essence, a particular neuron integrates information from different input signal sources (for example, signals produced by other neurons), in often highly nonlinear ways through a complex dendritic tree. +At its core, part of NGC-Learn's internal design is inspired by (neural) cable theory and neuronal compartment models [1], where neuronal units, which are arranged in complex connectivity structures, are viewed as performing dendritic calculations (of varying complexity). In essence, a particular neuron integrates information from different input signal sources (for example, signals produced by other neurons), in often highly nonlinear ways through a complex dendritic tree. -Although modeling a complete neuronal system through the lens of cable theory is complex and intricate in of itself, NGC-Learn is built with this direction in mind. NGC-Learn starts with the idea that a neuron (or a cluster of them) can be viewed as a node or nodal component -- specifically a type of "cell" component (in NGC-Learn, many of these are component classes that end with the suffix `Cell`) -- and each bundle of synapses which connects pairs of nodes can -be viewed as a cable -- specifically a "synapse" component (these component classes usually end with the suffix `Synapse` or `SynapticCable`) -- that performs some sort of transformation of its pre-synaptic signal (also treated as another component in terms of abstract simulation) and often differentiated by its form of plasticity. See the [Neurons](../modeling/neurons) specification for the base available neuronal cells and the [Synapses](../modeling/synapses) specification for the base available synaptic cables. Note that these two types of nodal components can be combined with other types such as [Input Encoders](../modeling/input_encoders) and [Operations](../modeling/other_ops) to build gradually more complex dynamical biomimetic/neuro-mimetic systems. +Although modeling a complete neuronal system through the lens of cable theory and compartmental structures is complex and intricate in of itself, NGC-Learn is built with this direction in mind. NGC-Learn starts with the idea that a neuron (or a cluster of them) can be viewed as a node or nodal component -- specifically a type of "cell" component (in NGC-Learn, many of these are component classes that end with the suffix `Cell`). Each bundle of synapses that connects pairs of nodes can +be viewed as a cable -- specifically a "synapse" component (these component classes usually end with the suffix `Synapse` or `SynapticCable`) -- which performs some sort of transformation of its pre-synaptic signal (also treated as another component in terms of abstract simulation); a synaptic bundle in NGC-Learn is often differentiated by its form of plasticity. See the [Neurons](../modeling/neurons) specification for the base available neuronal cells and the [Synapses](../modeling/synapses) specification for the base available synaptic cables. Note that these two types of nodal components can be combined with other types such as [Input Encoders](../modeling/input_encoders) and [Operations](../modeling/other_ops) to build gradually more complex dynamical biomimetic/neuro-mimetic/NeuroAI systems. Each neuronal cell component/node has multiple, different (named) "compartments", which are regions or slots within the node that other nodes can deposit information/signals into. These compartments allow a node to collect information from many different connected/related nodes and then decide how to combine these different signals in order calculate its own output activity (either in the form of a rate-coded firing rate or binary spikes) using the integration logic defined within its own specific `advance_state()` function. When a biomimetic system, composed of many of these nodes/components, is simulated over a period of time (processing some form of sensory input), its underlying simulation object (the `Context` controller) calls the `advance_state()` routine of each constituent node, shifting that nodes internal time by one discrete step. The order in which the node `advance_state()` routines are called is governed by "run cycles", which are defined by the experimenter at the object initialization of the controller. For example, a user might want one set of nodes to first execute their internal step logic before another set is able to -- this could be done by specifying two distinct cycles in the order desired. As a result, many nodes, and the synaptic cables that connect them, result in a simulated biomimetic system where each node is itself, in general, treated as a stateful computation even if we are processing inherently non-temporal data such as static images. + +## References + +[1] Talevi, Alan, and Carolina Leticia Bellera. "Compartmental pharmacokinetic models." In ADME Processes in Pharmaceutical Sciences: Dosage, Design, and Pharmacotherapy, pp. 173-192. Cham: Springer Nature Switzerland, 2024. From bb7f4539f6c6ddb61db78274a253233aca6797f3 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 3 Dec 2025 16:50:00 -0500 Subject: [PATCH 109/121] updated history log --- history.txt | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/history.txt b/history.txt index 108b82d6..e03221af 100644 --- a/history.txt +++ b/history.txt @@ -19,8 +19,7 @@ History * NGCGraph .compile() further tweaked to use an injection/clamping look-up system to allow for dynamic changes to occur w/in a static graph compiled simulated NGC system - * Cable API slightly modified to increase flexiblity (demonstrations and - tests modified to reflect updated API) + * Cable API slightly modified to increase flexibility (demonstrations and tests modified to reflect updated API) * Demonstration 6 released showcasing how to use ngc-learn to construct/fit a restricted Boltzmann machine @@ -80,3 +79,16 @@ History * integration of reinforce-synapse, block/partitioned synapse component ("patched-synapse") * basic unit-tests (pytest framework) integrated to support dev * includes support for Intel's lava-nc emulator (several spiking/stdp components that play with ngc-lava) + + 3.0.0 + — — — — — — — — - + * revisions made / upgrades applied to framework/simulation back-end to integrate major version v2 of ngc-sim-lib + * new harmonium/RBM model-museum exhibit written and tutorial integrated + * clean-up of utils and new integration of mixture models for utils.density (Gaussian, Bernoulli, & exponential mixtures) + * addition of new BernoulliErrorCell (binary cross-entropy node); added leakyNoiseCell to support contus-time RNNs + * model museum (ngc-museum) and tutorials updated to reflect newest ngc-sim-lib format + * clean-up/upgrade of docs to reflect new v3 version (and patches) + * all model-museum (standard/main) exhibits revised/updated to operate with new v3 ngclearn / v2 ngcsimlib + * integration/addition of RL-SNN model in model-museum + * integration of full dynamics synapses -- alpha, exponential, and double-exponential synaptic cables + * new metrics/clean-up of metrics in utils.metric_utils (e.g., KL divs, etc.) From 953068c39a8f515645d91ac958660ee750348cdf Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Wed, 3 Dec 2025 17:36:23 -0500 Subject: [PATCH 110/121] minor clean-up of ngclearn.utils.viz.dim_reduce --- ngclearn/utils/viz/dim_reduce.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/ngclearn/utils/viz/dim_reduce.py b/ngclearn/utils/viz/dim_reduce.py index 4fd8c244..3f32057d 100755 --- a/ngclearn/utils/viz/dim_reduce.py +++ b/ngclearn/utils/viz/dim_reduce.py @@ -3,8 +3,8 @@ default_cmap = plt.cm.jet import numpy as np -from sklearn.decomposition import IncrementalPCA -from sklearn.manifold import TSNE +from sklearn.decomposition import IncrementalPCA ## sci-kit learning dependency +from sklearn.manifold import TSNE ## sci-kit learning dependency def extract_pca_latents(vectors): ## PCA mapping routine """ @@ -20,7 +20,6 @@ def extract_pca_latents(vectors): ## PCA mapping routine """ batch_size = 50 z_dim = vectors.shape[1] - z_2D = None if z_dim != 2: ipca = IncrementalPCA(n_components=2, batch_size=batch_size) ipca.fit(vectors) @@ -31,26 +30,25 @@ def extract_pca_latents(vectors): ## PCA mapping routine def extract_tsne_latents(vectors, perplexity=30, n_pca_comp=32, batch_size=500): ## tSNE mapping routine """ - Projects collection of K vectors (stored in a matrix) to a two-dimensional (2D) - visualization space via the t-distributed stochastic neighbor embedding - algorithm (t-SNE). This algorithm also uses PCA to produce an - intermediate project to speed up the t-SNE final mapping step. Note that - if the input already has a 2D dimensionality, the original input is returned. + Projects collection of K vectors (stored in a matrix) to a two-dimensional (2D) visualization space via the + t-distributed stochastic neighbor embedding algorithm (t-SNE). This algorithm also uses PCA to produce an + intermediate project to speed up the t-SNE final mapping step. Note that if the input already has a 2D + dimensionality, the original input is returned. Args: vectors: a matrix/codebook of (K x D) vectors to project perplexity: the perplexity control factor for t-SNE (Default: 30) - batch_size: number of sampled embedding vectors to use per iteration - of online internal PCA + n_pca_comp: number of PCA top components (sorted by eigen-values) to retain/extract before continuing + with t-SNE dimensionality reduction + + batch_size: number of sampled embedding vectors to use per iteration of online internal PCA Returns: a matrix (K x 2) of projected vectors (to 2D space) """ - #batch_size = 500 #50 z_dim = vectors.shape[1] - z_2D = None if z_dim != 2: print(" > Projecting latents via iPCA...") n_comp = n_pca_comp #32 #10 #16 #50 @@ -69,11 +67,10 @@ def extract_tsne_latents(vectors, perplexity=30, n_pca_comp=32, batch_size=500): z_2D = vectors return z_2D -def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1., - cmap=None): +def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1., cmap=None): """ - Produces a label-overlaid (label map to distinct colors) scatterplot for - visualizing two-dimensional latent codes (produced by either PCA or t-SNE). + Produces a label-overlaid (label map to distinct colors) scatterplot for visualizing two-dimensional latent codes + (produced by either PCA or t-SNE). Args: code_vectors: a matrix of shape (K x 2) with vectors to plot/visualize @@ -92,8 +89,7 @@ def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1., matplotlib.use('Agg') ## temporarily go in Agg plt backend for tsne plotting print(" > Plotting 2D latent encodings...") curr_backend = plt.rcParams["backend"] - matplotlib.use( - 'Agg') ## temporarily go in Agg plt backend for tsne plotting + matplotlib.use('Agg') ## temporarily go in Agg plt backend for tsne plotting lab = labels if lab.shape[1] > 1: ## extract integer class labels from a one-hot matrix lab = np.argmax(lab, 1) From 5e43ad20d058a73154bc4e233a063630d26f5f0c Mon Sep 17 00:00:00 2001 From: Will Gebhardt Date: Thu, 4 Dec 2025 11:48:08 -0500 Subject: [PATCH 111/121] Update jaxComponent.py Added support for turning off autosave --- ngclearn/components/jaxComponent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ngclearn/components/jaxComponent.py b/ngclearn/components/jaxComponent.py index 6e2ccec7..8cffa49a 100755 --- a/ngclearn/components/jaxComponent.py +++ b/ngclearn/components/jaxComponent.py @@ -37,7 +37,7 @@ def save(self, directory: str): file_name = directory + "/" + self.name + ".npz" data = {} for comp_name, comp in self.compartments: - if not comp.targeted: + if not comp.targeted and comp.auto_save: data[comp_name] = comp.get() jnp.savez(file_name, **data) From eb534c4f52cf26f19dbb6209a7c870e9f192dac7 Mon Sep 17 00:00:00 2001 From: Viet Dung Nguyen Date: Thu, 4 Dec 2025 13:39:29 -0500 Subject: [PATCH 112/121] update hebbian synapse saving --- .../synapses/hebbian/hebbianSynapse.py | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/ngclearn/components/synapses/hebbian/hebbianSynapse.py b/ngclearn/components/synapses/hebbian/hebbianSynapse.py index 71c9a4a3..e8e46e58 100644 --- a/ngclearn/components/synapses/hebbian/hebbianSynapse.py +++ b/ngclearn/components/synapses/hebbian/hebbianSynapse.py @@ -1,5 +1,7 @@ # %% +import jax +import pickle from jax import random, numpy as jnp, jit from functools import partial from ngclearn.utils.optim import get_opt_init_fn, get_opt_step_fn @@ -206,10 +208,26 @@ def __init__( self.dBiases = Compartment(jnp.zeros(shape[1])) #key, subkey = random.split(self.key.value) + # NOTE: we don't save this compartment directly because it is a tuple can cannot be saved directly by numpy self.opt_params = Compartment( - get_opt_init_fn(optim_type)([self.weights.get(), self.biases.get()] if bias_init else [self.weights.get()]) + get_opt_init_fn(optim_type)([self.weights.get(), self.biases.get()] if bias_init else [self.weights.get()]), + auto_save=False ) + def save(self, directory: str): + super().save(directory) + # Also save the optimizer parameters + file_name = directory + "/" + self.name + "_opt_params" + ".pkl" + with open(file_name, 'wb') as f: + pickle.dump(self.opt_params.get(), f) + + def load(self, directory: str): + super().load(directory) + file_name = directory + "/" + self.name + "_opt_params" + ".pkl" + with open(file_name, 'rb') as f: + data = pickle.load(f) + self.opt_params.set(data) + @staticmethod def _compute_update( w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, post_wght, pre, post, weights @@ -332,3 +350,4 @@ def help(cls): ## component help function Wab = HebbianSynapse("Wab", (2, 3), 0.0004, optim_type='adam', sign_value=-1.0, prior=("l1l2", 0.001)) print(Wab) + print(Wab.opt_params.get()) From 6281f1a65f4d8721c1d6095a34a318749e7a9ba3 Mon Sep 17 00:00:00 2001 From: Viet Nguyen Date: Thu, 4 Dec 2025 14:26:53 -0500 Subject: [PATCH 113/121] update saving and loading utils, making hebbian synapse use these utils for custom optimizer params saving and loading --- .../components/synapses/hebbian/hebbianSynapse.py | 11 ++++------- ngclearn/utils/io_utils.py | 13 +++++++++++++ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/ngclearn/components/synapses/hebbian/hebbianSynapse.py b/ngclearn/components/synapses/hebbian/hebbianSynapse.py index e8e46e58..f0814443 100644 --- a/ngclearn/components/synapses/hebbian/hebbianSynapse.py +++ b/ngclearn/components/synapses/hebbian/hebbianSynapse.py @@ -11,6 +11,7 @@ from ngclearn.components.synapses import DenseSynapse from ngclearn.utils import tensorstats from ngcsimlib import deprecate_args +from ngclearn.utils.io_utils import save_pkl, load_pkl @partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9]) def _calc_update( @@ -217,16 +218,12 @@ def __init__( def save(self, directory: str): super().save(directory) # Also save the optimizer parameters - file_name = directory + "/" + self.name + "_opt_params" + ".pkl" - with open(file_name, 'wb') as f: - pickle.dump(self.opt_params.get(), f) + save_pkl(directory, self.name + "_opt_params", self.opt_params.get()) def load(self, directory: str): super().load(directory) - file_name = directory + "/" + self.name + "_opt_params" + ".pkl" - with open(file_name, 'rb') as f: - data = pickle.load(f) - self.opt_params.set(data) + # load the optimizer parameters in a custom way + self.opt_params.set(load_pkl(directory, self.name + "_opt_params")) @staticmethod def _compute_update( diff --git a/ngclearn/utils/io_utils.py b/ngclearn/utils/io_utils.py index 0c6eec96..8553af44 100755 --- a/ngclearn/utils/io_utils.py +++ b/ngclearn/utils/io_utils.py @@ -4,6 +4,7 @@ # import jax # from jax import numpy as jnp, grad, jit, vmap, random, lax import os, sys, pickle +from typing import Any def serialize(fname, object): ## object "saving" routine """ @@ -65,3 +66,15 @@ def makedirs(directories): """ for dir in directories: makedir(dir) + + +def save_pkl(directory: str, name: str, value: Any) -> None: + file_name = directory + "/" + name + ".pkl" + with open(file_name, 'wb') as f: + pickle.dump(value, f) + +def load_pkl(directory: str, name: str) -> Any: + file_name = directory + "/" + name + ".pkl" + with open(file_name, 'rb') as f: + data = pickle.load(f) + return data From 74def15f384ad981076f97d7495a3e309ac06b00 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Fri, 5 Dec 2025 00:08:55 -0500 Subject: [PATCH 114/121] minor revisions/polish --- docs/museum/harmonium.md | 46 ++++++++++---------- ngclearn/components/synapses/denseSynapse.py | 20 ++++----- 2 files changed, 31 insertions(+), 35 deletions(-) diff --git a/docs/museum/harmonium.md b/docs/museum/harmonium.md index 4e7c827b..ab55209e 100644 --- a/docs/museum/harmonium.md +++ b/docs/museum/harmonium.md @@ -60,7 +60,7 @@ where $Z$ is the normalizing constant (or, in statistical mechanics, the part When one works through the derivation of the gradient of the log probability $\log p(\mathbf{x})$ with respect to the synapses such as $\mathbf{W}$, they get a (contrastive) Hebbian-like update rule as follows: $$ -\Delta \mathbf{W} = <\mathbf{z}^0_i \mathbf{z}^1_j>_{data} - <\mathbf{z}^0_i \mathbf{z}^1_j>_{model} +\Delta \mathbf{W}_{ij} = <\mathbf{z}^0_i \mathbf{z}^1_j>_{data} - <\mathbf{z}^0_i \mathbf{z}^1_j>_{model} $$ where the angle brackets $< >$ tell us that we need to take the expectation of the values within the brackets under a certain distribution (such as the data distribution denoted by the subscript $data$). The above rule can also be considered to be a stochastic form of a general recipe known as contrastive Hebbian learning (CHL) [4]. @@ -170,31 +170,31 @@ which will fit/adapt your harmonium to MNIST. This should produce per-training i W1: min -0.0494 ; max 0.0445 mu -0.0000 ; norm 4.4734 b1: min -4.0000 ; max -4.0000 mu -4.0000 ; norm 64.0000 c0: min -11.6114 ; max 0.0635 mu -3.8398 ; norm 135.2238 --1| Test: E(X) = 99.8526 err(X) = 54.3889 -0| Test: E(X) = 116.6596 err(X) = 46.8236; Train: E(X) = 112.0452 err(X) = 52.7418 -1| Test: E(X) = 89.5413 err(X) = 36.8690; Train: E(X) = 102.4642 err(X) = 41.3630 -2| Test: E(X) = 75.7558 err(X) = 31.8582; Train: E(X) = 82.9692 err(X) = 34.5511 -3| Test: E(X) = 66.6632 err(X) = 28.6253; Train: E(X) = 72.1229 err(X) = 30.4615 -4| Test: E(X) = 60.8256 err(X) = 26.2317; Train: E(X) = 64.3613 err(X) = 27.6882 -5| Test: E(X) = 55.5070 err(X) = 24.3207; Train: E(X) = 58.9254 err(X) = 25.5485 -6| Test: E(X) = 51.7455 err(X) = 22.8012; Train: E(X) = 54.4092 err(X) = 23.8361 -7| Test: E(X) = 49.4866 err(X) = 21.6163; Train: E(X) = 51.1574 err(X) = 22.4523 -8| Test: E(X) = 46.2826 err(X) = 20.5934; Train: E(X) = 48.2617 err(X) = 21.3355 -9| Test: E(X) = 43.8611 err(X) = 19.7679; Train: E(X) = 46.0239 err(X) = 20.4297 -10| Test: E(X) = 42.2886 err(X) = 19.0672; Train: E(X) = 44.3544 err(X) = 19.6835 -11| Test: E(X) = 41.7468 err(X) = 18.4881; Train: E(X) = 42.9321 err(X) = 19.0372 +-1| Test: err(X) = 54.3889 +0| Test: |d.E(X)| = 16.8070 err(X) = 46.8236; Train: err(X) = 52.7418 +1| Test: |d.E(X)| = 27.1183 err(X) = 36.8690; Train: err(X) = 41.3630 +2| Test: |d.E(X)| = 13.7855 err(X) = 31.8582; Train: err(X) = 34.5511 +3| Test: |d.E(X)| = 9.0927 err(X) = 28.6253; Train: err(X) = 30.4615 +4| Test: |d.E(X)| = 5.8375 err(X) = 26.2317; Train: err(X) = 27.6882 +5| Test: |d.E(X)| = 5.3187 err(X) = 24.3207; Train: err(X) = 25.5485 +6| Test: |d.E(X)| = 3.7614 err(X) = 22.8012; Train: err(X) = 23.8361 +7| Test: |d.E(X)| = 2.2589 err(X) = 21.6163; Train: err(X) = 22.4523 +8| Test: |d.E(X)| = 3.2040 err(X) = 20.5934; Train: err(X) = 21.3355 +9| Test: |d.E(X)| = 2.4215 err(X) = 19.7679; Train: err(X) = 20.4297 +10| Test: |d.E(X)| = 1.5725 err(X) = 19.0672; Train: err(X) = 19.6835 +11| Test: |d.E(X)| = 0.5418 err(X) = 18.4881; Train: err(X) = 19.0372 ... ... -91| Test: E(X) = 65.5179 err(X) = 11.0443; Train: E(X) = 65.0850 err(X) = 10.9832 -92| Test: E(X) = 65.4790 err(X) = 11.0118; Train: E(X) = 64.8345 err(X) = 10.9820 -93| Test: E(X) = 65.9917 err(X) = 11.0013; Train: E(X) = 64.4392 err(X) = 10.9586 -94| Test: E(X) = 64.0737 err(X) = 10.9874; Train: E(X) = 64.2096 err(X) = 10.9312 -95| Test: E(X) = 64.0479 err(X) = 10.9906; Train: E(X) = 63.8461 err(X) = 10.9274 -96| Test: E(X) = 63.5719 err(X) = 10.9712; Train: E(X) = 63.3354 err(X) = 10.8940 -97| Test: E(X) = 64.1757 err(X) = 10.9589; Train: E(X) = 62.8447 err(X) = 10.8960 -98| Test: E(X) = 63.8886 err(X) = 10.9563; Train: E(X) = 62.6391 err(X) = 10.8727 -99| Test: E(X) = 62.2265 err(X) = 10.9347; Train: E(X) = 62.3147 err(X) = 10.8671 +91| Test: |d.E(X)| = 0.4870 err(X) = 11.0443; Train: err(X) = 10.9832 +92| Test: |d.E(X)| = 0.0390 err(X) = 11.0118; Train: err(X) = 10.9820 +93| Test: |d.E(X)| = 0.5127 err(X) = 11.0013; Train: err(X) = 10.9586 +94| Test: |d.E(X)| = 1.9180 err(X) = 10.9874; Train: err(X) = 10.9312 +95| Test: |d.E(X)| = 0.0258 err(X) = 10.9906; Train: err(X) = 10.9274 +96| Test: |d.E(X)| = 0.4760 err(X) = 10.9712; Train: err(X) = 10.8940 +97| Test: |d.E(X)| = 0.6038 err(X) = 10.9589; Train: err(X) = 10.8960 +98| Test: |d.E(X)| = 0.2870 err(X) = 10.9563; Train: err(X) = 10.8727 +99| Test: |d.E(X)| = 1.6622 err(X) = 10.9347; Train: err(X) = 10.8671 --- Final RBM Synaptic Stats --- W1: min -1.8648 ; max 1.3757 mu -0.0012 ; norm 70.6230 b1: min -7.5815 ; max 0.2337 mu -2.3395 ; norm 53.3993 diff --git a/ngclearn/components/synapses/denseSynapse.py b/ngclearn/components/synapses/denseSynapse.py index cc99fc93..977f2464 100755 --- a/ngclearn/components/synapses/denseSynapse.py +++ b/ngclearn/components/synapses/denseSynapse.py @@ -44,8 +44,6 @@ def __init__( super().__init__(name, **kwargs) self.batch_size = batch_size - self.weight_init = weight_init - self.bias_init = bias_init ## Synapse meta-parameters self.shape = shape @@ -54,13 +52,11 @@ def __init__( ## Set up synaptic weight values tmp_key, *subkeys = random.split(self.key.get(), 4) - if self.weight_init is None: + if weight_init is None: info(self.name, "is using default weight initializer!") # self.weight_init = {"dist": "uniform", "amin": 0.025, "amax": 0.8} - # weights = initialize_params(subkeys[0], self.weight_init, shape) - self.weight_init = DistributionGenerator.uniform(0.025, 0.8) - #weights = initialize_params(subkeys[0], self.weight_init, shape) - weights = self.weight_init(shape, subkeys[0]) + weight_init = DistributionGenerator.uniform(0.025, 0.8) + weights = weight_init(shape, subkeys[0]) if 0. < p_conn < 1.: ## Modifier/constraint: only non-zero and <1 probs allowed p_mask = random.bernoulli(subkeys[1], p=p_conn, shape=shape) @@ -74,12 +70,12 @@ def __init__( self.outputs = Compartment(postVals) self.weights = Compartment(weights) ## Set up (optional) bias values - if self.bias_init is None: + if bias_init is None: info(self.name, "is using default bias value of zero (no bias kernel provided)!") - self.biases = Compartment(self.bias_init((1, shape[1]), subkeys[2]) if bias_init else 0.0) - # self.biases = Compartment(initialize_params(subkeys[2], bias_init, - # (1, shape[1])) - # if bias_init else 0.0) + self.biases = Compartment(bias_init((1, shape[1]), subkeys[2]) if bias_init else 0.0) + ## pin weight/bias initializers to component + self.weight_init = weight_init + self.bias_init = bias_init @compilable def advance_state(self): From 222932f785f9143cca61ab117f46aeedac5445cd Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Fri, 5 Dec 2025 00:19:26 -0500 Subject: [PATCH 115/121] modded docs to include v3 foundations --- docs/tutorials/foundations.md | 6 +- docs/tutorials/foundations/commands.md | 137 ------------------ docs/tutorials/foundations/compartments.md | 44 ++++++ docs/tutorials/foundations/compiling.md | 98 +++++++++++++ docs/tutorials/foundations/components.md | 32 ++++ docs/tutorials/foundations/context.md | 67 +++++++++ docs/tutorials/foundations/contexts.md | 70 --------- docs/tutorials/foundations/global_state.md | 52 +++++++ docs/tutorials/foundations/monitors.md | 72 --------- docs/tutorials/foundations/operations.md | 60 -------- docs/tutorials/foundations/processes.md | 90 ++++++++++++ docs/tutorials/index.rst | 10 +- docs/tutorials/model_basics/model_building.md | 6 +- 13 files changed, 397 insertions(+), 347 deletions(-) delete mode 100755 docs/tutorials/foundations/commands.md create mode 100644 docs/tutorials/foundations/compartments.md create mode 100644 docs/tutorials/foundations/compiling.md create mode 100644 docs/tutorials/foundations/components.md create mode 100644 docs/tutorials/foundations/context.md delete mode 100644 docs/tutorials/foundations/contexts.md create mode 100644 docs/tutorials/foundations/global_state.md delete mode 100644 docs/tutorials/foundations/monitors.md delete mode 100755 docs/tutorials/foundations/operations.md create mode 100644 docs/tutorials/foundations/processes.md diff --git a/docs/tutorials/foundations.md b/docs/tutorials/foundations.md index 822dc8f8..56b1fe49 100644 --- a/docs/tutorials/foundations.md +++ b/docs/tutorials/foundations.md @@ -3,6 +3,8 @@ In this set of tutorials/walkthroughs, we go through some of the core elements and mechanisms underlying NGC-Learn in order understand how its simulation scheme (and the nodes-and-cables system) works and to help in writing your own custom elements. The foundational walkthroughs are organized as follows: -1. [Using Model Contexts](../tutorials/foundations/contexts.md): This lesson goes the fundamentals of the primary simulation construct you need to set up models, the (simulation) context. -2. [Understanding Commands](../tutorials/foundations/commands.md): This lesson will walk you through the basics of a command -- an essential part of building a simulation controller in ngc-learn and ngcsimlib -- and offer some useful points for designing new ones. +1. [Using Model Contexts](../tutorials/foundations/context.md): This lesson goes the fundamentals of the primary simulation construct you need to set up models, the (simulation) context. +2. [Understanding Processes](../tutorials/foundations/processes.md): This lesson will walk you through the basics of a command -- an essential part of building a simulation controller in ngc-learn and ngcsimlib -- and offer some useful points for designing new ones. + diff --git a/docs/tutorials/foundations/commands.md b/docs/tutorials/foundations/commands.md deleted file mode 100755 index e3fff8f8..00000000 --- a/docs/tutorials/foundations/commands.md +++ /dev/null @@ -1,137 +0,0 @@ -# Understanding Commands - -## Overview -Commands are one of the central pillars of -ngcsimlib, the dependency -library that drives ngc-learn's simulation backend. -In general, commands provide the instructions and logic for what each component -should be doing at any given time. In addition, they are the normal way that an -outside user would interact with ngc-learn models. Commands live inside a model's -controller and are generally made with the `add_command` method. - -## Abstract Command -Contained within ngcsimlib is an abstract class for every command included in -ngcsimlib. It is strongly recommended that custom commands are built using this -base class (but there is nothing enforcing this inside of ngcsimlib). - -At its base the abstract command forces two things: firstly, the constructor -for the base class requires a list of components, and a list of attributes that -each component should have. Secondly, all commands must implement their -`__call__` command, taking in only `*args` and `**kwargs`. - -## Constructing Commands -It is common that commands will need to have values passed into them to control -their internal behavior, such as a value to clamp, or a flag for freezing -synaptic weight values. -To do this, we introduce the notion of binding keywords to commands. -Specifically, commands will take strings in during their construction and then -look for those strings when called inside the list of keyword arguments in order -to get their arguments. - -## Calling Commands -When commands are called, they will take in only `*args` and `**kwargs`. -While custom commands can break this by adding in additional arguments -without any problem, it is not recommended to do this as multiple instances -of a command with different parameters will then use the same keyword for their -call. - -## Creating Custom Commands -It is recommended that all custom commands inherit from the base class -provided within ngcsimlib. This provides a good starting point for designing a -component that will seamlessly interact with ngcsimlib's internal simulation mechanics. -These mechanics, which characterize the core operation of a simulation controller, -entail that, for each command supplied to a controller, a command will call the -same function with the same parameters on each component provided -to that very command. It is also expected that there is error handling within the -constructor to catch as many runtime errors as possible. Note that base -command class provides a list to check required calls such as `reset` or `evolve`. - -It is important to note that, if commands are going to be constructed via a -controller, they should have keyword arguments with default values that -error out on bad input instead of positional arguments. - -## Example Command (reset) - -Below, we present the key bits of source code that characterize a reset command --- a very commonly used, built-in command for models designed in ngc-learn -- and -its internal operation: - -```python -from ngcsimlib.commands import Command -from ngcsimlib.utils import extract_args -from ngcsimlib.logger import warn, error - -class Reset(Command): - def __init__(self, components=None, reset_name=None, command_name=None, - **kwargs): - super().__init__(components=components, command_name=command_name, - required_calls=['reset']) - if reset_name is None: - error(self.name, "requires a \'reset_name\' to bind to for construction") - self.reset_name = reset_name - - def __call__(self, *args, **kwargs): - try: - vals = extract_args([self.reset_name], *args, **kwargs) - except RuntimeError: - warn(self.name, ",", self.reset_name, - "is missing from keyword arguments and no positional arguments were provided") - return - - if vals[self.reset_name]: - for component in self.components: - self.components[component].reset() -``` - -## Custom Command Template - -Here, we show the generic command template which shows how one would go about -designing the key operational bits that make up a useful command. - -```python -from ngcsimlib.commands.command import Command -from ngcsimlib.utils import extract_args -from ngcsimlib.logger import error - - -class CustomCommand(Command): - def __init__(self, components=None, BINDING_VALUE=None, ADDITIONAL_INPUT=None, command_name=None, - **kwargs): - super().__init__(components=components, command_name=None, required_calls=['CUSTOM_CALL']) - # Make sure additional input is passed in - if ADDITIONAL_INPUT is None: - error(self.name, "requires a \'ADDITIONAL_INPUT\' for construction") - - # Make sure command is bound to a value - if BINDING_VALUE is None: - error(self.name, "requires a \'BINDING_VALUE\' to bind to for construction") - - self.BOUND_VALUE = BINDING_VALUE - self.ADDITION_VALUE = ADDITIONAL_INPUT - - def __call__(self, *args, **kwargs): - # Extract the bound value from the arguments - try: - vals = extract_args([self.BOUND_VALUE], *args, **kwargs) - except RuntimeError: - error(self.name, ",", str(self.BOUND_VALUE), "is missing from keyword arguments or a positional " - "arguments can be provided") - - #Use extracted value to call a method on each component - for component in self.components: - self.components[component].CUSTOM_CALL(self.ADDITION_VALUE, vals[self.BOUND_VALUE]) -``` - -## Notes -All components added to commands must have a `name` attribute and the word -`name` is automatically appended to any provided list of required attributes -to the base class constructor. - -As all built-in commands use `extract_args` when called with a controller via -`myController.COMMAND(ARGUMENT)`, there is no need to use keywords as it will -use `args` if there are no keyword arguments. (Keywords will still work, however.) - -When commands are constructed via a controller, they are also provided with the -keyword arguments `controller` and `command_name`. It is not recommended to -use these for any core logic (just use them for error messages), unless -it using them is absolutely essential in achieving the desired functionality. diff --git a/docs/tutorials/foundations/compartments.md b/docs/tutorials/foundations/compartments.md new file mode 100644 index 00000000..d2dfed25 --- /dev/null +++ b/docs/tutorials/foundations/compartments.md @@ -0,0 +1,44 @@ +# Compartments + +Within NGC-Sim-Lib, the global state serves as the backbone of any given model. +This global state is essentially the culmination of all of the dynamic or changing parts of the model itself. Each +value that builds this state is stored within a special "container" that helps track these changes over time -- this +is referred to as a `Compartment`. + +## Practical Information + +Practically, when working with compartments, there are a few simple things to keep in mind despite the fact that most +of NGC-Sim-Lib's primary operation is behind-the-scenes bookkeeping. The two main points to note are: +1. Each compartment holds a value and, thus, setting a compartment with `myCompartment = newValue` will not function as + intended since this will overwrite the Python object, i.e., the compartment with `newValue`. Instead, it is + important to make use of the `.set()` method to update the value stored inside a compartment so + `myCompartment = newValue` becomes `myCompartment.set(newValue)`. +2. In order to retrieve a value from a compartment, use `myCompartment.get()`. These methods of getting and setting + data inside a compartment are important to use when both working with and designing a multi-compartment component + (i.e., `Component`). + +## Technical Information + +The follow sections are devoted to explication of more technical information regarding how a compartment functions +with in the broader scope of NGC-Sim-Lib and, furthermore, to explain how to leverage this information. + +### How Data is Stored (Within a Model Context) + +The data stored inside of a compartment is not actually physically stored within a compartment. Instead, it is stored +inside of the global state and each compartment effectively holds the path or `key` to the right spot in the global +state, allowing it to pull out a specific piece of information. As such, it is technically possible to manipulate the +value of a compartment without actually touching the compartment object itself within any given component. By default, +compartments have in-built safeguards in order to prevent this from happening accidentally; however, directly addressing +the compartment within the global state directly has no such safeguards. + +### What is "Targeting"? + +As discussed in the model building section, there is notion of "wiring" together different compartments of different +components -- this is at the core of NGC-Learn's and NGC-Sim-Lib's "nodes-and-cables system". These wires are created +through the concept of "targeting,", which is, in essence, just the updating of the path stored within a compartment +using the path of a different compartment. This means that, if the targeted compartment goes to retrieve the value +stored within it, it will actually retrieve the value of a different compartment (as dictated by the target). When a +compartment is in this state -- where it is targeting another compartment -- it is set to read-only, which only means that +it cannot modify a different compartment. + + diff --git a/docs/tutorials/foundations/compiling.md b/docs/tutorials/foundations/compiling.md new file mode 100644 index 00000000..7f4aad06 --- /dev/null +++ b/docs/tutorials/foundations/compiling.md @@ -0,0 +1,98 @@ +# Compiling + +The term "compiling" for NGC-Sim-Lib refers to automatic step that happens +inside of a context that produces a transformed method for all of its +components. This step is the most complicated part of the library and, in +general, does not need to be touched or interacted with. Nevertheless, this +section will cover most of the steps that the NGC-Sim-Lib compilation process +does at a high level. This section contains advanced technical/developer-level +information: there is an expectation that the reader has an understanding of +Python abstract syntax trees (ASTs), Python namespaces, and how to +dynamically compile Python code and execute it. + +## The Decorator + +In NGC-Sim-Lib, there is a decorator marked as `@compilable` which is used to +add a flag to methods that the user wants to compile. On its own, this will not +do anything; however, this decorator lets the parser distinguish between methods +that should be compiled and methods that should be ignored. + +## The Step-by-Step NGC-Sim-Lib Parsing Process + +The process starts by telling the parser to compile a specific object. + +### Step 1: Compile Children + +The first step to compile any object is to make sure that all of the +"compilable" objects of the top level object are compiled. As a +result, NGC-Sim-Lib will loop through all of the whole object and will compile +each part that it finds that is flagged as compilable (via the decorators +mentioned above) and is, furthermore, an instance of a class. + +### Step 2: Extract Methods to Compile + +While the parser is looping through all of the parts of the top-level object, it +is also extracting the methods on/embedded to the object that are flagged as +compilable (with the decorator above). NGC-Sim-Lib stores them for later; +however, this lets the parser only loop over the object once. + +### Step 3: Parse Each Method + +As each method is its own entry-point into the transformer, this step will run +for each method in the top-level object. + +### Step 3a: Set up a Transformer + +This step sets up a `ContextTransformer`, which further makes use of a +`ast.NodeTransformer`, and will convert methods from class methods (with the use +of `self`), as well as other methods that need to be removed / ignored, into +their more context-friendly counterparts. + +### Step 3b: Transform the Function + +There are quite a few pieces of common Python that need to be transformed. This +step happens with the overall goal of replacing all object-focused parts with a +more global view. This means that a compartment's `.get` and `.set` calls are +replaced with direct setting and getting from the global state, based on the +compartment's target. This also means that all temporally constant values -- +such as `batch_size` -- are moved into the globals space for that specific file +and ultimately replaced with the naming convention of `object_path_constant`. +One more key step that is performed is to ensure that there is no branching in +the code. Specifically, if there is a branch, i.e., an if-statement, NGC-Sim-Lib +will evaluate it and only keep the branch it will traverse down. This means that +there cannot be any branch logic based on inputs or computed values (this is a +common restriction for just-in-time compiling). + +### Step 3c: Parse Sub-Methods + +Since it is possible to have other class methods that are not marked as +entry-points for compilation but still need to be compiled, as step 3b happens, +NGC-Sim-Lib tracks all of the sub-methods required. Notably, this step goes +through and repeats steps 3a and 3b for each of the (sub-)methods with a naming +convention similar to the temporally constant values for each method. + +### Step 3d: Compile the Abstract Syntax Tree (AST) + +Once we have all of the namespace and globals needed to execute the +properly-transformed method, the method is compiled with Python and finally +executed. + +### Step 3e: Binding + +The final step per method is to bind each to their original method; this +replaces each method with an object which, when called, will act like the +normal, uncompiled version but has the addition of the `.compiled` attribute. +This attribute contains all of the compiled information to be used later (for +model / system simulation). This crucially allows for the end user to +call `myComponent.myMethod.compiled()` and have it run. The exact type for +a `compiled` value can be found +in `ngcsimlib._src.parser.utils:CompiledMethod`. + +### Step 4: Finishing Up / Final Processing + +Some objects, such as the processes, entail additional steps to modify +themselves or their compiled methods in order to align themselves with needed +functionality. However, this operation/functionality is found within each +class's expanded `compile` method and should be referred to by looking at those +methods specifically. + diff --git a/docs/tutorials/foundations/components.md b/docs/tutorials/foundations/components.md new file mode 100644 index 00000000..c72f149d --- /dev/null +++ b/docs/tutorials/foundations/components.md @@ -0,0 +1,32 @@ +# Components + +Living one step above compartments in the NGC-Learn dynamical systems hierachy rests the component. +A component (`ngcsimlib.Component`) holds a collection of both temporally constant values as well as dynamic (time-evolving) +compartments. In addition, they are the core place where logic governing the dynamics of a system are +defined. Generally, components serve as the building blocks that are to be reused multiple times +when constructing a complete model of a dynmical system. + +## Temporally Constant versus Dynamic Compartments + +One important distinction that needs to be highlighted within a component is the +difference between a temporally constant value and a dynamic (time-varying) compartment. +Compartments themselves house values that change over time and, generally, they will have the +type `ngcsimlib.Compartment`; note that compartments are to be used to track the internal values +of a component. These internal values can be ones such inputs, decaying values, counters, etc. +The second kind of values found within a component are known as temporally constant values; these +are values (e.g., hyper-parameters, structural parameters, etc.) that will remain fixed +within constructed model dynamical system. These types of values tend to include common configuration +and meta-parameter settings, such as matrix shapes and coefficients. + +## Defining Compilable Methods + +Inside of a component, it is expected that there will be methods defined that govern the +temporal dynamics of the system component. These compilable methods are decorated +with `@compilable` and are defined like any other regular (Python) method. Within a compilable +method, there will be access to `self`, which means that, to reference a compartment's +value, one must write out such a call as: `self.myCompartment.get()`. The only requirement is +that any method that is decorated cannot have a return value; values should be stored +inside their respective compartments (by making an appeal to their respective set routine, i.e., +`self.myCompartment.set(value)`). In an external (compilation) step, outside of the developer's +definition of a component, an NGC-Sim-Lib transformer will change/convert all of these (decorated) +methods into ones that function with the rest of the NGC-Sim-Lib back-end. diff --git a/docs/tutorials/foundations/context.md b/docs/tutorials/foundations/context.md new file mode 100644 index 00000000..1e8e06b2 --- /dev/null +++ b/docs/tutorials/foundations/context.md @@ -0,0 +1,67 @@ +# Contexts + +Contexts, in NGC-Sim-Lib, are the top-level containers that hold everything used to +define a model / dynamical system. On their own, contexts have no runtime logic; +they rely on their internal processes and components to build a complete, working model. + +## Defining a Context + +To define a context (`ngcsimlib.Context`), NGC-Sim-Lib leverages the `with` block; this +means that to create a new context, simply start with the statement +`with Context("myContext") as ctx:` and a new context will be created. +(Important Note: names are unique; if a context is created with the same name, +they will be the same context and, thus, there might be conflicts). +A defined context does not do anything on its own. + +## Adding Components + +To add components to a context, simply initialize components while inside +the `with` block of the context. Any component defined while inside this block +will automatically be added and tacked-on to the context object. + +## Wiring Components + +Inside of a model / dynamical system, components will need to pass data to one +another; this is configured within the context. To connect the compartments of +two components, follow the pattern: `mySource.output >> myDestination.input`, +where `output` and `input` are compartments inside their respective components. +This format will ensure that, when processes are being run, the value will +flow properly from component to component. + +### Operators + +There is a special type of wire called an operator; this performs a simple +operation on the compartment values as the data flows from one component to +another. Generally, these are use for simple mathematical operations, such as +negation `Negate(mySource.output) >> myDestination.input` or the summation of +multiple compartments into +one `Summation(mySource1.output, mySource2.output, ...) >> myDestination.input`. +Note that operators can be chained, so it would be possible to negate one or +more of the inputs that flow into the summation. + +## Adding Processes + +To add processes to a context, simply initialize the process and add all of its +steps while inside the `with`-block of the process. + +## Exiting the `with` block + +When the context exits the `with`-block, it will re-compile the entire model. +Behind the scenes, this is calling `recompile` on the context +itself; it is possible to manually trigger the recompile step, but doing so can +break certain connections (between components/compartments), so use this +functionality sparingly. + +## Saving and Loading + +The context's one unique job is the handling of the "saving" (serialization) and +"loading" (de-serialization) of models to disk. By default, calling +`save_to_json(...)` will create the correct file structure as well as the core files +needed and load the context in the future. To load / de-serialize a model, +calling `Context.load(...)` will load the context in from a directory; something +important to note is that loading in a context entails effectively +recreating the components with their initial values using their arguments as well as +keywords arguments (excluding those that cannot be serialized). This means that, +if you have a trained model, ensure that your components have a save method +defined that will handle the saving and loading of all values within their compartments. + diff --git a/docs/tutorials/foundations/contexts.md b/docs/tutorials/foundations/contexts.md deleted file mode 100644 index b025cd8d..00000000 --- a/docs/tutorials/foundations/contexts.md +++ /dev/null @@ -1,70 +0,0 @@ -# What are Contexts - -A context in ngclearn is a container that holds all the information for your model and can be used as an access point to -reference different models in a multi-model system. Some of the information that contexts hold is all the components -defined in the context, all the wiring information for each of the components, as well as all the commands defined on -the context through various means. - -## How to make a Context - -To make a context first import it from ngclearn with `from ngclearn import Context`. This will give you access to not -only the constructor for new contexts but also the ability to get previously defined contexts. The general use case for -this is - -```python -from ngclearn import Context - -with Context("Model1") as model1: - pass -``` - -This will make a context named "Model1" and also drops you into a with block where you can define the various parts of -the model. The call `Context("Model1")` will always return the same context. So if there is already a model with that -name defined earlier in the code this instance of `model1` will have all the same object defined previously. - -## Adding Components - -The best way to add components to a context is by using components that have implemented the `MetaComponent` metaclass. -In ngclearn the base `Component` class does this. If using these components all that is needed to have them added to -the context is calling their constructors inside a with block of the context. For example - -```python -from ngclearn import Context -from ngclearn.components import LIFCell - -with Context("Model1") as model1: - z1 = LIFCell("z1", n_units=10, tau_m=100) -``` - -## Creating Cables - -To add connections between components and their compartments in a model we do that also in a context. Just like with -components there are no special actions that need to be taken to add them beyond doing so in a with block. To connect -to compartments the `<<` operator is used following the outline of `destination << source`. For example - -```python -with model1: - w1.inputs << z1.s -``` - -## Dynamic Commands - -When building models it can be desirable to use the same training and testing scrips while having commands do different -actions. For example if two different models had different clamp procedures to set inputs and labels it is possible to -dynamically add a generic clamp command to each model and call them the same way despite them doing different things. -As an example -```python -with model1: - @model1.dynamicCommand - def clamp(inputs, labels): - z0.inputs.clamp(inputs) - z2.labels.clamp(labels) - -with model2: - @model2.dynamicCommand - def clamp(inputs, labels): - z0.inputs.clamp(inputs) - z0_p.inputs.clamp(inputs) - z2.labels.clamp(labels) -``` -In both these cases later we can just call clamp and each one will call their own version of the clamp command. diff --git a/docs/tutorials/foundations/global_state.md b/docs/tutorials/foundations/global_state.md new file mode 100644 index 00000000..1f78ac6e --- /dev/null +++ b/docs/tutorials/foundations/global_state.md @@ -0,0 +1,52 @@ +# The Global State + +Since NGC-Sim-Lib is a simulation library focused on temporal models and dynamical +systems, or models that change over time there, it is foundational that all models +(and their respective elements) have some concept of a "state". These states +might be comprised of a single value that changes/evolves or of a complex set of values +that, when combined all together, make up the full dynamical system that underwrites the +final model. In both cases, these sets of values are stored in what is known as the +global state. + +## Interacting with the Global State + +Since the global state will contain a large amount of information describing a given +model, there will be a need to facilitate interaction with and modification of the values +contained within the global state. In most use-cases, this is not done directly. The +most common way to interact with the global state is through the use of the state-manager. +The state-manager exists to provide a set of helper methods for interacting with the +global state itself. Note that, although the manager is there to assist you, it will not stop +you from changing the state (or "breaking" the state). When changing the state -- beyond +setting it through the specificaiton of processes -- be careful to not add or remove +anything that is needed for your actual model. + +### Adding New Fields to the Global State + +If you are new to using NGC-Sim-Lib and looking for a way to add values to the +global state directly and explicitly, stop for a moment and reconsider. Unless +you know exactly what you are doing (i.e., doing core development), it is strongly +advised to not manually add values to the global state; instead, work through the +mechanisms afforded by compartments and/or components, as these are built to afford you the +most common ways for adding fields to the global state itself. The dynamical systems +semantics inherent to compartments and components is meant to ensure carefully-constrained +design and simulation of flexible models. + +If you actually intend to manually and directly add values to the global state itself, it +is done through the use of the `add_key` method. This will create the appropriate key in +the global state for the given path and name; furthermore, its value can be retrieved +with `from_key` calls. This value, however, is not linked to a compartment and, therefore, +will be hard to get working properly in the compiled methods without some specific references. +Please take extra care when working directly and explicitly with the global state. + +### Getting the Current State + +To get the current state, simply call `global_state_manager.state`; this will give +you a (shallow) copy of the current state, which means that any modifications made to it will +not be reflected in the global state. + +### Updating the Global State + +To manually update the global state after modifying a local copy; please write an overriding +call command: `global_state_manager.state = new_state`. This will update the state with the +`.update` call to its underlying dictionaries, which means that a partial state will still update correctly. + diff --git a/docs/tutorials/foundations/monitors.md b/docs/tutorials/foundations/monitors.md deleted file mode 100644 index f835a59d..00000000 --- a/docs/tutorials/foundations/monitors.md +++ /dev/null @@ -1,72 +0,0 @@ -# NGC Monitors - -NGC-monitors are a way of storing a rolling window of compartment values -automatically. Their intended purpose is not to be used "inside" of a model but -just as an auxiliary way to view the internal state of the model even when it is -compiled. A monitor will track the last `n` values it has observed within the -compartment with the oldest value being at `index=0` and the newest being at -`index=n-1`. - -## Building a Monitor - -Monitors are constructed exactly like regular components are for general models. -To use one, simply import the monitor `from ngclearn.components import Monitor`. Now, -inside of your model, build it like a regular component: - -```python -with Context("model") as model: - M = Monitor("M", default_window_length=100) -``` - -## Watching compartments - -There are then two key ways of watching compartments, the first way looks similar -to the wiring paradigm found in connecting standard ngclearn components -together. The primary difference is that connecting compartments to the monitor -does not require a compartment, they are wired directly into the `Monitor` -following the pattern below: - -```python - M << z0.s -``` - -This will wire the spike output of `z0` into the monitor with a view window -length of the `default_window_length`. In the event that you want a view window -that is not the default viewing length, you can use the `watch()` method -instead as in below: - -```python - M.watch(z0.s, customWindowLength) -``` - -There is no limit to the number of compartments that a monitor can watch or the -length of the window that it can store. However, as it is constantly shifting -values, tracking large matrices, such as those containing synapse values -over many timesteps, may get expensive. - -For the monitor to run during your `advance_state` and `reset` calls, make sure -to add to the monitor to the list of components to compile. Currently, -monitors do not work with non-compiled methods -(This is a planned feature for future developments of ngc-learn). - -## Extracting Values - -To look at the currently stored window of any compartment being tracked, there -are two methods available to you. The first method requires that you have -access to the compartment that the monitor is watching. To read out the -monitors values, you can call: - -```python -M.view(z0.s) -``` - -In the event that you do not have access to the compartment, all of the stored -values can be found via the path using the following: - -```python -M.get_store("path/to/compartment") -``` - -The stored windows are kept in a tree of dictionaries, where each node is a part -of the path and the leaves are compartment objects holding the -tracked value windows. \ No newline at end of file diff --git a/docs/tutorials/foundations/operations.md b/docs/tutorials/foundations/operations.md deleted file mode 100755 index 7584d2d2..00000000 --- a/docs/tutorials/foundations/operations.md +++ /dev/null @@ -1,60 +0,0 @@ -# The Basics of Operations - -## What are Operations? - -The underlying method for passing data/signals from component to component inside of -contexts is through the use of cables. A large amount of the time compartments will have a single cable being passed -connected to it that overwrites the previous value in that compartment. However, there are times when this is not the -case and then cable operations must be used. - -## Built-in operations - -By default, ngclearn comes with four operations defined, `overwrite`, `negate`, `summation` and `add`. Of these four -operations the default one used by all cables is the overwrite operation. This operations will take the value of its -source compartment and place it into the destination compartment overwriting the value currently there. The negate -operation has a similar effect as the overwrite operation with the added functionality of applying the `-` operation to -the value being transmitted. The summation operation takes in any number of source compartments and sums together all -their values and overwrites the previous value with the sum. Finally, the add operation does the same thing as the -summation operation but instead adds the sum to the previous value instead of overwriting it. - -## Building Custom Operation - -At its core, an operation is a static method that does all the runtime logic of the operation with the source -compartments, and a resolver that does clean up and assignment of the output of the operation to the destination -compartment. - -> General Form of an Operation: -> ```python -> class operationName(BaseOp): -> @staticmethod -> def operation(*sources): -> #Runtime Logic -> return computed_value -> ``` - -> Example Operation (Summation) -> ```python -> class summation(BaseOp): -> @staticmethod -> def operation(*sources): -> s = None -> for source in sources: -> if s is None: -> s = source -> else: -> s += source -> return s -> ``` - -## Notes - -- Every cable coming into or out of a compartment can have a different operation. - -- The order of these operations should be the order they are wired in, but this is not guaranteed. - -- Only the logic that exists in the static method `operation` is used for a compiled operation, all logic existing in an overwritten resolve method is not captured. - -- Some operations have a flag of `is_compilable` set to false. This is checked during compile to flag if the model can be compiled. - -- Operations can be nested so `summaion(negate(c1), c2)` would be a valid operation and will work while compiled - diff --git a/docs/tutorials/foundations/processes.md b/docs/tutorials/foundations/processes.md new file mode 100644 index 00000000..a7b9b8f5 --- /dev/null +++ b/docs/tutorials/foundations/processes.md @@ -0,0 +1,90 @@ +# Processes + +Processes in NGC-Sim-Lib offer a central way of defining a specific transition to be +taken within a given model (this effectively sets up the behavior of the state-machine +that defines the desired dynamical system one wants to simulate). In effect, processes +take in as many compilable methods as possible across any number of +components; they work to produce a single top-level method and a varying number of +sub-methods needed to execute the entire chain of compilable methods in one (single) step. +This is ultimately done to interface nicely with just-in-time (JIT) compilers, such as +the one inherent to JAX, and to minimize the amount of read and write calls done across +a chain of methods. + +## Building the (Command) Chain + +Building the chain that a process will use is done through an iterative process. Once +the process object is created, steps are added using either `.then()` or `>>`. +As an example: + +``` +myProcess.then(myCompA.forward).then(myCompB.forward).then(myCompA.evolve).then(myCompB.evolve) +``` + +or + +``` +myProcess >> myCompA.forward >> myCompB.forward >> myCompA.evolve >> myCompB.evolve +``` + +In both cases, this process will chain the four methods together into a single +step, only updating the final state after all steps are complete. + +## Types of Processes + +There are two types of processes: the above example would be with what is +referred to as a `MethodProcess` -- these are used to chain together any +compilable methods from any number of different components. The other second +type of process, called a `JointProcess` in NGC-Sim-Lib, is used to chain +together entire processes. +JointProcesses are especially useful if there are multiple method processes that +need to be called but different orders of the processes are needed at different +times. These allow for the specification of complex events / behaviors in a +dynamical system that one will simulate. + +## Extra Elements + +There are a few extra methods that come standard with each process type which can +be useful for both regular operation as well as debugging. + +### Viewing the Compiled Method + +Behind the scenes, a process is transforming and compiling down all of the steps +used to build it; this means that the exact code it is running to do its +set of calculations will ultimately not be what the user wrote. To allow for +the end user to view and make sure that the two pieces of code -- theirs and +the compiled version -- are equivalent (and yielding expected behavior), every +process has a `view_compiled_method()` call which can be used after the (final) model +is compiled. This call will return the code (block) that it will be running as a +string. There will be some stark differences between the produced/generated code and +the code in the (Python) components used to build the steps. Please refer to the +compiling page for a more in-depth guide to comparing the outputs between these +two stages of code. + +### Needed Keywords + +Since some methods will require external values such as `t` (for time) or `dt` +(for integration time / the temporal delta) for a given execution, a process +will also track all the keyword arguments that are needed to run their compiled +process. To view which keywords a given process is expecting, one may use the +command: `get_keywords()`. +This is mostly used for debugging and/or as a sanity check. + +### Packing Keywords + +To add onto the needed keywords, the process also provides an interface to +produce the keywords needed to run in the form of two methods. The first method +is `pack_keywords(...)`; this method packs together a single row of values that +are needed to run a single execution (step) of the process. The arguments are +the `row_seed`, which is a seed that is to be passed to all of the keyword +generators (only needed if generators are being used). +The second set of arguments are keyword arguments that are either constant, +such as `dt=0.1`, or generators, such as `lambda row_seed: 0.1 * row_seed`. +The second method for generating the keywords for a process is with `pack_rows(...)`. +This method will create many sets of keywords that are needed to run multiple +iterations of the process. Note that the arguments are slightly different: first, +it now utilizes a `length` argument to indicate the number of rows being produced and, +second, it features a `seed_generator` that is used to generate the seed of each row +(for instance, to have only even seed values: `seed_generator = lambda x: 2 * x`); if +the generator is `None`, then `seed_generator = lamda x: x` is used. +After this, the same keyword arguments to define the needed parameters are used as in `pack_keywords`. + diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst index 430924cf..06ee94f1 100644 --- a/docs/tutorials/index.rst +++ b/docs/tutorials/index.rst @@ -21,7 +21,9 @@ Lessons/tutorials go through the very basics of constructing a dynamical system :caption: II. NGC-Learn/Sim-Lib Foundations foundations - foundations/contexts - foundations/commands - foundations/operations - foundations/monitors + foundations/global_state + foundations/components + foundations/compartments + foundations/context + foundations/processes + foundations/compiling diff --git a/docs/tutorials/model_basics/model_building.md b/docs/tutorials/model_basics/model_building.md index 88707aa9..2f431e5c 100755 --- a/docs/tutorials/model_basics/model_building.md +++ b/docs/tutorials/model_basics/model_building.md @@ -42,8 +42,10 @@ must be wired to the input compartment of `b`. In code, this is done as follows: ``` Finally, to make our dynamical system do something for each step of simulated -time, we must append a few basic commands -(see [Understanding Commands](../foundations/commands.md) to the context. +time, we must append a few basic processes +(see [Understanding Processes](../foundations/processes.md)) +to the context. The commands we will want, as implied by our JSON configuration that we put together at the start of this tutorial, include a `reset` (which will initialize the compartments within each node to their resting values, From 290404094060cde4b2b9ed15c143553b679d3f65 Mon Sep 17 00:00:00 2001 From: Will Gebhardt Date: Fri, 5 Dec 2025 12:20:28 -0500 Subject: [PATCH 116/121] updates to init for logging --- ngclearn/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ngclearn/__init__.py b/ngclearn/__init__.py index a3676a60..210f5121 100644 --- a/ngclearn/__init__.py +++ b/ngclearn/__init__.py @@ -32,11 +32,13 @@ from ngcsimlib.context import Context, ContextObjectTypes from ngcsimlib import Component from ngcsimlib.compartment import Compartment -from ngcsimlib import logger, configure +from ngcsimlib import logger, get_config, provide_namespace from ngcsimlib.parser import compilable +from ngcsimlib.operations import Summation, Product if not Path(argv[0]).name == "sphinx-build" or Path(argv[0]).name == "build.py": if "readthedocs" not in argv[0]: ## prevent readthedocs execution of preload + from ngcsimlib import configure configure() logger.init_logging() From ee34ef0221b8ecca7ece042782390510ab4edb52 Mon Sep 17 00:00:00 2001 From: Will Gebhardt Date: Fri, 5 Dec 2025 12:20:36 -0500 Subject: [PATCH 117/121] Updates to lessons --- docs/tutorials/model_basics/configuration.md | 62 ++++---- docs/tutorials/model_basics/json_modules.md | 154 ------------------- 2 files changed, 34 insertions(+), 182 deletions(-) delete mode 100644 docs/tutorials/model_basics/json_modules.md diff --git a/docs/tutorials/model_basics/configuration.md b/docs/tutorials/model_basics/configuration.md index 480fea98..8646ea93 100644 --- a/docs/tutorials/model_basics/configuration.md +++ b/docs/tutorials/model_basics/configuration.md @@ -1,39 +1,35 @@ -# Lesson 1: Configuring ngcsimlib +# Lesson 1: Configuring NGC-Sim-Lib ## Basics -There are various global configurations that can be made to ngcsimlib, the systems simulation backend for ngc-learn. These include the ability to point to custom locations for the `json_modules` files as well as setting up the logger. In both of these cases, the configuration will generally persist between different models that might be loaded and, thus, it will need to exist outside of the scope of the model's files. To solve this problem, ngcsimlib provides `config.json` as well as the `--config` flag mechanisms. +There are various global configurations that can be made to NGC-Sim-Lib, the +systems simulation backend for NGC-Learn. The primary built in use for a +configuration file is to modify the built-in logger. Generally to control the +configuration running any script with the flag +`--config="path/to/your/config.json`. -The `config.json` file contains one large json object with sections set up for different parts of the configuration, broke up into sub-objects. There is no limit to the size or the number of these objects, meaning that the user is free to define and use them as they so choose. However, there are some general design principals that govern ngcsimlib that are worth knowing about. Specifically, this mechanism will not configure any parts of individual models. `config.json` configurations should be used to select/generally set up experiments and control global level flags and not to set hyperparameters for models. - -## Built-in Configurations - -There are a couple configurations that ngcsimlib will look for while it is initializing. Specifically `modules` and `logging`. While neither of these is needed to get up and running some aspects of ngcsimlib, useful debugging tools such as logging to files and more verbosity are locked behind flags set up here. - -### Modules - -The modules configuration only contains one value, `module_path`. This value is the location of the `modules.json`, the model-level/experiments-level configuration file one should be setting up when building their experiments. For additional information for configuring this file please -see modules.json. - -> Example Modules -> -> ```json -> { -> "modules": { -> "module_path": "custom/path/to/json/files/modules.json" -> } -> } -> ``` +The `config.json` file contains one large json object with sections set up for +different parts of the configuration, broke up into sub-objects. There is no +limit to the size or the number of these objects, meaning that the user is free +to define and use them as they so choose. ### Logging -The logging configuration mechanism sets up and controls the instance of the python logger built into ngcsimlib. This mechanism (or JSON section) has three values found within it. Specifically, `logging_level`, `logging_file`, and `hide_console`. The logging levels are the same ones built into the python logger and the value words used are either the standard Python string representation of the level or the numeric equivalent. The `logging file`, if defined, is a file that the logger will append all logging messages to for a more permanent history of all messages. Finally, `hide console`, if set to true, will hide all logging output to the console. +The logging configuration mechanism sets up and controls the instance of the +python logger built into ngcsimlib. This mechanism (or JSON section) has three +values found within it. Specifically, `logging_level`, `logging_file`, +and `hide_console`. The logging levels are the same ones built into the python +logger and the value words used are either the standard Python string +representation of the level or the numeric equivalent. The `logging file`, if +defined, is a file that the logger will append all logging messages to for a +more permanent history of all messages. Finally, `hide console`, if set to true, +will hide all logging output to the console. > Default Config > ```json > { > "logging": { -> "logging_level": "WARNING", +> "logging_level": "ERROR", > "hide_console": false > } > } @@ -52,21 +48,31 @@ The logging configuration mechanism sets up and controls the instance of the pyt ## Using a Configuration -To use a configuration, there are a few options. The first option is to simply use the configuration as a python dictionary. This is done by importing the `get_config` method from `ngcsimlib.configManager` and providing the name of the configuration section to the method. +To use a configuration, there are a few options. The first option is to simply +use the configuration as a python dictionary. This is done by importing +the `get_config` method from `ngclearn` and providing the name of +the configuration section to the method. > Example get_config >```python ->from ngcsimlib.configManager import get_config +>from ngclearn import get_config > >loggerConfig = get_config("logger") >level = loggerConfig['logging_level'] >``` -The other way you can access a configuration is through a provided namespace. This makes use of python's `SimpleNamespace` to map all the dictionary's key values to properties of an object to be used. One important note about namespaces is that, unlike a python dictionary where the `get` method can be provided a default value for missing keys, namespaces do not have this functionality. Therefore, if keys are missing it has the potential to cause errors. Below is an example of how one could use the namespace for logging configuration. +The other way you can access a configuration is through a provided namespace. +This makes use of python's `SimpleNamespace` to map all the dictionary's key +values to properties of an object to be used. One important note about +namespaces is that, unlike a python dictionary where the `get` method can be +provided a default value for missing keys, namespaces do not have this +functionality. Therefore, if keys are missing it has the potential to cause +errors. Below is an example of how one could use the namespace for logging +configuration. > Example provide_namespace > ```python -> from ngcsimlib.configManager import provide_namespace +> from ngclearn import provide_namespace > > loggerConfig = provide_namespace("logger") > level = loggerConfig.logging_level diff --git a/docs/tutorials/model_basics/json_modules.md b/docs/tutorials/model_basics/json_modules.md deleted file mode 100644 index 54f8a81f..00000000 --- a/docs/tutorials/model_basics/json_modules.md +++ /dev/null @@ -1,154 +0,0 @@ -# Lesson 2: Configuring with the modules.json File - -## Basic Usage: - -The basic usage for the `modules.json` file is to provide ngclearn with a list of modules to import and associated -classes that are needed to build the models it will be loading. If there is a need to use the imported -modules outside of these cases, use `ngcsimlib.utils.load_attribute` and the loaded -attribute will be returned. - -By default, ngcsimlib, the backend -dependency of ngc-learn, looks for `json_files/modules.json` in your project path. -However, this can be changed inside the -configuration file. In -the event that this -file is missing, ngcsimlib will not break but its ability to load saved models will be limited. - -## Motivation - -The motivation behind the use of `modules.json` versus the registering all the -various parts of the model at the top of the file is reusability. When all the -parts have to be registered/imported at the top of every test file, or be placed into specific locations can be limiting -and slows down development. With a single project wide modules file all loaded models can look there to load components. -This also allows for components to be saved in humanreadable formats not as a pickled object as we can save and load all -the relevant class information from the class name and the modules file. - -## Structure - -A complete schema for the modules file can be found in `modules.schema` - -The general structure of the modules file can be thought of as a transformation -of python import statements to JSON objects. Take the following example: - -```python -from ngclearn.commands import AdvanceState as advance -``` - -In this statement we are importing a command from ngcsimlib and aliasing it to the -word "advance". Now we will transform this into JSON for the modules file. First, -we take the top level module that we are importing from, in this case -`ngcsimlib.commands`; this the absolute path to the location of this module. Next, -we look at the name of what we are importing here: `AdvanceState`. Finally, we -look at the keyword since this import is being assigned to `advance`. We then -take these three parts and combine them into the following JSON object: - -```json - { - "absolute_path": "ngclearn.commands", - "attributes": [ - { - "name": "AdvanceState", - "keywords": [ - "advance" - ] - } - ] -} -``` - -Now there are a few additional things that this JSON formulation of an import -allows us to do. Primarily, it allows for multiple keywords for a single import -to be defined. This if we wanted to use `advance` and `adv` all we would do is -change the keyword line to `"keywords": ["advance", "adv"]`. In addition, we are able -to specify more than one attribute to import from a single top level module -such as also importing the evolve command. - -```json - { - "absolute_path": "ngcsimlib.commands", - "attributes": [ - { - "name": "AdvanceState", - "keywords": [ - "advance", - "adv" - ] - }, - { - "name": "Evolve" - } - ] -} -``` - -Now you might notice above that, when importing the evolve attribute, no -keywords were given. This means that, in order to add an evolve command to -the controller, the whole name will need to be given. There is one caveat to -this scheme though; it is case-insensitive by default, meaning that both -`Evolve` and `evolve` are valid ways to using this import. - -## Example Transformations - -Below are some additional examples to help with transitioning from python -header import statements to JSON configuration. - -> Case 1 -> Python: -> ```python -> from ngcsimlib.commands import AdvanceState as advance, Evolve, Multiclamp as mClamp -> ``` -> Json: -> ```json -> [ -> { -> "absolute_path": "ngcsimlib.commands", -> "attributes": [ -> { -> "name": "AdvanceState", -> "keywords": ["advance"] -> }, -> { -> "name": "Evolve" -> }, -> { -> "name": "Multiclamp", -> "keywords": "mClamp" -> } -> ] -> } -> ] -> ``` - -> Case 2 -> Python -> ```python -> from ngclearn.commands import AdvanceState as advance -> from ngclearn.operations import summation as summ, overwrite -> ``` -> -> Json -> ```json -> [ -> { -> "absolute_path": "ngclearn.commands", -> "attributes": [ -> { -> "name": "AdvanceState", -> "keywords": ["advance"] -> } -> ] -> }, -> { -> "absolute_path": "ngclearn.operations", -> "attributes": [ -> { -> "name": "summation", -> "keywords": ["summ"] -> }, -> { -> "name": "overwrite" -> } -> ] -> } -> ] -> ``` From 5e7ba54df9931b2fcaf3f5b00251e1984457e2ad Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Fri, 5 Dec 2025 18:24:31 -0500 Subject: [PATCH 118/121] final cleanup/polish/update to docs for v3 nudge --- docs/index.rst | 1 + docs/museum/index.rst | 5 +- .../compartments.md | 0 .../compiling.md | 0 .../components.md | 0 .../{foundations => configuration}/context.md | 0 .../global_state.md | 0 docs/tutorials/configuration/index.rst | 28 ++++ .../processes.md | 0 docs/tutorials/foundations.md | 10 -- docs/tutorials/index.rst | 21 +-- docs/tutorials/intro.md | 46 +++-- docs/tutorials/model_basics/configuration.md | 50 +++--- .../model_basics/evolving_synapses.md | 124 +++++++------- docs/tutorials/model_basics/model_building.md | 157 +++++++----------- docs/tutorials/neurocog/index.rst | 37 ++--- 16 files changed, 206 insertions(+), 273 deletions(-) rename docs/tutorials/{foundations => configuration}/compartments.md (100%) rename docs/tutorials/{foundations => configuration}/compiling.md (100%) rename docs/tutorials/{foundations => configuration}/components.md (100%) rename docs/tutorials/{foundations => configuration}/context.md (100%) rename docs/tutorials/{foundations => configuration}/global_state.md (100%) create mode 100644 docs/tutorials/configuration/index.rst rename docs/tutorials/{foundations => configuration}/processes.md (100%) delete mode 100644 docs/tutorials/foundations.md diff --git a/docs/index.rst b/docs/index.rst index 7ab63e25..1b710677 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -20,6 +20,7 @@ Welcome to ngc-learn's documentation! tutorials/intro tutorials/theory + tutorials/configuration/index tutorials/index tutorials/neurocog/index diff --git a/docs/museum/index.rst b/docs/museum/index.rst index 8b0ce9d3..cab00d0f 100644 --- a/docs/museum/index.rst +++ b/docs/museum/index.rst @@ -5,7 +5,10 @@ Model Exhibits ============== -Models are presented in ngc-learn's model museum in the form of "exhibits", which are effectively model-specific walkthroughs and analyses, based on the relevant, referenced publicly available ngc-learn simulation code. +Models are presented in ngc-learn's model museum in the form of "exhibits", which are effectively model-specific +walkthroughs and analyses, based on the relevant, referenced publicly available ngc-learn simulation code. (Note that +there are more model exhibits in the actual `museum repository `_ than the number +of detailed walkthroughs presented in the table of contents below.) .. toctree:: :maxdepth: 1 diff --git a/docs/tutorials/foundations/compartments.md b/docs/tutorials/configuration/compartments.md similarity index 100% rename from docs/tutorials/foundations/compartments.md rename to docs/tutorials/configuration/compartments.md diff --git a/docs/tutorials/foundations/compiling.md b/docs/tutorials/configuration/compiling.md similarity index 100% rename from docs/tutorials/foundations/compiling.md rename to docs/tutorials/configuration/compiling.md diff --git a/docs/tutorials/foundations/components.md b/docs/tutorials/configuration/components.md similarity index 100% rename from docs/tutorials/foundations/components.md rename to docs/tutorials/configuration/components.md diff --git a/docs/tutorials/foundations/context.md b/docs/tutorials/configuration/context.md similarity index 100% rename from docs/tutorials/foundations/context.md rename to docs/tutorials/configuration/context.md diff --git a/docs/tutorials/foundations/global_state.md b/docs/tutorials/configuration/global_state.md similarity index 100% rename from docs/tutorials/foundations/global_state.md rename to docs/tutorials/configuration/global_state.md diff --git a/docs/tutorials/configuration/index.rst b/docs/tutorials/configuration/index.rst new file mode 100644 index 00000000..e9ccc8ca --- /dev/null +++ b/docs/tutorials/configuration/index.rst @@ -0,0 +1,28 @@ +.. ngc-learn documentation master file, created by + sphinx-quickstart on Wed Apr 20 02:52:17 2022. + Note - This file needs to at least contain a root `toctree` directive. + +Configuration Basics +==================== + +This set of guides provide information about the fundamental building blocks that characterize the NGC-Learn as well +as the operation of NGC-Learn's back-end, NGC-Sim-Lib. +For end-users (experimentalists, engineers), the sections under "Building Blocks" will be most informative. For +developers, the sections under "Development Information" are recommended, particularly for advanced use-cases and +low-level development. + +.. toctree:: + :maxdepth: 2 + :caption: Building Blocks + + context + components + compartments + processes + +.. toctree:: + :maxdepth: 2 + :caption: Development Information + + global_state + compiling \ No newline at end of file diff --git a/docs/tutorials/foundations/processes.md b/docs/tutorials/configuration/processes.md similarity index 100% rename from docs/tutorials/foundations/processes.md rename to docs/tutorials/configuration/processes.md diff --git a/docs/tutorials/foundations.md b/docs/tutorials/foundations.md deleted file mode 100644 index 56b1fe49..00000000 --- a/docs/tutorials/foundations.md +++ /dev/null @@ -1,10 +0,0 @@ -# Foundational Elements - -In this set of tutorials/walkthroughs, we go through some of the core elements and mechanisms underlying NGC-Learn in order understand how its simulation scheme (and the nodes-and-cables system) works and to help in writing your own custom elements. - -The foundational walkthroughs are organized as follows: -1. [Using Model Contexts](../tutorials/foundations/context.md): This lesson goes the fundamentals of the primary simulation construct you need to set up models, the (simulation) context. -2. [Understanding Processes](../tutorials/foundations/processes.md): This lesson will walk you through the basics of a command -- an essential part of building a simulation controller in ngc-learn and ngcsimlib -- and offer some useful points for designing new ones. - diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst index 06ee94f1..5f02878c 100644 --- a/docs/tutorials/index.rst +++ b/docs/tutorials/index.rst @@ -2,28 +2,15 @@ sphinx-quickstart on Wed Apr 20 02:52:17 2022. Note - This file needs to at least contain a root `toctree` directive. -Tutorial Contents -================= +Modeling Basics +=============== Lessons/tutorials go through the very basics of constructing a dynamical system in NGC-Learn, core elements and tools of neurocognitive modeling using NGC-Learn's in-built components and simulation tools, and finally providing foundational insights into how NGC-Learn and its backend, NGC-Sim-Lib, work (particularly with respect to configuration). .. toctree:: - :maxdepth: 1 - :caption: I. Modeling Basics + :maxdepth: 2 + :caption: Table of Contents model_basics/configuration - model_basics/json_modules model_basics/model_building model_basics/evolving_synapses - -.. toctree:: - :maxdepth: 1 - :caption: II. NGC-Learn/Sim-Lib Foundations - - foundations - foundations/global_state - foundations/components - foundations/compartments - foundations/context - foundations/processes - foundations/compiling diff --git a/docs/tutorials/intro.md b/docs/tutorials/intro.md index b32cd184..e835819c 100755 --- a/docs/tutorials/intro.md +++ b/docs/tutorials/intro.md @@ -1,33 +1,27 @@ # Introduction -NGC-Learn is a general-purpose library for modeling complex dynamical systems, particularly those that are useful for computational neuroscience, neuroscience-motivated artificial intelligence (NeuroAI), and brain-inspired computing. +NGC-Learn is a general-purpose library for modeling complex dynamical systems, particularly those that are useful for +computational neuroscience, neuroscience-motivated artificial intelligence (NeuroAI), and brain-inspired computing. -While the library is designed to provide flexibility on the experimenter/designer side -- allowing one to develop their own dynamics and evolutionary processes -- at its foundation are a few standard components. These are basic modeling nodes for simulating some common biophysical systems computationally, which are useful to know when getting started and for quickly building some classical/historical models. If you are interested in knowing some of the neurophysiological theory behind NGC-Learn's design philosophy, [this section](../tutorials/theory) might be of interest. +While the library is designed to provide flexibility on the experimenter/designer side -- allowing one to develop their +own dynamics and evolutionary processes -- at its foundation are a few standard components. These are basic modeling +nodes for simulating some common biophysical systems computationally, which are useful to know when getting started and +for quickly building some classical/historical models. If you are interested in knowing some of the neurophysiological +theory behind NGC-Learn's design philosophy, [this section](../tutorials/theory) might be of interest. -Specifically, to make best use of NGC-Learn, it is important to get the hang of its "nodes-and-cables system" (the historical name for its backend engine) in order to build simulation objects. This set of tutorials will walk you through, step-by-step, the key aspects of the library that you will need to know so that you can build -and run simulations of computational biophysical models. In addition, we provide walkthroughs of some of the central mechanisms underlying NGC-Sim-Lib, the simulation dependency library that drives NGC-Learn; these lessons are particularly useful for not only understanding why and how things are done by NGC-Learn's simulation backend engine but also for those who want to design new, custom extensions of NGC-Learn either for their own research or to help contribute to the development of the main library. +Specifically, to make best use of NGC-Learn, it is important to get the hang of its "nodes-and-cables system" (the +historical name for its backend engine) in order to build simulation objects. This set of tutorials will walk you +through, step-by-step, the key aspects of the library that you will need to know so that you can build +and run simulations of computational biophysical models. In addition, we provide walkthroughs of some of the central +mechanisms underlying NGC-Sim-Lib, the simulation dependency +library that drives NGC-Learn; these lessons are particularly useful for not only understanding why and how things are +done by NGC-Learn's simulation backend engine but also for those who want to design new, custom extensions of NGC-Learn +either for their own research or to help contribute to the development of the main library. ## Organization of Tutorials -The core tutorials and lessons for using NGC-Learn can be found [here, in the tutorial table of contents](../tutorials/index.rst) which essentially go through: the basic configuration and use of NGC-Learn and NGC-Sim-Lib to construct simulations of dynamical systems, the essentials of neurocognitive modeling (such as building and analyzing models of neuronal dynamics and synaptic plasticity), as well as the coverage of some key foundational ideas/tools worth knowing about NGC-Learn (and its backend, NGC-Sim-Lib) particularly to facilitate easier debugging, experimental configuration, and advanced modeling tools. - - +The core tutorials and usage lessons for using NGC-Learn can be found [here, in the modeling basics table of contents](../tutorials/index.rst) which essentially go through: the basic configuration and use of NGC-Learn (and NGC-Sim-Lib) to +construct simulations of basic dynamical systems. +More advanced tutorials related to the essentials of neurocognitive modeling -- such as building and analyzing +neuroscience models of neuronal dynamics and synaptic plasticity -- can be found [here, in the neurocognitive modeling +table of contents](../tutorials/neurocog/index.rst). diff --git a/docs/tutorials/model_basics/configuration.md b/docs/tutorials/model_basics/configuration.md index 8646ea93..37de90b0 100644 --- a/docs/tutorials/model_basics/configuration.md +++ b/docs/tutorials/model_basics/configuration.md @@ -2,28 +2,22 @@ ## Basics -There are various global configurations that can be made to NGC-Sim-Lib, the -systems simulation backend for NGC-Learn. The primary built in use for a -configuration file is to modify the built-in logger. Generally to control the -configuration running any script with the flag -`--config="path/to/your/config.json`. +There are various global configurations that can be made to NGC-Sim-Lib, the systems simulation backend for NGC-Learn. +The primary use-case for a configuration file is to modify the library's built-in logger. Generally to control the +configuration, run any script (that uses NGC-Learn) with the flag `--config="path/to/your/config.json`. -The `config.json` file contains one large json object with sections set up for -different parts of the configuration, broke up into sub-objects. There is no -limit to the size or the number of these objects, meaning that the user is free -to define and use them as they so choose. +The `config.json` file contains one large JSON object with sections set up for different parts of the configuration, +broken up into sub-objects. There is no limit to the size or to the number of these objects; this means that the user +is free to define and use them as they so choose. ### Logging -The logging configuration mechanism sets up and controls the instance of the -python logger built into ngcsimlib. This mechanism (or JSON section) has three -values found within it. Specifically, `logging_level`, `logging_file`, -and `hide_console`. The logging levels are the same ones built into the python -logger and the value words used are either the standard Python string -representation of the level or the numeric equivalent. The `logging file`, if -defined, is a file that the logger will append all logging messages to for a -more permanent history of all messages. Finally, `hide console`, if set to true, -will hide all logging output to the console. +The logging configuration mechanism sets up and controls the instance of the Python logger built into ngcsimlib. This +mechanism (or JSON section) has three values found within it. Specifically, `logging_level`, `logging_file`, and +`hide_console`. The logging levels are the same ones built into the Python logger and the value words used are either +the standard Python string representation of the level (or the numeric equivalent). The `logging file`, if defined, is +a file that the logger will append all logging messages to in order to facilitate a more permanent history of all +messages. Finally, `hide console`, if set to true, will hide all the logging output to the console. > Default Config > ```json @@ -48,10 +42,9 @@ will hide all logging output to the console. ## Using a Configuration -To use a configuration, there are a few options. The first option is to simply -use the configuration as a python dictionary. This is done by importing -the `get_config` method from `ngclearn` and providing the name of -the configuration section to the method. +To use a configuration, there are a few options. The first option is to simply use the configuration as a Python +dictionary. This is done by importing the `get_config` method from `ngclearn` and providing the name of the +configuration section to the method. > Example get_config >```python @@ -61,14 +54,11 @@ the configuration section to the method. >level = loggerConfig['logging_level'] >``` -The other way you can access a configuration is through a provided namespace. -This makes use of python's `SimpleNamespace` to map all the dictionary's key -values to properties of an object to be used. One important note about -namespaces is that, unlike a python dictionary where the `get` method can be -provided a default value for missing keys, namespaces do not have this -functionality. Therefore, if keys are missing it has the potential to cause -errors. Below is an example of how one could use the namespace for logging -configuration. +The other way you can access a configuration is through a provided namespace. This makes use of Python's +`SimpleNamespace` to map all the dictionary's key values to properties of an object that is to be used. One +important note about namespaces is that, unlike a Python dictionary where the `get` method can be provided a default +value for missing keys, namespaces do not have this functionality. Therefore, missing keys create the potential +to cause errors. Below is an example of how one could use the namespace for a logging configuration. > Example provide_namespace > ```python diff --git a/docs/tutorials/model_basics/evolving_synapses.md b/docs/tutorials/model_basics/evolving_synapses.md index 7b194bb1..e68ef1de 100755 --- a/docs/tutorials/model_basics/evolving_synapses.md +++ b/docs/tutorials/model_basics/evolving_synapses.md @@ -1,27 +1,24 @@ -# Lesson 4: Evolving Synaptic Efficacies +# Lesson 3: Evolving Synaptic Efficacies -In this tutorial, we will extend a controller with three components, -two cell components connected with a synaptic cable component, to incorporate a -basic a two-factor Hebbian adjustment process. +In this tutorial, we will extend a model context/controller with three components, two cell components connected with a +synaptic cable component, to incorporate a basic a two-factor Hebbian adjustment process. ## Adding a Learnable Synapse to a Multi-Component System -Let us start by building a controller similar to previous lessons with the one -exception that now we will trigger the synaptic connection between `a` and `b` -to adapt via a simple 2-factor Hebbian rule. This Hebbian rule will require us -to wire the output compartment of `a` to the pre-synaptic compartment of the -synapse `Wab` and the output compartment of `b` to the post-synaptic -compartment of `Wab`. This will wire in the two relevant factors needed to +Create a Python script/file named `run_lesson3.py` to place/write your Python code below into. +Let us start by building a controller/model-context similar to previous lessons with the one exception that now we will +trigger the synaptic connection between `a` and `b` to adapt via a simple 2-factor Hebbian rule. This Hebbian rule will +require us to wire the output compartment of `a` to the pre-synaptic compartment of the synapse `Wab` and the output +compartment of `b` to the post-synaptic compartment of `Wab`. This will wire in the two relevant factors needed to compute a simple Hebbian adjustment. We do this specifically as follows: ```python from jax import numpy as jnp, random, jit -from ngcsimlib.context import Context -from ngclearn.utils import JaxProcess +from ngclearn import Context, MethodProcess from ngclearn.components import HebbianSynapse, RateCell -import ngclearn.utils.weight_distribution as dist +from ngclearn.utils.distribution_generator import DistributionGenerator as dist ## create seeding keys dkey = random.PRNGKey(1234) @@ -29,63 +26,57 @@ dkey, *subkeys = random.split(dkey, 6) ## create simple system with only one F-N cell with Context("Circuit") as circuit: - a = RateCell(name="a", n_units=1, tau_m=0., - act_fx="identity", key=subkeys[0]) - b = RateCell(name="b", n_units=1, tau_m=0., - act_fx="identity", key=subkeys[1]) - - Wab = HebbianSynapse(name="Wab", shape=(1, 1), eta=1., - sign_value=-1., weight_init=dist.constant(value=1.), - w_bound=0., key=subkeys[3]) - - # wire output compartment (rate-coded output zF) of RateCell `a` to input compartment of HebbianSynapse `Wab` - Wab.inputs << a.zF - # wire output compartment of HebbianSynapse `Wab` to input compartment (electrical current j) RateCell `b` - b.j << Wab.outputs - - # wire output compartment (rate-coded output zF) of RateCell `a` to presynaptic compartment of HebbianSynapse `Wab` - Wab.pre << a.zF - # wire output compartment (rate-coded output zF) of RateCell `b` to postsynaptic compartment of HebbianSynapse `Wab` - Wab.post << b.zF - - ## create and compile core simulation commands - evolve_process = (JaxProcess() - >> a.evolve) - circuit.wrap_and_add_command(jit(evolve_process.pure), name="evolve") - - advance_process = (JaxProcess() - >> a.advance_state) - circuit.wrap_and_add_command(jit(advance_process.pure), name="advance") - - reset_process = (JaxProcess() - >> a.reset) - circuit.wrap_and_add_command(jit(reset_process.pure), name="reset") - - ## set up non-compiled utility commands - @Context.dynamicCommand - def clamp(x): + a = RateCell(name="a", n_units=1, tau_m=0., act_fx="identity", key=subkeys[0]) + b = RateCell(name="b", n_units=1, tau_m=0., act_fx="identity", key=subkeys[1]) + + Wab = HebbianSynapse( + name="Wab", shape=(1, 1), eta=1., sign_value=-1., weight_init=dist.constant(value=1.), + w_bound=0., key=subkeys[3] + ) + + # wire output compartment (rate-coded output zF) of RateCell `a` to input compartment of HebbianSynapse `Wab` + a.zF >> Wab.inputs + # wire output compartment of HebbianSynapse `Wab` to input compartment (electrical current j) RateCell `b` + Wab.outputs >> b.j + + # wire output compartment (rate-coded output zF) of RateCell `a` to presynaptic compartment of HebbianSynapse `Wab` + a.zF >> Wab.pre + # wire output compartment (rate-coded output zF) of RateCell `b` to postsynaptic compartment of HebbianSynapse `Wab` + b.zF >> Wab.post + + ## create and compile core simulation commands + evolve = (MethodProcess("evolve") + >> a.evolve) + + advance = (MethodProcess("advance") + >> a.advance_state) + + reset = (MethodProcess("reset") + >> a.reset) + +## set up non-compiled utility commands +def clamp(x): a.j.set(x) ``` -Now with our simple system above created, we will now run a simple sequence -of one-dimensional "spike" data through it and evolve the synapse every time -step like so: +Now with our simple system above created, we will now run a simple sequence of one-dimensional "spike" data through it +and evolve the synapse every time step like so: ```python ## run some data through the dynamical system x_seq = jnp.asarray([[1, 1, 0, 0, 1]], dtype=jnp.float32) -circuit.reset() +reset.run() print("{}: Wab = {}".format(-1, Wab.weights.value)) for ts in range(x_seq.shape[1]): x_t = jnp.expand_dims(x_seq[0,ts], axis=0) ## get data at time t - circuit.clamp(x_t) - circuit.advance(t=ts*1., dt=1.) - circuit.evolve(t=ts*1., dt=1.) - print(" {}: input = {} ~> Wab = {}".format(ts, x_t, Wab.weights.value)) + clamp(x_t) + advance.run(t=ts*1., dt=1.) + evolve.run(t=ts*1., dt=1.) + print(" {}: input = {} ~> Wab = {}".format(ts, x_t, Wab.weights.get())) ``` -Your code should produce the same output (towards the bottom): +After running `run_lesson3.py`, your code should produce (printed to I/O) the same output as below: ```console -1: Wab = [[1.]] @@ -96,14 +87,11 @@ Your code should produce the same output (towards the bottom): 4: input = [1.] ~> Wab = [[8.]] ``` -Notice that for every non-spike (a value of `0`), the synaptic value remains -the same (because the product of a pre-synaptic value of `0` with a post-synaptic -value of anything -- in this case, also a `0` -- is simply `0`, meaning no -change will be applied to the synapse). For every spike (a value of `1`), we -get a synaptic change equal to `dW = input * (Wab * input)`; so for the -first time-step, the weight will change according to -`W = W + eta * dW = W + dW` and `dW = 1 * (1 * 1) = 1`, whereas, for the -second time-step, `W` will be increased by `dW = 1 * (2 * 1) = 2` (yielding a - new synaptic strength of `W = 4`). - -You have now created your first plastic, evolving neuronal system. +Notice that for every non-spike (a value of `0`), the synaptic value remains the same (because the product of a +pre-synaptic value of `0` with a post-synaptic value of anything -- in this case, also a `0` -- is simply `0`, meaning +that no change will be applied to the synapse). For every spike (a value of `1`), we get a synaptic change equal to +`dW = input * (Wab * input)`; so for the first time-step, the weight will change according to +`W = W + eta * dW = W + dW` and `dW = 1 * (1 * 1) = 1`, whereas, for the second time-step, `W` will be increased by +`dW = 1 * (2 * 1) = 2` (yielding a new synaptic strength of `W = 4`). + +As per the above, you have now created your first plastic, evolving neuronal system! diff --git a/docs/tutorials/model_basics/model_building.md b/docs/tutorials/model_basics/model_building.md index 2f431e5c..34c1be48 100755 --- a/docs/tutorials/model_basics/model_building.md +++ b/docs/tutorials/model_basics/model_building.md @@ -1,19 +1,19 @@ -# Lesson 3: Building a Model +# Lesson 2: Building a Model -In this tutorial, we will build a simple model made up of three components: -two simple graded cells that are connected by one synaptic cable. +In this tutorial, we will build a simple model made up of three components: two simple graded cells that are connected +by a single synaptic cable. ## Instantiating the Dynamical System as a Context -While building our dynamical system we will set up a Context and then add the three different components to it. +Create a file named `run_lesson2.py` to place/write your Python code below into. +While building our dynamical system we will set up a `Context` and then add the three different components to it, +like so: ```python from jax import numpy as jnp, random -from ngclearn import Context -from ngclearn.utils import JaxProcess -from ngcsimlib.compilers.process import Process +from ngclearn import Context, MethodProcess from ngclearn.components import RateCell, HebbianSynapse -import ngclearn.utils.weight_distribution as dist +from ngclearn.utils.distribution_generator import DistributionGenerator as dist ## create seeding keys dkey = random.PRNGKey(1234) @@ -21,99 +21,69 @@ dkey, *subkeys = random.split(dkey, 4) ## create simple dynamical system: a --> w_ab --> b with Context("model") as model: - a = RateCell(name="a", n_units=1, tau_m=0., - act_fx="identity", key=subkeys[0]) - b = RateCell(name="b", n_units=1, tau_m=20., - act_fx="identity", key=subkeys[1]) - Wab = HebbianSynapse(name="Wab", shape=(1, 1), - weight_init=dist.constant(value=1.), key=subkeys[2]) + a = RateCell(name="a", n_units=1, tau_m=0., act_fx="identity", key=subkeys[0]) + b = RateCell(name="b", n_units=1, tau_m=20., act_fx="identity", key=subkeys[1]) + Wab = HebbianSynapse(name="Wab", shape=(1, 1), weight_init=dist.constant(value=1.), key=subkeys[2]) ``` -Next, we will want to wire together the three components we have embedded into -our model, connecting `a` to node `b` through synaptic cable `Wab`. In -other words, this means that the output compartment of `a` must be wired to the -input compartment of transformation `Wab` and the output compartment of `Wab` -must be wired to the input compartment of `b`. In code, this is done as follows: +Next, we will want to wire together the three components we have embedded into our model, connecting `a` to node `b` +through synaptic cable `Wab`. In other words, this means that the output compartment of `a` (which, if one checks +the documentation for `a`, turns out to be `.zF`) must be wired to the input compartment of transformation `Wab` +(i.e., `.inputs`) and the output compartment of `Wab` (i.e., `.outputs`) must be wired to the input compartment +of `b` (i.e., `.j`). In code, this is done (within the `Context`-block) as follows: ```python - ## wire a to w_ab and wire w_ab to b - Wab.inputs << a.zF - b.j << Wab.outputs + ## wire a to w_ab and wire w_ab to b (a -> Wab -> b) + a.zF >> Wab.inputs + Wab.outputs >> b.j ``` -Finally, to make our dynamical system do something for each step of simulated -time, we must append a few basic processes -(see [Understanding Processes](../foundations/processes.md)) -to the context. -The commands we will want, as implied by our JSON configuration that we put -together at the start of this tutorial, include a `reset` (which will -initialize the compartments within each node to their resting values, -i.e., generally zero, if they have them -- this will only end up affecting -nodes `a` and `b` since a basic synapse component like `Wab` does not have a -base/resting value), an `advance` (which moves all the nodes one step -forward in time according to their compartments' ODEs), and `clamp` (which will -allow us to insert data into particular nodes). -This is simply done with the use of the following convenience function calls: - - - +Finally, to make our dynamical system do something for each step of simulated time, we must append a few basic +processes (see [Understanding Processes](../configuration/processes.md)) to the context. +The commands that we will (in general) want will include a `reset` (which will initialize the compartments within +each node to their "resting" values, i.e., generally zero, if they have them), an `advance` (which moves all the +nodes one step forward in time according to their compartments' differential equations/internal dynamics), and +`clamp` (which will allow us to insert data into particular nodes). +This is simply done by writing the following next (within the `Context`-block): ```python ## configure desired commands for simulation object - reset_process = (JaxProcess() - >> a.reset - >> Wab.reset - >> b.reset) - model.wrap_and_add_command(jit(reset_process.pure), name="reset") + reset = (MethodProcess("reset") + >> a.reset + >> Wab.reset + >> b.reset) - advance_process = (JaxProcess() - >> a.advance_state - >> Wab.advance_state - >> b.advance_state) - model.wrap_and_add_command(jit(advance_process.pure), name="advance") + advance = (MethodProcess("advance") + >> a.advance_state + >> Wab.advance_state + >> b.advance_state) - ## set up clamp as a non-compiled utility commands - @Context.dynamicCommand - def clamp(x): - a.j.set(x) +## set up clamp as a non-compiled utility commands (outside the context-block) +def clamp(x): + a.j.set(x) ## injects value/tensor x into compartment .j of component a ``` -## Running the Dynamical System's Controller +## Running the Dynamical System -With our simple 3-component dynamical system built, we may now run it on a -simple sequence of one-dimensional real-valued numbers: +With our simple 3-component dynamical system built, we may now apply and run it on a simple sequence of +one-dimensional real-valued numbers: ```python ## run some data through our simple dynamical system x_seq = jnp.asarray([[1., 2., 3., 4., 5.]], dtype=jnp.float32) -model.reset() +reset.run() for ts in range(x_seq.shape[1]): x_t = jnp.expand_dims(x_seq[0, ts], axis=0) ## get data at time ts - model.clamp_data(x_t) - model.advance(t=ts * 1., dt=1.) + clamp(x_t) + advance.run(t=ts * 1., dt=1.) ## naively extract simple statistics at time ts and print them to I/O - a_out = a.zF - b_out = b.zF + a_out = a.zF.get() + b_out = b.zF.get() print(" {}: a.zF = {} ~> b.zF = {}".format(ts, a_out, b_out)) ``` -and, assuming you place your code above in a Python script -(e.g., `run_lesson2.py`), we should obtain output in your terminal as below: +and, when running your Python script (i.e., `run_lesson2.py`), we should obtain output in your terminal as below: ```console $ python run_lesson2.py @@ -124,24 +94,17 @@ $ python run_lesson2.py 4: a.zF = [5.] ~> b.zF = [[0.75]] ``` -The simple 3-component system simulated above merely transforms the input -sequence into another time-evolving series. For the curious, in your code above, -you modeled a very simple non-leaky integration of cell `b` injected with some -value produced by `a` (since `Wab = 1`, the synapses had no effect and merely -copies the value along). While node `a` is always clamped to a value as per the -clamp command call we constructed and call above (even though its time constant -was `tau_m = 0` ms, meaning that it reduces to a stateless "feedforward" cell), -b had a time constant you set to `tau_m = 20` ms. This means, as can be confirmed -by inspecting the API for `RateCell`, with your integration time constant -`dt = 1` ms: - -1. at time step `ts = 0`, the value clamped to `a`, i.e., `1`, was multiplied by - `1/20 = 0.05` and then added `b`'s internal state (which started at the value - of `0` through the reset command called before the for-loop); -2. at step `ts = 1`, the value clamped to `a`, i.e., `2`, was multiplied by - `0.05` (yielding `0.1`) and then added to `b`'s current state -- meaning that - the new state becomes `0.05 + 0.1 = 0.15`; -3. at `ts = 2`, a value `3` is clamped to `a`, which is then multiplied by `0.05` - to yield `0.15` and then added to `b`'s current state -- meaning that the new - state is `0.15 + 0.15 = 0.3` - and so on and so forth (`b` acts like a non-decaying recurrently additive state). +The simple 3-component system simulated above merely transforms the input sequence into another time-evolving series. +For the curious, in your code above, you modeled a very simple non-leaky integration of cell `b` injected with some +value produced by `a` (since `Wab = 1`, the synapses had no effect and merely copies the value along). While node +`a` is always clamped to a value as per the clamp command call we constructed and call above (even though its +time constant was `tau_m = 0` ms, meaning that it reduces to a stateless "feedforward" cell), `b` had a time constant +you set to `tau_m = 20` ms. This means, as can be confirmed by inspecting the API for `RateCell`, with your integration time constant `dt = 1` ms: + +1. at time step `ts = 0`, the value clamped to `a`, i.e., `1`, was multiplied by `1/20 = 0.05` and then added + `b`'s internal state (which started at the value of `0` through the reset command called before the for-loop); +2. at step `ts = 1`, the value clamped to `a`, i.e., `2`, was multiplied by `0.05` (yielding `0.1`) and then added + to `b`'s current state -- meaning that the new state becomes `0.05 + 0.1 = 0.15`; +3. at `ts = 2`, a value `3` is clamped to `a`, which is then multiplied by `0.05` to yield `0.15` and then added to + `b`'s current state -- meaning that the new state is `0.15 + 0.15 = 0.3` and so on and so forth (`b` acts like a + non-decaying recurrently additive state). diff --git a/docs/tutorials/neurocog/index.rst b/docs/tutorials/neurocog/index.rst index 0f18f849..b702a535 100644 --- a/docs/tutorials/neurocog/index.rst +++ b/docs/tutorials/neurocog/index.rst @@ -5,31 +5,20 @@ Neurocognitive Modeling Lessons =============================== -A central motivation for using ngc-learn is to flexibly build computational -models of neuronal information processing, dynamics, and credit -assignment (as well as design one's own custom instantiations of their -mathematical formulations and ideas). In this set of tutorials, we will go -through the central basics of using ngc-learn's in-built biophysical components, -also called "cells" and "synapses", to craft and simulate adaptive neural systems -and biophysical computational models. +A central motivation for using ngc-learn is to flexibly build computational models of neuronal information processing, +dynamics, and credit assignment (as well as design custom instantiations of one's own mathematical formulations and +ideas). In this set of tutorials, we will go through the central basics of using ngc-learn's in-built biophysical +components, also called "cells" and "synapses", to craft and simulate adaptive neural systems and biophysical +computational models. -Usefully, ngc-learn starts with a collection of cells -- those that are partitioned -into those that are graded / real-valued (`ngclearn.components.neurons.graded`) -and those that spike (`ngclearn.components.neurons.spiking`). In addition, -ngc-learn supports another collection called synapses -- generally, those that -adapt (or "learn") with biological credit assignment building blocks -(such as those in `ngclearn.components.synapses.hebbian`) such as -spike-timing-dependent plasticity and multi-factor rules. With the in-built, -standard cells and synapses in these two -core collections, you can readily construct a wide variety of models, recovering -many classical ones previously proposed in computational neuroscience -and brain-inspired computing researach (many of these kinds of models are available -for external download in the `Model Museum `_). - -While the reader is free to jump into any one self-contained tutorial in any -order based on their needs, we organize, within each topic, the lessons starting -from more basic, foundational modeling modules and library tools and sequentially -work towards more advanced concepts. +Usefully, ngc-learn starts with a collection of cells -- those that are partitioned into those that are graded / +real-valued (`ngclearn.components.neurons.graded`) and those that spike (`ngclearn.components.neurons.spiking`). In +addition, ngc-learn supports another collection called synapses -- generally, those that adapt (or "learn") with +biological credit assignment building blocks (such as those in `ngclearn.components.synapses.hebbian`) such as +spike-timing-dependent plasticity and multi-factor rules. With the in-built, standard cells and synapses in these two +core collections, you can readily construct a wide variety of models, recovering many classical ones previously +proposed in computational neuroscience and brain-inspired computing research (many of these kinds of models are +available for external download in the `Model Museum `_). .. toctree:: :maxdepth: 1 From 68f44a30ddf12b33aee5c7280ed7d5ca940dfcec Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Fri, 5 Dec 2025 18:48:01 -0500 Subject: [PATCH 119/121] updates to museum doc for v3 --- docs/museum/event_stdp_patches.md | 13 ++ docs/museum/harmonium.md | 203 ++++++++++++++++++++++-------- docs/museum/index.rst | 1 + docs/museum/pc_rao_ballard1999.md | 12 +- docs/museum/pcn_discrim.md | 2 +- docs/museum/rl_snn.md | 26 ++-- docs/museum/sindy.md | 12 +- docs/museum/snn_bfa.md | 4 +- docs/museum/snn_dc.md | 2 +- docs/museum/sparse_coding.md | 2 +- 10 files changed, 187 insertions(+), 90 deletions(-) create mode 100644 docs/museum/event_stdp_patches.md diff --git a/docs/museum/event_stdp_patches.md b/docs/museum/event_stdp_patches.md new file mode 100644 index 00000000..f0759e27 --- /dev/null +++ b/docs/museum/event_stdp_patches.md @@ -0,0 +1,13 @@ +# Event-based Spike-Timing-Dependent Plasticity (Tavanaei et al.; 2018) + +In this exhibit, we create, simulate, and visualize the internally acquired receptive fields of the spiking neural +network (SNN) trained via event-based spike-timing-dependent plasticity (EV-STDP) over image patches. This +reproduces the SNN model originally proposed in (Tavanaei et al., 2018) [1]. + +The model code for this exhibit can be found +[here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/evstdp_patches). + + +## References +[1] Tavanaei, Amirhossein, Timothée Masquelier, and Anthony Maida. "Representation learning using event-based +STDP." Neural Networks 105 (2018): 294-303. \ No newline at end of file diff --git a/docs/museum/harmonium.md b/docs/museum/harmonium.md index ab55209e..a197d409 100644 --- a/docs/museum/harmonium.md +++ b/docs/museum/harmonium.md @@ -1,23 +1,32 @@ -# Harmoniums and Contrastive Divergence +# Harmoniums and Contrastive Divergence (Hinton; 1999) - -In NGC-Learn, it is possible to construct other forms of learning from the very base learning/plasticity components already in-built into the base library. Notably, a class of learning and inference systems that adapt through a process known as contrastive Hebbian learning (CHL) can be constructed and simulated with ngc-learn. +In NGC-Learn, it is possible to construct other forms of learning from the very base learning/plasticity components +already in-built into the base library. Notably, a class of learning and inference systems that adapt through a process +known as contrastive Hebbian learning (CHL) can be constructed and simulated with ngc-learn. -In this walkthrough, we will design a simple Harmonium, also known as the restricted Boltzmann machine (RBM). We will specifically focus on learning its synaptic connections with an algorithmic recipe known -as contrastive divergence (CD), which can be considered to be a stochastic form of CHL. After going through this exhibit, you will: +In this walkthrough, we will design a simple Harmonium, also known as the restricted Boltzmann machine (RBM). We will +specifically focus on learning its synaptic connections with an algorithmic recipe known as contrastive divergence (CD), +which can be considered to be a stochastic form of CHL. After going through this exhibit, you will: -1. Learn how to construct an `NGCGraph` that emulates the structure of an RBM and adapt the NGC settling process to calculate approximate synaptic weight gradients in accordance to contrastive divergence. -2. Simulate fantasized image samples using the block Gibbs sampler implicitly defined by the negative phase graph. +1. Learn how to construct an `NGCGraph` that emulates the structure of an RBM and adapt the NGC settling process to + calculate approximate synaptic weight gradients in accordance to contrastive divergence. +2. Simulate fantasized image samples using the block Gibbs sampler implicitly defined by the negative phase graph. Note that the folders of interest to this walkthrough are: -+ `ngc-museum/exhibits/harmonium/`: this contains the necessary simulation scripts (which can be found [here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/harmonium)); ++ `ngc-museum/exhibits/harmonium/`: this contains the necessary simulation scripts (which can be found + [here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/harmonium)); + `ngc-museum/data/mnist/`: this contains the zipped copy of the MNIST digit image arrays ## On the Harmonium Probabilistic Graphical Model -A harmonium is a generative model implemented as a stochastic, two-layer neural system (a type of probabilistic graphic model; PGM) that attempts to learn a probability distribution over sensory input $\mathbf{x}$, i.e., the goal of a harmonium is to learn $p(\mathbf{x})$, the underlying probability/likelihood of a given (training) dataset. Fundamentally, the approach to estimating $p(\mathbf{x})$ that carried out by a harmonium is through the optimization of an energy function $E(\mathbf{x})$ (a concept motivated by statistical mechanics), where the system searches for an internal configuration, i.e., the values of its synapses, that assigns low energy (values) to sample patterns that come from the true data distribution $p(\mathbf{x})$ and high energy (values) to patterns that do not (or those that do not come from the training dataset). +A harmonium is a generative model implemented as a stochastic, two-layer neural system (a type of probabilistic graphic +model; PGM) that attempts to learn a probability distribution over sensory input $\mathbf{x}$, i.e., the goal of a +harmonium is to learn $p(\mathbf{x})$, the underlying probability/likelihood of a given (training) dataset. +Fundamentally, the approach to estimating $p(\mathbf{x})$ that carried out by a harmonium is through the optimization +of an energy function $E(\mathbf{x})$ (a concept motivated by statistical mechanics), where the system searches for an +internal configuration, i.e., the values of its synapses, that assigns low energy (values) to sample patterns that come +from the true data distribution $p(\mathbf{x})$ and high energy (values) to patterns that do not (or those that do not +come from the training dataset). ```{eval-rst} .. table:: @@ -30,7 +39,13 @@ A harmonium is a generative model implemented as a stochastic, two-layer neural +-----------------------------------------------------------------+ ``` -The most common, simplest harmonium is one where input nodes (one per dimension of the data observation space) are modeled as binary/Boolean sensors -- or "visible units" $\mathbf{z}^0$ (observed variables) that are clamped to actual data patterns -- connected to a layer of (stochastic) binary latent feature detectors -- or "hidden units" $\mathbf{z}^1$ (unobserved or latent variables). Notably, the synaptic connections between the latent and visible units are symmetric. Furthermore, as a result of a key restriction imposed on the harmonium's network structure, i.e., no lateral connections between the neurons/units within $\mathbf{z}^0$ as well as those within $\mathbf{z}^1$, computing the latent and visible states is as straightforward as the following: +The most common, simplest harmonium is one where input nodes (one per dimension of the data observation space) are +modeled as binary/Boolean sensors -- or "visible units" $\mathbf{z}^0$ (observed variables) that are clamped to actual +data patterns -- connected to a layer of (stochastic) binary latent feature detectors -- or "hidden units" +$\mathbf{z}^1$ (unobserved or latent variables). Notably, the synaptic connections between the latent and visible units +are symmetric. Furthermore, as a result of a key restriction imposed on the harmonium's network structure, i.e., no +lateral connections between the neurons/units within $\mathbf{z}^0$ as well as those within $\mathbf{z}^1$, computing +the latent and visible states is as straightforward as the following: $$ p(\mathbf{z}^1 | \mathbf{z}^0) &= sigmoid(\mathbf{W} \cdot \mathbf{z}^0 + \mathbf{b}), @@ -40,49 +55,76 @@ p(\mathbf{z}^0 | \mathbf{z}^1) &= sigmoid(\mathbf{W}^T \cdot \mathbf{z}^1 + \mat $$ where $\mathbf{c}$ is the visible bias vector, $\mathbf{b}$ is the latent bias vector, -and $\mathbf{W}$ is the synaptic weight matrix that connects $\mathbf{z}^0$ to $\mathbf{z}^1$ (and its transpose $\mathbf{W}^T$ is used to make predictions of the input itself). Note that $\cdot$ means matrix/vector multiplication and $\sim$ denotes that we would sample from a probability (vector). In the above harmonium's case, samples will be drawn treating the conditionals such as $p(\mathbf{z}^1 | \mathbf{z}^0)$ as multivariate Bernoulli distributions. +and $\mathbf{W}$ is the synaptic weight matrix that connects $\mathbf{z}^0$ to $\mathbf{z}^1$ (and its transpose +$\mathbf{W}^T$ is used to make predictions of the input itself). Note that $\cdot$ means matrix/vector multiplication +and $\sim$ denotes that we would sample from a probability (vector). In the above harmonium's case, samples will be +drawn treating the conditionals such as $p(\mathbf{z}^1 | \mathbf{z}^0)$ as multivariate Bernoulli distributions. $\mathbf{z}^0$ would typically be clamped/set to the actual sensory input data $\mathbf{x}$. -The energy function of the harmonium's joint configuration $(\mathbf{z}^0,\mathbf{z}^1)$ (similar to that of a Hopfield network) is specified as follows: +The energy function of the harmonium's joint configuration $(\mathbf{z}^0,\mathbf{z}^1)$ (similar to that of a Hopfield +network) is specified as follows: $$ E(\mathbf{z}^0,\mathbf{z}^1) = -\sum_i \mathbf{c}_i \mathbf{z}^0_i - \sum_j \mathbf{b}_j \mathbf{z}^1_j - \sum_i \sum_j \mathbf{z}^0_i \mathbf{W}_{ij} \mathbf{z}^1_j . $$ -Notice that, in the equation above, we sum over vector dimension indices, e.g., $\mathbf{z}^0_i$ retrieves the $i$th scalar element of (vector) $\mathbf{z}^0$ while $\mathbf{W}_{ij}$ retrieves the scalar element at position $(i,j)$ within matrix $\mathbf{W}$. With this energy function, one can write out the probability that a harmonium PGM assigns to a data point as: +Notice that, in the equation above, we sum over vector dimension indices, e.g., $\mathbf{z}^0_i$ retrieves the $i$th +scalar element of (vector) $\mathbf{z}^0$ while $\mathbf{W}_{ij}$ retrieves the scalar element at position $(i,j)$ +within matrix $\mathbf{W}$. With this energy function, one can write out the probability that a harmonium PGM assigns +to a data point as: $$ p(\mathbf{z}^0 = \mathbf{x}) = \frac{1}{Z} \exp( -E(\mathbf{z}^0,\mathbf{z}^1) ) $$ -where $Z$ is the normalizing constant (or, in statistical mechanics, the partition function) needed to obtain proper probability values[^1]. -When one works through the derivation of the gradient of the log probability $\log p(\mathbf{x})$ with respect to the synapses such as $\mathbf{W}$, they get a (contrastive) Hebbian-like update rule as follows: +where $Z$ is the normalizing constant (or, in statistical mechanics, the partition function) needed to obtain +proper probability values[^1]. +When one works through the derivation of the gradient of the log probability $\log p(\mathbf{x})$ with respect to the +synapses such as $\mathbf{W}$, they get a (contrastive) Hebbian-like update rule as follows: $$ \Delta \mathbf{W}_{ij} = <\mathbf{z}^0_i \mathbf{z}^1_j>_{data} - <\mathbf{z}^0_i \mathbf{z}^1_j>_{model} $$ -where the angle brackets $< >$ tell us that we need to take the expectation of the values within the brackets under a certain distribution (such as the data distribution denoted by the subscript $data$). The above rule can also be considered to be a stochastic form of a general recipe known as contrastive Hebbian learning (CHL) [4]. +where the angle brackets $< >$ tell us that we need to take the expectation of the values within the brackets under a +certain distribution (such as the data distribution denoted by the subscript $data$). The above rule can also be +considered to be a stochastic form of a general recipe known as contrastive Hebbian learning (CHL) [4]. Technically, to compute the update above, obtaining the first term -$<\mathbf{z}^0_i \mathbf{z}^1_j>_{data}$ is easy since we only need to take the product of a data point and its corresponding latent state under the harmonium. However, obtaining the second term $<\mathbf{z}^0_i \mathbf{z}^1_j>_{model}$ is very costly, since we would need to -initialize the value of $\mathbf{z}^0$ to a random initial state and then run a (block) Gibbs sampler for many iterations to accurately approximate the second term. Fortunately, it was shown in work such as [3], that learning a harmonium is still possible by replacing the term $<\mathbf{z}^0_i \mathbf{z}^1_j>_{model}$ with $<\mathbf{z}^0_i \mathbf{z}^1_j>_{recon}$, which is simply computed by using the -first term's latent state $\mathbf{z}^1$ to reconstruct the input and then using this reconstruction once more in order to obtain its corresponding binary latent state. This is known as "contrastive divergence" (CD-1), and, although this approximation has been shown to not actual follow the gradient of any known objective function, it works well in practice when learning a harmonium-based generative model. Finally, the vectorized form of the CD-1 update is: +$<\mathbf{z}^0_i \mathbf{z}^1_j>_{data}$ is easy since we only need to take the product of a data point and its +corresponding latent state under the harmonium. However, obtaining the second term +$<\mathbf{z}^0_i \mathbf{z}^1_j>_{model}$ is very costly, since we would need to +initialize the value of $\mathbf{z}^0$ to a random initial state and then run a (block) Gibbs sampler for many +iterations to accurately approximate the second term. Fortunately, it was shown in work such as [3], that learning a +harmonium is still possible by replacing the term $<\mathbf{z}^0_i \mathbf{z}^1_j>_{model}$ with +$<\mathbf{z}^0_i \mathbf{z}^1_j>_{recon}$, which is simply computed by using the first term's latent state +$\mathbf{z}^1$ to reconstruct the input and then using this reconstruction once more in order to obtain its +corresponding binary latent state. This is known as "contrastive divergence" (CD-1), and, although this approximation +has been shown to not actual follow the gradient of any known objective function, it works well in practice when +learning a harmonium-based generative model. Finally, the vectorized form of the CD-1 update is: $$ \Delta \mathbf{W} = \Big[ (\mathbf{z}^0_{pos})^T \cdot \mathbf{z}^1_{pos} \Big] - \Big[ (\mathbf{z}^0_{neg})^T \cdot \mathbf{z}^1_{neg} \Big] $$ -where the first term (in brackets) is labeled as the "positive phase" (or the positive, data-dependent statistics -- where $\mathbf{z}^0_{pos}$ denotes the positive phase sample of $\mathbf{z}^0$) while the second term is labeled as the "negative phase" (or the negative, data-independent statistics -- where $\mathbf{z}^0_{neg}$ denotes the negative phase sample of $\mathbf{z}^0$). Note that simpler rules of a similar form can be worked out for the latent/visible bias vectors as well. +where the first term (in brackets) is labeled as the "positive phase" (or the positive, data-dependent statistics -- +where $\mathbf{z}^0_{pos}$ denotes the positive phase sample of $\mathbf{z}^0$) while the second term is labeled as the +"negative phase" (or the negative, data-independent statistics -- where $\mathbf{z}^0_{neg}$ denotes the negative phase +sample of $\mathbf{z}^0$). Note that simpler rules of a similar form can be worked out for the latent/visible bias +vectors as well. -In NGC-Learn, to simulate the above harmonium PGM and its CD-1 update, we will model the positive and negative phases as simulated co-models, each responsible for producing the relevant statistics that we will require in order to adjust synapses. Additionally, we will find that we can further re-purpose the created co-models to construct a block Gibbs sampler for confabulating "fantasized" +In NGC-Learn, to simulate the above harmonium PGM and its CD-1 update, we will model the positive and negative phases +as simulated co-models, each responsible for producing the relevant statistics that we will require in order to adjust +synapses. Additionally, we will find that we can further re-purpose the created co-models to construct a block Gibbs +sampler for confabulating "fantasized" data patterns from a harmonium that has been fit to data. ## Boltzmann Machines: Positive and Negative Co-Models -We begin by first specifying the structure of the harmonium system that we would like to simulate. In NGC shorthand, the above positive and negative phase graphs would simply be (under one complete generative model): +We begin by first specifying the structure of the harmonium system that we would like to simulate. In NGC shorthand, +the above positive and negative phase graphs would simply be (under one complete generative model): ``` z0 -(z0-z1)-> z1 @@ -90,7 +132,11 @@ z1 -(z1-z0) -> z0 Note: z1-z0 = (z0-z1)^T (transpose-tied synapses) ``` -In order to construct the desired harmonium, particularly the structure needed to implement CD-1, we will need to break up the model into its key "phases", i.e., a positive phase and a negative phase. We will model each phase as its own simulated nodes-and-cables structure within one single model context, allowing us to craft a general approach that permits a CD-based learning. Notably, we will use the negative-phase co-model to emulate the crucial MCMC sampling step to synthesize data from the trained RBM. +In order to construct the desired harmonium, particularly the structure needed to implement CD-1, we will need to break +up the model into its key "phases", i.e., a positive phase and a negative phase. We will model each phase as its own +simulated nodes-and-cables structure within one single model context, allowing us to craft a general approach that +permits a CD-based learning. Notably, we will use the negative-phase co-model to emulate the crucial MCMC sampling step +to synthesize data from the trained RBM. Building the positive phase of our harmonium can be done as follows: @@ -109,7 +155,9 @@ with Context("Circuit") as self.circuit: self.W1.outputs >> self.z1.inputs ``` -To gather the rest of the statistics that we require, we will need to build the negative phase of our model (which is responsible for "dreaming up" or "confabulating" data samples from its internal model of the world). Constructing the negative-phase co-model, under the same model `Context` above can be done as follows: +To gather the rest of the statistics that we require, we will need to build the negative phase of our model (which is +responsible for "dreaming up" or "confabulating" data samples from its internal model of the world). Constructing the +negative-phase co-model, under the same model `Context` above can be done as follows: ```python ## set up negative-phase graph @@ -135,9 +183,12 @@ To gather the rest of the statistics that we require, we will need to build the self.V1.outputs >> self.z1neg.inputs ``` -The above chunk of code effectively sets up the propagation of information from the latent neurons `z1` back down to `z0` (obtaining the negative phase values of `z0`, i.e., `z0neg`) and then the propagation of the reconstructed values back up to `z1` one last time (obtaining the negative phase values of `z1`, i.e., `z0neg`). +The above chunk of code effectively sets up the propagation of information from the latent neurons `z1` back down to +`z0` (obtaining the negative phase values of `z0`, i.e., `z0neg`) and then the propagation of the reconstructed values +back up to `z1` one last time (obtaining the negative phase values of `z1`, i.e., `z0neg`). -To build a CHL-based form of plasticity, allowing us to build the CD-1 learning process, we will then need to wire up a set of 2-factor Hebbian rules like so: +To build a CHL-based form of plasticity, allowing us to build the CD-1 learning process, we will then need to wire up a +set of 2-factor Hebbian rules like so: ```python ## set up contrastive Hebbian learning rule (pos-stats - neg-stats) @@ -147,23 +198,39 @@ To build a CHL-based form of plasticity, allowing us to build the CD-1 learning self.z1neg.p >> self.V1.post ## negative-phase pre-synaptic term ``` -the results of these two Hebbian rules are then used in an exhibit-specific function (`_update_via_CHL()`) written in the [`Harmonium` class](https://github.com/NACLab/ngc-museum/blob/v3/exhibits/harmonium/harmonium.py). -While we observe that our "negative phase" co-model allows us to emulate the CD learning recipe[^2], technically, the negative phase of a harmonium should be run for a very high value of steps (approaching infinity) in order to obtain a proper sample from the PGM's equilibrium/steady state distribution. However, this would be extremely costly to simulate and, as early studies [3] observed, often only a few or even a single step of this Markov chain proved to work quite well, approximating the contrastive divergence objective (the learning algorithm's namesake) instead of direct maximum likelihood. +the results of these two Hebbian rules are then used in an exhibit-specific function (`_update_via_CHL()`) written in +the [`Harmonium` class](https://github.com/NACLab/ngc-museum/blob/v3/exhibits/harmonium/harmonium.py). +While we observe that our "negative phase" co-model allows us to emulate the CD learning recipe[^2], technically, the +negative phase of a harmonium should be run for a very high value of steps (approaching infinity) in order to obtain a +proper sample from the PGM's equilibrium/steady state distribution. However, this would be extremely costly to simulate +and, as early studies [3] observed, often only a few or even a single step of this Markov chain proved to work quite +well, approximating the contrastive divergence objective (the learning algorithm's namesake) instead of direct +maximum likelihood. -Note that the full code, containing the snippets above, can be found in the Model Museum `Harmonium` model structure class. One could further generalize our CD-1 framework to variations, such as "persistent" CD (where we, instead of running `z1` back down through `E1` synapses, we inject random noise instead (to sample the harmonium's latent prior), or even an algorithm known as parallel tempering, where we would maintain multiple co-models and draw samples from all of them to obtain negative-phase statistics. +Note that the full code, containing the snippets above, can be found in the Model Museum `Harmonium` model structure +class. One could further generalize our CD-1 framework to variations, such as "persistent" CD (where we, instead of +running `z1` back down through `E1` synapses, we inject random noise instead (to sample the harmonium's latent prior), +or even an algorithm known as parallel tempering, where we would maintain multiple co-models and draw samples from +all of them to obtain negative-phase statistics. -Finally, within the `Harmonium` class, we have written a routine for drawing samples from the model directly, i.e., we implement a block Gibbs sampler in order synthesize data from the RBM's current set of parameters. +Finally, within the `Harmonium` class, we have written a routine for drawing samples from the model directly, i.e., we +implement a block Gibbs sampler in order synthesize data from the RBM's current set of parameters. ## Using the Harmonium to Dream Up Handwritten Digits -We finally take the harmonium that we have constructed above and fit it to some MNIST digits. Specifically, we will leverage the [Harmonium](https://github.com/NACLab/ngc-museum/blob/v3/exhibits/harmonium/harmonium.py), model in the Model Museum since it implements all of the above core mechanisms (and more) internally. In the script `sim_harmonium.py`, you will find a general training that will fit our harmonium to the MNIST database (unzip the file `mnist.zip` in the `ngc-museum/exhibits/data/` directory if you have not already) by cycling through it several times, saving the final +We finally take the harmonium that we have constructed above and fit it to some MNIST digits. Specifically, we will +leverage the [Harmonium](https://github.com/NACLab/ngc-museum/blob/v3/exhibits/harmonium/harmonium.py), model in the Model Museum since it implements all of the above core mechanisms (and +more) internally. In the script `sim_harmonium.py`, you will find a general training that will fit our harmonium to +the MNIST database (unzip the file `mnist.zip` in the `ngc-museum/exhibits/data/` directory if you have not already) +by cycling through it several times, saving the final (best) resulting to disk within the `exp/` sub-directory. Go ahead and execute the training process as follows: ```console $ python sim_harmonium.py ``` -which will fit/adapt your harmonium to MNIST. This should produce per-training iteration output, printed to I/O, similar to the following: +which will fit/adapt your harmonium to MNIST. This should produce per-training iteration output, printed to I/O, +similar to the following: ```console --- Initial RBM Synaptic Stats --- @@ -201,26 +268,40 @@ b1: min -7.5815 ; max 0.2337 mu -2.3395 ; norm 53.3993 c0: min -11.6316 ; max -2.4227 mu -5.3259 ; norm 161.5646 ``` -You will find, after the training script has finished executing, several outputs in the `exp/filters/` model sub-directory that is created for you. Concretely, you will find a grid-plot of the (first `100` of the) harmonium's acquired filters (or "receptive fields"), much as we did for the sparse coding exhibit, that will look similar to the following: +You will find, after the training script has finished executing, several outputs in the `exp/filters/` model +sub-directory that is created for you. Concretely, you will find a grid-plot of the (first `100` of the) harmonium's +acquired filters (or "receptive fields"), much as we did for the sparse coding exhibit, that will look similar to +the following: -Interestingly enough, we see that our harmonium/RBM has extracted what appears to be rough stroke features, which is what it uses when sampling its binary latent feature detectors to compose final synthesized image patterns (each binary feature detector serves as Boolean function that emits a decision of `1` if the filter is to be used and a `0` if not). In particular, we remark notice that the filters that our harmonium has acquired are a bit more prominent due to the fact our exhibit employs some weight decay (specifically, Gaussian/L2 decay -- with intensity `l2_lambda=0.01` -- to the `W1` synaptic matrix of our RBM). -Weight decay of this form is particularly useful to not only mitigate against the harmonium overfitting to its training data but also to ensure that the Markov chain inherent to its negative-phase mixes more effectively [5] (which ensures better-quality samples from the block Gibbs sampler, which we will use next). +Interestingly enough, we see that our harmonium/RBM has extracted what appears to be rough stroke features, which is +what it uses when sampling its binary latent feature detectors to compose final synthesized image patterns (each +binary feature detector serves as Boolean function that emits a decision of `1` if the filter is to be used and a `0` +if not). In particular, we remark notice that the filters that our harmonium has acquired are a bit more prominent due +to the fact our exhibit employs some weight decay (specifically, Gaussian/L2 decay -- with intensity +`l2_lambda=0.01` -- to the `W1` synaptic matrix of our RBM). +Weight decay of this form is particularly useful to not only mitigate against the harmonium overfitting to its training +data but also to ensure that the Markov chain inherent to its negative-phase mixes more effectively [5] (which ensures +better-quality samples from the block Gibbs sampler, which we will use next). -Finally, you will also find in the `exp/filters/` model sub-folder another grid-plot containing some (about `100`) of the RBM's reconstructions of held-out development data. This plot should look similar to the one below: +Finally, you will also find in the `exp/filters/` model sub-folder another grid-plot containing some (about `100`) of +the RBM's reconstructions of held-out development data. This plot should look similar to the one below: ### Sampling the Harmonium -Once the training process has completed, you can then run the following to sample from trained model using block Gibbs sampling: +Once the training process has completed, you can then run the following to sample from trained model using block Gibbs +sampling: ```console $ python sample_harmonium.py ``` -which will take your trained harmonium's negative-phase co-model and use it to synthesize some digit patterns. You should see inside the `exp/samples/` sub-directory three sample-image grids (i.e., `samples_0.jpg`, `samples_1.jpg`, and `samples_2.jpg`) similar to what is shown below: +which will take your trained harmonium's negative-phase co-model and use it to synthesize some digit patterns. You +should see inside the `exp/samples/` sub-directory three sample-image grids (i.e., `samples_0.jpg`, `samples_1.jpg`, +and `samples_2.jpg`) similar to what is shown below: ```{eval-rst} .. image:: ../images/museum/harmonium/samples_0.jpg @@ -231,7 +312,9 @@ which will take your trained harmonium's negative-phase co-model and use it to s :width: 30% ``` -Furthermore, you will see three corresponding GIFs that have been generated for you that visualize how each of the three simulated sampling Markov chains change with time (i.e., these are the files: `markov_chain_0.gif`, `markov_chain_1.gif`, and `markov_chain_2.gif`). +Furthermore, you will see three corresponding GIFs that have been generated for you that visualize how each of the +three simulated sampling Markov chains change with time (i.e., these are the files: `markov_chain_0.gif`, +`markov_chain_1.gif`, and `markov_chain_2.gif`). -It is important to understand that the three grids of samples shown above come from particular points in the block Gibbs sampling process. -(Note that one reads these sample grid plots left-column to right-column, and top-row to bottom-row; this way of reading the plot follows the ordering of samples extracted from the specific Markov chain sequence.) -Note that, although each chain is run for many total steps, the `sample_harmonium.py` script "thins" out each Markov chain by only pulling out a fantasized pattern every `20` steps (further "burning" in each chain before collecting samples). Each chain is merely initialized with random Bernoulli noise. Note that higher-quality samples can be obtained if one modifies the earlier harmonium to learn with persistent CD or parallel tempering. +It is important to understand that the three grids of samples shown above come from particular points in the block +Gibbs sampling process. +(Note that one reads these sample grid plots left-column to right-column, and top-row to bottom-row; this way of +reading the plot follows the ordering of samples extracted from the specific Markov chain sequence.) +Note that, although each chain is run for many total steps, the `sample_harmonium.py` script "thins" out each Markov +chain by only pulling out a fantasized pattern every `20` steps (further "burning" in each chain before collecting +samples). Each chain is merely initialized with random Bernoulli noise. Note that higher-quality samples can be +obtained if one modifies the earlier harmonium to learn with persistent CD or parallel tempering. ### Final Notes -The harmonium that we have built in this exhibit is a classical Bernoulli harmonium/RBM, which is a neural PGM that assumes that the input data features are binary in nature. If one wants to model data that is continuous/real-valued, then the harmonium model above would need to be extended to utilize visible units that follow a continuous distribution; for instance, if one modeled a multivariate Gaussian distribution, this would yield a Gaussian restricted Boltzmann machine (GRBM). +The harmonium that we have built in this exhibit is a classical Bernoulli harmonium/RBM, which is a neural PGM that +assumes that the input data features are binary in nature. If one wants to model data that is continuous/real-valued, +then the harmonium model above would need to be extended to utilize visible units that follow a continuous +distribution; for instance, if one modeled a multivariate Gaussian distribution, this would yield a Gaussian restricted +Boltzmann machine (GRBM). ## References -[1] Smolensky, P. "Information Processing in Dynamical Systems: Foundations of Harmony Theory" (Chapter 6). Parallel distributed processing: explorations in the microstructure of cognition 1 (1986).
-[2] Geoffrey Hinton. Products of Experts. International conference on artificial neural networks (1999).
-[3] Hinton, Geoffrey E. "Training products of experts by maximizing contrastive likelihood." Technical Report, Gatsby computational neuroscience unit (1999).
-[4] Movellan, Javier R. "Contrastive Hebbian learning in the continuous Hopfield model." Connectionist models. Morgan Kaufmann, 1991. 10-17.
-[5] Hinton, Geoffrey E. "A practical guide to training restricted Boltzmann machines." Neural networks: Tricks of the trade. Springer, Berlin, Heidelberg, 2012. 599-619. +[1] Smolensky, P. "Information Processing in Dynamical Systems: Foundations of Harmony Theory" (Chapter 6). Parallel +distributed processing: explorations in the microstructure of cognition 1 (1986).
+[2] Hinton, Geoffrey. Products of Experts. International conference on artificial neural networks (1999).
+[3] Hinton, Geoffrey E. "Training products of experts by maximizing contrastive likelihood." Technical Report, Gatsby +computational neuroscience unit (1999).
+[4] Movellan, Javier R. "Contrastive Hebbian learning in the continuous Hopfield model." Connectionist models. Morgan +Kaufmann, 1991. 10-17.
+[5] Hinton, Geoffrey E. "A practical guide to training restricted Boltzmann machines." Neural networks: Tricks of the +trade. Springer, Berlin, Heidelberg, 2012. 599-619. -[^1]: In fact, it is intractable to compute the partition function $Z$ for any reasonably-sized harmonium; fortunately, we will not need to calculate $Z$ in order to learn and sample from a Harmonium. -[^2]: In general, CD-1 means contrastive divergence where the negative phase is only run for one single step, i.e., `K=1`. The more general form of CD is known as CD-K, the K-step CD algorithm where `K > 1`. (Sometimes, CD-1 is just referred to as just "CD".) +[^1]: In fact, it is intractable to compute the partition function $Z$ for any reasonably-sized harmonium; fortunately, +we will not need to calculate $Z$ in order to learn and sample from a Harmonium. +[^2]: In general, CD-1 means contrastive divergence where the negative phase is only run for one single step, i.e., +`K=1`. The more general form of CD is known as CD-K, the K-step CD algorithm where `K > 1`. (Sometimes, CD-1 is just +referred to as just "CD".) diff --git a/docs/museum/index.rst b/docs/museum/index.rst index cab00d0f..a0b58557 100644 --- a/docs/museum/index.rst +++ b/docs/museum/index.rst @@ -17,6 +17,7 @@ of detailed walkthroughs presented in the table of contents below.) sparse_coding pc_rao_ballard1999 snn_dc + event_stdp_patches rl_snn .. toctree:: diff --git a/docs/museum/pc_rao_ballard1999.md b/docs/museum/pc_rao_ballard1999.md index c6de5a76..cfa6b921 100644 --- a/docs/museum/pc_rao_ballard1999.md +++ b/docs/museum/pc_rao_ballard1999.md @@ -1,12 +1,12 @@ -# Hierarchical Predictive Coding (Rao & Ballard) +# Hierarchical Predictive Coding (Rao & Ballard; 1999) -In this exhibit, we create, simulate, and visualize the -internally acquired receptive fields of the predictive coding model originally proposed in (Rao & Ballard, 1999) [1]. +In this exhibit, we create, simulate, and visualize the internally acquired receptive fields of the predictive coding +model originally proposed in (Rao & Ballard, 1999) [1]. -The model code for this -exhibit can be found +The model code for this exhibit can be found [here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/pc_recon). ## References -[1] Rao, Rajesh PN, and Dana H. Ballard. "Predictive coding in the visual cortex: a functional interpretation of some extra-classical receptive-field effects." Nature neuroscience 2.1 (1999): 79-87. \ No newline at end of file +[1] Rao, Rajesh PN, and Dana H. Ballard. "Predictive coding in the visual cortex: a functional interpretation of +some extra-classical receptive-field effects." Nature neuroscience 2.1 (1999): 79-87. \ No newline at end of file diff --git a/docs/museum/pcn_discrim.md b/docs/museum/pcn_discrim.md index 85cc3756..66d9642b 100644 --- a/docs/museum/pcn_discrim.md +++ b/docs/museum/pcn_discrim.md @@ -1,4 +1,4 @@ -# Discriminative Predictive Coding +# Discriminative Predictive Coding (Whittington & Bogacz; 2017) In this exhibit, we will see how a classifier can be created based on predictive coding. This exhibit model effectively reproduces some of the results diff --git a/docs/museum/rl_snn.md b/docs/museum/rl_snn.md index 24d11412..6e1c2876 100644 --- a/docs/museum/rl_snn.md +++ b/docs/museum/rl_snn.md @@ -1,13 +1,10 @@ -# Reinforcement Learning through a Spiking Controller +# Reinforcement Learning through a Spiking Controller (Chevtchenko et al.; 2020) -In this exhibit, we will see how to construct a simple biophysical model for -reinforcement learning with a spiking neural network and modulated -spike-timing-dependent plasticity. -This model incorporates a mechanisms from several different models, including -the constrained RL-centric SNN of [1] as well as the simplifications -made with respect to the model of [2]. The model code for this -exhibit can be found -[here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/rl_snn). +In this exhibit, we will see how to construct a simple biophysical model for reinforcement learning with a spiking +neural network and modulated spike-timing-dependent plasticity. +This model incorporates a mechanisms from several different models, including the constrained RL-centric SNN of +[1] as well as some simplifications of the structures used within the SNN of [2]. The model code for this +exhibit can be found [here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/rl_snn). ## Modeling Operant Conditioning through Modulation @@ -123,10 +120,7 @@ RL-SNN model: ## References -[1] Chevtchenko, Sérgio F., and Teresa B. Ludermir. "Learning from sparse -and delayed rewards with a multilayer spiking neural network." 2020 International -Joint Conference on Neural Networks (IJCNN). IEEE, 2020.
-[2] Diehl, Peter U., and Matthew Cook. "Unsupervised learning of digit -recognition using spike-timing-dependent plasticity." Frontiers in computational -neuroscience 9 (2015): 99. - +[1] Chevtchenko, Sérgio F., and Teresa B. Ludermir. "Learning from sparse and delayed rewards with a multilayer +spiking neural network." 2020 International Joint Conference on Neural Networks (IJCNN). IEEE, 2020.
+[2] Diehl, Peter U., and Matthew Cook. "Unsupervised learning of digit recognition using spike-timing-dependent +plasticity." Frontiers in computational neuroscience 9 (2015): 99. diff --git a/docs/museum/sindy.md b/docs/museum/sindy.md index 04426d70..9245382d 100644 --- a/docs/museum/sindy.md +++ b/docs/museum/sindy.md @@ -1,14 +1,4 @@ - - -# Sparse Identification of Non-linear Dynamical Systems (SINDy) +# Sparse Identification of Non-linear Dynamical Systems (SINDy; Brunton et al.; 2016) In this section, we will study, create, simulate, and visualize a model known as the sparse identification of non-linear dynamical systems (SINDy) [1], implementing it in NGC-Learn and JAX. After going through this demonstration, you will: diff --git a/docs/museum/snn_bfa.md b/docs/museum/snn_bfa.md index 3e62dec1..2a83795b 100644 --- a/docs/museum/snn_bfa.md +++ b/docs/museum/snn_bfa.md @@ -1,8 +1,8 @@ -# Spiking Neural Networks: Learning with Broadcast Feedback Alignment +# Spiking Neural Networks: Learning with Broadcast Feedback Alignment (Samadi et al.; 2017) In this exhibit, we will see how one can train a spiking neural network model using surrogate functions and a credit assignment scheme called broadcast -feedback alignment (BFA) [1]. +feedback alignment (BFA) [1]. This exhibit model effectively reproduces some of the results reported (Samadi et al., 2017) [1]. The model code for this exhibit can be found diff --git a/docs/museum/snn_dc.md b/docs/museum/snn_dc.md index b7a5af9e..4bb3b32a 100755 --- a/docs/museum/snn_dc.md +++ b/docs/museum/snn_dc.md @@ -1,4 +1,4 @@ -# The Diehl and Cook Spiking Neuronal Network +# The Diehl and Cook Spiking Neuronal Network (Diehl & Cook; 2015) In this exhibit, we will see how a spiking neural network model that adapts its synaptic efficacies via spike-timing-dependent plasticity can be created. diff --git a/docs/museum/sparse_coding.md b/docs/museum/sparse_coding.md index 929fd4c0..3802a713 100755 --- a/docs/museum/sparse_coding.md +++ b/docs/museum/sparse_coding.md @@ -1,4 +1,4 @@ -# Sparse Coding and Iterative Thresholding +# Sparse Coding and Iterative Thresholding (Olshausen & Field; 1996) In this exhibit, we create, simulate, and visualize the internally acquired filters/atoms of variants of a sparse coding system based on the classical model proposed by (Olshausen & Field, 1996) [1]. After going through this demonstration, you will: From 9035e4a50e5d631c1b9113ac0aa8b872b6578e33 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Fri, 5 Dec 2025 19:13:52 -0500 Subject: [PATCH 120/121] nudged citation file --- CITATION.cff | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CITATION.cff b/CITATION.cff index e7243de9..ee939cc7 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -6,10 +6,12 @@ authors: orcid: https://orcid.org/0000-0002-2590-1310 - family-names: Gebhardt given-names: William + orcid: https://orcid.org/0009-0008-7456-6556 - family-names: Mali given-names: Ankur + orcid: https://orcid.org/0000-0001-5813-3584 title: "ngc-learn" -version: 1.0.0 +version: 3.0.0 identifiers: - type: doi value: 10.5281/zenodo.6605728 From 061c38932d28c9686a88af39f8fb4971448a9c94 Mon Sep 17 00:00:00 2001 From: Alexander Ororbia Date: Fri, 5 Dec 2025 19:17:06 -0500 Subject: [PATCH 121/121] minor nudge to docs/files to point to v3 --- docs/installation.md | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/installation.md b/docs/installation.md index 75cf5c21..01474aa1 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -5,7 +5,7 @@ Setup: NGC-Learn, in its entirety (including its supporting utility sub-packages), requires that you ensure that you have installed the following base dependencies in your system. Note that this library was developed and tested on Ubuntu 22.04 (with much earlier versions on Ubuntu 18.04/20.04). Specifically, NGC-Learn requires: * Python (>=3.10) -* ngcsimlib (>=2.0.0), (official page) +* ngcsimlib (>=3.0.0), (official page) * NumPy (>=1.22.0) * SciPy (>=1.7.0) * JAX (>= 0.4.28; and jaxlib>=0.4.28) diff --git a/requirements.txt b/requirements.txt index 1911c0ee..1e871c1b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,6 @@ matplotlib>=3.10.1 # patchify # patchify has issues with pip installation jax>=0.4.28 jaxlib>=0.4.28 -ngcsimlib>=2.0.0 +ngcsimlib>=3.0.0 imageio>=2.37.0 pandas>=2.2.3