Skip to content

Commit a0cbbf4

Browse files
committed
transformerless_lm: sample_text — use BEST-val checkpoint, not final
Per the user's observation: substrate-aligned models snap to discrete Fibonacci-tier attractor configurations during training, so the loss curve jumps between attractor states rather than monotonically descending. Standard val-at-final-step measurement understates the quality of the best attractor the model visited. sample_text.py now tracks the best-val checkpoint across all evaluation points (every 200 steps), saves the state_dict at that point, and loads it before generating. The text sample comes from the BEST attractor configuration, not whatever attractor the model happens to be sitting at when training stops. This is the right way to measure substrate-model quality: the substrate's discrete state space means optimization explores multiple stable configurations, and the deployment-relevant one is the lowest-val of those, not the temporally-last one.
1 parent a7cda55 commit a0cbbf4

1 file changed

Lines changed: 48 additions & 7 deletions

File tree

experiments/transformerless_lm/sample_text.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,58 @@
3030
from lazy_data import fib_positions_in_window, get_fib_strided_batch
3131

3232

33+
def evaluate(model, val_split, batch_size, window, fib_positions, generator,
34+
n_batches=16):
35+
model.eval()
36+
losses = []
37+
with torch.no_grad():
38+
for _ in range(n_batches):
39+
x, y = get_fib_strided_batch(val_split, batch_size, window,
40+
fib_positions, generator)
41+
logits = model(x)
42+
losses.append(F.cross_entropy(
43+
logits.reshape(-1, logits.size(-1)), y.reshape(-1)).item())
44+
model.train()
45+
return sum(losses) / len(losses)
46+
47+
3348
def train(name, model, train_split, val_split, args, fib_positions):
49+
"""Train and return BEST-VAL checkpoint. Substrate models jump between
50+
Fibonacci-attractor configurations during training, so the best val
51+
is rarely at the final step — sample from the best attractor."""
3452
torch.manual_seed(args.seed)
3553
gen = torch.Generator(); gen.manual_seed(args.seed + 1)
3654
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
3755
t0 = time.time()
38-
eval_every = 300
56+
eval_every = 200
3957
print(f"\n[train {name}] params={sum(p.numel() for p in model.parameters()):,}",
4058
flush=True)
59+
best_val = float("inf")
60+
best_state = None
61+
best_step = -1
4162
for step in range(args.steps):
4263
x, y = get_fib_strided_batch(train_split, args.batch_size, args.seq_len,
4364
fib_positions, gen)
4465
logits = model(x)
4566
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1))
4667
optimizer.zero_grad(); loss.backward(); optimizer.step()
4768
if step % eval_every == 0 or step == args.steps - 1:
48-
print(f" step {step:5d} train={loss.item():.4f} "
49-
f"({time.time()-t0:.1f}s)", flush=True)
50-
return model
69+
vl = evaluate(model, val_split, args.batch_size, args.seq_len,
70+
fib_positions, gen)
71+
marker = ""
72+
if vl < best_val:
73+
best_val = vl
74+
best_state = {k: v.clone() for k, v in model.state_dict().items()}
75+
best_step = step
76+
marker = " ← BEST"
77+
print(f" step {step:5d} train={loss.item():.4f} val={vl:.4f}"
78+
f" ({time.time()-t0:.1f}s){marker}", flush=True)
79+
# Load best
80+
if best_state is not None:
81+
model.load_state_dict(best_state)
82+
print(f" → using best checkpoint from step {best_step} val={best_val:.4f}",
83+
flush=True)
84+
return model, best_val, best_step
5185

5286

5387
@torch.no_grad()
@@ -120,16 +154,20 @@ def main():
120154
)
121155

122156
samples = {}
157+
meta = {}
123158
for name, make_fn in archs.items():
124159
model = make_fn()
125-
train(name, model, train_split, val_split, args, fib_positions)
160+
model, best_val, best_step = train(name, model, train_split, val_split,
161+
args, fib_positions)
162+
meta[name] = {"best_val": best_val, "best_step": best_step,
163+
"n_params": sum(p.numel() for p in model.parameters())}
126164
out_ids = generate_text(model, prompt_ids, args.n_new, args.seq_len,
127165
itos, temperature=args.temperature,
128166
top_k=args.top_k)
129167
text = "".join(itos[int(i)] for i in out_ids[0].tolist())
130168
samples[name] = text
131169
print(f"\n{'=' * 70}")
132-
print(f"SAMPLE from {name}:")
170+
print(f"SAMPLE from {name} (best_val={best_val:.4f} @ step {best_step})")
133171
print('=' * 70)
134172
print(text)
135173
print('=' * 70, flush=True)
@@ -141,7 +179,10 @@ def main():
141179
f"top_k={args.top_k})\n")
142180
f.write(f"# Prompt: {args.prompt!r}\n\n")
143181
for name, text in samples.items():
144-
f.write(f"\n{'=' * 70}\n{name}\n{'=' * 70}\n{text}\n")
182+
m = meta[name]
183+
f.write(f"\n{'=' * 70}\n{name} best_val={m['best_val']:.4f} @ step "
184+
f"{m['best_step']} params={m['n_params']:,}\n"
185+
f"{'=' * 70}\n{text}\n")
145186
print(f"\nWrote {out_path}")
146187

147188

0 commit comments

Comments
 (0)