11from jax import random , numpy as jnp , jit
2- from ngclearn import resolver , Component , Compartment
2+ from ngcsimlib .compilers .process import transition
3+ from ngcsimlib .component import Component
4+ from ngcsimlib .compartment import Compartment
5+
36from .convSynapse import ConvSynapse
7+ from ngclearn .utils .weight_distribution import initialize_params
8+ from ngcsimlib .logger import info
9+ from ngclearn .utils import tensorstats
10+ import ngclearn .utils .weight_distribution as dist
411from ngclearn .components .synapses .convolution .ngcconv import (_conv_same_transpose_padding ,
512 _conv_valid_transpose_padding )
613from ngclearn .components .synapses .convolution .ngcconv import (conv2d , _calc_dX_conv ,
@@ -143,8 +150,9 @@ def _init(self, batch_size, x_size, shape, stride, padding, pad_args,
143150 self .x_delta_shape = (dx , dy )
144151
145152 @staticmethod
146- def _compute_update (sign_value , w_decay , bias_init , stride , pad_args ,
147- delta_shape , pre , post , weights ):
153+ def _compute_update (
154+ sign_value , w_decay , bias_init , stride , pad_args , delta_shape , pre , post , weights
155+ ): ## synaptic kernel adjustment calculation co-routine
148156 ## compute adjustment to filters
149157 dWeights = calc_dK_conv (pre , post , delta_shape = delta_shape ,
150158 stride_size = stride , padding = pad_args )
@@ -157,10 +165,12 @@ def _compute_update(sign_value, w_decay, bias_init, stride, pad_args,
157165 dBiases = jnp .sum (post , axis = 0 , keepdims = True ) * sign_value
158166 return dWeights , dBiases
159167
168+ @transition (output_compartments = ["opt_params" , "weights" , "biases" , "dWeights" , "dBiases" ])
160169 @staticmethod
161- def _evolve (opt , sign_value , w_decay , w_bounds , is_nonnegative , bias_init ,
162- stride , pad_args , delta_shape , pre , post , weights , biases ,
163- opt_params ):
170+ def evolve (
171+ opt , sign_value , w_decay , w_bounds , is_nonnegative , bias_init , stride , pad_args , delta_shape , pre , post ,
172+ weights , biases , opt_params
173+ ):
164174 ## calc dFilters / dBiases - update to filters and biases
165175 dWeights , dBiases = HebbianConvSynapse ._compute_update (
166176 sign_value , w_decay , bias_init , stride , pad_args , delta_shape ,
@@ -180,17 +190,11 @@ def _evolve(opt, sign_value, w_decay, w_bounds, is_nonnegative, bias_init,
180190 weights = jnp .clip (weights , - w_bounds , w_bounds )
181191 return opt_params , weights , biases , dWeights , dBiases
182192
183- @resolver (_evolve )
184- def evolve (self , opt_params , weights , biases , dWeights , dBiases ):
185- self .opt_params .set (opt_params )
186- self .weights .set (weights )
187- self .biases .set (biases )
188- self .dWeights .set (dWeights )
189- self .dBiases .set (dBiases )
190-
193+ @transition (output_compartments = ["dInputs" ])
191194 @staticmethod
192- def _backtransmit (sign_value , x_size , shape , stride , padding , x_delta_shape ,
193- antiPad , post , weights ): ## action-backpropagating routine
195+ def backtransmit (
196+ sign_value , x_size , shape , stride , padding , x_delta_shape , antiPad , post , weights
197+ ): ## action-backpropagating routine
194198 ## calc dInputs - adjustment w.r.t. input signal
195199 k_size , k_size , n_in_chan , n_out_chan = shape
196200 # antiPad = None
@@ -206,12 +210,9 @@ def _backtransmit(sign_value, x_size, shape, stride, padding, x_delta_shape,
206210 dInputs = dInputs * sign_value
207211 return dInputs
208212
209- @resolver (_backtransmit )
210- def backtransmit (self , dInputs ):
211- self .dInputs .set (dInputs )
212-
213+ @transition (output_compartments = ["inputs" , "outputs" , "pre" , "post" , "dInputs" ])
213214 @staticmethod
214- def _reset (in_shape , out_shape ):
215+ def reset (in_shape , out_shape ):
215216 preVals = jnp .zeros (in_shape )
216217 postVals = jnp .zeros (out_shape )
217218 inputs = preVals
@@ -221,14 +222,6 @@ def _reset(in_shape, out_shape):
221222 dInputs = preVals
222223 return inputs , outputs , pre , post , dInputs
223224
224- @resolver (_reset )
225- def reset (self , inputs , outputs , pre , post , dInputs ):
226- self .inputs .set (inputs )
227- self .outputs .set (outputs )
228- self .pre .set (pre )
229- self .post .set (post )
230- self .dInputs .set (dInputs )
231-
232225 @classmethod
233226 def help (cls ): ## component help function
234227 properties = {
0 commit comments