Skip to content

Commit a7cda55

Browse files
committed
transformerless_lm: text sampling bench — does FibGen produce usable output?
Loss numbers are an abstract proxy. The deployment-meaningful question is whether the +5-7% val loss penalty on FibGen and composed transformerless translates to barely-perceptible quality difference or to broken output text. sample_text.py: - trains dense_crt, fibgen_K32_cross, and composed_transformerless on TinyShakespeare with lazy-loading - generates 400 characters per arch from a fixed Shakespeare prompt - temperature sampling with top-k for readable output - prints all three side-by-side so quality is eyeball-comparable If the FibGen samples are coherent and stylistically Shakespeare-ish, the 90% throughput / 37x less memory result becomes a deployable substrate-native LLM. If they read as gibberish or noticeably worse, we know where the compression ceiling lies for end-user quality.
1 parent c6d352e commit a7cda55

1 file changed

Lines changed: 149 additions & 0 deletions

File tree

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""Sample text generation from trained models.
2+
3+
Loss numbers are abstract. Actual generated text is the deployment-meaningful
4+
quality signal: does a +5-7% val-loss penalty translate to barely-perceptible
5+
output or to broken text?
6+
7+
Trains dense_crt vs fibgen_K32_cross vs composed_transformerless on
8+
TinyShakespeare with lazy-loading, then generates a sample from a fixed
9+
prompt for each. Greedy decoding by default; temperature sampling
10+
optional. Output is human-readable so you can eyeball it.
11+
12+
If the FibGen output is coherent and stylistically Shakespeare-ish, the
13+
inference-economics result (90% throughput, 37x less memory) translates
14+
into a deployable model.
15+
"""
16+
17+
import argparse
18+
import sys
19+
import time
20+
from pathlib import Path
21+
22+
import torch
23+
import torch.nn.functional as F
24+
25+
sys.path.insert(0, str(Path(__file__).parent))
26+
from corpus import make_dataset
27+
from models import make_model
28+
from models_fibgen import FibGenLM, FibGenTransformerless
29+
from train_distractor_mix import build_distractor_stream
30+
from lazy_data import fib_positions_in_window, get_fib_strided_batch
31+
32+
33+
def train(name, model, train_split, val_split, args, fib_positions):
34+
torch.manual_seed(args.seed)
35+
gen = torch.Generator(); gen.manual_seed(args.seed + 1)
36+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
37+
t0 = time.time()
38+
eval_every = 300
39+
print(f"\n[train {name}] params={sum(p.numel() for p in model.parameters()):,}",
40+
flush=True)
41+
for step in range(args.steps):
42+
x, y = get_fib_strided_batch(train_split, args.batch_size, args.seq_len,
43+
fib_positions, gen)
44+
logits = model(x)
45+
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1))
46+
optimizer.zero_grad(); loss.backward(); optimizer.step()
47+
if step % eval_every == 0 or step == args.steps - 1:
48+
print(f" step {step:5d} train={loss.item():.4f} "
49+
f"({time.time()-t0:.1f}s)", flush=True)
50+
return model
51+
52+
53+
@torch.no_grad()
54+
def generate_text(model, prompt_ids, n_new, seq_len, itos,
55+
temperature: float = 1.0, top_k: int = None):
56+
model.eval()
57+
out = prompt_ids.clone()
58+
for _ in range(n_new):
59+
ctx = out[:, -seq_len:]
60+
logits = model(ctx)[:, -1, :] / max(temperature, 1e-6)
61+
if top_k is not None:
62+
v, _ = logits.topk(top_k)
63+
logits[logits < v[..., -1:]] = float("-inf")
64+
if temperature <= 1e-3:
65+
next_id = logits.argmax(dim=-1, keepdim=True)
66+
else:
67+
probs = F.softmax(logits, dim=-1)
68+
next_id = torch.multinomial(probs, num_samples=1)
69+
out = torch.cat([out, next_id], dim=-1)
70+
return out
71+
72+
73+
def main():
74+
parser = argparse.ArgumentParser()
75+
parser.add_argument("--steps", type=int, default=1500)
76+
parser.add_argument("--batch-size", type=int, default=32)
77+
parser.add_argument("--seq-len", type=int, default=128)
78+
parser.add_argument("--d-model", type=int, default=128)
79+
parser.add_argument("--n-blocks", type=int, default=4)
80+
parser.add_argument("--lr", type=float, default=3e-4)
81+
parser.add_argument("--seed", type=int, default=42)
82+
parser.add_argument("--distractor-frac", type=float, default=0.20)
83+
parser.add_argument("--prompt", type=str,
84+
default="ROMEO:\nWhat light through")
85+
parser.add_argument("--n-new", type=int, default=400,
86+
help="Number of new characters to generate.")
87+
parser.add_argument("--temperature", type=float, default=0.8)
88+
parser.add_argument("--top-k", type=int, default=10)
89+
parser.add_argument("--out", type=str, default="results_samples.txt")
90+
args = parser.parse_args()
91+
92+
chars, stoi, itos, encoded = make_dataset(seq_len=args.seq_len,
93+
source="tinyshakespeare")
94+
vocab_size = len(chars)
95+
train_split, val_split = build_distractor_stream(
96+
encoded, args.distractor_frac, args.seq_len, args.seed,
97+
)
98+
fib_positions = fib_positions_in_window(args.seq_len)
99+
100+
# Build the three archs
101+
archs = {
102+
"dense_crt": lambda: make_model(
103+
"crt_only", vocab_size=vocab_size, seq_len=args.seq_len,
104+
d_model=args.d_model, n_blocks=args.n_blocks,
105+
),
106+
"fibgen_K32_cross": lambda: FibGenLM(
107+
vocab_size=vocab_size, d_model=args.d_model,
108+
n_blocks=args.n_blocks, seq_len=args.seq_len, K=32, mode="cross",
109+
),
110+
"composed_transformerless": lambda: FibGenTransformerless(
111+
vocab_size=vocab_size, d_model=args.d_model, n_blocks=args.n_blocks,
112+
seq_len=args.seq_len, K=32, mode="cross", n_specialists=5,
113+
),
114+
}
115+
116+
# Encode prompt (handle unknown chars by mapping to space)
117+
space_id = stoi.get(" ", 0)
118+
prompt_ids = torch.tensor(
119+
[[stoi.get(c, space_id) for c in args.prompt]], dtype=torch.long,
120+
)
121+
122+
samples = {}
123+
for name, make_fn in archs.items():
124+
model = make_fn()
125+
train(name, model, train_split, val_split, args, fib_positions)
126+
out_ids = generate_text(model, prompt_ids, args.n_new, args.seq_len,
127+
itos, temperature=args.temperature,
128+
top_k=args.top_k)
129+
text = "".join(itos[int(i)] for i in out_ids[0].tolist())
130+
samples[name] = text
131+
print(f"\n{'=' * 70}")
132+
print(f"SAMPLE from {name}:")
133+
print('=' * 70)
134+
print(text)
135+
print('=' * 70, flush=True)
136+
137+
# Write to file
138+
out_path = Path(__file__).parent / args.out
139+
with open(out_path, "w") as f:
140+
f.write(f"# Samples (steps={args.steps}, temperature={args.temperature}, "
141+
f"top_k={args.top_k})\n")
142+
f.write(f"# Prompt: {args.prompt!r}\n\n")
143+
for name, text in samples.items():
144+
f.write(f"\n{'=' * 70}\n{name}\n{'=' * 70}\n{text}\n")
145+
print(f"\nWrote {out_path}")
146+
147+
148+
if __name__ == "__main__":
149+
main()

0 commit comments

Comments
 (0)