1+ from jax import random , numpy as jnp , jit
2+ from ngclearn .components .jaxComponent import JaxComponent
3+ from ngclearn .utils .matrix_utils import decompose_to_mps
4+ from ngcsimlib .logger import info
5+
6+ from ngclearn import compilable
7+ from ngclearn import Compartment
8+
9+ class MPSSynapse (JaxComponent ):
10+ """
11+ A Matrix Product State (MPS) compressed synaptic cable.
12+
13+ This component represents a synaptic weight matrix decomposed into a
14+ contracted chain of low-rank tensor cores (also known as a Tensor Train).
15+ This architecture drastically reduces parameter counts for high-dimensional
16+ layers—from O(N*M) to O(N*K + M*K)—while maintaining high expressive power
17+ and biological plausibility through local error-driven updates.
18+
19+ | References:
20+ | Stoudenmire, E. Miles, and David J. Schwab. "Supervised learning with
21+ | quantum-inspired tensor networks." Advances in neural information
22+ | processing systems 29 (2016).
23+ |
24+ | Novikov, Alexander, et al. "Tensorizing neural networks." Advances in
25+ | neural information processing systems 28 (2015).
26+ |
27+ | Nuijten, W. W. L., et al. "A Message Passing Realization of Expected
28+ | Free Energy Minimization." arXiv preprint arXiv:2501.03154 (2025).
29+ |
30+ | Wilson, P. "Performing Active Inference with Explainable Tensor
31+ | Networks." (2024).
32+ |
33+ | Fields, Chris, et al. "Control flow in active inference systems."
34+ | arXiv preprint arXiv:2303.01514 (2023).
35+
36+ | --- Synapse Compartments: ---
37+ | inputs - external input signal values (shape: batch_size x in_dim)
38+ | outputs - transformed signal values (shape: batch_size x out_dim)
39+ | pre - pre-synaptic latent state values for learning (shape: batch_size x in_dim)
40+ | post - post-synaptic error signal values for learning (shape: batch_size x out_dim)
41+ | core1 - first MPS tensor core (shape: 1 x in_dim x bond_dim)
42+ | core2 - second MPS tensor core (shape: bond_dim x out_dim x 1)
43+ | key - JAX PRNG key used for stochasticity
44+
45+ Args:
46+ name: the string name of this component
47+
48+ shape: tuple specifying the shape of the latent synaptic weight matrix
49+ (number of inputs, number of outputs)
50+
51+ bond_dim: the internal rank or "bond dimension" of the MPS compression.
52+ Higher values increase expressive power at the cost of more parameters.
53+ (Default: 16)
54+
55+ batch_size: the number of samples in a concurrent batch (Default: 1)
56+ """
57+
58+ def __init__ (self , name , shape , bond_dim = 16 , batch_size = 1 , ** kwargs ):
59+ super ().__init__ (name , ** kwargs )
60+
61+ self .batch_size = batch_size
62+ self .shape = shape
63+ self .bond_dim = bond_dim
64+
65+ # Initialize synaptic cores using a small normal distribution
66+ tmp_key , * subkeys = random .split (self .key .get (), 3 )
67+
68+ # Core 1: maps input dimension to the internal bond dimension
69+ c1 = random .normal (subkeys [0 ], (1 , shape [0 ], bond_dim )) * 0.05
70+ self .core1 = Compartment (c1 )
71+
72+ # Core 2: maps internal bond dimension to the output dimension
73+ c2 = random .normal (subkeys [1 ], (bond_dim , shape [1 ], 1 )) * 0.05
74+ self .core2 = Compartment (c2 )
75+
76+ # Initialize Port/Compartment values
77+ preVals = jnp .zeros ((self .batch_size , shape [0 ]))
78+ postVals = jnp .zeros ((self .batch_size , shape [1 ]))
79+
80+ self .inputs = Compartment (preVals )
81+ self .outputs = Compartment (postVals )
82+ self .pre = Compartment (preVals )
83+ self .post = Compartment (postVals )
84+
85+ @compilable
86+ def advance_state (self ):
87+ """
88+ Performs the forward synaptic transformation using MPS contraction.
89+
90+ The full transformation is equivalent to: outputs = inputs @ (Core1 * Core2),
91+ but computed via iterative contraction to maintain memory efficiency:
92+ 1. z = inputs contracted with Core1 (Batch x Bond_Dim)
93+ 2. outputs = z contracted with Core2 (Batch x Out_Dim)
94+ """
95+ x = self .inputs .get ()
96+ c1 = self .core1 .get ()
97+ c2 = self .core2 .get ()
98+
99+ # Contraction 1: (Batch, In) @ (1, In, Bond) -> (Batch, Bond)
100+ z = jnp .einsum ('bi,mik->bk' , x , c1 )
101+
102+ # Contraction 2: (Batch, Bond) @ (Bond, Out, 1) -> (Batch, Out)
103+ out = jnp .einsum ('bk,kno->bn' , z , c2 )
104+
105+ self .outputs .set (out )
106+
107+ @compilable
108+ def project_backward (self , error_signal ):
109+ """
110+ Projects an error signal backwards through the compressed synaptic cable.
111+
112+ This allows the passing of messages/gradients through the hierarchy
113+ without ever reconstructing the full dense weight matrix, ensuring
114+ O(N) complexity relative to the input/output dimensions.
115+ """
116+ c1 = self .core1 .get ()
117+ c2 = self .core2 .get ()
118+ # 1. Project error through Core 2 to the bond space
119+ e_back = jnp .einsum ('bo,kno->bk' , error_signal , c2 )
120+ # 2. Project from bond space through Core 1 to the input space
121+ return jnp .einsum ('bk,mik->bi' , e_back , c1 )
122+
123+ @compilable
124+ def evolve (self , eta = 0.01 ):
125+ """
126+ Updates the MPS tensor cores using local error-driven (Hebbian) gradients.
127+
128+ This utilizes the 'pre' and 'post' compartments to update core1 and core2.
129+ Because the weights are factorized, the update to each core depends on
130+ the activity and errors projected through the other core, maintaining
131+ global consistency through local message passing.
132+ """
133+ x = self .pre .get () # Shape: (Batch, In)
134+ err = self .post .get () # Shape: (Batch, Out)
135+ c1 = self .core1 .get () # Shape: (1, In, K)
136+ c2 = self .core2 .get () # Shape: (K, Out, 1)
137+
138+ # 1. Compute latent bond activity for Core 2 update
139+ z = jnp .einsum ('bi,mik->bk' , x , c1 )
140+
141+ # 2. Update Core 2 (Gradients relative to bond activity and output error)
142+ dc2 = jnp .einsum ('bk,bn->kn' , z , err )
143+ dc2 = jnp .expand_dims (dc2 , axis = 2 )
144+
145+ # 3. Update Core 1 (Gradients relative to input activity and back-projected error)
146+ err_back = jnp .einsum ('bn,kno->bk' , err , c2 )
147+ dc1 = jnp .einsum ('bi,bk->ik' , x , err_back )
148+ dc1 = jnp .expand_dims (dc1 , axis = 0 )
149+
150+ # Apply updates to synaptic cores
151+ self .core1 .set (c1 + eta * dc1 )
152+ self .core2 .set (c2 + eta * dc2 )
153+
154+ @compilable
155+ def reset (self ):
156+ """
157+ Resets input, output, and activity compartments to zero.
158+ """
159+ if not self .inputs .targeted :
160+ self .inputs .set (jnp .zeros ((self .batch_size , self .shape [0 ])))
161+
162+ self .outputs .set (jnp .zeros ((self .batch_size , self .shape [1 ])))
163+
164+ if not self .pre .targeted :
165+ self .pre .set (jnp .zeros ((self .batch_size , self .shape [0 ])))
166+
167+ if not self .post .targeted :
168+ self .post .set (jnp .zeros ((self .batch_size , self .shape [1 ])))
169+
170+ @property
171+ def weights (self ):
172+ """
173+ Reconstructs the full dense matrix from the MPS cores for analysis.
174+ Note: This is computationally expensive for high-dimensional layers.
175+ """
176+ return Compartment (jnp .einsum ('mik,kno->in' , self .core1 .get (), self .core2 .get ()))
177+
178+ @weights .setter
179+ def weights (self , W ):
180+ """
181+ Sets the synaptic cores by decomposing a provided dense matrix W
182+ using Singular Value Decomposition (SVD).
183+ """
184+ c1 , c2 = decompose_to_mps (W , bond_dim = self .bond_dim )
185+ self .core1 .set (c1 )
186+ self .core2 .set (c2 )
187+
188+ @classmethod
189+ def help (cls ):
190+ """
191+ Returns an info dictionary describing the component.
192+ """
193+ properties = {
194+ "synapse_type" : "MPSSynapse - performs a compressed synaptic "
195+ "transformation of inputs to produce output signals via "
196+ "Matrix Product State (MPS) core contractions."
197+ }
198+ compartment_props = {
199+ "inputs" :
200+ {"inputs" : "Takes in external input signal values" ,
201+ "pre" : "Pre-synaptic latent state values for learning" ,
202+ "post" : "Post-synaptic error signal values for learning" },
203+ "states" :
204+ {"core1" : "First MPS tensor core (1, in_dim, bond_dim)" ,
205+ "core2" : "Second MPS tensor core (bond_dim, out_dim, 1)" ,
206+ "key" : "JAX PRNG key" },
207+ "outputs" :
208+ {"outputs" : "Output of compressed synaptic transformation" },
209+ }
210+ hyperparams = {
211+ "shape" : "Shape of latent weight matrix (in_dim, out_dim)" ,
212+ "bond_dim" : "The compression rank/bond-dimension of the MPS chain" ,
213+ "batch_size" : "Batch size dimension of this component"
214+ }
215+ info = {cls .__name__ : properties ,
216+ "compartments" : compartment_props ,
217+ "dynamics" : "outputs = [inputs @ Core1] @ Core2" ,
218+ "hyperparameters" : hyperparams }
219+ return info
220+
221+ if __name__ == '__main__' :
222+ from ngcsimlib .context import Context
223+ with Context ("MPS_Test" ) as ctx :
224+ Wab = MPSSynapse ("Wab" , (10 , 5 ), bond_dim = 4 )
225+ print (Wab )
0 commit comments