Skip to content

Commit 4f434f1

Browse files
antonviceago109
andauthored
feat: Integrate MPSSynapse Component (#140)
* feat: integrate MPSSynapse component for compressed synaptic transforms * style: conform to Google docstrings, move utils, and add unit tests * feat: implement native learning via evolve method and unit tests * Fixed MPS Matrix Properties: I fixed the .T transpose bug you interrupted earlier—because self.W10.weights inside an MPSSynapse generates the tensor via an einsum, returning an Array, get() throws an error. * Fix MPS synapse memory leak by implementing project_backward * Delete uv.lock * docs: add academic references and detailed docstrings to MPSSynapse * sorry, here you go, I loosened the test tolerances to 1e-2 as suggested --------- Co-authored-by: Alex Ororbia <agocse109@gmail.com>
1 parent 2fd70fc commit 4f434f1

5 files changed

Lines changed: 350 additions & 0 deletions

File tree

ngclearn/components/synapses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .hebbian.expSTDPSynapse import ExpSTDPSynapse
1414
from .hebbian.eventSTDPSynapse import EventSTDPSynapse
1515
from .hebbian.BCMSynapse import BCMSynapse
16+
from .mpsSynapse import MPSSynapse
1617

1718
## conv/deconv synaptic components
1819
from .convolution.convSynapse import ConvSynapse
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
from jax import random, numpy as jnp, jit
2+
from ngclearn.components.jaxComponent import JaxComponent
3+
from ngclearn.utils.matrix_utils import decompose_to_mps
4+
from ngcsimlib.logger import info
5+
6+
from ngclearn import compilable
7+
from ngclearn import Compartment
8+
9+
class MPSSynapse(JaxComponent):
10+
"""
11+
A Matrix Product State (MPS) compressed synaptic cable.
12+
13+
This component represents a synaptic weight matrix decomposed into a
14+
contracted chain of low-rank tensor cores (also known as a Tensor Train).
15+
This architecture drastically reduces parameter counts for high-dimensional
16+
layers—from O(N*M) to O(N*K + M*K)—while maintaining high expressive power
17+
and biological plausibility through local error-driven updates.
18+
19+
| References:
20+
| Stoudenmire, E. Miles, and David J. Schwab. "Supervised learning with
21+
| quantum-inspired tensor networks." Advances in neural information
22+
| processing systems 29 (2016).
23+
|
24+
| Novikov, Alexander, et al. "Tensorizing neural networks." Advances in
25+
| neural information processing systems 28 (2015).
26+
|
27+
| Nuijten, W. W. L., et al. "A Message Passing Realization of Expected
28+
| Free Energy Minimization." arXiv preprint arXiv:2501.03154 (2025).
29+
|
30+
| Wilson, P. "Performing Active Inference with Explainable Tensor
31+
| Networks." (2024).
32+
|
33+
| Fields, Chris, et al. "Control flow in active inference systems."
34+
| arXiv preprint arXiv:2303.01514 (2023).
35+
36+
| --- Synapse Compartments: ---
37+
| inputs - external input signal values (shape: batch_size x in_dim)
38+
| outputs - transformed signal values (shape: batch_size x out_dim)
39+
| pre - pre-synaptic latent state values for learning (shape: batch_size x in_dim)
40+
| post - post-synaptic error signal values for learning (shape: batch_size x out_dim)
41+
| core1 - first MPS tensor core (shape: 1 x in_dim x bond_dim)
42+
| core2 - second MPS tensor core (shape: bond_dim x out_dim x 1)
43+
| key - JAX PRNG key used for stochasticity
44+
45+
Args:
46+
name: the string name of this component
47+
48+
shape: tuple specifying the shape of the latent synaptic weight matrix
49+
(number of inputs, number of outputs)
50+
51+
bond_dim: the internal rank or "bond dimension" of the MPS compression.
52+
Higher values increase expressive power at the cost of more parameters.
53+
(Default: 16)
54+
55+
batch_size: the number of samples in a concurrent batch (Default: 1)
56+
"""
57+
58+
def __init__(self, name, shape, bond_dim=16, batch_size=1, **kwargs):
59+
super().__init__(name, **kwargs)
60+
61+
self.batch_size = batch_size
62+
self.shape = shape
63+
self.bond_dim = bond_dim
64+
65+
# Initialize synaptic cores using a small normal distribution
66+
tmp_key, *subkeys = random.split(self.key.get(), 3)
67+
68+
# Core 1: maps input dimension to the internal bond dimension
69+
c1 = random.normal(subkeys[0], (1, shape[0], bond_dim)) * 0.05
70+
self.core1 = Compartment(c1)
71+
72+
# Core 2: maps internal bond dimension to the output dimension
73+
c2 = random.normal(subkeys[1], (bond_dim, shape[1], 1)) * 0.05
74+
self.core2 = Compartment(c2)
75+
76+
# Initialize Port/Compartment values
77+
preVals = jnp.zeros((self.batch_size, shape[0]))
78+
postVals = jnp.zeros((self.batch_size, shape[1]))
79+
80+
self.inputs = Compartment(preVals)
81+
self.outputs = Compartment(postVals)
82+
self.pre = Compartment(preVals)
83+
self.post = Compartment(postVals)
84+
85+
@compilable
86+
def advance_state(self):
87+
"""
88+
Performs the forward synaptic transformation using MPS contraction.
89+
90+
The full transformation is equivalent to: outputs = inputs @ (Core1 * Core2),
91+
but computed via iterative contraction to maintain memory efficiency:
92+
1. z = inputs contracted with Core1 (Batch x Bond_Dim)
93+
2. outputs = z contracted with Core2 (Batch x Out_Dim)
94+
"""
95+
x = self.inputs.get()
96+
c1 = self.core1.get()
97+
c2 = self.core2.get()
98+
99+
# Contraction 1: (Batch, In) @ (1, In, Bond) -> (Batch, Bond)
100+
z = jnp.einsum('bi,mik->bk', x, c1)
101+
102+
# Contraction 2: (Batch, Bond) @ (Bond, Out, 1) -> (Batch, Out)
103+
out = jnp.einsum('bk,kno->bn', z, c2)
104+
105+
self.outputs.set(out)
106+
107+
@compilable
108+
def project_backward(self, error_signal):
109+
"""
110+
Projects an error signal backwards through the compressed synaptic cable.
111+
112+
This allows the passing of messages/gradients through the hierarchy
113+
without ever reconstructing the full dense weight matrix, ensuring
114+
O(N) complexity relative to the input/output dimensions.
115+
"""
116+
c1 = self.core1.get()
117+
c2 = self.core2.get()
118+
# 1. Project error through Core 2 to the bond space
119+
e_back = jnp.einsum('bo,kno->bk', error_signal, c2)
120+
# 2. Project from bond space through Core 1 to the input space
121+
return jnp.einsum('bk,mik->bi', e_back, c1)
122+
123+
@compilable
124+
def evolve(self, eta=0.01):
125+
"""
126+
Updates the MPS tensor cores using local error-driven (Hebbian) gradients.
127+
128+
This utilizes the 'pre' and 'post' compartments to update core1 and core2.
129+
Because the weights are factorized, the update to each core depends on
130+
the activity and errors projected through the other core, maintaining
131+
global consistency through local message passing.
132+
"""
133+
x = self.pre.get() # Shape: (Batch, In)
134+
err = self.post.get() # Shape: (Batch, Out)
135+
c1 = self.core1.get() # Shape: (1, In, K)
136+
c2 = self.core2.get() # Shape: (K, Out, 1)
137+
138+
# 1. Compute latent bond activity for Core 2 update
139+
z = jnp.einsum('bi,mik->bk', x, c1)
140+
141+
# 2. Update Core 2 (Gradients relative to bond activity and output error)
142+
dc2 = jnp.einsum('bk,bn->kn', z, err)
143+
dc2 = jnp.expand_dims(dc2, axis=2)
144+
145+
# 3. Update Core 1 (Gradients relative to input activity and back-projected error)
146+
err_back = jnp.einsum('bn,kno->bk', err, c2)
147+
dc1 = jnp.einsum('bi,bk->ik', x, err_back)
148+
dc1 = jnp.expand_dims(dc1, axis=0)
149+
150+
# Apply updates to synaptic cores
151+
self.core1.set(c1 + eta * dc1)
152+
self.core2.set(c2 + eta * dc2)
153+
154+
@compilable
155+
def reset(self):
156+
"""
157+
Resets input, output, and activity compartments to zero.
158+
"""
159+
if not self.inputs.targeted:
160+
self.inputs.set(jnp.zeros((self.batch_size, self.shape[0])))
161+
162+
self.outputs.set(jnp.zeros((self.batch_size, self.shape[1])))
163+
164+
if not self.pre.targeted:
165+
self.pre.set(jnp.zeros((self.batch_size, self.shape[0])))
166+
167+
if not self.post.targeted:
168+
self.post.set(jnp.zeros((self.batch_size, self.shape[1])))
169+
170+
@property
171+
def weights(self):
172+
"""
173+
Reconstructs the full dense matrix from the MPS cores for analysis.
174+
Note: This is computationally expensive for high-dimensional layers.
175+
"""
176+
return Compartment(jnp.einsum('mik,kno->in', self.core1.get(), self.core2.get()))
177+
178+
@weights.setter
179+
def weights(self, W):
180+
"""
181+
Sets the synaptic cores by decomposing a provided dense matrix W
182+
using Singular Value Decomposition (SVD).
183+
"""
184+
c1, c2 = decompose_to_mps(W, bond_dim=self.bond_dim)
185+
self.core1.set(c1)
186+
self.core2.set(c2)
187+
188+
@classmethod
189+
def help(cls):
190+
"""
191+
Returns an info dictionary describing the component.
192+
"""
193+
properties = {
194+
"synapse_type": "MPSSynapse - performs a compressed synaptic "
195+
"transformation of inputs to produce output signals via "
196+
"Matrix Product State (MPS) core contractions."
197+
}
198+
compartment_props = {
199+
"inputs":
200+
{"inputs": "Takes in external input signal values",
201+
"pre": "Pre-synaptic latent state values for learning",
202+
"post": "Post-synaptic error signal values for learning"},
203+
"states":
204+
{"core1": "First MPS tensor core (1, in_dim, bond_dim)",
205+
"core2": "Second MPS tensor core (bond_dim, out_dim, 1)",
206+
"key": "JAX PRNG key"},
207+
"outputs":
208+
{"outputs": "Output of compressed synaptic transformation"},
209+
}
210+
hyperparams = {
211+
"shape": "Shape of latent weight matrix (in_dim, out_dim)",
212+
"bond_dim": "The compression rank/bond-dimension of the MPS chain",
213+
"batch_size": "Batch size dimension of this component"
214+
}
215+
info = {cls.__name__: properties,
216+
"compartments": compartment_props,
217+
"dynamics": "outputs = [inputs @ Core1] @ Core2",
218+
"hyperparameters": hyperparams}
219+
return info
220+
221+
if __name__ == '__main__':
222+
from ngcsimlib.context import Context
223+
with Context("MPS_Test") as ctx:
224+
Wab = MPSSynapse("Wab", (10, 5), bond_dim=4)
225+
print(Wab)

ngclearn/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .distribution_generator import DistributionGenerator
22
from .JaxProcessesMixin import JaxJointProcess as JointProcess, JaxMethodProcess as MethodProcess
33
from .model_utils import tensorstats
4+
from .matrix_utils import decompose_to_mps
45

ngclearn/utils/matrix_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import jax.numpy as jnp
2+
3+
def decompose_to_mps(W, bond_dim=16):
4+
"""
5+
Decomposes a dense matrix W into two MPS cores using SVD.
6+
7+
Args:
8+
W: The dense matrix to decompose of shape (in_dim, out_dim).
9+
10+
bond_dim: The internal rank/bond-dimension of the MPS compression.
11+
12+
Returns:
13+
A tuple containing:
14+
core1: First tensor core of shape (1, in_dim, bond_dim).
15+
core2: Second tensor core of shape (bond_dim, out_dim, 1).
16+
"""
17+
U, S, Vh = jnp.linalg.svd(W, full_matrices=False)
18+
k = min(bond_dim, len(S))
19+
U_k = U[:, :k]
20+
S_k = S[:k]
21+
Vh_k = Vh[:k, :]
22+
23+
s_sqrt = jnp.sqrt(S_k)
24+
core1 = (U_k * s_sqrt).reshape(1, W.shape[0], k)
25+
core2 = (s_sqrt[:, None] * Vh_k).reshape(k, W.shape[1], 1)
26+
return core1, core2
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from jax import numpy as jnp, random
2+
import numpy as np
3+
from ngcsimlib.context import Context
4+
from ngclearn import MethodProcess
5+
from ngclearn.components.synapses.mpsSynapse import MPSSynapse
6+
7+
def test_mps_synapse_reconstruction():
8+
"""
9+
Tests if MPSSynapse can be initialized from a dense matrix and
10+
reproduce the transformation reasonably well.
11+
"""
12+
dkey = random.PRNGKey(42)
13+
in_dim, out_dim = 20, 10
14+
bond_dim = 8
15+
16+
with Context("mps_test") as ctx:
17+
mps = MPSSynapse("mps", (in_dim, out_dim), bond_dim=bond_dim)
18+
advance = MethodProcess("advance") >> mps.advance_state
19+
reset = MethodProcess("reset") >> mps.reset
20+
21+
# Create a structured matrix
22+
W_orig = random.normal(dkey, (in_dim, out_dim))
23+
24+
# Set weights via setter (uses decompose_to_mps)
25+
mps.weights = W_orig
26+
27+
# Check reconstruction fidelity
28+
W_recon = mps.weights
29+
error = jnp.linalg.norm(W_orig - W_recon) / jnp.linalg.norm(W_orig)
30+
31+
# With bond_dim=8 and rank=10, error should be small but non-zero
32+
assert error < 0.5
33+
34+
# Test transformation correctness
35+
x = random.normal(dkey, (1, in_dim))
36+
mps.inputs.set(x)
37+
advance.run()
38+
39+
y_mps = mps.outputs.get()
40+
y_dense = x @ W_recon
41+
42+
np.testing.assert_allclose(y_mps, y_dense, atol=1e-2)
43+
print(f"MPS Reconstruction Test Passed. Error: {error*100:.2f}%")
44+
45+
def test_mps_synapse_learning():
46+
"""
47+
Tests if MPSSynapse can learn from error signals via the evolve method.
48+
"""
49+
dkey = random.PRNGKey(123)
50+
in_dim, out_dim = 10, 5
51+
bond_dim = 4
52+
53+
with Context("mps_learning_test") as ctx:
54+
mps = MPSSynapse("mps", (in_dim, out_dim), bond_dim=bond_dim)
55+
advance = MethodProcess("advance") >> mps.advance_state
56+
evolve = MethodProcess("evolve") >> mps.evolve
57+
reset = MethodProcess("reset") >> mps.reset
58+
59+
# Target: Map x to target_y
60+
x = random.normal(dkey, (1, in_dim))
61+
target_y = random.normal(dkey, (1, out_dim))
62+
63+
# 1. Initial prediction
64+
mps.inputs.set(x)
65+
advance.run()
66+
initial_y = mps.outputs.get()
67+
initial_error = jnp.sum((target_y - initial_y)**2)
68+
69+
# 2. Learning step
70+
# Hebbian learning expects pre and post
71+
# error = target - predicted
72+
error_signal = (target_y - initial_y)
73+
74+
mps.pre.set(x)
75+
mps.post.set(error_signal)
76+
77+
# Run learning multiple times to see descent
78+
for _ in range(5):
79+
evolve.run(eta=0.1)
80+
# Re-predict
81+
advance.run()
82+
current_y = mps.outputs.get()
83+
current_error = jnp.sum((target_y - current_y)**2)
84+
mps.post.set(target_y - current_y) # Update error signal for next step
85+
86+
final_y = mps.outputs.get()
87+
final_error = jnp.sum((target_y - final_y)**2)
88+
89+
print(f"Initial Error: {initial_error:.6f}")
90+
print(f"Final Error: {final_error:.6f}")
91+
92+
assert final_error < initial_error
93+
print("MPS Learning Test Passed.")
94+
95+
if __name__ == "__main__":
96+
test_mps_synapse_reconstruction()
97+
test_mps_synapse_learning()

0 commit comments

Comments
 (0)