Skip to content

Commit 867ca15

Browse files
author
Alexander Ororbia
committed
clean-up and integration of lif/raf-SRMs
1 parent 8bb2dc7 commit 867ca15

6 files changed

Lines changed: 500 additions & 77 deletions

File tree

ngclearn/components/neurons/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,7 @@
1616
from .spiking.izhikevichCell import IzhikevichCell
1717
from .spiking.hodgkinHuxleyCell import HodgkinHuxleyCell
1818
from .spiking.RAFCell import RAFCell
19+
## point to spike-response models (SRMs)
20+
from .spiking.LIFSRM import LIFSRM
21+
from .spiking.RAFSRM import RAFSRM
1922

ngclearn/components/neurons/spiking/LIFCell.py

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,9 @@
55
from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
66
triangular_estimator,
77
straight_through_estimator)
8-
98
from ngclearn import compilable #from ngcsimlib.parser import compilable
109
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
1110

12-
def _dfv(t, v, params): ## voltage dynamics wrapper
13-
j, rfr, tau_m, refract_T, v_rest, g_L = params
14-
mask = (rfr >= refract_T) * 1. # get refractory mask
15-
## update voltage / membrane potential
16-
dv_dt = (v_rest - v) * g_L + (j * mask)
17-
dv_dt = dv_dt * (1. / tau_m)
18-
return dv_dt
19-
20-
21-
#@partial(jit, static_argnums=[3, 4])
22-
def _update_theta(dt, v_theta, s, tau_theta, theta_plus: Array=0.05):
23-
### Runs homeostatic threshold update dynamics one step (via Euler integration).
24-
#theta_decay = 0.9999999 #0.999999762 #jnp.exp(-dt/1e7)
25-
#theta_plus = 0.05
26-
#_V_theta = V_theta * theta_decay + S * theta_plus
27-
theta_decay = jnp.exp(-dt/tau_theta)
28-
_v_theta = v_theta * theta_decay + s * theta_plus
29-
#_V_theta = V_theta + -V_theta * (dt/tau_theta) + S * alpha
30-
return _v_theta
31-
3211

3312
class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
3413
"""
@@ -108,9 +87,24 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
10887
""" ## batch_size arg?
10988

11089
def __init__(
111-
self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., conduct_leak=1., tau_theta=1e7,
112-
theta_plus=0.05, refract_time=5., one_spike=False, integration_type="euler", surrogate_type="straight_through",
113-
v_min=None, max_one_spike=False, key=None
90+
self,
91+
name,
92+
n_units,
93+
tau_m,
94+
resist_m=1.,
95+
thr=-52.,
96+
v_rest=-65.,
97+
v_reset=-60.,
98+
conduct_leak=1.,
99+
tau_theta=1e7,
100+
theta_plus=0.05,
101+
refract_time=5.,
102+
one_spike=False,
103+
integration_type="euler",
104+
surrogate_type="straight_through",
105+
v_min=None,
106+
max_one_spike=False,
107+
key=None
114108
):
115109
super().__init__(name, key)
116110

@@ -162,29 +156,50 @@ def __init__(
162156
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") ## time-of-last-spike
163157
# self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value")
164158

159+
@staticmethod
160+
def _dfv(t, v, params): ## voltage dynamics wrapper
161+
j, rfr, tau_m, refract_T, v_rest, g_L = params
162+
mask = (rfr >= refract_T) * 1. ## get refractory mask
163+
## update voltage / membrane potential
164+
dv_dt = (v_rest - v) * g_L + (j * mask)
165+
dv_dt = dv_dt * (1. / tau_m)
166+
return dv_dt
167+
168+
#@partial(jit, static_argnums=[3, 4])
169+
@staticmethod
170+
def _update_theta(dt, v_theta, s, tau_theta, theta_plus: Array=0.05):
171+
### Runs homeostatic threshold update dynamics one step (via Euler integration).
172+
#theta_decay = 0.9999999 #0.999999762 #jnp.exp(-dt/1e7)
173+
#theta_plus = 0.05
174+
#_V_theta = V_theta * theta_decay + S * theta_plus
175+
theta_decay = jnp.exp(-dt/tau_theta)
176+
_v_theta = v_theta * theta_decay + s * theta_plus
177+
#_V_theta = V_theta + -V_theta * (dt/tau_theta) + S * alpha
178+
return _v_theta
179+
165180
@compilable
166181
def advance_state(self, dt, t):
167-
j = self.j.get() * self.resist_m
182+
j = self.j.get() * self.resist_m ## get current electrical current input
168183

169184
_v_thr = self.thr_theta.get() + self.thr ## calc present voltage threshold
170185

186+
## perform step of ODE integration
171187
v_params = (j, self.rfr.get(), self.tau_m.get(), self.refract_T, self.v_rest, self.g_L)
172-
173-
if self.intgFlag == 1:
174-
_, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params)
175-
else:
176-
_, _v = step_euler(0., self.v.get(), _dfv, dt, v_params)
177-
188+
if self.intgFlag == 1: ## midpoint method
189+
_, _v = step_rk2(0., self.v.get(), LIFCell._dfv, dt, v_params)
190+
else: ## take forward Euler step
191+
_, _v = step_euler(0., self.v.get(), LIFCell._dfv, dt, v_params)
192+
## calculate spike emission and post-spike voltage-reset mechanism
178193
s = (_v > _v_thr) * 1.
179194
_rfr = (self.rfr.get() + dt) * (1. - s)
180195
_v = _v * (1. - s) + s * self.v_reset
181196

182-
raw_s = s
197+
raw_s = s ## "raw" spikes
183198

184199
if self.one_spike and not self.max_one_spike:
185200
key, skey = random.split(self.key.get(), 2)
186-
187-
m_switch = (jnp.sum(s) > 0.).astype(jnp.float32) ## TODO: not batch-able
201+
#m_switch = (jnp.sum(s) > 0.).astype(jnp.float32) ## TODO: this line is not batch-able
202+
m_switch = (jnp.sum(s, axis=1, keepdims=True) > 0.).astype(jnp.float32)
188203
rS = s * random.uniform(skey, s.shape)
189204
rS = nn.one_hot(jnp.argmax(rS, axis=1), num_classes=s.shape[1], dtype=jnp.float32)
190205
s = s * (1. - m_switch) + rS * m_switch
@@ -196,7 +211,7 @@ def advance_state(self, dt, t):
196211

197212
if self.tau_theta > 0.:
198213
## run one integration step for threshold dynamics
199-
thr_theta = _update_theta(dt, self.thr_theta.get(), raw_s, self.tau_theta, self.theta_plus) #.get())
214+
thr_theta = LIFCell._update_theta(dt, self.thr_theta.get(), raw_s, self.tau_theta, self.theta_plus) #.get())
200215
self.thr_theta.set(thr_theta)
201216

202217
## update time-of-last spike variable(s)
@@ -205,7 +220,7 @@ def advance_state(self, dt, t):
205220
if self.v_min is not None: ## ensures voltage never < v_rest
206221
_v = jnp.maximum(_v, self.v_min)
207222

208-
223+
## update internal compartment values
209224
self.v.set(_v)
210225
self.s.set(s)
211226
self.s_raw.set(raw_s)
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
from ngclearn.components.jaxComponent import JaxComponent
2+
from jax import numpy as jnp, jit
3+
from ngclearn import compilable, Compartment
4+
5+
6+
class LIFSRM(JaxComponent): ## LIF spike-response model (LIF-SRM)
7+
"""
8+
The leaky integrate-and-fire (LIF) spike-response model (SRM); this SRM computes
9+
dynamics of LIF units analytically.
10+
11+
| --- Cell Input Compartments: ---
12+
| current_j - electrical current input (takes in external signals)
13+
| --- Cell State Compartments: ---
14+
| v - membrane potential/voltage state
15+
| j_lowpass - internal low-pass-filtered current (maintained by this SRM)
16+
| key - JAX PRNG key
17+
| --- Cell Output Compartments: ---
18+
| s - emitted binary spikes/action potentials
19+
| t_last_spike - time-of-last-spike (output)
20+
| last_t_eval - time of last (SRM) evaluation
21+
22+
| References:
23+
| Gerstner, W., 1995. Time structure of the activity in neural network
24+
| models. Physical review E, 51(1), p.738.
25+
26+
| Pedagogical Reference:
27+
| http://www.scholarpedia.org/article/Spike-response_model
28+
29+
Args:
30+
name: the string name of this cell
31+
32+
n_units: number of cellular entities (neural population size)
33+
34+
tau_m: membrane time constant (ms)
35+
36+
thr: base value for adaptive thresholds that govern short-term
37+
plasticity (in milliVolts, or mV; default: -52. mV)
38+
39+
v_rest: reversal potential or membrane resting potential (in mV; default: -65 mV)
40+
41+
v_reset: membrane reset potential (in mV) -- upon occurrence of a spike,
42+
a neuronal cell's membrane potential will be set to this value;
43+
(default: -60 mV)
44+
"""
45+
46+
def __init__(
47+
self,
48+
name,
49+
n_units,
50+
tau_m, ## membrane time constant (ms)
51+
thr=-52., ## threshold (mV)
52+
v_rest=-65., ## membrane resting potential (mV)
53+
v_reset=-60., ## membrne reset potential (mV)
54+
batch_size=1,
55+
**kwargs
56+
):
57+
super().__init__(name, **kwargs)
58+
## LIF-SRM meta-parameters
59+
self.n_units = n_units
60+
self.tau_m = tau_m ## membrane time-constant
61+
self.thr = thr ## threshold (mV)
62+
self.v_rest = v_rest ## resting potential (mV)
63+
self.v_reset = v_reset ## reset potential (mV)
64+
self.batch_size = batch_size
65+
66+
## LIF-SRM key compartments
67+
self.current_j = Compartment(jnp.zeros((self.batch_size, self.n_units)))
68+
self.v = Compartment(jnp.full((self.batch_size, self.n_units), self.v_rest))
69+
self.s = Compartment(jnp.zeros((self.batch_size, self.n_units)))
70+
71+
## analytical SRM state compartments/variables (NOTE: designed to avoid maintaining explicit spike history tensors)
72+
self.t_last_spike = Compartment(jnp.full((self.batch_size, self.n_units), -1.0))
73+
self.j_lowpass = Compartment(jnp.zeros((self.batch_size, self.n_units))) ## integrated input trace
74+
self.last_t_eval = Compartment(jnp.zeros((self.batch_size, self.n_units))) ## tracks clock index (when evaluated)
75+
76+
@compilable
77+
def advance_state(self, dt, t):
78+
## pass last evaluation clock marker into kernel co-routine (to handle time jumps analytically)
79+
v_new, updated_j_trace = LIFSRM._evaluate_SRM_filter( ## apply SRM
80+
t, self.t_last_spike.get(), self.j_lowpass.get(), self.last_t_eval.get(),
81+
self.current_j.get(), self.tau_m, self.v_rest, self.v_reset, dt
82+
)
83+
84+
s_new = (v_new > self.thr) * 1.0
85+
updated_t_last = jnp.where(s_new == 1.0, t, self.t_last_spike.get())
86+
v_output = v_new * (1.0 - s_new) + s_new * self.v_reset
87+
88+
## update compartment states
89+
self.v.set(v_output)
90+
self.s.set(s_new)
91+
self.j_lowpass.set(updated_j_trace)
92+
self.t_last_spike.set(updated_t_last)
93+
self.last_t_eval.set(jnp.full((self.batch_size, self.n_units), t)) ## mark this time-stamp as "evaluated"
94+
95+
@compilable
96+
def reset(self):
97+
self.current_j.set(jnp.zeros((self.batch_size, self.n_units)))
98+
self.v.set(jnp.full((self.batch_size, self.n_units), self.v_rest))
99+
self.s.set(jnp.zeros((self.batch_size, self.n_units)))
100+
self.t_last_spike.set(jnp.full((self.batch_size, self.n_units), -1.0))
101+
self.j_lowpass.set(jnp.zeros((self.batch_size, self.n_units)))
102+
self.last_t_eval.set(jnp.zeros((self.batch_size, self.n_units)))
103+
104+
@staticmethod
105+
def _evaluate_SRM_filter( ## kernel co-routine
106+
t,
107+
t_last_spike,
108+
j_lowpass,
109+
last_t_eval,
110+
current_j,
111+
tau_m,
112+
v_rest,
113+
v_reset,
114+
dt
115+
):
116+
## applies filter-based SRM - tracks integrated voltage contributions
117+
## calculate continuous elapsed time since this specific neuron group was last evaluated
118+
delta_t_eval = t - last_t_eval
119+
## analytically decay historical input voltage trace over skipped time gap
120+
decayed_j_trace = j_lowpass * jnp.exp(-delta_t_eval / tau_m)
121+
## add new incoming current pulse scaled to operate akin to single LIFCell Euler step
122+
new_j_trace = decayed_j_trace + (dt / tau_m) * current_j ## epsilon-kernel
123+
## calc analytical spike-post-emission kernel values (self-reset mechanism)
124+
has_spiked = (t_last_spike >= 0.0) * 1.0
125+
s_post = t - t_last_spike ## kappa kernel
126+
eta_val = has_spiked * (v_reset - v_rest) * jnp.exp(-s_post / tau_m) ## eta kernel
127+
128+
v_total = v_rest + new_j_trace + eta_val ## sum explicit kernel terms
129+
return v_total, new_j_trace
130+
131+
@staticmethod
132+
def predict_next_spike( ## next-spike-time prediction co-routine
133+
t_start, t_last_spike, j_lowpass, tau_m, v_rest, v_reset, thr
134+
):
135+
## co-routine to predict future next spike, takes advantage of an LIF-SRM's
136+
## closed-form setup; specificaly, this function calculates a future (global)
137+
## clock time-stamp as to when this LIF model's decaying voltage would
138+
## cross a firing threshold (note, this does not require step-wise numerical integration)
139+
140+
## reconstruct base self-reset kernel magnitude (eta_val)
141+
### based on what historical displacement remains from last discharge event
142+
has_spiked = (t_last_spike >= 0.0) * 1.0
143+
s_post = t_start - t_last_spike
144+
eta_val = has_spiked * (v_reset - v_rest) * jnp.exp(-s_post / tau_m)
145+
## extract combined driving force variable
146+
total_driving_trace = j_lowpass + eta_val
147+
148+
## define static threshold distance displacement metric
149+
thr_distance = thr - v_rest
150+
151+
## calc closed-form logarithmic isolation calculation for remaining segment time
152+
can_reach_thr = total_driving_trace > thr_distance
153+
## for cases where a neuronal unit does not have enough charge to cross threshold:
154+
safe_ratio = jnp.where( ## handles division-by-zero / negative log errors
155+
can_reach_thr, total_driving_trace / jnp.maximum(thr_distance, 1e-5), 1.0
156+
)
157+
s_remaining = tau_m * jnp.log(safe_ratio)
158+
## absorb into current evaluation timestamp tracking variable
159+
predicted_t = t_start + s_remaining
160+
## safety check: if total driving force is insufficient to cross,
161+
## then flag output as "un-triggered" (i.e.,-1.0)
162+
return jnp.where(can_reach_thr, predicted_t, -1.0) # predicted spike time(s)
163+

0 commit comments

Comments
 (0)