1- from ngclearn import resolver , Component , Compartment
1+ # %%
2+
23from ngclearn .components .jaxComponent import JaxComponent
34from jax import numpy as jnp , jit
45from ngclearn .utils import tensorstats
56from ngclearn .utils .model_utils import sigmoid , d_sigmoid
6- from ngcsimlib .compilers .process import transition
7+
8+ from ngcsimlib .logger import info
9+ from ngcsimlib .compartment import Compartment
10+ from ngcsimlib .parser import compilable
711
812class BernoulliErrorCell (JaxComponent ): ## Rate-coded/real-valued error unit/cell
913 """
@@ -59,14 +63,20 @@ def __init__(self, name, n_units, batch_size=1, input_logits=False, shape=None,
5963 self .modulator = Compartment (restVals + 1.0 ) # to be set/consumed
6064 self .mask = Compartment (restVals + 1.0 )
6165
62- @transition (output_compartments = ["dp" , "dtarget" , "L" , "mask" ])
63- @staticmethod
64- def advance_state (dt , p , target , modulator , mask , input_logits ): ## compute Bernoulli error cell output
66+ # @transition(output_compartments=["dp", "dtarget", "L", "mask"])
67+ @compilable
68+ def advance_state (self , dt ): ## compute Bernoulli error cell output
69+ # Get the variables
70+ p = self .p .get ()
71+ target = self .target .get ()
72+ modulator = self .modulator .get ()
73+ mask = self .mask .get ()
74+
6575 # Moves Bernoulli error cell dynamics one step forward. Specifically, this routine emulates the error unit
6676 # behavior of the local cost functional
6777 eps = 0.0001
6878 _p = p
69- if input_logits : ## convert from "logits" to probs via sigmoidal link function
79+ if self . input_logits : ## convert from "logits" to probs via sigmoidal link function
7080 _p = sigmoid (p )
7181 _p = jnp .clip (_p , eps , 1. - eps ) ## post-process to prevent div by 0
7282 x = target
@@ -78,7 +88,7 @@ def advance_state(dt, p, target, modulator, mask, input_logits): ## compute Bern
7888 log_p = jnp .log (_p ) ## ln(p)
7989 log_one_min_p = jnp .log (one_min_p ) ## ln(1 - p)
8090 L = jnp .sum (log_p * x + log_one_min_p * one_min_x ) ## Bern LL
81- if input_logits :
91+ if self . input_logits :
8292 dL_dp = x - _p ## d(Bern LL)/dp where _p = sigmoid(p)
8393 else :
8494 dL_dp = x / (_p ) - one_min_x / one_min_p ## d(Bern LL)/dp
@@ -89,14 +99,21 @@ def advance_state(dt, p, target, modulator, mask, input_logits): ## compute Bern
8999 dp = dp * modulator * mask ## NOTE: how does mask apply to a multivariate Bernoulli?
90100 dtarget = dL_dx * modulator * mask
91101 mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
92- return dp , dtarget , jnp .squeeze (L ), mask
93-
94- @transition (output_compartments = ["dp" , "dtarget" , "target" , "p" , "modulator" , "L" , "mask" ])
95- @staticmethod
96- def reset (batch_size , shape ): ## reset core components/statistics
97- _shape = (batch_size , shape [0 ])
98- if len (shape ) > 1 :
99- _shape = (batch_size , shape [0 ], shape [1 ], shape [2 ])
102+
103+ # Set state
104+ # dp, dtarget, jnp.squeeze(L), mask
105+ self .dp .set (dp )
106+ self .dtarget .set (dtarget )
107+ self .L .set (jnp .squeeze (L ))
108+ self .mask .set (mask )
109+
110+
111+ # @transition(output_compartments=["dp", "dtarget", "target", "p", "modulator", "L", "mask"])
112+ @compilable
113+ def reset (self , batch_size ): ## reset core components/statistics
114+ _shape = (batch_size , self .shape [0 ])
115+ if len (self .shape ) > 1 :
116+ _shape = (batch_size , self .shape [0 ], self .shape [1 ], self .shape [2 ])
100117 restVals = jnp .zeros (_shape ) ## "rest"/reset values
101118 dp = restVals
102119 dtarget = restVals
@@ -105,7 +122,16 @@ def reset(batch_size, shape): ## reset core components/statistics
105122 modulator = restVals + 1. ## reset modulator signal
106123 L = 0. #jnp.zeros((1, 1)) ## rest loss
107124 mask = jnp .ones (_shape ) ## reset mask
108- return dp , dtarget , target , p , modulator , L , mask
125+
126+ # Set compartment
127+ self .dp .set (dp )
128+ self .dtarget .set (dtarget )
129+ self .target .set (target )
130+ self .p .set (p )
131+ self .modulator .set (modulator )
132+ self .L .set (L )
133+ self .mask .set (mask )
134+
109135
110136 @classmethod
111137 def help (cls ): ## component help function
@@ -136,11 +162,11 @@ def help(cls): ## component help function
136162 return info
137163
138164 def __repr__ (self ):
139- comps = [varname for varname in dir (self ) if Compartment . is_compartment (getattr (self , varname ))]
165+ comps = [varname for varname in dir (self ) if isinstance (getattr (self , varname ), Compartment )]
140166 maxlen = max (len (c ) for c in comps ) + 5
141167 lines = f"[{ self .__class__ .__name__ } ] PATH: { self .name } \n "
142168 for c in comps :
143- stats = tensorstats (getattr (self , c ).value )
169+ stats = tensorstats (getattr (self , c ).get () )
144170 if stats is not None :
145171 line = [f"{ k } : { v } " for k , v in stats .items ()]
146172 line = ", " .join (line )
0 commit comments