Skip to content

Commit 2fd70fc

Browse files
authored
Refactor patch synapse (#138)
* Refactor multi-patch synapse creation and initialization Refactor _create_multi_patch_synapses function to use n_modules instead of n_sub_models and update weight initialization. Introduce weight masks for synaptic weights. * Refactor HebbianPatchedSynapse and add attributes Refactor HebbianPatchedSynapse initialization and add new attributes for post-in and pre-out.
1 parent eb89be5 commit 2fd70fc

File tree

2 files changed

+74
-77
lines changed

2 files changed

+74
-77
lines changed

ngclearn/components/synapses/patched/hebbianPatchedSynapse.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,8 @@ def __init__(
197197
pre_wght=1., post_wght=1., p_conn=1., resist_scale=1., batch_size=1, **kwargs
198198
):
199199
super().__init__(
200-
name, shape, n_sub_models, stride_shape, block_mask, weight_init, bias_init, resist_scale, p_conn,
201-
batch_size=batch_size, **kwargs
200+
name, shape, n_sub_models, stride_shape, weight_init, bias_init, resist_scale, p_conn,
201+
batch_size, **kwargs
202202
)
203203

204204
prior_type, prior_lmbda = prior
@@ -288,6 +288,8 @@ def reset(self):
288288
# NOTE: Quick workaround is to check if targeted is in the input or not
289289
hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(preVals) # inputs
290290
self.outputs.set(postVals) # outputs
291+
self.post_in.set(postVals) # post_in
292+
self.pre_out.set(preVals) # pre_out
291293
self.pre.set(preVals) # pre
292294
self.post.set(postVals) # post
293295
self.dWeights.set(jnp.zeros(self.shape)) # dW
@@ -304,6 +306,7 @@ def help(cls): ## component help function
304306
compartment_props = {
305307
"inputs":
306308
{"inputs": "Takes in external input signal values",
309+
"post_in": "Takes in external input signal values",
307310
"pre": "Pre-synaptic statistic for Hebb rule (z_j)",
308311
"post": "Post-synaptic statistic for Hebb rule (z_i)"},
309312
"states":
@@ -314,7 +317,8 @@ def help(cls): ## component help function
314317
{"dWeights": "Synaptic weight value adjustment matrix produced at time t",
315318
"dBiases": "Synaptic bias/base-rate value adjustment vector produced at time t"},
316319
"outputs":
317-
{"outputs": "Output of synaptic transformation"},
320+
{"outputs": "Output of synaptic transformation",
321+
"pre_out": "Output of synaptic transformation"},
318322
}
319323
hyperparams = {
320324
"shape": "Overall shape of synaptic weight value matrix; number inputs x number outputs",
@@ -351,3 +355,10 @@ def help(cls): ## component help function
351355
plt.imshow(Wab.weights.get(), cmap='gray')
352356
plt.show()
353357

358+
359+
360+
361+
362+
363+
364+

ngclearn/components/synapses/patched/patchedSynapse.py

Lines changed: 60 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -4,78 +4,55 @@
44
from jax import random, numpy as jnp, jit
55
from ngclearn.components.jaxComponent import JaxComponent
66
from ngclearn.utils.distribution_generator import DistributionGenerator
7-
87
from ngcsimlib.logger import info
98
from ngclearn import compilable #from ngcsimlib.parser import compilable
109
from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
11-
# from ngclearn.utils.weight_distribution import initialize_params
12-
13-
14-
# def _create_multi_patch_synapses(key, shape, n_sub_models, sub_stride, weight_init):
15-
# sub_shape = (shape[0] // n_sub_models, shape[1] // n_sub_models)
16-
# di, dj = sub_shape
17-
# si, sj = sub_stride
18-
19-
# weight_shape = ((n_sub_models * di) + 2 * si, (n_sub_models * dj) + 2 * sj)
20-
# #weights = initialize_params(key[2], {"dist": "constant", "value": 0.}, weight_shape, use_numpy=True)
21-
# large_weight_init = DistributionGenerator.constant(value=0.)
22-
# weights = large_weight_init(weight_shape, key[2])
23-
24-
# for i in range(n_sub_models):
25-
# start_i = i * di
26-
# end_i = (i + 1) * di + 2 * si
27-
# start_j = i * dj
28-
# end_j = (i + 1) * dj + 2 * sj
29-
30-
# shape_ = (end_i - start_i, end_j - start_j) # (di + 2 * si, dj + 2 * sj)
31-
32-
# ## FIXME: this line below might be wonky...
33-
# weights.at[start_i: end_i, start_j: end_j].set( weight_init(shape_, key[2]) )
34-
# # weights[start_i : end_i,
35-
# # start_j : end_j] = initialize_params(key[2], init_kernel=weight_init, shape=shape_, use_numpy=True)
36-
# if si != 0:
37-
# weights.at[:si,:].set(0.) ## FIXME: this setter line might be wonky...
38-
# weights.at[-si:,:].set(0.) ## FIXME: this setter line might be wonky...
39-
# if sj != 0:
40-
# weights.at[:,:sj].set(0.) ## FIXME: this setter line might be wonky...
41-
# weights.at[:, -sj:].set(0.) ## FIXME: this setter line might be wonky...
42-
43-
# return weights
44-
45-
def _create_multi_patch_synapses(key, shape, n_sub_models, sub_stride, weight_init):
46-
sub_shape = (shape[0] // n_sub_models, shape[1] // n_sub_models)
47-
di, dj = sub_shape
48-
si, sj = sub_stride
49-
50-
weight_shape = ((n_sub_models * di) + 2 * si, (n_sub_models * dj) + 2 * sj)
51-
# weights = initialize_params(key[2], {"dist": "constant", "value": 0.}, weight_shape, use_numpy=True)
52-
weights = DistributionGenerator.constant(value=0.)(weight_shape, key[2])
53-
54-
for i in range(n_sub_models):
10+
11+
def _create_multi_patch_synapses(key, shape, n_modules, module_stride=(0, 0), initialization_type=DistributionGenerator.fan_in_gaussian()):
12+
key, *subkey = random.split(key, n_modules+10)
13+
14+
module_shape = (shape[0] // n_modules, shape[1] // n_modules)
15+
di, dj = module_shape
16+
si, sj = module_stride
17+
18+
module_shape = di + (2 * si), dj + (2 * sj)
19+
20+
21+
weight_shape = ((n_modules * di) + 2 * si, (n_modules * dj) + 2 * sj)
22+
weights = jnp.zeros(weight_shape)
23+
w_masks = jnp.zeros(weight_shape)
24+
25+
for i in range(n_modules):
5526
start_i = i * di
5627
end_i = (i + 1) * di + 2 * si
5728
start_j = i * dj
5829
end_j = (i + 1) * dj + 2 * sj
5930

60-
shape_ = (end_i - start_i, end_j - start_j) # (di + 2 * si, dj + 2 * sj)
31+
shape_ = (end_i - start_i, end_j - start_j) # (di + 2 * si, dj + 2 * sj)
6132

62-
# weights[start_i : end_i,
63-
# start_j : end_j] = initialize_params(key[2],
64-
# init_kernel=weight_init,
65-
# shape=shape_,
66-
# use_numpy=True)
6733
weights = weights.at[start_i : end_i,
68-
start_j : end_j].set(weight_init(shape_, key[2]))
34+
start_j : end_j].set(initialization_type(shape_, subkey[i]))
35+
36+
w_masks = w_masks.at[start_i : end_i,
37+
start_j : end_j].set(jnp.ones(shape_))
38+
6939
if si!=0:
7040
weights = weights.at[:si,:].set(0.)
7141
weights = weights.at[-si:,:].set(0.)
42+
43+
w_masks = w_masks.at[:si,:].set(0.)
44+
w_masks = w_masks.at[-si:,:].set(0.)
45+
7246
if sj!=0:
7347
weights = weights.at[:,:sj].set(0.)
7448
weights = weights.at[:, -sj:].set(0.)
7549

76-
return weights
50+
w_masks = weights.at[:,:sj].set(0.)
51+
w_masks = weights.at[:, -sj:].set(0.)
7752

7853

54+
return weights, module_shape, w_masks
55+
7956
class PatchedSynapse(JaxComponent): ## base patched synaptic cable
8057
"""
8158
A patched dense synaptic cables that creates multiple small dense synaptic cables; no form of synaptic evolution/adaptation
@@ -114,7 +91,7 @@ class PatchedSynapse(JaxComponent): ## base patched synaptic cable
11491
bias_init: a kernel to drive initialization of biases for this synaptic cable
11592
(Default: None, which turns off/disables biases)
11693
117-
block_mask: weight mask matrix
94+
w_masks: weight mask matrix
11895
11996
pre_wght: pre-synaptic weighting factor (Default: 1.)
12097
@@ -127,8 +104,7 @@ class PatchedSynapse(JaxComponent): ## base patched synaptic cable
127104
this to < 1. will result in a sparser synaptic structure
128105
"""
129106

130-
def __init__(
131-
self, name, shape, n_sub_models=1, stride_shape=(0,0), block_mask=None, weight_init=None, bias_init=None,
107+
def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), weight_init=None, bias_init=None,
132108
resist_scale=1., p_conn=1., batch_size=1, **kwargs
133109
):
134110
super().__init__(name, **kwargs)
@@ -144,60 +120,63 @@ def __init__(
144120
tmp_key, *subkeys = random.split(self.key.get(), 4)
145121
if self.weight_init is None:
146122
info(self.name, "is using default weight initializer!")
147-
#self.weight_init = {"dist": "fan_in_gaussian"}
148123
self.weight_init = DistributionGenerator.fan_in_gaussian()
149124

150-
weights = _create_multi_patch_synapses(
151-
key=subkeys, shape=shape, n_sub_models=self.n_sub_models, sub_stride=self.sub_stride,
152-
weight_init=self.weight_init
153-
)
154-
155-
self.block_mask = jnp.where(weights!=0, 1, 0)
156-
self.sub_shape = (shape[0]//n_sub_models, shape[1]//n_sub_models)
125+
weights, self.sub_shape, self.w_masks = _create_multi_patch_synapses(
126+
key=tmp_key, shape=shape, n_modules=self.n_sub_models, module_stride=self.sub_stride,
127+
initialization_type = self.weight_init
128+
)
157129

158130
self.shape = weights.shape
159-
self.sub_shape = self.sub_shape[0]+(2*self.sub_stride[0]), self.sub_shape[1]+(2*self.sub_stride[1])
160-
131+
161132
if 0. < p_conn < 1.: ## only non-zero and <1 probs allowed
162133
mask = random.bernoulli(subkeys[1], p=p_conn, shape=self.shape)
163134
weights = weights * mask ## sparsify matrix
164135

165136
## Compartment setup
166137
preVals = jnp.zeros((self.batch_size, self.shape[0]))
167138
postVals = jnp.zeros((self.batch_size, self.shape[1]))
139+
168140
self.inputs = Compartment(preVals)
169141
self.outputs = Compartment(postVals)
170142
self.weights = Compartment(weights)
171143

144+
self.post_in = Compartment(postVals)
145+
self.pre_out = Compartment(preVals)
146+
self.weights_T = Compartment(weights.T)
147+
172148
## Set up (optional) bias values
173149
if self.bias_init is None:
174150
info(self.name, "is using default bias value of zero (no bias "
175151
"kernel provided)!")
176152
self.biases = Compartment(self.bias_init((1, self.shape[1]), subkeys[2]) if bias_init else 0.0)
177-
#elf.biases = Compartment(initialize_params(subkeys[2], bias_init, (1, self.shape[1])) if bias_init else 0.0)
178153

179154
@compilable
180155
def advance_state(self):
181156
# Get the variables
182157
inputs = self.inputs.get()
158+
post_in = self.post_in.get()
183159
weights = self.weights.get()
184160
biases = self.biases.get()
185161

186162
outputs = (jnp.matmul(inputs, weights) * self.Rscale) + biases
163+
pre_out = jnp.matmul(post_in, weights.T)
187164

188165
# Update compartment
189166
self.outputs.set(outputs)
167+
self.pre_out.set(pre_out)
190168

191169
@compilable
192170
def reset(self):
193171
preVals = jnp.zeros((self.batch_size, self.shape[0]))
194172
postVals = jnp.zeros((self.batch_size, self.shape[1]))
195-
inputs = preVals
196-
outputs = postVals
173+
197174
# BUG: the self.inputs here does not have the targeted field
198175
# NOTE: Quick workaround is to check if targeted is in the input or not
199-
hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(inputs)
200-
self.outputs.set(outputs)
176+
hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(preVals)
177+
self.outputs.set(postVals)
178+
self.post_in.set(postVals)
179+
self.pre_out.set(preVals)
201180

202181
@classmethod
203182
def help(cls): ## component help function
@@ -208,13 +187,15 @@ def help(cls): ## component help function
208187
}
209188
compartment_props = {
210189
"inputs":
211-
{"inputs": "Takes in external input signal values"},
190+
{"inputs": "Takes in external input signal values",
191+
"post_in": "Takes in external input signal values"},
212192
"states":
213193
{"weights": "Synapse efficacy/strength parameter values",
214194
"biases": "Base-rate/bias parameter values",
215195
"key": "JAX PRNG key"},
216196
"outputs":
217-
{"outputs": "Output of synaptic transformation"},
197+
{"outputs": "Output of synaptic transformation",
198+
"pre_out": "Output of synaptic transformation"},
218199
}
219200
hyperparams = {
220201
"shape": "Overall shape of synaptic weight value matrix; number inputs x number outputs",
@@ -224,7 +205,7 @@ def help(cls): ## component help function
224205
"weight_init": "Initialization conditions for synaptic weight (W) values",
225206
"bias_init": "Initialization conditions for bias/base-rate (b) values",
226207
"resist_scale": "Resistance level scaling factor (Rscale); applied to output of transformation",
227-
"block_mask": "weight mask matrix",
208+
"w_masks": "weight mask matrix",
228209
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)"
229210
}
230211
info = {cls.__name__: properties,
@@ -241,3 +222,8 @@ def help(cls): ## component help function
241222
plt.imshow(Wab.weights.get(), cmap='gray')
242223
plt.show()
243224

225+
226+
227+
228+
229+

0 commit comments

Comments
 (0)