Skip to content

Commit ef8627a

Browse files
author
Alexander Ororbia
committed
cleaned up vq-synapse
1 parent d21fc65 commit ef8627a

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

ngclearn/components/synapses/competitive/vectorQuantizeSynapse.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def __init__(
9292
distance_function=("minkowski", 2),
9393
label_dim=0, ## if > 0, then this becomes supervised LVQ(1)
9494
initial_patterns=None, ## possible class-based prototypes to init by
95-
lanvegin_noise_scale=0., ## scale of Langevin noise to apply to updates
95+
langevin_noise_scale=0., ## scale of Langevin noise to apply to updates
9696
weight_init=None,
9797
resist_scale=1.,
9898
p_conn=1.,
@@ -103,7 +103,7 @@ def __init__(
103103
name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs
104104
)
105105

106-
### Synapse and VQ hyper-parameters
106+
### Synapse / VQ hyper-parameters
107107
self.label_dim = label_dim
108108
self.K = 1 ## number of winners (for a bmu)
109109
dist_fun, dist_order = distance_function ## Default: ("minkowski", 2) -> Euclidean
@@ -120,19 +120,18 @@ def __init__(
120120
self.eta_decr = eta_decrement #0.001
121121
self.syn_decay = syn_decay
122122
self.w_bound = w_bound ## soft synaptic value bound (on magnitude)
123-
self.zeta = langevin_noise_scale #0.2 #0.35 #1. ## Langevin dampening factor
123+
self.zeta = langevin_noise_scale ## Langevin dampening factor
124124

125125
## VQ Compartment setup
126126
label_syn_init = labels_init = jnp.zeros((1, 1))
127-
if self.label_dim > 0:
127+
if self.label_dim > 0: ## do we set up label memory matrix?
128128
label_syn_init = jnp.zeros((label_dim, self.shape[1]))
129129
labels_init = jnp.zeros((self.batch_size, self.label_dim))
130130
self.labels = Compartment(labels_init, display_name="Label Units")
131131
self.pred_labels = Compartment(labels_init, display_name="Predicted Label Values")
132132
self.label_weights = Compartment(label_syn_init, display_name="Label Synapses / Memory")
133133
self.eta = Compartment(jnp.zeros((1, 1)) + self.initial_eta, display_name="Dynamic step size")
134134
self.i_tick = Compartment(jnp.zeros((1, 1)))
135-
#self.bmu = Compartment(jnp.zeros((1, 1)), display_name="Best matching unit mask")
136135
self.dWeights = Compartment(self.weights.get() * 0)
137136

138137
if initial_patterns is not None: ## preload memory synaptic matrix
@@ -145,19 +144,19 @@ def __init__(
145144
W = jnp.concat([initX, W[:, 0:(H - initX.shape[1])]], axis=1)
146145
W = W[:, ptrs] ## shuffle memories
147146
self.weights.set(W)
148-
if self.label_dim > 0:
147+
if self.label_dim > 0: ## do we preload label matrix?
149148
Wy = self.label_weights.get()
150149
Wy = jnp.concat([initY, Wy[:, 0:(H - initX.shape[1])]], axis=1)
151150
Wy = Wy[:, ptrs] ## shuffle memories
152151
self.label_weights.set(Wy)
153152
else: ## memory is exactly the set of stored patterns/templates
154153
self.weights.set(initX)
155-
if self.label_dim > 0:
154+
if self.label_dim > 0: ## do we preload label matrix?
156155
self.label_weights.set(initY)
157156
@compilable
158157
def advance_state(self): ## forward-inference step of VQ
159158
x_in = self.inputs.get()
160-
x_in = x_in / jnp.linalg.norm(x_in, axis=1, keepdims=True)
159+
x_in = x_in / jnp.linalg.norm(x_in, axis=1, keepdims=True) ## NOTE: we normalize input patterns (?)
161160
self.inputs.set(x_in)
162161
W = self.weights.get().T ## get (transposed) memory matrix
163162

0 commit comments

Comments
 (0)