77
88class GaussianErrorCell (JaxComponent ): ## Rate-coded/real-valued error unit/cell
99 """
10- A simple (non-spiking) Gaussian error cell - this is a fixed-point solution
11- of a mismatch signal .
10+ A simple (non-spiking) Gaussian error cell - this is a fixed-point calculation of a mismatch signal. Specifically,
11+ this error cell offers a configurable variance and calculates its local free energy (Gaussian log likelihood) .
1212
1313 | --- Cell Input Compartments: ---
1414 | mu - predicted value (takes in external signals)
15- | Sigma - predicted covariance (takes in external signals)
15+ | Sigma - predicted covariance (takes in external signals), or, if just a scalar, then it's sigma^2
1616 | target - desired/goal value (takes in external signals)
1717 | modulator - modulation signal (takes in optional external signals)
1818 | mask - binary/gating mask to apply to error neuron calculations
@@ -31,7 +31,8 @@ class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
3131
3232 sigma: initial/fixed value for prediction covariance matrix (𝚺) in multivariate gaussian distribution;
3333 Note that if the compartment `Sigma` is never used, then this cell assumes that the covariance collapses
34- to a constant/fixed `sigma`
34+ to a constant/fixed `sigma^2`, i.e., Sigma = sigma^2, where `sigma` is a scalar standard deviation argument
35+ (Default: 1)
3536 """
3637 def __init__ (self , name , n_units , batch_size = 1 , sigma = 1. , shape = None , ** kwargs ):
3738 super ().__init__ (name , ** kwargs )
@@ -67,13 +68,14 @@ def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs):
6768 self .mask = Compartment (restVals + 1.0 )
6869
6970 @staticmethod
70- def eval_log_density (target , mu , Sigma ):
71+ def _eval_log_density (target , mu , Sigma ): ## Gaussian log likelihood
72+ ## NOTE: ln(p) = -(x - mu)^2 * 1/(2 Sigma), where Sigma might be sigma^2 or covariance matrix
7173 _dmu = (target - mu )
7274 log_density = - jnp .sum (jnp .square (_dmu )) * (0.5 / Sigma )
73- return log_density
75+ return log_density , _dmu ## return density and raw delta
7476
7577 @compilable
76- def advance_state (self , dt ): ## compute Gaussian error cell output
78+ def advance_state (self , dt ): ## compute Gaussian error cell output (fixed-point)
7779 # Get the variables
7880 mu = self .mu .get ()
7981 target = self .target .get ()
@@ -87,12 +89,12 @@ def advance_state(self, dt): ## compute Gaussian error cell output
8789 # but should support full log likelihood of the multivariate Gaussian with covariance of different types
8890 # TODO: could introduce a variant of GaussianErrorCell that moves according to an ODE
8991 # (using integration time constant dt)
90- _dmu = ( target - mu ) # e (error unit)
91- dmu = _dmu / Sigma
92- dtarget = - dmu # reverse of e
93- dSigma = Sigma * 0 + 1. # no derivative is calculated at this time for sigma
94- L = - jnp . sum ( jnp . square ( _dmu )) * ( 0.5 / Sigma )
95- #L = GaussianErrorCell.eval_log_density(target, mu, Sigma)
92+
93+ L , _dmu = GaussianErrorCell . _eval_log_density ( target , mu , Sigma ) # L = -jnp.sum(jnp.square( _dmu)) * (0.5 / Sigma)
94+ ## _dmu => "raw" e (error unit/mis-match) # _dmu = (target - mu)
95+ dmu = _dmu / Sigma ## obtain precision-scaled e: (target - mu)/Sigma
96+ dtarget = - dmu # reverse of e ## -(target - mu)/Sigma
97+ dSigma = Sigma * 0 + 1. # no derivative is calculated at this time for Sigma
9698
9799 dmu = dmu * modulator * mask ## not sure how mask will apply to a full covariance...
98100 dtarget = dtarget * modulator * mask
@@ -127,16 +129,18 @@ def batched_reset(self, batch_size):
127129 self .dmu .set (dmu )
128130 self .dtarget .set (dtarget )
129131 self .dSigma .set (dSigma )
130- self .target .set (target )
131- self .mu .set (mu )
132+ if not self .target .targeted :
133+ self .target .set (target )
134+ if not self .mu .targeted :
135+ self .mu .set (mu )
132136 self .modulator .set (modulator )
133137 self .L .set (L )
134138 self .mask .set (mask )
135139
136140 @classmethod
137141 def help (cls ): ## component help function
138142 properties = {
139- "cell_type" : "GaussianErrorcell - computes mismatch/error signals at "
143+ "cell_type" : "GaussianErrorCell - computes mismatch/error signals at "
140144 "each time step t (between a `target` and a prediction `mu`)"
141145 }
142146 compartment_props = {
@@ -147,7 +151,7 @@ def help(cls): ## component help function
147151 "modulator" : "External input modulatory/scaling signal(s)" ,
148152 "mask" : "External binary/gating mask to apply to signals" },
149153 "outputs" :
150- {"L" : "Local loss value computed/ embodied by this error-cell" ,
154+ {"L" : "Local loss / free-energy value embodied by this error-cell" ,
151155 "dmu" : "first derivative of loss w.r.t. prediction value(s)" ,
152156 "dSigma" : "first derivative of loss w.r.t. variance/covariance value(s)" ,
153157 "dtarget" : "first derivative of loss w.r.t. target value(s)" },
0 commit comments