Skip to content

Commit b354c11

Browse files
committed
transformerless_lm: FibGen v2 — extended Fibonacci table + cross-frequency mode
Two follow-ups from the v1 result (100x compression at +19% loss floor): 1. FIBONACCI table extended from 16 to 32 entries. v1 silently clamped K to 16, so the "K=32" arm was actually K=16. Now K can scale. 2. New "cross" generator mode. v1 used separable components where each uses the SAME Fibonacci frequency on both axes, giving rank-K matrices. Cross mode uses INDEPENDENT (F_k_i, F_k_j) frequency pairs, giving K^2 outer-product components for 4*K^2 params per layer. Compression at d=128, n_blocks=4 (vs dense 800K params): K=16 separable: 8K params (100x), 16 separable components per layer K=16 cross: 25K params ( 32x), 256 cross-frequency components K=32 separable: 9K params ( 88x), 32 separable components K=32 cross: 81K params ( 10x), 1024 cross-frequency components Bench sweeps the 6 cells of (K, mode) against dense_crt to find the expressivity/compression Pareto frontier. If cross mode breaks through the +19% loss wall at acceptable compression, the generator-from-seed thesis has a real competitive case for inference.
1 parent 288f201 commit b354c11

2 files changed

Lines changed: 102 additions & 74 deletions

File tree

experiments/transformerless_lm/models_fibgen.py

Lines changed: 68 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -42,44 +42,66 @@
4242
import torch.nn.functional as F
4343

4444

45-
FIBONACCI = [1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597]
45+
# Extended unique-positive Fibonacci table — 32 entries.
46+
# Previous 16-entry version caused K>16 to silently clamp.
47+
FIBONACCI = [
48+
1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987,
49+
1597, 2584, 4181, 6765, 10946, 17711, 28657, 46368, 75025,
50+
121393, 196418, 317811, 514229, 832040, 1346269, 2178309, 3524578,
51+
]
4652

4753

4854
class FibGenLinear(nn.Module):
4955
"""Drop-in replacement for nn.Linear where W is generated from a seed.
5056
57+
Two generator modes:
58+
59+
"separable" (the original): each component uses the SAME Fibonacci
60+
frequency on both axes. Generates rank-K terms.
61+
W[i,j] = Σ_k [a_k cos(F_k·i) cos(F_k·j) + ...]
62+
Seed: 4·K params.
63+
64+
"cross" (new): each component uses INDEPENDENT Fibonacci frequencies
65+
on the two axes. Generates a full K_i × K_j grid of frequency
66+
pairs, so the matrix is a sum of K_i·K_j outer products of
67+
single-frequency 1-D bases.
68+
W[i,j] = Σ_{k_i, k_j} [a_{kk'} cos(F_{k_i}·i) cos(F_{k_j}·j) + ...]
69+
Seed: 4·K² params. Equal expressivity as separable at K_separable = K²,
70+
but with the substrate-canonical Fibonacci-coprime structure that
71+
makes the basis non-degenerate (Fibonacci frequencies are pairwise
72+
substrate-distinguishable).
73+
5174
Args:
5275
in_features: input dim.
5376
out_features: output dim.
54-
K: number of Fibonacci-frequency components in the generator.
55-
Higher K = more capacity, more params. K=16 → 64 params
56-
(vs in·out for a stored matrix).
77+
K: number of Fibonacci frequencies per axis.
78+
mode: "separable" or "cross".
5779
bias: whether to include a learnable bias vector.
58-
init_scale: scales the seed initialization. The generated W has
59-
magnitude ~ init_scale · sqrt(4K), so smaller init_scale
60-
gives smaller initial weights.
80+
init_scale: scales the seed initialization.
6181
"""
6282

6383
def __init__(self, in_features: int, out_features: int, K: int = 16,
84+
mode: str = "separable",
6485
bias: bool = True, init_scale: float = 0.1):
6586
super().__init__()
6687
self.in_features = in_features
6788
self.out_features = out_features
6889
self.K = min(K, len(FIBONACCI))
69-
# Seed: 4 coefficients per Fibonacci component (cc, sc, cs, ss).
90+
if mode not in ("separable", "cross"):
91+
raise ValueError(f"unknown mode: {mode}")
92+
self.mode = mode
93+
n_components = self.K if mode == "separable" else self.K * self.K
7094
self.seed = nn.Parameter(
71-
torch.randn(self.K, 4) * (init_scale / max(1, math.sqrt(self.K)))
95+
torch.randn(n_components, 4) * (init_scale / max(1, math.sqrt(n_components)))
7296
)
7397
if bias:
7498
self.bias = nn.Parameter(torch.zeros(out_features))
7599
else:
76100
self.register_parameter("bias", None)
77-
# Precompute the cos/sin of position·Fibonacci-frequency for both
78-
# axes. These are FIXED — no gradient flows through positions.
101+
# Precompute cos/sin position·Fibonacci-frequency tables.
79102
i_idx = torch.arange(out_features).float()
80103
j_idx = torch.arange(in_features).float()
81104
freqs = torch.tensor(FIBONACCI[:self.K], dtype=torch.float)
82-
# angles: [out, K], [in, K]
83105
a_i = 2 * math.pi * i_idx.unsqueeze(1) * freqs.unsqueeze(0) / max(out_features, 1)
84106
a_j = 2 * math.pi * j_idx.unsqueeze(1) * freqs.unsqueeze(0) / max(in_features, 1)
85107
self.register_buffer("cos_i", torch.cos(a_i)) # [out, K]
@@ -88,21 +110,24 @@ def __init__(self, in_features: int, out_features: int, K: int = 16,
88110
self.register_buffer("sin_j", torch.sin(a_j))
89111

90112
def generate_W(self) -> torch.Tensor:
91-
# seed: [K, 4] → split into 4 [K] tensors.
92-
a, b, c, d = self.seed[:, 0], self.seed[:, 1], self.seed[:, 2], self.seed[:, 3]
93-
# W = sum_k (
94-
# a_k · cos_i[:, k] · cos_j[:, k]^T +
95-
# b_k · sin_i[:, k] · cos_j[:, k]^T +
96-
# c_k · cos_i[:, k] · sin_j[:, k]^T +
97-
# d_k · sin_i[:, k] · sin_j[:, k]^T
98-
# )
99-
# Each term is an [out, in] outer product.
100-
# Compose via einsum: [out, K] · [K] · [K, in] (with the diagonal)
101-
# → [out, in].
102-
W = torch.einsum("ok,k,jk->oj", self.cos_i, a, self.cos_j)
103-
W = W + torch.einsum("ok,k,jk->oj", self.sin_i, b, self.cos_j)
104-
W = W + torch.einsum("ok,k,jk->oj", self.cos_i, c, self.sin_j)
105-
W = W + torch.einsum("ok,k,jk->oj", self.sin_i, d, self.sin_j)
113+
if self.mode == "separable":
114+
a, b, c, d = self.seed[:, 0], self.seed[:, 1], self.seed[:, 2], self.seed[:, 3]
115+
W = torch.einsum("ok,k,jk->oj", self.cos_i, a, self.cos_j)
116+
W = W + torch.einsum("ok,k,jk->oj", self.sin_i, b, self.cos_j)
117+
W = W + torch.einsum("ok,k,jk->oj", self.cos_i, c, self.sin_j)
118+
W = W + torch.einsum("ok,k,jk->oj", self.sin_i, d, self.sin_j)
119+
return W
120+
# mode == "cross": seed shape [K*K, 4], reshape to [K, K, 4]
121+
K = self.K
122+
seed = self.seed.view(K, K, 4)
123+
a, b, c, d = seed[..., 0], seed[..., 1], seed[..., 2], seed[..., 3]
124+
# W[i,j] = Σ_{k_i, k_j} [a · cos_i[i, k_i] cos_j[j, k_j] + ...]
125+
# einsum: cos_i [out, k_i] @ a [k_i, k_j] -> [out, k_j], then
126+
# · cos_j [in, k_j] -> [out, in].
127+
W = torch.einsum("ol,lm,jm->oj", self.cos_i, a, self.cos_j)
128+
W = W + torch.einsum("ol,lm,jm->oj", self.sin_i, b, self.cos_j)
129+
W = W + torch.einsum("ol,lm,jm->oj", self.cos_i, c, self.sin_j)
130+
W = W + torch.einsum("ol,lm,jm->oj", self.sin_i, d, self.sin_j)
106131
return W
107132

108133
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -127,11 +152,11 @@ def n_dense_equivalent_params(self) -> int:
127152
class FibGenAttention(nn.Module):
128153
"""Single-head self-attention with all linear layers FibGen-generated."""
129154

130-
def __init__(self, d_model: int, K: int = 16):
155+
def __init__(self, d_model: int, K: int = 16, mode: str = "separable"):
131156
super().__init__()
132157
self.d_model = d_model
133-
self.qkv = FibGenLinear(d_model, 3 * d_model, K=K)
134-
self.out = FibGenLinear(d_model, d_model, K=K)
158+
self.qkv = FibGenLinear(d_model, 3 * d_model, K=K, mode=mode)
159+
self.out = FibGenLinear(d_model, d_model, K=K, mode=mode)
135160

136161
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
137162
B, T, D = x.shape
@@ -148,21 +173,22 @@ def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
148173
class FibGenFeedForward(nn.Module):
149174
"""FFN with FibGen-generated linear layers."""
150175

151-
def __init__(self, d_model: int, expansion: int = 4, K: int = 16):
176+
def __init__(self, d_model: int, expansion: int = 4, K: int = 16,
177+
mode: str = "separable"):
152178
super().__init__()
153179
d_inner = d_model * expansion
154-
self.w1 = FibGenLinear(d_model, d_inner, K=K)
155-
self.w2 = FibGenLinear(d_inner, d_model, K=K)
180+
self.w1 = FibGenLinear(d_model, d_inner, K=K, mode=mode)
181+
self.w2 = FibGenLinear(d_inner, d_model, K=K, mode=mode)
156182

157183
def forward(self, x: torch.Tensor) -> torch.Tensor:
158184
return self.w2(F.gelu(self.w1(x)))
159185

160186

161187
class FibGenBlock(nn.Module):
162-
def __init__(self, d_model: int, K: int = 16):
188+
def __init__(self, d_model: int, K: int = 16, mode: str = "separable"):
163189
super().__init__()
164-
self.attn = FibGenAttention(d_model, K=K)
165-
self.ff = FibGenFeedForward(d_model, K=K)
190+
self.attn = FibGenAttention(d_model, K=K, mode=mode)
191+
self.ff = FibGenFeedForward(d_model, K=K, mode=mode)
166192
self.ln1 = nn.LayerNorm(d_model)
167193
self.ln2 = nn.LayerNorm(d_model)
168194

@@ -184,25 +210,20 @@ class FibGenLM(nn.Module):
184210
"""
185211

186212
def __init__(self, vocab_size: int, d_model: int, n_blocks: int,
187-
seq_len: int, K: int = 16):
213+
seq_len: int, K: int = 16, mode: str = "separable"):
188214
super().__init__()
189215
self.seq_len = seq_len
190216
self.K = K
191-
# Embedding implemented as FibGen + index → FibGen produces a
192-
# [vocab, d_model] table that we index into.
193-
self.embed_gen = FibGenLinear(vocab_size, d_model, K=K, bias=False)
194-
# Positional encoding stays CRT-Fibonacci (already substrate-aligned,
195-
# and it's a buffer, not a learned weight).
217+
self.mode = mode
218+
self.embed_gen = FibGenLinear(vocab_size, d_model, K=K, mode=mode,
219+
bias=False)
196220
pe = self._crt_pe(seq_len, d_model)
197221
self.register_buffer("pe", pe)
198222
self.blocks = nn.ModuleList([
199-
FibGenBlock(d_model, K=K) for _ in range(n_blocks)
223+
FibGenBlock(d_model, K=K, mode=mode) for _ in range(n_blocks)
200224
])
201225
self.ln_f = nn.LayerNorm(d_model)
202-
# Head: FibGen too (or tied with embed — but tied with a generator
203-
# means head and embed share the SAME generator seed which forces
204-
# a constraint. Pick untied for now to test capacity.)
205-
self.head = FibGenLinear(d_model, vocab_size, K=K, bias=False)
226+
self.head = FibGenLinear(d_model, vocab_size, K=K, mode=mode, bias=False)
206227
mask = torch.tril(torch.ones(seq_len, seq_len))
207228
self.register_buffer("mask", mask)
208229

experiments/transformerless_lm/train_fibgen.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def main():
106106
parser.add_argument("--distractor-frac", type=float, default=0.20)
107107
parser.add_argument("--K-sweep", type=str, default="8,16,32",
108108
help="Comma-separated K values for FibGen.")
109+
parser.add_argument("--modes", type=str, default="separable,cross",
110+
help="Comma-separated generator modes.")
109111
parser.add_argument("--out", type=str, default="results_fibgen.json")
110112
args = parser.parse_args()
111113

@@ -135,16 +137,20 @@ def make_crt():
135137
results["dense_crt"] = train_one("dense_crt", vocab_size, train_split,
136138
val_split, args, fib_positions, make_crt)
137139

138-
# 2. FibGen at each K
140+
# 2. FibGen at each K x mode
139141
K_values = [int(k) for k in args.K_sweep.split(",")]
140-
for K in K_values:
141-
def make_fibgen(K=K):
142-
return FibGenLM(vocab_size=vocab_size, d_model=args.d_model,
143-
n_blocks=args.n_blocks, seq_len=args.seq_len, K=K)
144-
results[f"fibgen_K{K}"] = train_one(
145-
f"fibgen_K{K}", vocab_size, train_split, val_split, args,
146-
fib_positions, make_fibgen,
147-
)
142+
modes = [m.strip() for m in args.modes.split(",")]
143+
for mode in modes:
144+
for K in K_values:
145+
def make_fibgen(K=K, mode=mode):
146+
return FibGenLM(vocab_size=vocab_size, d_model=args.d_model,
147+
n_blocks=args.n_blocks, seq_len=args.seq_len,
148+
K=K, mode=mode)
149+
name = f"fibgen_K{K}_{mode}"
150+
results[name] = train_one(
151+
name, vocab_size, train_split, val_split, args,
152+
fib_positions, make_fibgen,
153+
)
148154

149155
# Summary
150156
print()
@@ -166,24 +172,25 @@ def make_fibgen(K=K):
166172

167173
# Verdict
168174
base_val = results["dense_crt"]["final_val"]
169-
print(f"VERDICT (uniform-random floor: {uniform_floor:.4f}, dense_crt: {base_val:.4f}):")
170-
for K in K_values:
171-
r = results[f"fibgen_K{K}"]
172-
if r["final_val"] < uniform_floor * 0.85:
173-
tag = "LEARNED (≤85% of uniform floor)"
174-
elif r["final_val"] < uniform_floor * 0.95:
175-
tag = "WEAK LEARNING"
176-
else:
177-
tag = "FAILED (near uniform-random)"
178-
# Compute compression
179-
dense_eq = 0
180-
stored = 0
181-
m = FibGenLM(vocab_size=vocab_size, d_model=args.d_model,
182-
n_blocks=args.n_blocks, seq_len=args.seq_len, K=K)
183-
ss = m.storage_summary()
184-
compr = ss["compression"]
185-
print(f" K={K:>3}: val={r['final_val']:.4f} "
186-
f"compression={compr:.1f}x → {tag}")
175+
print(f"VERDICT (uniform-random floor: {uniform_floor:.4f}, "
176+
f"dense_crt: {base_val:.4f}):")
177+
for mode in modes:
178+
for K in K_values:
179+
r = results[f"fibgen_K{K}_{mode}"]
180+
if r["final_val"] < uniform_floor * 0.85:
181+
tag = "LEARNED"
182+
elif r["final_val"] < uniform_floor * 0.95:
183+
tag = "WEAK LEARNING"
184+
else:
185+
tag = "FAILED"
186+
m = FibGenLM(vocab_size=vocab_size, d_model=args.d_model,
187+
n_blocks=args.n_blocks, seq_len=args.seq_len,
188+
K=K, mode=mode)
189+
ss = m.storage_summary()
190+
gap_pct = (r["final_val"] - base_val) / base_val * 100
191+
print(f" K={K:>3} mode={mode:<10}: val={r['final_val']:.4f} "
192+
f"compr={ss['compression']:5.1f}x vs_dense={gap_pct:+5.1f}% "
193+
f"→ {tag}")
187194

188195
out_path = Path(__file__).parent / args.out
189196
with open(out_path, "w") as f:

0 commit comments

Comments
 (0)