@@ -172,8 +172,7 @@ def __init__(
172172 prior = ("constant" , 0. ), w_decay = 0. , sign_value = 1. , optim_type = "sgd" , pre_wght = 1. , post_wght = 1. , p_conn = 1. ,
173173 resist_scale = 1. , batch_size = 1 , ** kwargs
174174 ):
175- super ().__init__ (name , shape , weight_init , bias_init , resist_scale ,
176- p_conn , batch_size = batch_size , ** kwargs )
175+ super ().__init__ (name , shape , weight_init , bias_init , resist_scale , p_conn , batch_size = batch_size , ** kwargs )
177176
178177 if w_decay > 0. :
179178 prior = ('l2' , w_decay )
@@ -209,13 +208,14 @@ def __init__(
209208 self .dBiases = Compartment (jnp .zeros (shape [1 ]))
210209
211210 #key, subkey = random.split(self.key.value)
212- self .opt_params = Compartment (get_opt_init_fn ( optim_type )(
213- [self .weights .get (), self .biases .get ()]
214- if bias_init else [ self . weights . get ()]) )
211+ self .opt_params = Compartment (
212+ get_opt_init_fn ( optim_type )( [self .weights .get (), self .biases .get ()] if bias_init else [ self . weights . get ()])
213+ )
215214
216215 @staticmethod
217- def _compute_update (w_bound , is_nonnegative , sign_value , prior_type , prior_lmbda , pre_wght ,
218- post_wght , pre , post , weights ):
216+ def _compute_update (
217+ w_bound , is_nonnegative , sign_value , prior_type , prior_lmbda , pre_wght , post_wght , pre , post , weights
218+ ):
219219 ## calculate synaptic update values
220220 dW , db = _calc_update (
221221 pre , post , weights , w_bound , is_nonnegative = is_nonnegative ,
@@ -257,8 +257,8 @@ def evolve(self):
257257 def reset (self ): #, batch_size, shape):
258258 preVals = jnp .zeros ((self .batch_size , self .shape [0 ]))
259259 postVals = jnp .zeros ((self .batch_size , self .shape [1 ]))
260- # not self.inputs.targeted and self.inputs.set(preVals) # inputs
261- self .inputs .set (preVals )
260+ if not self .inputs .targeted :
261+ self .inputs .set (preVals )
262262 self .outputs .set (postVals ) # outputs
263263 self .pre .set (preVals ) # pre
264264 self .post .set (postVals ) # post
@@ -310,20 +310,6 @@ def help(cls): ## component help function
310310 "hyperparameters" : hyperparams }
311311 return info
312312
313- def __repr__ (self ):
314- comps = [varname for varname in dir (self ) if isinstance (getattr (self , varname ), Compartment )]
315- maxlen = max (len (c ) for c in comps ) + 5
316- lines = f"[{ self .__class__ .__name__ } ] PATH: { self .name } \n "
317- for c in comps :
318- stats = tensorstats (getattr (self , c ).get ())
319- if stats is not None :
320- line = [f"{ k } : { v } " for k , v in stats .items ()]
321- line = ", " .join (line )
322- else :
323- line = "None"
324- lines += f" { f'({ c } )' .ljust (maxlen )} { line } \n "
325- return lines
326-
327313if __name__ == '__main__' :
328314 from ngcsimlib .context import Context
329315 with Context ("Bar" ) as bar :
0 commit comments