44from jax import random , numpy as jnp , jit
55from ngclearn .components .jaxComponent import JaxComponent
66from ngclearn .utils .distribution_generator import DistributionGenerator
7-
87from ngcsimlib .logger import info
98from ngclearn import compilable #from ngcsimlib.parser import compilable
109from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
11- # from ngclearn.utils.weight_distribution import initialize_params
12-
13-
14- # def _create_multi_patch_synapses(key, shape, n_sub_models, sub_stride, weight_init):
15- # sub_shape = (shape[0] // n_sub_models, shape[1] // n_sub_models)
16- # di, dj = sub_shape
17- # si, sj = sub_stride
18-
19- # weight_shape = ((n_sub_models * di) + 2 * si, (n_sub_models * dj) + 2 * sj)
20- # #weights = initialize_params(key[2], {"dist": "constant", "value": 0.}, weight_shape, use_numpy=True)
21- # large_weight_init = DistributionGenerator.constant(value=0.)
22- # weights = large_weight_init(weight_shape, key[2])
23-
24- # for i in range(n_sub_models):
25- # start_i = i * di
26- # end_i = (i + 1) * di + 2 * si
27- # start_j = i * dj
28- # end_j = (i + 1) * dj + 2 * sj
29-
30- # shape_ = (end_i - start_i, end_j - start_j) # (di + 2 * si, dj + 2 * sj)
31-
32- # ## FIXME: this line below might be wonky...
33- # weights.at[start_i: end_i, start_j: end_j].set( weight_init(shape_, key[2]) )
34- # # weights[start_i : end_i,
35- # # start_j : end_j] = initialize_params(key[2], init_kernel=weight_init, shape=shape_, use_numpy=True)
36- # if si != 0:
37- # weights.at[:si,:].set(0.) ## FIXME: this setter line might be wonky...
38- # weights.at[-si:,:].set(0.) ## FIXME: this setter line might be wonky...
39- # if sj != 0:
40- # weights.at[:,:sj].set(0.) ## FIXME: this setter line might be wonky...
41- # weights.at[:, -sj:].set(0.) ## FIXME: this setter line might be wonky...
42-
43- # return weights
44-
45- def _create_multi_patch_synapses (key , shape , n_sub_models , sub_stride , weight_init ):
46- sub_shape = (shape [0 ] // n_sub_models , shape [1 ] // n_sub_models )
47- di , dj = sub_shape
48- si , sj = sub_stride
49-
50- weight_shape = ((n_sub_models * di ) + 2 * si , (n_sub_models * dj ) + 2 * sj )
51- # weights = initialize_params(key[2], {"dist": "constant", "value": 0.}, weight_shape, use_numpy=True)
52- weights = DistributionGenerator .constant (value = 0. )(weight_shape , key [2 ])
53-
54- for i in range (n_sub_models ):
10+
11+ def _create_multi_patch_synapses (key , shape , n_modules , module_stride = (0 , 0 ), initialization_type = DistributionGenerator .fan_in_gaussian ()):
12+ key , * subkey = random .split (key , n_modules + 10 )
13+
14+ module_shape = (shape [0 ] // n_modules , shape [1 ] // n_modules )
15+ di , dj = module_shape
16+ si , sj = module_stride
17+
18+ module_shape = di + (2 * si ), dj + (2 * sj )
19+
20+
21+ weight_shape = ((n_modules * di ) + 2 * si , (n_modules * dj ) + 2 * sj )
22+ weights = jnp .zeros (weight_shape )
23+ w_masks = jnp .zeros (weight_shape )
24+
25+ for i in range (n_modules ):
5526 start_i = i * di
5627 end_i = (i + 1 ) * di + 2 * si
5728 start_j = i * dj
5829 end_j = (i + 1 ) * dj + 2 * sj
5930
60- shape_ = (end_i - start_i , end_j - start_j ) # (di + 2 * si, dj + 2 * sj)
31+ shape_ = (end_i - start_i , end_j - start_j ) # (di + 2 * si, dj + 2 * sj)
6132
62- # weights[start_i : end_i,
63- # start_j : end_j] = initialize_params(key[2],
64- # init_kernel=weight_init,
65- # shape=shape_,
66- # use_numpy=True)
6733 weights = weights .at [start_i : end_i ,
68- start_j : end_j ].set (weight_init (shape_ , key [2 ]))
34+ start_j : end_j ].set (initialization_type (shape_ , subkey [i ]))
35+
36+ w_masks = w_masks .at [start_i : end_i ,
37+ start_j : end_j ].set (jnp .ones (shape_ ))
38+
6939 if si != 0 :
7040 weights = weights .at [:si ,:].set (0. )
7141 weights = weights .at [- si :,:].set (0. )
42+
43+ w_masks = w_masks .at [:si ,:].set (0. )
44+ w_masks = w_masks .at [- si :,:].set (0. )
45+
7246 if sj != 0 :
7347 weights = weights .at [:,:sj ].set (0. )
7448 weights = weights .at [:, - sj :].set (0. )
7549
76- return weights
50+ w_masks = weights .at [:,:sj ].set (0. )
51+ w_masks = weights .at [:, - sj :].set (0. )
7752
7853
54+ return weights , module_shape , w_masks
55+
7956class PatchedSynapse (JaxComponent ): ## base patched synaptic cable
8057 """
8158 A patched dense synaptic cables that creates multiple small dense synaptic cables; no form of synaptic evolution/adaptation
@@ -114,7 +91,7 @@ class PatchedSynapse(JaxComponent): ## base patched synaptic cable
11491 bias_init: a kernel to drive initialization of biases for this synaptic cable
11592 (Default: None, which turns off/disables biases)
11693
117- block_mask : weight mask matrix
94+ w_masks : weight mask matrix
11895
11996 pre_wght: pre-synaptic weighting factor (Default: 1.)
12097
@@ -127,8 +104,7 @@ class PatchedSynapse(JaxComponent): ## base patched synaptic cable
127104 this to < 1. will result in a sparser synaptic structure
128105 """
129106
130- def __init__ (
131- self , name , shape , n_sub_models = 1 , stride_shape = (0 ,0 ), block_mask = None , weight_init = None , bias_init = None ,
107+ def __init__ (self , name , shape , n_sub_models = 1 , stride_shape = (0 ,0 ), weight_init = None , bias_init = None ,
132108 resist_scale = 1. , p_conn = 1. , batch_size = 1 , ** kwargs
133109 ):
134110 super ().__init__ (name , ** kwargs )
@@ -144,60 +120,63 @@ def __init__(
144120 tmp_key , * subkeys = random .split (self .key .get (), 4 )
145121 if self .weight_init is None :
146122 info (self .name , "is using default weight initializer!" )
147- #self.weight_init = {"dist": "fan_in_gaussian"}
148123 self .weight_init = DistributionGenerator .fan_in_gaussian ()
149124
150- weights = _create_multi_patch_synapses (
151- key = subkeys , shape = shape , n_sub_models = self .n_sub_models , sub_stride = self .sub_stride ,
152- weight_init = self .weight_init
153- )
154-
155- self .block_mask = jnp .where (weights != 0 , 1 , 0 )
156- self .sub_shape = (shape [0 ]// n_sub_models , shape [1 ]// n_sub_models )
125+ weights , self .sub_shape , self .w_masks = _create_multi_patch_synapses (
126+ key = tmp_key , shape = shape , n_modules = self .n_sub_models , module_stride = self .sub_stride ,
127+ initialization_type = self .weight_init
128+ )
157129
158130 self .shape = weights .shape
159- self .sub_shape = self .sub_shape [0 ]+ (2 * self .sub_stride [0 ]), self .sub_shape [1 ]+ (2 * self .sub_stride [1 ])
160-
131+
161132 if 0. < p_conn < 1. : ## only non-zero and <1 probs allowed
162133 mask = random .bernoulli (subkeys [1 ], p = p_conn , shape = self .shape )
163134 weights = weights * mask ## sparsify matrix
164135
165136 ## Compartment setup
166137 preVals = jnp .zeros ((self .batch_size , self .shape [0 ]))
167138 postVals = jnp .zeros ((self .batch_size , self .shape [1 ]))
139+
168140 self .inputs = Compartment (preVals )
169141 self .outputs = Compartment (postVals )
170142 self .weights = Compartment (weights )
171143
144+ self .post_in = Compartment (postVals )
145+ self .pre_out = Compartment (preVals )
146+ self .weights_T = Compartment (weights .T )
147+
172148 ## Set up (optional) bias values
173149 if self .bias_init is None :
174150 info (self .name , "is using default bias value of zero (no bias "
175151 "kernel provided)!" )
176152 self .biases = Compartment (self .bias_init ((1 , self .shape [1 ]), subkeys [2 ]) if bias_init else 0.0 )
177- #elf.biases = Compartment(initialize_params(subkeys[2], bias_init, (1, self.shape[1])) if bias_init else 0.0)
178153
179154 @compilable
180155 def advance_state (self ):
181156 # Get the variables
182157 inputs = self .inputs .get ()
158+ post_in = self .post_in .get ()
183159 weights = self .weights .get ()
184160 biases = self .biases .get ()
185161
186162 outputs = (jnp .matmul (inputs , weights ) * self .Rscale ) + biases
163+ pre_out = jnp .matmul (post_in , weights .T )
187164
188165 # Update compartment
189166 self .outputs .set (outputs )
167+ self .pre_out .set (pre_out )
190168
191169 @compilable
192170 def reset (self ):
193171 preVals = jnp .zeros ((self .batch_size , self .shape [0 ]))
194172 postVals = jnp .zeros ((self .batch_size , self .shape [1 ]))
195- inputs = preVals
196- outputs = postVals
173+
197174 # BUG: the self.inputs here does not have the targeted field
198175 # NOTE: Quick workaround is to check if targeted is in the input or not
199- hasattr (self .inputs , "targeted" ) and not self .inputs .targeted and self .inputs .set (inputs )
200- self .outputs .set (outputs )
176+ hasattr (self .inputs , "targeted" ) and not self .inputs .targeted and self .inputs .set (preVals )
177+ self .outputs .set (postVals )
178+ self .post_in .set (postVals )
179+ self .pre_out .set (preVals )
201180
202181 @classmethod
203182 def help (cls ): ## component help function
@@ -208,13 +187,15 @@ def help(cls): ## component help function
208187 }
209188 compartment_props = {
210189 "inputs" :
211- {"inputs" : "Takes in external input signal values" },
190+ {"inputs" : "Takes in external input signal values" ,
191+ "post_in" : "Takes in external input signal values" },
212192 "states" :
213193 {"weights" : "Synapse efficacy/strength parameter values" ,
214194 "biases" : "Base-rate/bias parameter values" ,
215195 "key" : "JAX PRNG key" },
216196 "outputs" :
217- {"outputs" : "Output of synaptic transformation" },
197+ {"outputs" : "Output of synaptic transformation" ,
198+ "pre_out" : "Output of synaptic transformation" },
218199 }
219200 hyperparams = {
220201 "shape" : "Overall shape of synaptic weight value matrix; number inputs x number outputs" ,
@@ -224,7 +205,7 @@ def help(cls): ## component help function
224205 "weight_init" : "Initialization conditions for synaptic weight (W) values" ,
225206 "bias_init" : "Initialization conditions for bias/base-rate (b) values" ,
226207 "resist_scale" : "Resistance level scaling factor (Rscale); applied to output of transformation" ,
227- "block_mask " : "weight mask matrix" ,
208+ "w_masks " : "weight mask matrix" ,
228209 "p_conn" : "Probability of a connection existing (otherwise, it is masked to zero)"
229210 }
230211 info = {cls .__name__ : properties ,
@@ -241,3 +222,8 @@ def help(cls): ## component help function
241222 plt .imshow (Wab .weights .get (), cmap = 'gray' )
242223 plt .show ()
243224
225+
226+
227+
228+
229+
0 commit comments