|
| 1 | +from ngclearn.components.jaxComponent import JaxComponent |
| 2 | +from jax import numpy as jnp, jit |
| 3 | +from ngclearn import compilable, Compartment |
| 4 | + |
| 5 | + |
| 6 | +class LIFSRM(JaxComponent): ## LIF spike-response model (LIF-SRM) |
| 7 | + """ |
| 8 | + The leaky integrate-and-fire (LIF) spike-response model (SRM); this SRM computes |
| 9 | + dynamics of LIF units analytically. |
| 10 | +
|
| 11 | + | --- Cell Input Compartments: --- |
| 12 | + | current_j - electrical current input (takes in external signals) |
| 13 | + | --- Cell State Compartments: --- |
| 14 | + | v - membrane potential/voltage state |
| 15 | + | j_lowpass - internal low-pass-filtered current (maintained by this SRM) |
| 16 | + | key - JAX PRNG key |
| 17 | + | --- Cell Output Compartments: --- |
| 18 | + | s - emitted binary spikes/action potentials |
| 19 | + | t_last_spike - time-of-last-spike (output) |
| 20 | + | last_t_eval - time of last (SRM) evaluation |
| 21 | +
|
| 22 | + | References: |
| 23 | + | Gerstner, W., 1995. Time structure of the activity in neural network |
| 24 | + | models. Physical review E, 51(1), p.738. |
| 25 | +
|
| 26 | + | Pedagogical Reference: |
| 27 | + | http://www.scholarpedia.org/article/Spike-response_model |
| 28 | +
|
| 29 | + Args: |
| 30 | + name: the string name of this cell |
| 31 | +
|
| 32 | + n_units: number of cellular entities (neural population size) |
| 33 | +
|
| 34 | + tau_m: membrane time constant (ms) |
| 35 | +
|
| 36 | + thr: base value for adaptive thresholds that govern short-term |
| 37 | + plasticity (in milliVolts, or mV; default: -52. mV) |
| 38 | +
|
| 39 | + v_rest: reversal potential or membrane resting potential (in mV; default: -65 mV) |
| 40 | +
|
| 41 | + v_reset: membrane reset potential (in mV) -- upon occurrence of a spike, |
| 42 | + a neuronal cell's membrane potential will be set to this value; |
| 43 | + (default: -60 mV) |
| 44 | + """ |
| 45 | + |
| 46 | + def __init__( |
| 47 | + self, |
| 48 | + name, |
| 49 | + n_units, |
| 50 | + tau_m, ## membrane time constant (ms) |
| 51 | + thr=-52., ## threshold (mV) |
| 52 | + v_rest=-65., ## membrane resting potential (mV) |
| 53 | + v_reset=-60., ## membrne reset potential (mV) |
| 54 | + batch_size=1, |
| 55 | + **kwargs |
| 56 | + ): |
| 57 | + super().__init__(name, **kwargs) |
| 58 | + ## LIF-SRM meta-parameters |
| 59 | + self.n_units = n_units |
| 60 | + self.tau_m = tau_m ## membrane time-constant |
| 61 | + self.thr = thr ## threshold (mV) |
| 62 | + self.v_rest = v_rest ## resting potential (mV) |
| 63 | + self.v_reset = v_reset ## reset potential (mV) |
| 64 | + self.batch_size = batch_size |
| 65 | + |
| 66 | + ## LIF-SRM key compartments |
| 67 | + self.current_j = Compartment(jnp.zeros((self.batch_size, self.n_units))) |
| 68 | + self.v = Compartment(jnp.full((self.batch_size, self.n_units), self.v_rest)) |
| 69 | + self.s = Compartment(jnp.zeros((self.batch_size, self.n_units))) |
| 70 | + |
| 71 | + ## analytical SRM state compartments/variables (NOTE: designed to avoid maintaining explicit spike history tensors) |
| 72 | + self.t_last_spike = Compartment(jnp.full((self.batch_size, self.n_units), -1.0)) |
| 73 | + self.j_lowpass = Compartment(jnp.zeros((self.batch_size, self.n_units))) ## integrated input trace |
| 74 | + self.last_t_eval = Compartment(jnp.zeros((self.batch_size, self.n_units))) ## tracks clock index (when evaluated) |
| 75 | + |
| 76 | + @compilable |
| 77 | + def advance_state(self, dt, t): |
| 78 | + ## pass last evaluation clock marker into kernel co-routine (to handle time jumps analytically) |
| 79 | + v_new, updated_j_trace = LIFSRM._evaluate_SRM_filter( ## apply SRM |
| 80 | + t, self.t_last_spike.get(), self.j_lowpass.get(), self.last_t_eval.get(), |
| 81 | + self.current_j.get(), self.tau_m, self.v_rest, self.v_reset, dt |
| 82 | + ) |
| 83 | + |
| 84 | + s_new = (v_new > self.thr) * 1.0 |
| 85 | + updated_t_last = jnp.where(s_new == 1.0, t, self.t_last_spike.get()) |
| 86 | + v_output = v_new * (1.0 - s_new) + s_new * self.v_reset |
| 87 | + |
| 88 | + ## update compartment states |
| 89 | + self.v.set(v_output) |
| 90 | + self.s.set(s_new) |
| 91 | + self.j_lowpass.set(updated_j_trace) |
| 92 | + self.t_last_spike.set(updated_t_last) |
| 93 | + self.last_t_eval.set(jnp.full((self.batch_size, self.n_units), t)) ## mark this time-stamp as "evaluated" |
| 94 | + |
| 95 | + @compilable |
| 96 | + def reset(self): |
| 97 | + self.current_j.set(jnp.zeros((self.batch_size, self.n_units))) |
| 98 | + self.v.set(jnp.full((self.batch_size, self.n_units), self.v_rest)) |
| 99 | + self.s.set(jnp.zeros((self.batch_size, self.n_units))) |
| 100 | + self.t_last_spike.set(jnp.full((self.batch_size, self.n_units), -1.0)) |
| 101 | + self.j_lowpass.set(jnp.zeros((self.batch_size, self.n_units))) |
| 102 | + self.last_t_eval.set(jnp.zeros((self.batch_size, self.n_units))) |
| 103 | + |
| 104 | + @staticmethod |
| 105 | + def _evaluate_SRM_filter( ## kernel co-routine |
| 106 | + t, |
| 107 | + t_last_spike, |
| 108 | + j_lowpass, |
| 109 | + last_t_eval, |
| 110 | + current_j, |
| 111 | + tau_m, |
| 112 | + v_rest, |
| 113 | + v_reset, |
| 114 | + dt |
| 115 | + ): |
| 116 | + ## applies filter-based SRM - tracks integrated voltage contributions |
| 117 | + ## calculate continuous elapsed time since this specific neuron group was last evaluated |
| 118 | + delta_t_eval = t - last_t_eval |
| 119 | + ## analytically decay historical input voltage trace over skipped time gap |
| 120 | + decayed_j_trace = j_lowpass * jnp.exp(-delta_t_eval / tau_m) |
| 121 | + ## add new incoming current pulse scaled to operate akin to single LIFCell Euler step |
| 122 | + new_j_trace = decayed_j_trace + (dt / tau_m) * current_j ## epsilon-kernel |
| 123 | + ## calc analytical spike-post-emission kernel values (self-reset mechanism) |
| 124 | + has_spiked = (t_last_spike >= 0.0) * 1.0 |
| 125 | + s_post = t - t_last_spike ## kappa kernel |
| 126 | + eta_val = has_spiked * (v_reset - v_rest) * jnp.exp(-s_post / tau_m) ## eta kernel |
| 127 | + |
| 128 | + v_total = v_rest + new_j_trace + eta_val ## sum explicit kernel terms |
| 129 | + return v_total, new_j_trace |
| 130 | + |
| 131 | + @staticmethod |
| 132 | + def predict_next_spike( ## next-spike-time prediction co-routine |
| 133 | + t_start, t_last_spike, j_lowpass, tau_m, v_rest, v_reset, thr |
| 134 | + ): |
| 135 | + ## co-routine to predict future next spike, takes advantage of an LIF-SRM's |
| 136 | + ## closed-form setup; specificaly, this function calculates a future (global) |
| 137 | + ## clock time-stamp as to when this LIF model's decaying voltage would |
| 138 | + ## cross a firing threshold (note, this does not require step-wise numerical integration) |
| 139 | + |
| 140 | + ## reconstruct base self-reset kernel magnitude (eta_val) |
| 141 | + ### based on what historical displacement remains from last discharge event |
| 142 | + has_spiked = (t_last_spike >= 0.0) * 1.0 |
| 143 | + s_post = t_start - t_last_spike |
| 144 | + eta_val = has_spiked * (v_reset - v_rest) * jnp.exp(-s_post / tau_m) |
| 145 | + ## extract combined driving force variable |
| 146 | + total_driving_trace = j_lowpass + eta_val |
| 147 | + |
| 148 | + ## define static threshold distance displacement metric |
| 149 | + thr_distance = thr - v_rest |
| 150 | + |
| 151 | + ## calc closed-form logarithmic isolation calculation for remaining segment time |
| 152 | + can_reach_thr = total_driving_trace > thr_distance |
| 153 | + ## for cases where a neuronal unit does not have enough charge to cross threshold: |
| 154 | + safe_ratio = jnp.where( ## handles division-by-zero / negative log errors |
| 155 | + can_reach_thr, total_driving_trace / jnp.maximum(thr_distance, 1e-5), 1.0 |
| 156 | + ) |
| 157 | + s_remaining = tau_m * jnp.log(safe_ratio) |
| 158 | + ## absorb into current evaluation timestamp tracking variable |
| 159 | + predicted_t = t_start + s_remaining |
| 160 | + ## safety check: if total driving force is insufficient to cross, |
| 161 | + ## then flag output as "un-triggered" (i.e.,-1.0) |
| 162 | + return jnp.where(can_reach_thr, predicted_t, -1.0) # predicted spike time(s) |
| 163 | + |
0 commit comments