1- from jax import random , jit
21import numpy as np
32from ngclearn .utils import weight_distribution as dist
4- from ngclearn import Context , numpy as jnp
5- from ngclearn .components import (RateCell ,
6- HebbianSynapse ,
7- GaussianErrorCell ,
8- StaticSynapse )
9- from ngclearn .utils .model_utils import scanner
10-
3+ from ngclearn import numpy as jnp
114
5+ from jax import numpy as jnp , random , jit
6+ from ngclearn import Context , MethodProcess
7+ from ngclearn .components .synapses .hebbian .hebbianSynapse import HebbianSynapse
8+ from ngclearn .components .neurons .graded .gaussianErrorCell import GaussianErrorCell
9+ from ngcsimlib .global_state import stateManager
1210
1311class Iterative_Ridge ():
1412 """
1513 A neural circuit implementation of the iterative Ridge (L2) algorithm
16- using Hebbian learning update rule.
14+ using a Hebbian learning update rule.
1715
18- The circuit implements sparse regression through Hebbian synapses with L2 regularization.
16+ This circuit implements sparse regression through Hebbian synapses with L2 regularization.
1917
2018 The specific differential equation that characterizes this model is adding lmbda * W
2119 to the dW (the gradient of loss/energy function):
@@ -75,54 +73,43 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l
7573 feature_dim = dict_dim
7674
7775 with Context (self .name ) as self .circuit :
78- self .W = HebbianSynapse ("W" , shape = (feature_dim , sys_dim ), eta = self .lr ,
79- sign_value = - 1 , weight_init = dist .constant (weight_fill ),
80- prior = ('ridge' , ridge_lmbda ), w_bound = 0. ,
81- optim_type = optim_type , key = subkeys [0 ])
76+ self .W = HebbianSynapse (
77+ "W" , shape = (feature_dim , sys_dim ), eta = self .lr , sign_value = - 1 ,
78+ weight_init = dist .constant (weight_fill ), prior = ('ridge' , ridge_lmbda ), w_bound = 0. ,
79+ optim_type = optim_type , key = subkeys [0 ]
80+ )
8281 self .err = GaussianErrorCell ("err" , n_units = sys_dim )
8382
8483 # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
8584 self .W .batch_size = batch_size
8685 self .err .batch_size = batch_size
8786 # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
88- self .err . mu << self .W . outputs
89- self .W . post << self .err . dmu
87+ self .W . outputs >> self .err . mu
88+ self .err . dmu >> self .W . post
9089 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
91- advance_cmd , advance_args = self .circuit .compile_by_key (self .W , ## execute prediction synapses
92- self .err , ## finally, execute error neurons
93- compile_key = "advance_state" )
94- evolve_cmd , evolve_args = self .circuit .compile_by_key (self .W , compile_key = "evolve" )
95- reset_cmd , reset_args = self .circuit .compile_by_key (self .err , self .W , compile_key = "reset" )
96- # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
97- self .dynamic ()
98-
99- def dynamic (self ): ## create dynamic commands forself.circuit
100- W , err = self .circuit .get_components ("W" , "err" )
101- self .self = W
102- self .err = err
103-
104- @Context .dynamicCommand
105- def batch_set (batch_size ):
106- self .W .batch_size = batch_size
107- self .err .batch_size = batch_size
108-
109- @Context .dynamicCommand
110- def clamps (y_scaled , X ):
111- self .W .inputs .set (X )
112- self .W .pre .set (X )
113- self .err .target .set (y_scaled )
114-
115- self .circuit .wrap_and_add_command (jit (self .circuit .evolve ), name = "evolve" )
116- self .circuit .wrap_and_add_command (jit (self .circuit .advance_state ), name = "advance" )
117- self .circuit .wrap_and_add_command (jit (self .circuit .reset ), name = "reset" )
118-
119-
120- @scanner
121- def _process (compartment_values , args ):
122- _t , _dt = args
123- compartment_values = self .circuit .advance_state (compartment_values , t = _t , dt = _dt )
124- return compartment_values , compartment_values [self .W .weights .path ]
12590
91+ advance = (MethodProcess (name = "advance_state" )
92+ >> self .W .advance_state
93+ >> self .err .advance_state )
94+ self .advance = advance
95+
96+ evolve = (MethodProcess (name = "evolve" )
97+ >> self .W .evolve )
98+ self .evolve = evolve
99+
100+ reset = (MethodProcess (name = "reset" )
101+ >> self .err .reset
102+ >> self .W .reset )
103+ self .reset = reset
104+
105+ def batch_set (self , batch_size ):
106+ self .W .batch_size = batch_size
107+ self .err .batch_size = batch_size
108+
109+ def clamp (self , y_scaled , X ):
110+ self .W .inputs .set (X )
111+ self .W .pre .set (X )
112+ self .err .target .set (y_scaled )
126113
127114 def thresholding (self , scale = 2 ):
128115 coef_old = self .coef_ #self.W.weights.value
@@ -135,21 +122,15 @@ def thresholding(self, scale=2):
135122
136123
137124 def fit (self , y , X ):
138- self .circuit . reset ()
139- self .circuit . clamps (y_scaled = y , X = X )
125+ self .reset . run ()
126+ self .clamp (y_scaled = y , X = X )
140127
141128 for i in range (self .epochs ):
142- self .circuit ._process (jnp .array ([[self .dt * i , self .dt ] for i in range (self .T )]))
143- self .circuit .evolve (t = self .T , dt = self .dt )
144-
145- self .coef_ = np .array (self .W .weights .value )
146-
147- return self .coef_ , self .err .mu .value , self .err .L .value
148-
149-
150-
151-
152-
129+ inputs = jnp .array (self .advance .pack_rows (self .T , t = lambda x : x , dt = self .dt ))
130+ stateManager .state , outputs = self .advance .scan (inputs )
131+ self .evolve .run (t = self .T , dt = self .dt )
153132
133+ self .coef_ = np .array (self .W .weights .get ())
154134
135+ return self .coef_ , self .err .mu .get (), self .err .L .get ()
155136
0 commit comments