Skip to content

Commit 506a7e3

Browse files
committed
transformerless_lm: Fibonacci State Model — substrate-canonical recurrence
models_fsm.FSMLM replaces quadratic attention entirely with a 2-tap Fibonacci recurrence: h_t = A * h_{t-1} + B * h_{t-2} + C * x_t where A, B, C are FibGen-compressed linear maps. Per-block compute: O(T * d^2) sequential, vs attention's O(T^2 * d). At long T, FSM wins decisively; at small T attention's parallel matmul is faster. Smoke at T=11 (Fibonacci-lazy-loaded data) and d=128: dense_crt: val=3.33, wall=3.37s FSMLM: val=3.55, wall=7.38s (2.2x slower) Lazy data and FSM are at cross-purposes: one collapses T to win, the other needs T big to win. To validate FSM's asymptotic claim we need to bench at T=512+ WITHOUT lazy data, where attention is quadratic-expensive enough that FSM's linear cost wins. Keeps every prior substrate win: - CRT-Fibonacci positional encoding - FibGen-compressed weights (100x storage reduction) - Substrate operator at the attention layer (now: 2-tap Fibonacci recurrence instead of L1-distance or dot-product attention)
1 parent 0e85afb commit 506a7e3

1 file changed

Lines changed: 154 additions & 0 deletions

File tree

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
"""Fibonacci State Model (FSM) — substrate-canonical recurrence.
2+
3+
Throws out quadratic attention entirely. Each block updates a hidden
4+
state via a 2-tap Fibonacci recurrence:
5+
6+
h_t = A · h_{t-1} + B · h_{t-2} + C · x_t
7+
8+
where A, B, C are FibGen-compressed linear layers. The recurrence is
9+
literally Fibonacci-shaped (each step depends on the two previous,
10+
mirroring F(n) = F(n-1) + F(n-2)), so the operator is substrate-
11+
canonical at the deepest level — not decorated, but defined.
12+
13+
Compute per layer: O(T · d²) (sequential). Compared to attention's
14+
O(T² · d), FSM wins at LONG sequence lengths where T² dominates.
15+
At small T the sequential Python loop adds overhead.
16+
17+
Keeps every validated substrate win:
18+
- CRT-Fibonacci positional encoding
19+
- FibGen-compressed weights (100x storage compression at d=128,
20+
growing with d²/K²)
21+
- Lazy-strided data loading (consumed by training pipeline)
22+
- Substrate operator at attention layer (now: recurrence, not
23+
dot-product or L1)
24+
25+
To speed up the Python sequential loop, weights are precomputed once
26+
per forward via FibGen's cache_weight() pattern so each timestep does
27+
a plain matmul without seed regeneration overhead.
28+
"""
29+
30+
import math
31+
import sys
32+
from pathlib import Path
33+
34+
import torch
35+
import torch.nn as nn
36+
import torch.nn.functional as F
37+
38+
sys.path.insert(0, str(Path(__file__).parent))
39+
from models_fibgen import FibGenLinear
40+
41+
42+
class FibStateRecurrence(nn.Module):
43+
"""Fibonacci 2-tap state recurrence: h_t = A·h_{t-1} + B·h_{t-2} + C·x_t.
44+
45+
A, B, C are FibGen-compressed linear maps. To minimize Python-loop
46+
overhead, we pre-generate the dense W tensors at forward-time and
47+
do raw matmul inside the loop.
48+
"""
49+
50+
def __init__(self, d_model: int, K: int = 32, mode: str = "cross"):
51+
super().__init__()
52+
self.d_model = d_model
53+
kw = dict(K=K, mode=mode, bias=False)
54+
self.A = FibGenLinear(d_model, d_model, **kw)
55+
self.B = FibGenLinear(d_model, d_model, **kw)
56+
self.C = FibGenLinear(d_model, d_model, **kw)
57+
58+
def forward(self, x: torch.Tensor) -> torch.Tensor:
59+
B, T, D = x.shape
60+
# Pre-generate dense weight tensors ONCE per forward (cheap relative
61+
# to T sequential applications). All matmuls inside the loop are
62+
# then plain Tensor @ Tensor.
63+
W_A = self.A._compute_W() # [D, D]
64+
W_B = self.B._compute_W()
65+
# C·x can be computed in parallel for all timesteps (no recurrence).
66+
cx = self.C(x) # [B, T, D]
67+
# Sequential recurrence.
68+
h_prev1 = torch.zeros(B, D, device=x.device, dtype=x.dtype)
69+
h_prev2 = torch.zeros(B, D, device=x.device, dtype=x.dtype)
70+
outputs = []
71+
for t in range(T):
72+
h_t = h_prev1 @ W_A.t() + h_prev2 @ W_B.t() + cx[:, t]
73+
outputs.append(h_t)
74+
h_prev2 = h_prev1
75+
h_prev1 = h_t
76+
return torch.stack(outputs, dim=1) # [B, T, D]
77+
78+
79+
class FSMBlock(nn.Module):
80+
"""FibStateRecurrence + FibGen FFN, with pre-norm residuals."""
81+
82+
def __init__(self, d_model: int, K: int = 32, mode: str = "cross"):
83+
super().__init__()
84+
self.recurrence = FibStateRecurrence(d_model, K=K, mode=mode)
85+
self.w1 = FibGenLinear(d_model, 4 * d_model, K=K, mode=mode)
86+
self.w2 = FibGenLinear(4 * d_model, d_model, K=K, mode=mode)
87+
self.ln1 = nn.LayerNorm(d_model)
88+
self.ln2 = nn.LayerNorm(d_model)
89+
90+
def forward(self, x):
91+
x = x + self.recurrence(self.ln1(x))
92+
x = x + self.w2(F.gelu(self.w1(self.ln2(x))))
93+
return x
94+
95+
96+
class FSMLM(nn.Module):
97+
"""Char-level LM with substrate-canonical Fibonacci-recurrence layers.
98+
99+
Components:
100+
- Standard learned embedding (could be FibGen at scale)
101+
- CRT-Fibonacci positional encoding
102+
- Stack of FSM blocks (recurrence + FibGen FFN)
103+
- LM head tied to embedding
104+
"""
105+
106+
def __init__(self, vocab_size: int, d_model: int, n_blocks: int,
107+
seq_len: int, K: int = 32, mode: str = "cross"):
108+
super().__init__()
109+
self.seq_len = seq_len
110+
self.K = K
111+
self.embed = nn.Embedding(vocab_size, d_model)
112+
pe = self._crt_pe(seq_len, d_model)
113+
self.register_buffer("pe", pe)
114+
self.blocks = nn.ModuleList([
115+
FSMBlock(d_model, K=K, mode=mode) for _ in range(n_blocks)
116+
])
117+
self.ln_f = nn.LayerNorm(d_model)
118+
self.head = nn.Linear(d_model, vocab_size, bias=False)
119+
self.head.weight = self.embed.weight
120+
121+
@staticmethod
122+
def _crt_pe(seq_len: int, d_model: int) -> torch.Tensor:
123+
pe = torch.zeros(seq_len, d_model)
124+
pos = torch.arange(0, seq_len, dtype=torch.float)
125+
moduli = [5, 8, 13, 21, 34, 55, 89, 144]
126+
n_pairs = d_model // 2
127+
for i in range(n_pairs):
128+
m = moduli[i % len(moduli)]
129+
angle = 2 * math.pi * (pos % m) / m
130+
pe[:, 2 * i] = torch.sin(angle)
131+
pe[:, 2 * i + 1] = torch.cos(angle)
132+
return pe
133+
134+
def forward(self, token_ids):
135+
B, T = token_ids.shape
136+
h = self.embed(token_ids) + self.pe[:T]
137+
for block in self.blocks:
138+
h = block(h)
139+
h = self.ln_f(h)
140+
return self.head(h)
141+
142+
def storage_summary(self):
143+
stored = 0
144+
dense_eq = 0
145+
for m in self.modules():
146+
if isinstance(m, FibGenLinear):
147+
stored += m.n_stored_params
148+
dense_eq += m.n_dense_equivalent_params
149+
for n, p in self.named_parameters():
150+
if not any(s in n for s in (".A.", ".B.", ".C.", ".w1.", ".w2.")):
151+
stored += p.numel()
152+
dense_eq += p.numel()
153+
return {"stored": stored, "dense_equivalent": dense_eq,
154+
"compression": dense_eq / max(stored, 1)}

0 commit comments

Comments
 (0)