Skip to content

Commit 6b78909

Browse files
author
Alexander Ororbia
committed
tweaked/cleaned-up gaussian-error-cell
1 parent ef8627a commit 6b78909

1 file changed

Lines changed: 21 additions & 17 deletions

File tree

ngclearn/components/neurons/graded/gaussianErrorCell.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77

88
class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
99
"""
10-
A simple (non-spiking) Gaussian error cell - this is a fixed-point solution
11-
of a mismatch signal.
10+
A simple (non-spiking) Gaussian error cell - this is a fixed-point calculation of a mismatch signal. Specifically,
11+
this error cell offers a configurable variance and calculates its local free energy (Gaussian log likelihood).
1212
1313
| --- Cell Input Compartments: ---
1414
| mu - predicted value (takes in external signals)
15-
| Sigma - predicted covariance (takes in external signals)
15+
| Sigma - predicted covariance (takes in external signals), or, if just a scalar, then it's sigma^2
1616
| target - desired/goal value (takes in external signals)
1717
| modulator - modulation signal (takes in optional external signals)
1818
| mask - binary/gating mask to apply to error neuron calculations
@@ -31,7 +31,8 @@ class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
3131
3232
sigma: initial/fixed value for prediction covariance matrix (𝚺) in multivariate gaussian distribution;
3333
Note that if the compartment `Sigma` is never used, then this cell assumes that the covariance collapses
34-
to a constant/fixed `sigma`
34+
to a constant/fixed `sigma^2`, i.e., Sigma = sigma^2, where `sigma` is a scalar standard deviation argument
35+
(Default: 1)
3536
"""
3637
def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs):
3738
super().__init__(name, **kwargs)
@@ -67,13 +68,14 @@ def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs):
6768
self.mask = Compartment(restVals + 1.0)
6869

6970
@staticmethod
70-
def eval_log_density(target, mu, Sigma):
71+
def _eval_log_density(target, mu, Sigma): ## Gaussian log likelihood
72+
## NOTE: ln(p) = -(x - mu)^2 * 1/(2 Sigma), where Sigma might be sigma^2 or covariance matrix
7173
_dmu = (target - mu)
7274
log_density = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma)
73-
return log_density
75+
return log_density, _dmu ## return density and raw delta
7476

7577
@compilable
76-
def advance_state(self, dt): ## compute Gaussian error cell output
78+
def advance_state(self, dt): ## compute Gaussian error cell output (fixed-point)
7779
# Get the variables
7880
mu = self.mu.get()
7981
target = self.target.get()
@@ -87,12 +89,12 @@ def advance_state(self, dt): ## compute Gaussian error cell output
8789
# but should support full log likelihood of the multivariate Gaussian with covariance of different types
8890
# TODO: could introduce a variant of GaussianErrorCell that moves according to an ODE
8991
# (using integration time constant dt)
90-
_dmu = (target - mu) # e (error unit)
91-
dmu = _dmu / Sigma
92-
dtarget = -dmu # reverse of e
93-
dSigma = Sigma * 0 + 1. # no derivative is calculated at this time for sigma
94-
L = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma)
95-
#L = GaussianErrorCell.eval_log_density(target, mu, Sigma)
92+
93+
L, _dmu = GaussianErrorCell._eval_log_density(target, mu, Sigma) # L = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma)
94+
## _dmu => "raw" e (error unit/mis-match) # _dmu = (target - mu)
95+
dmu = _dmu / Sigma ## obtain precision-scaled e: (target - mu)/Sigma
96+
dtarget = -dmu # reverse of e ## -(target - mu)/Sigma
97+
dSigma = Sigma * 0 + 1. # no derivative is calculated at this time for Sigma
9698

9799
dmu = dmu * modulator * mask ## not sure how mask will apply to a full covariance...
98100
dtarget = dtarget * modulator * mask
@@ -127,16 +129,18 @@ def batched_reset(self, batch_size):
127129
self.dmu.set(dmu)
128130
self.dtarget.set(dtarget)
129131
self.dSigma.set(dSigma)
130-
self.target.set(target)
131-
self.mu.set(mu)
132+
if not self.target.targeted:
133+
self.target.set(target)
134+
if not self.mu.targeted:
135+
self.mu.set(mu)
132136
self.modulator.set(modulator)
133137
self.L.set(L)
134138
self.mask.set(mask)
135139

136140
@classmethod
137141
def help(cls): ## component help function
138142
properties = {
139-
"cell_type": "GaussianErrorcell - computes mismatch/error signals at "
143+
"cell_type": "GaussianErrorCell - computes mismatch/error signals at "
140144
"each time step t (between a `target` and a prediction `mu`)"
141145
}
142146
compartment_props = {
@@ -147,7 +151,7 @@ def help(cls): ## component help function
147151
"modulator": "External input modulatory/scaling signal(s)",
148152
"mask": "External binary/gating mask to apply to signals"},
149153
"outputs":
150-
{"L": "Local loss value computed/embodied by this error-cell",
154+
{"L": "Local loss / free-energy value embodied by this error-cell",
151155
"dmu": "first derivative of loss w.r.t. prediction value(s)",
152156
"dSigma": "first derivative of loss w.r.t. variance/covariance value(s)",
153157
"dtarget": "first derivative of loss w.r.t. target value(s)"},

0 commit comments

Comments
 (0)