@@ -92,7 +92,7 @@ def __init__(
9292 distance_function = ("minkowski" , 2 ),
9393 label_dim = 0 , ## if > 0, then this becomes supervised LVQ(1)
9494 initial_patterns = None , ## possible class-based prototypes to init by
95- lanvegin_noise_scale = 0. , ## scale of Langevin noise to apply to updates
95+ langevin_noise_scale = 0. , ## scale of Langevin noise to apply to updates
9696 weight_init = None ,
9797 resist_scale = 1. ,
9898 p_conn = 1. ,
@@ -103,7 +103,7 @@ def __init__(
103103 name , shape , weight_init , None , resist_scale , p_conn , batch_size = batch_size , ** kwargs
104104 )
105105
106- ### Synapse and VQ hyper-parameters
106+ ### Synapse / VQ hyper-parameters
107107 self .label_dim = label_dim
108108 self .K = 1 ## number of winners (for a bmu)
109109 dist_fun , dist_order = distance_function ## Default: ("minkowski", 2) -> Euclidean
@@ -120,19 +120,18 @@ def __init__(
120120 self .eta_decr = eta_decrement #0.001
121121 self .syn_decay = syn_decay
122122 self .w_bound = w_bound ## soft synaptic value bound (on magnitude)
123- self .zeta = langevin_noise_scale #0.2 #0.35 #1. # # Langevin dampening factor
123+ self .zeta = langevin_noise_scale ## Langevin dampening factor
124124
125125 ## VQ Compartment setup
126126 label_syn_init = labels_init = jnp .zeros ((1 , 1 ))
127- if self .label_dim > 0 :
127+ if self .label_dim > 0 : ## do we set up label memory matrix?
128128 label_syn_init = jnp .zeros ((label_dim , self .shape [1 ]))
129129 labels_init = jnp .zeros ((self .batch_size , self .label_dim ))
130130 self .labels = Compartment (labels_init , display_name = "Label Units" )
131131 self .pred_labels = Compartment (labels_init , display_name = "Predicted Label Values" )
132132 self .label_weights = Compartment (label_syn_init , display_name = "Label Synapses / Memory" )
133133 self .eta = Compartment (jnp .zeros ((1 , 1 )) + self .initial_eta , display_name = "Dynamic step size" )
134134 self .i_tick = Compartment (jnp .zeros ((1 , 1 )))
135- #self.bmu = Compartment(jnp.zeros((1, 1)), display_name="Best matching unit mask")
136135 self .dWeights = Compartment (self .weights .get () * 0 )
137136
138137 if initial_patterns is not None : ## preload memory synaptic matrix
@@ -145,19 +144,19 @@ def __init__(
145144 W = jnp .concat ([initX , W [:, 0 :(H - initX .shape [1 ])]], axis = 1 )
146145 W = W [:, ptrs ] ## shuffle memories
147146 self .weights .set (W )
148- if self .label_dim > 0 :
147+ if self .label_dim > 0 : ## do we preload label matrix?
149148 Wy = self .label_weights .get ()
150149 Wy = jnp .concat ([initY , Wy [:, 0 :(H - initX .shape [1 ])]], axis = 1 )
151150 Wy = Wy [:, ptrs ] ## shuffle memories
152151 self .label_weights .set (Wy )
153152 else : ## memory is exactly the set of stored patterns/templates
154153 self .weights .set (initX )
155- if self .label_dim > 0 :
154+ if self .label_dim > 0 : ## do we preload label matrix?
156155 self .label_weights .set (initY )
157156 @compilable
158157 def advance_state (self ): ## forward-inference step of VQ
159158 x_in = self .inputs .get ()
160- x_in = x_in / jnp .linalg .norm (x_in , axis = 1 , keepdims = True )
159+ x_in = x_in / jnp .linalg .norm (x_in , axis = 1 , keepdims = True ) ## NOTE: we normalize input patterns (?)
161160 self .inputs .set (x_in )
162161 W = self .weights .get ().T ## get (transposed) memory matrix
163162
0 commit comments