1- from jax import random , numpy as jnp , jit
1+ from jax import random , numpy as jnp , jit , nn
22from functools import partial
33from ngclearn import compilable
44from ngclearn import Compartment
@@ -69,7 +69,7 @@ def __init__(
6969 name ,
7070 shape , ## determines memory matrix size
7171 eta = 0.05 , ## learning rate
72- eta_decrement = 0. ,
72+ eta_decrement = 0. , ## linear scheduled decrement over eta
7373 vigilance = 0.3 , ## vigilance parameter (rho)
7474 weight_init = None ,
7575 resist_scale = 1. ,
@@ -96,34 +96,64 @@ def __init__(
9696 self .i_tick = Compartment (jnp .zeros ((1 , 1 )))
9797 #self.bmu = Compartment(jnp.zeros((1, 1)), display_name="Best matching unit mask")
9898 self .dWeights = Compartment (self .weights .get () * 0 )
99- self .misses = Compartment (jnp .zeros ((batch_size , 1 )))
99+ self .misses = Compartment (jnp .zeros ((batch_size , 1 ))) ## marker for non-resonant patterns in a batch
100+
101+ self .weights .set (self .weights .get () * 0 )
102+ self .used = Compartment (jnp .zeros ((1 , shape [1 ]))) ## marks if memory slot used
103+
104+ def insert (self , x , idx ): ## manual memory insertion co-routine
105+ W = self .weights .get ()
106+ z_m = jnp .expand_dims (nn .one_hot (idx , W .shape [1 ]), axis = 0 )
107+ dW = (W * 0 + x .T ) * z_m
108+ W = W + dW
109+ self .weights .set (W )
110+ self .used .set (((self .used .get () + z_m ) > 0. ) * 1. )
111+
112+ def grow (self , n_memories ): ## grow out memory matrix by fixed amount
113+ W = self .weights .get ()
114+ used = self .used .get ()
115+ ## expand memory matrix by a fixed set of empty memory slots
116+ W = jnp .concat ([W , jnp .zeros ((W .shape [0 ], n_memories ))], axis = 1 )
117+ n_unused = jnp .zeros ((1 , n_memories ))
118+ used = jnp .concat ([used , n_unused ], axis = 1 )
119+ #print("used: ", used.shape)
120+ self .used .set (used )
121+ self .weights .set (W )
122+ self .dWeights .set (W * 0 )
123+ self .shape = self .weights .get ().shape
100124
101- #@compilable
102- def consolidate (self , wipe_mem = False ): ## memory storage/consolidation routine
103- ## note that this co-routine needs to be non-compilable/non-jit-i-fied, as
104- ## its main purpose is to structurally alter the memory matrix W (dynamically)
105- x_in = self .inputs .get ()
106- #xprobe = self.xprobe.get()
107- xprobe = _normalize (x_in , norm_fx = self .norm_fx )
108- if wipe_mem is False :
109- W = self .weights .get () ## get current memory matrix
110- miss_mask = self .misses .get ()
111- if jnp .sum (miss_mask ) > 0 : ## for non-resonant patterns
112- r , c = jnp .nonzero (miss_mask )
113- mem = xprobe [r , :]
114- W = jnp .concat ([W , mem .T ], axis = 1 )
115- self .weights .set (W )
116- else :
117- self .weights .set (xprobe .T )
125+ @compilable
126+ def consolidate (self ): ## memory consolition co-routine (for non-resonant signals)
127+ n_used = int (jnp .sum (self .used .get ())) ## number unused slots left
128+ x = self .xprobe .get ()
129+ W = self .weights .get ()
130+ nonresonants = self .misses .get ()
131+ ## we project non-resonant memories to empty slots in memory W
132+ S = jnp .eye (x .shape [0 ], self .shape [1 ], k = n_used )
133+ dWstore = jnp .matmul ((x * nonresonants ).T , S )
134+ W = W + dWstore ## Hebbian update to memory
135+ ## re-compute number of used slots post-consolidation
136+ nW = jnp .linalg .norm (W , ord = 2 , axis = 0 , keepdims = True )
137+ used = (nW > 0. ) * 1
138+
139+ self .weights .set (W )
140+ self .used .set (used )
141+ ## else, currently discard un-absorbed/non-resonant patterns
142+ ## can add a function that "grows" out block matrix by a chunk (to control growth)
143+ ## TODO: add pruning mechanism for low-usage slots
118144
119145 @compilable
120146 def advance_state (self ): ## forward-inference step of ART2A
121147 x_in = self .inputs .get ()
122148 W = self .weights .get () ## get (transposed) memory matrix
149+ used = self .used .get ()
123150
124151 x = _normalize (x_in , norm_fx = self .norm_fx )
125152 self .xprobe .set (x )
126153 sims = jnp .matmul (x , W ) ## compute similarities (parallel dot products)
154+ ## we correct activities by masking out unused slots
155+ sims_min = jnp .amin (sims , axis = 1 , keepdims = True )
156+ sims = sims * used + (1. - used ) * (sims_min - 1. )
127157 z_winners = sims * bkwta (sims , nWTA = self .K ) ## get winner mask (hidden layer)
128158 self .outputs .set (z_winners )
129159
@@ -142,15 +172,17 @@ def evolve(self, t, dt): ## competitive Hebbian update step of ART2A
142172 wnew = (- jnp .matmul (z_winners , W .T ) + x ) * m ## B x D
143173 dW = jnp .matmul (wnew .T , hits ) ## D x Z ## adjustment matrix
144174 W = W + dW * eta ## D x Z ## do a step of Hebbian ascent
175+ nonresonants = 1. - m ## mark non-resonant patterns in batch
145176
146177 ## NOTE: is this post-weight-update normalization needed?
147- nW = jnp .linalg .norm (W , ord = 2 , axis = 0 , keepdims = True )
148- mz = (jnp .sum (hits , axis = 0 , keepdims = True ) > 0. ) * 1.
149- W = W / (nW * mz + (1. - mz ))
150-
151- self .weights .set (W ) ## memory matrix advances forward to new state
152- self .misses .set (1. - m ) ## store unused/non-resonant pattern mask
153-
178+ #nW = jnp.linalg.norm(W, ord=2, axis=0, keepdims=True)
179+ #used = (nW > 0.) * 1
180+ #mz = (jnp.sum(hits, axis=0, keepdims=True) > 0.) * 1.
181+ #W = W / (nW * mz + (1. - mz))
182+
183+ self .weights .set (W )
184+ self .misses .set (nonresonants ) ## store unused/non-resonant pattern mas
185+
154186 #tmp_key, *subkeys = random.split(self.key.get(), 3)
155187 #self.key.set(tmp_key)
156188 ## synaptic update noise
@@ -171,8 +203,8 @@ def reset(self):
171203 self .inputs .set (preVals )
172204 self .outputs .set (postVals )
173205 self .xprobe .set (preVals )
174- self .misses .set (jnp .zeros ((self .batch_size .get (), 1 )))
175-
206+ # self.misses.set(jnp.zeros((self.batch_size.get(), 1)))
207+ self . misses . set ( self . misses . get () * 0 )
176208 self .dWeights .set (jnp .zeros (self .shape .get ()))
177209
178210 @classmethod
0 commit comments