Skip to content

Commit 966f4a6

Browse files
committed
transformerless_lm: lazy training applied to FibGen seed components
Two substrate-aligned implementations of "lazy loading at training": (1) LAZY_DROPOUT (lazy_tier_dropout=True): Bernoulli mask on each seed component with keep_prob = 1/sqrt(tier). Tier-1 components (smallest Fibonacci index, k_i=k_j=0) are always active; tier-k components active 1/sqrt(k) of steps. Eval rescales by keep_prob so train E[output] matches eval output. (2) TIER_LR_SCALE (apply_tier_lr_scale post-backward): Keep all components in the forward, but scale each component's gradient by 1/sqrt(tier) before optimizer.step(). Low-tier components learn fast; high-tier learn slowly. Deterministic, no train/eval mismatch. train_lazy_subsim.py runs a 3-arm bench at d=128: subsim_baseline - vanilla Subsim (the +5.7% gap to dense) subsim_lazy_dropout - variant (1) subsim_tier_lr - variant (2) Both lazy variants are designed to make the trained model PRUNABLE post-hoc: prune high-tier components first, since they carry less learned signal. This delivers the "35B in 8GB" framing through deployment-time component pruning. 100-step smoke: lazy_dropout converges but ~2x slower per val unit (harder optimization with stochastic masking). tier_lr_scale untested in the smoke; expected to be smoother because it's deterministic. Bench will reveal which variant best balances training stability with post-training pruneability.
1 parent bfbf691 commit 966f4a6

3 files changed

Lines changed: 246 additions & 14 deletions

File tree

experiments/transformerless_lm/models_fibgen.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,18 +87,43 @@ class FibGenLinear(nn.Module):
8787

8888
def __init__(self, in_features: int, out_features: int, K: int = 16,
8989
mode: str = "separable",
90-
bias: bool = True, init_scale: float = 0.1):
90+
bias: bool = True, init_scale: float = 0.1,
91+
lazy_tier_dropout: bool = False):
9192
super().__init__()
9293
self.in_features = in_features
9394
self.out_features = out_features
9495
self.K = min(K, len(FIBONACCI))
9596
if mode not in ("separable", "cross"):
9697
raise ValueError(f"unknown mode: {mode}")
9798
self.mode = mode
99+
self.lazy_tier_dropout = lazy_tier_dropout
98100
n_components = self.K if mode == "separable" else self.K * self.K
99101
self.seed = nn.Parameter(
100102
torch.randn(n_components, 4) * (init_scale / max(1, math.sqrt(n_components)))
101103
)
104+
105+
# Fibonacci tier per seed component, used for lazy-tier dropout.
106+
# Lower tier = more important = active more often.
107+
if mode == "separable":
108+
# Component k → tier (k+1). F(tier) = Fibonacci number.
109+
tiers_int = [i + 1 for i in range(self.K)]
110+
else:
111+
# Cross-mode pair (k_i, k_j) → tier max(k_i, k_j) + 1.
112+
# Pair (0, 0) is tier 1 (most important, always active).
113+
# Pair (31, 31) is tier 32 (rarely active under 1/F(32) probability).
114+
tiers_int = [max(k_i, k_j) + 1
115+
for k_i in range(self.K) for k_j in range(self.K)]
116+
# Two substrate-aligned schemes available on this buffer:
117+
# (1) lazy_tier_dropout=True -> mask seed via Bernoulli(tier_keep_probs)
118+
# (2) gradient-scale via tier_lr_scale (applied by training loop)
119+
keep_probs = torch.tensor(
120+
[1.0 / math.sqrt(t) for t in tiers_int], dtype=torch.float,
121+
)
122+
self.register_buffer("tier_keep_probs", keep_probs)
123+
# tier-weighted learning rate: low-tier components get full LR, high-tier
124+
# get reduced LR proportional to 1/sqrt(tier). Apply by multiplying
125+
# seed.grad by this buffer BEFORE optimizer.step().
126+
self.register_buffer("tier_lr_scale", keep_probs.unsqueeze(-1))
102127
if bias:
103128
self.bias = nn.Parameter(torch.zeros(out_features))
104129
else:
@@ -151,6 +176,29 @@ def generate_W(self) -> torch.Tensor:
151176
return cached
152177
return self._compute_W()
153178

179+
def _maybe_lazy_seed(self) -> torch.Tensor:
180+
"""Returns the seed (optionally masked by Fibonacci-tier dropout).
181+
182+
Substrate-native lazy LOADING applied to the seed itself:
183+
- Tier 1 components are always active (full participation)
184+
- Tier-k components active with probability 1/sqrt(k)
185+
- Only active components contribute to this step's forward;
186+
only they receive gradient on backward.
187+
188+
Magnitude matching: at training the mask is Bernoulli; at eval
189+
we scale the seed by the per-component keep_prob so the
190+
EXPECTED forward output during training matches the deterministic
191+
forward at eval. This avoids the magnitude crash that pure-mask
192+
without scaling caused.
193+
"""
194+
if not self.lazy_tier_dropout:
195+
return self.seed
196+
if self.training:
197+
mask = torch.bernoulli(self.tier_keep_probs) # [n_components]
198+
return self.seed * mask.unsqueeze(-1)
199+
# eval: deterministic, scaled by keep_prob to match training E[seed]
200+
return self.seed * self.tier_keep_probs.unsqueeze(-1)
201+
154202
def _forward_compressed(self, x: torch.Tensor) -> torch.Tensor:
155203
"""Substrate-native forward: compute y = W·x WITHOUT materializing W.
156204
@@ -166,8 +214,9 @@ def _forward_compressed(self, x: torch.Tensor) -> torch.Tensor:
166214
K-dim projected x, then projected back.
167215
"""
168216
# x: [B, T, in_features]
217+
seed = self._maybe_lazy_seed()
169218
if self.mode == "separable":
170-
a, b, c, d = self.seed[:, 0], self.seed[:, 1], self.seed[:, 2], self.seed[:, 3]
219+
a, b, c, d = seed[:, 0], seed[:, 1], seed[:, 2], seed[:, 3]
171220
# Project x into Fibonacci-basis along input axis: [B, T, K]
172221
x_cos = x @ self.cos_j # [B, T, K]
173222
x_sin = x @ self.sin_j # [B, T, K]
@@ -185,8 +234,8 @@ def _forward_compressed(self, x: torch.Tensor) -> torch.Tensor:
185234
return y
186235
# cross mode: seed [K, K, 4] mixing matrix
187236
K = self.K
188-
seed = self.seed.view(K, K, 4)
189-
a, b, c, d = seed[..., 0], seed[..., 1], seed[..., 2], seed[..., 3]
237+
seed_cross = seed.view(K, K, 4)
238+
a, b, c, d = seed_cross[..., 0], seed_cross[..., 1], seed_cross[..., 2], seed_cross[..., 3]
190239
x_cos = x @ self.cos_j # [B, T, K]
191240
x_sin = x @ self.sin_j
192241
# K×K mixing in seed space:

experiments/transformerless_lm/models_subsim.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,16 @@ class SubstrateSimilarityAttention(nn.Module):
4949
"""
5050

5151
def __init__(self, d_model: int, K: int = 32, seq_len: int = 128,
52-
fibgen_K: int = 32, mode: str = "cross"):
52+
fibgen_K: int = 32, mode: str = "cross",
53+
lazy_tier_dropout: bool = False):
5354
super().__init__()
5455
self.d_model = d_model
5556
self.K = K
56-
self.W_sig = FibGenLinear(d_model, K, K=fibgen_K, mode=mode, bias=False)
57-
self.W_v = FibGenLinear(d_model, d_model, K=fibgen_K, mode=mode, bias=False)
58-
self.W_out = FibGenLinear(d_model, d_model, K=fibgen_K, mode=mode, bias=False)
57+
kw = dict(K=fibgen_K, mode=mode, bias=False,
58+
lazy_tier_dropout=lazy_tier_dropout)
59+
self.W_sig = FibGenLinear(d_model, K, **kw)
60+
self.W_v = FibGenLinear(d_model, d_model, **kw)
61+
self.W_out = FibGenLinear(d_model, d_model, **kw)
5962
# Standard causal mask; substrate-distance attention is dense in
6063
# principle. Could also use Fibonacci-offset mask for sparsity.
6164
mask = torch.tril(torch.ones(seq_len, seq_len))
@@ -81,13 +84,16 @@ class SubsimBlock(nn.Module):
8184
"""Substrate-similarity attention + FibGen FFN."""
8285

8386
def __init__(self, d_model: int, seq_len: int, K: int = 32,
84-
fibgen_K: int = 32, mode: str = "cross"):
87+
fibgen_K: int = 32, mode: str = "cross",
88+
lazy_tier_dropout: bool = False):
8589
super().__init__()
8690
self.attn = SubstrateSimilarityAttention(d_model, K=K, seq_len=seq_len,
87-
fibgen_K=fibgen_K, mode=mode)
91+
fibgen_K=fibgen_K, mode=mode,
92+
lazy_tier_dropout=lazy_tier_dropout)
8893
# FFN with FibGen weights (separate K for FFN if desired)
89-
self.w1 = FibGenLinear(d_model, 4 * d_model, K=fibgen_K, mode=mode)
90-
self.w2 = FibGenLinear(4 * d_model, d_model, K=fibgen_K, mode=mode)
94+
kw = dict(K=fibgen_K, mode=mode, lazy_tier_dropout=lazy_tier_dropout)
95+
self.w1 = FibGenLinear(d_model, 4 * d_model, **kw)
96+
self.w2 = FibGenLinear(4 * d_model, d_model, **kw)
9197
self.ln1 = nn.LayerNorm(d_model)
9298
self.ln2 = nn.LayerNorm(d_model)
9399

@@ -108,15 +114,16 @@ class SubsimLM(nn.Module):
108114

109115
def __init__(self, vocab_size: int, d_model: int, n_blocks: int,
110116
seq_len: int, K: int = 32, fibgen_K: int = 32,
111-
mode: str = "cross"):
117+
mode: str = "cross", lazy_tier_dropout: bool = False):
112118
super().__init__()
113119
self.seq_len = seq_len
114120
self.K = K
115121
self.embed = nn.Embedding(vocab_size, d_model)
116122
pe = self._crt_pe(seq_len, d_model)
117123
self.register_buffer("pe", pe)
118124
self.blocks = nn.ModuleList([
119-
SubsimBlock(d_model, seq_len, K=K, fibgen_K=fibgen_K, mode=mode)
125+
SubsimBlock(d_model, seq_len, K=K, fibgen_K=fibgen_K, mode=mode,
126+
lazy_tier_dropout=lazy_tier_dropout)
120127
for _ in range(n_blocks)
121128
])
122129
self.ln_f = nn.LayerNorm(d_model)
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
"""Lazy training applied to FibGen seed components.
2+
3+
Two substrate-aligned variants tested:
4+
5+
(1) LAZY_DROPOUT: Bernoulli mask on each FibGen seed component.
6+
keep_prob = 1/sqrt(tier) so low-tier (small Fibonacci index)
7+
components active near-always, high-tier components active
8+
stochastically. Eval rescales by keep_prob to match expected
9+
training magnitudes. This is "lazy loading at the seed level":
10+
each step uses only a substrate-defined subset of components.
11+
12+
(2) TIER_LR_SCALE: keep all components active in the forward, but
13+
scale each component's GRADIENT by 1/sqrt(tier) before
14+
optimizer.step(). Low-tier components learn fast (full LR),
15+
high-tier learn slowly. Over training, low-tier components
16+
accumulate more signal. Deterministic, no train/eval mismatch.
17+
18+
Both share the substrate intent ("fold to respected tier") but
19+
differ in implementation. We also include the pure-baseline Subsim
20+
for direct comparison.
21+
22+
The deployment payoff (orthogonal to which training scheme wins):
23+
post-training, prune high-tier components and measure perplexity
24+
loss. The lazy-trained model should prune more gracefully because
25+
high-tier components were either inactive (variant 1) or had small
26+
learned magnitudes (variant 2).
27+
"""
28+
29+
import argparse
30+
import json
31+
import sys
32+
import time
33+
from pathlib import Path
34+
35+
import torch
36+
import torch.nn.functional as F
37+
38+
sys.path.insert(0, str(Path(__file__).parent))
39+
from corpus import make_dataset
40+
from models import make_model
41+
from models_subsim import SubsimLM
42+
from models_fibgen import FibGenLinear
43+
from train_distractor_mix import build_distractor_stream
44+
from lazy_data import fib_positions_in_window, get_fib_strided_batch
45+
46+
47+
def evaluate(model, val_split, batch_size, window, fib_positions, generator,
48+
n_batches=16):
49+
model.eval()
50+
losses = []
51+
with torch.no_grad():
52+
for _ in range(n_batches):
53+
x, y = get_fib_strided_batch(val_split, batch_size, window,
54+
fib_positions, generator)
55+
logits = model(x)
56+
losses.append(F.cross_entropy(
57+
logits.reshape(-1, logits.size(-1)), y.reshape(-1)).item())
58+
model.train()
59+
return sum(losses) / len(losses)
60+
61+
62+
def apply_tier_lr_scale(model: torch.nn.Module):
63+
"""For each FibGenLinear, multiply seed.grad by tier_lr_scale.
64+
Tier-1 components get full grad; tier-k get grad * 1/sqrt(k)."""
65+
for m in model.modules():
66+
if isinstance(m, FibGenLinear) and m.seed.grad is not None:
67+
m.seed.grad.mul_(m.tier_lr_scale)
68+
69+
70+
def train_one(name, model, train_split, val_split, args, fib_positions,
71+
apply_lr_scale: bool = False):
72+
torch.manual_seed(args.seed)
73+
gen = torch.Generator(); gen.manual_seed(args.seed + 1)
74+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
75+
n_params = sum(p.numel() for p in model.parameters())
76+
print(f"\n[train {name}] params={n_params:,} "
77+
f"apply_lr_scale={apply_lr_scale}", flush=True)
78+
t0 = time.time()
79+
best_val = float("inf")
80+
best_step = -1
81+
eval_every = 200
82+
val_hist = []
83+
for step in range(args.steps):
84+
x, y = get_fib_strided_batch(train_split, args.batch_size, args.seq_len,
85+
fib_positions, gen)
86+
logits = model(x)
87+
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1))
88+
optimizer.zero_grad(); loss.backward()
89+
if apply_lr_scale:
90+
apply_tier_lr_scale(model)
91+
optimizer.step()
92+
if step % eval_every == 0 or step == args.steps - 1:
93+
vl = evaluate(model, val_split, args.batch_size, args.seq_len,
94+
fib_positions, gen)
95+
val_hist.append((step, vl, time.time() - t0))
96+
marker = ""
97+
if vl < best_val:
98+
best_val = vl
99+
best_step = step
100+
marker = " ← BEST"
101+
print(f" step {step:5d} val={vl:.4f} ({time.time()-t0:.1f}s){marker}",
102+
flush=True)
103+
return {"name": name, "n_params": n_params, "best_val": best_val,
104+
"best_step": best_step, "wall_time": time.time() - t0,
105+
"val_history": val_hist}
106+
107+
108+
def main():
109+
parser = argparse.ArgumentParser()
110+
parser.add_argument("--steps", type=int, default=2500)
111+
parser.add_argument("--batch-size", type=int, default=32)
112+
parser.add_argument("--seq-len", type=int, default=128)
113+
parser.add_argument("--d-model", type=int, default=128)
114+
parser.add_argument("--n-blocks", type=int, default=4)
115+
parser.add_argument("--lr", type=float, default=3e-4)
116+
parser.add_argument("--seed", type=int, default=42)
117+
parser.add_argument("--distractor-frac", type=float, default=0.20)
118+
parser.add_argument("--out", type=str, default="results_lazy_subsim.json")
119+
args = parser.parse_args()
120+
121+
chars, stoi, itos, encoded = make_dataset(seq_len=args.seq_len,
122+
source="tinyshakespeare")
123+
vocab_size = len(chars)
124+
train_split, val_split = build_distractor_stream(
125+
encoded, args.distractor_frac, args.seq_len, args.seed,
126+
)
127+
fib_positions = fib_positions_in_window(args.seq_len)
128+
129+
results = {}
130+
131+
# 1. Baseline Subsim (no lazy)
132+
m = SubsimLM(vocab_size=vocab_size, d_model=args.d_model,
133+
n_blocks=args.n_blocks, seq_len=args.seq_len,
134+
K=32, fibgen_K=32, mode="cross",
135+
lazy_tier_dropout=False)
136+
results["subsim_baseline"] = train_one(
137+
"subsim_baseline", m, train_split, val_split, args, fib_positions,
138+
)
139+
140+
# 2. Subsim + lazy seed dropout
141+
m = SubsimLM(vocab_size=vocab_size, d_model=args.d_model,
142+
n_blocks=args.n_blocks, seq_len=args.seq_len,
143+
K=32, fibgen_K=32, mode="cross",
144+
lazy_tier_dropout=True)
145+
results["subsim_lazy_dropout"] = train_one(
146+
"subsim_lazy_dropout", m, train_split, val_split, args, fib_positions,
147+
)
148+
149+
# 3. Subsim + tier-weighted gradient scaling
150+
m = SubsimLM(vocab_size=vocab_size, d_model=args.d_model,
151+
n_blocks=args.n_blocks, seq_len=args.seq_len,
152+
K=32, fibgen_K=32, mode="cross",
153+
lazy_tier_dropout=False)
154+
results["subsim_tier_lr"] = train_one(
155+
"subsim_tier_lr", m, train_split, val_split, args, fib_positions,
156+
apply_lr_scale=True,
157+
)
158+
159+
# Summary
160+
print()
161+
print("=" * 84)
162+
print(f"{'config':<24} {'params':>10} {'best_val':>10} {'best_step':>10} "
163+
f"{'wall':>10}")
164+
print("-" * 84)
165+
for name, r in results.items():
166+
print(f"{name:<24} {r['n_params']:>10,} {r['best_val']:>10.4f} "
167+
f"{r['best_step']:>10} {r['wall_time']:>9.1f}s")
168+
169+
out_path = Path(__file__).parent / args.out
170+
with open(out_path, "w") as f:
171+
json.dump(results, f, indent=2, default=str)
172+
print(f"\nWrote {out_path}")
173+
174+
175+
if __name__ == "__main__":
176+
main()

0 commit comments

Comments
 (0)