Skip to content

Commit 927f75c

Browse files
L1.6: Array<->JIT bridging at the dispatch boundary
Before this commit, the JIT dispatch in omnimcode-cli/src/main.rs returned None for any Value::Array argument, falling through to tree-walk even for fns the codegen had successfully lowered. The harmonic libraries' hot paths all take arrays as input, so the JIT was inactive on the most performance-critical code despite being "registered." The fix: marshal Value::Array (int-only) into a length-prefixed Box<[i64]> with layout `[len, v0, v1, ..., vN]` — matching the stack-frame array layout the dual-band lowerer's NewArray ops already use. The JIT'd function's ArrayLen / ArrayIndex code reads from the marshalled buffer with the same access pattern, so no codegen changes were needed. The Box<[i64]>s are pinned for the duration of the call via `_pinned: Vec<Box<[i64]>>`. They drop after .call() returns. The JIT'd code is guaranteed not to retain the pointer beyond the call (verified by the lowerer's stack-local array discipline). Empirical: sum_array(arr_range(0, 1000)) over 1000 iterations: tree-walk: 803 ms JIT (this): 7 ms speedup: 115× Tests: 6 new in omnimcode-codegen/tests/jit_array_bridge.rs - sum (simple read pattern) - max (branchy compare-and-update) - mixed args (array + scalar interleaved) - empty array (length 0 doesn't crash) - large array (1000 elements, exceeds typical alloca size) - non-int array rejection (falls through to tree-walk) All pass. No regressions in the existing 41 codegen tests (now 48 total). Read-only contract: the bridge doesn't write back to the original HArray even if the JIT'd fn mutated the buffer. The common case (sum, score, count) is read-only; mutating-array fns return i64 today so output-side bridging is a future extension. Also: experiments/transformerless_lm/train_distractor_mix.py preparing the adversarial-mix scaling test for the CRT-PE + HBit stack (per the README's prescription that hybrid wins on distribution-shift data, not clean training). Runs 3 seeds at 1500 steps with 20% char-shuffled distractor injection. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 6cf6582 commit 927f75c

3 files changed

Lines changed: 517 additions & 3 deletions

File tree

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
"""Adversarial-mix scaling test for the CRT-PE + HBit-gate stack.
2+
3+
The README's transformerless-LM section explicitly predicts that the
4+
`hybrid` arch (CRT-PE + HBit-tension gate) loses to `crt_only` on
5+
clean training data because the gate has nothing useful to gate
6+
against. The architectural prescription:
7+
8+
"OR train with mixed-clean-and-distractor batches so the gate
9+
has something to gate against."
10+
11+
This file builds the distractor-mix corpus and re-runs the three
12+
architectures on it. If the README's prediction is correct, `hybrid`
13+
should now beat `crt_only` on validation loss against the on-distribution
14+
held-out set (because the gate learns to attend to real-text patterns
15+
and skip the distractor patterns during training).
16+
17+
CONSTRUCTION:
18+
- Take TinyShakespeare as the on-distribution corpus
19+
- Build distractors by char-shuffling random windows of the same
20+
corpus (same char distribution, no structural patterns)
21+
- Mix into the training stream at distractor_frac (default 20%)
22+
- Validate on PURE shakespeare (the actual task) so we measure
23+
"does the model learn shakespeare *despite* the noise?"
24+
25+
Hypothesis: `hybrid` wins this regime because the gate's down-
26+
weighting of off-manifold keys helps the model ignore the noise
27+
chunks. If `hybrid` ties or loses, the README's architectural
28+
hypothesis is falsified at this scale.
29+
"""
30+
31+
import argparse
32+
import sys
33+
import time
34+
import statistics
35+
from pathlib import Path
36+
37+
import torch
38+
import torch.nn.functional as F
39+
40+
sys.path.insert(0, str(Path(__file__).parent))
41+
from corpus import make_dataset
42+
from models import make_model
43+
44+
45+
def build_distractor_stream(
46+
encoded: torch.Tensor,
47+
distractor_frac: float,
48+
seq_len: int,
49+
seed: int,
50+
) -> tuple[torch.Tensor, torch.Tensor]:
51+
"""Build a training stream where `distractor_frac` of seq_len-sized
52+
chunks are char-shuffled versions of randomly-drawn windows from
53+
the same corpus. Same char distribution as the original (so the
54+
softmax baseline can't exploit a vocabulary shift); structural
55+
patterns destroyed.
56+
57+
Returns (train_stream, on_dist_val) where:
58+
train_stream is a 1-D tensor with mixed clean + distractor chunks
59+
on_dist_val is the unchanged tail of the input for held-out eval
60+
"""
61+
g = torch.Generator()
62+
g.manual_seed(seed)
63+
n = encoded.numel()
64+
n_train_total = int(n * 0.9)
65+
n_val = n - n_train_total
66+
val_split = encoded[n_train_total:] # PURE shakespeare; not touched
67+
68+
# Build the mixed training stream chunk by chunk.
69+
n_chunks = n_train_total // seq_len
70+
chunks = []
71+
for i in range(n_chunks):
72+
if torch.rand(1, generator=g).item() < distractor_frac:
73+
# Distractor: take a random window, shuffle its chars in-place.
74+
start = torch.randint(0, n_train_total - seq_len, (1,), generator=g).item()
75+
window = encoded[start:start + seq_len].clone()
76+
perm = torch.randperm(seq_len, generator=g)
77+
chunks.append(window[perm])
78+
else:
79+
# Clean: contiguous shakespeare slice.
80+
start = torch.randint(0, n_train_total - seq_len, (1,), generator=g).item()
81+
chunks.append(encoded[start:start + seq_len].clone())
82+
train_stream = torch.cat(chunks)
83+
print(f"Mixed-stream: {len(chunks)} chunks ({seq_len} chars each), "
84+
f"distractor_frac={distractor_frac:.2f}; val on {n_val:,} clean chars")
85+
return train_stream, val_split
86+
87+
88+
def get_batch_split(encoded_split, batch_size: int, seq_len: int, generator):
89+
n = encoded_split.numel()
90+
ix = torch.randint(0, n - seq_len - 1, (batch_size,), generator=generator)
91+
x = torch.stack([encoded_split[i:i + seq_len] for i in ix])
92+
y = torch.stack([encoded_split[i + 1:i + seq_len + 1] for i in ix])
93+
return x, y
94+
95+
96+
def evaluate(model, val_split, batch_size, seq_len, n_batches, generator):
97+
model.eval()
98+
losses = []
99+
with torch.no_grad():
100+
for _ in range(n_batches):
101+
x, y = get_batch_split(val_split, batch_size, seq_len, generator)
102+
logits = model(x)
103+
loss = F.cross_entropy(
104+
logits.reshape(-1, logits.size(-1)),
105+
y.reshape(-1),
106+
)
107+
losses.append(loss.item())
108+
model.train()
109+
return sum(losses) / len(losses)
110+
111+
112+
def train_one(arch, train_split, val_split, vocab_size, args, seed):
113+
torch.manual_seed(seed)
114+
gen = torch.Generator()
115+
gen.manual_seed(seed + 1)
116+
117+
model = make_model(
118+
arch,
119+
vocab_size=vocab_size,
120+
seq_len=args.seq_len,
121+
d_model=args.d_model,
122+
n_blocks=args.n_blocks,
123+
)
124+
n_params = sum(p.numel() for p in model.parameters())
125+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
126+
127+
print(f"\n[arch={arch}] params={n_params:,}", flush=True)
128+
t0 = time.time()
129+
val_history = []
130+
for step in range(args.steps):
131+
x, y = get_batch_split(train_split, args.batch_size, args.seq_len, gen)
132+
logits = model(x)
133+
loss = F.cross_entropy(
134+
logits.reshape(-1, logits.size(-1)),
135+
y.reshape(-1),
136+
)
137+
optimizer.zero_grad()
138+
loss.backward()
139+
optimizer.step()
140+
if step % args.eval_every == 0 or step == args.steps - 1:
141+
tl = loss.item()
142+
vl = evaluate(model, val_split, args.batch_size, args.seq_len, n_batches=16, generator=gen)
143+
val_history.append((step, vl))
144+
elapsed = time.time() - t0
145+
print(f" step {step:5d} train={tl:.4f} val={vl:.4f} ({elapsed:.1f}s)", flush=True)
146+
147+
last_few = val_history[-3:]
148+
final_val = sum(v for _, v in last_few) / len(last_few)
149+
return dict(
150+
arch=arch,
151+
n_params=n_params,
152+
val_history=val_history,
153+
final_val=final_val,
154+
time=time.time() - t0,
155+
)
156+
157+
158+
def main():
159+
parser = argparse.ArgumentParser()
160+
parser.add_argument("--steps", type=int, default=1500)
161+
parser.add_argument("--batch-size", type=int, default=32)
162+
parser.add_argument("--seq-len", type=int, default=128)
163+
parser.add_argument("--d-model", type=int, default=128)
164+
parser.add_argument("--n-blocks", type=int, default=4)
165+
parser.add_argument("--lr", type=float, default=3e-4)
166+
parser.add_argument("--eval-every", type=int, default=100)
167+
parser.add_argument("--seeds", type=str, default="42,7,123")
168+
parser.add_argument("--distractor-frac", type=float, default=0.20,
169+
help="Fraction of training chunks that are char-shuffled.")
170+
args = parser.parse_args()
171+
172+
seeds = [int(s) for s in args.seeds.split(",")]
173+
174+
chars, stoi, itos, encoded = make_dataset(seq_len=args.seq_len, source="tinyshakespeare")
175+
vocab_size = len(chars)
176+
177+
print(f"Corpus: TinyShakespeare ({encoded.numel():,} chars, vocab {vocab_size})")
178+
print(f"Adversarial-mix test: distractor_frac={args.distractor_frac:.2f}")
179+
print(f"Model: d_model={args.d_model}, n_blocks={args.n_blocks}, seq_len={args.seq_len}")
180+
print(f"Training: steps={args.steps}, batch={args.batch_size}, lr={args.lr}, seeds={seeds}", flush=True)
181+
182+
all_results = {arch: [] for arch in ["standard", "crt_only", "hybrid"]}
183+
for seed in seeds:
184+
print(f"\n=========== seed {seed} ===========")
185+
# Build the mixed stream FRESH per seed so seeds are honest.
186+
train_split, val_split = build_distractor_stream(
187+
encoded, args.distractor_frac, args.seq_len, seed,
188+
)
189+
for arch in ["standard", "crt_only", "hybrid"]:
190+
r = train_one(arch, train_split, val_split, vocab_size, args, seed)
191+
all_results[arch].append(r["final_val"])
192+
print(f" [seed {seed}] {arch}: final_val={r['final_val']:.4f}", flush=True)
193+
194+
print()
195+
print("=" * 70)
196+
print(f"{'arch':<12} {'mean_final_val':>16} {'std':>10} {'win_rate':>12}")
197+
print("-" * 70)
198+
base = all_results["standard"]
199+
for arch in ["standard", "crt_only", "hybrid"]:
200+
vals = all_results[arch]
201+
mean = sum(vals) / len(vals)
202+
std = statistics.stdev(vals) if len(vals) > 1 else 0.0
203+
if arch == "standard":
204+
wr = "—"
205+
else:
206+
wins = sum(1 for v, b in zip(vals, base) if v < b)
207+
wr = f"{wins}/{len(vals)}"
208+
print(f"{arch:<12} {mean:>16.4f} {std:>10.4f} {wr:>12}")
209+
print()
210+
base_mean = sum(base) / len(base)
211+
for arch in ["crt_only", "hybrid"]:
212+
vals = all_results[arch]
213+
mean = sum(vals) / len(vals)
214+
rel = (mean - base_mean) / base_mean * 100
215+
verdict = "BETTER" if mean < base_mean else "WORSE"
216+
print(f" {arch:<12} vs standard: {mean - base_mean:+.4f} ({rel:+.1f}%) — {verdict}")
217+
# Also compare hybrid vs crt_only directly — this is the key question.
218+
hyb_mean = sum(all_results["hybrid"]) / len(all_results["hybrid"])
219+
crt_mean = sum(all_results["crt_only"]) / len(all_results["crt_only"])
220+
rel = (hyb_mean - crt_mean) / crt_mean * 100
221+
crt_better = hyb_mean < crt_mean
222+
print(f" hybrid vs crt_only: {hyb_mean - crt_mean:+.4f} ({rel:+.1f}%) — "
223+
f"{'GATE EARNS KEEP' if crt_better else 'GATE STILL COSTS'}")
224+
225+
226+
if __name__ == "__main__":
227+
main()

omnimcode-cli/src/main.rs

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,16 +186,63 @@ fn maybe_register_jit(
186186
if args.len() != jf.arity {
187187
return None;
188188
}
189+
// L1.6: Array↔JIT bridging. Convert each Value::Array (whose
190+
// elements are all HInt) to a length-prefixed Box<[i64]> with
191+
// layout `[len, v0, v1, ..., vN]`. The JIT'd fn was lowered
192+
// assuming NewArray-style alloca layout (slot 0 = length,
193+
// slots 1..=N = elements) so the same access pattern works
194+
// for both internal and external arrays. We hand the raw
195+
// pointer to the JIT as the i64 arg.
196+
//
197+
// The Boxes are held in `_pinned` for the duration of the
198+
// call so the JIT'd code can dereference them safely. Drop
199+
// happens after .call() returns; the JIT'd fn must NOT
200+
// retain the pointer beyond the call (it doesn't — arrays
201+
// are stack-local in the lowered IR).
202+
//
203+
// Read-only contract: we don't write back to the original
204+
// HArray even if the JIT'd fn mutated the buffer. The
205+
// common case (sum, score, count) is read-only; mutating
206+
// array fns currently fall through to tree-walk on the
207+
// OUTPUT side (their return is i64, not the array).
189208
let mut int_args: Vec<i64> = Vec::with_capacity(args.len());
209+
let mut _pinned: Vec<Box<[i64]>> = Vec::new();
190210
for a in args {
191211
match a {
192212
Value::HInt(h) => int_args.push(h.value),
193213
Value::Bool(b) => int_args.push(if *b { 1 } else { 0 }),
194-
_ => return None, // non-int arg → fall through to tree-walk
214+
Value::Array(arr) => {
215+
let items = arr.items.borrow();
216+
// Only support int-typed arrays at the boundary.
217+
// Any non-int element → fall through to tree-walk.
218+
if !items.iter().all(|v| matches!(v, Value::HInt(_) | Value::Bool(_))) {
219+
return None;
220+
}
221+
// Layout: slot 0 = length, slots 1..=N = elements.
222+
let mut buf: Vec<i64> = Vec::with_capacity(items.len() + 1);
223+
buf.push(items.len() as i64);
224+
for v in items.iter() {
225+
buf.push(match v {
226+
Value::HInt(h) => h.value,
227+
Value::Bool(b) => if *b { 1 } else { 0 },
228+
_ => unreachable!(),
229+
});
230+
}
231+
let boxed = buf.into_boxed_slice();
232+
let ptr = boxed.as_ptr() as i64;
233+
_pinned.push(boxed);
234+
int_args.push(ptr);
235+
}
236+
_ => return None, // other non-int args → fall through to tree-walk
195237
}
196238
}
197-
jf.call(&int_args)
198-
.map(|r| Ok(Value::HInt(HInt::new(r))))
239+
let result = jf.call(&int_args)
240+
.map(|r| Ok(Value::HInt(HInt::new(r))));
241+
// _pinned drops here, freeing the marshalled buffers.
242+
// Safe because the JIT'd code didn't retain the pointers
243+
// (verified by the lowerer's stack-local array discipline).
244+
drop(_pinned);
245+
result
199246
},
200247
);
201248
interp.set_jit_dispatch(Some(dispatch));

0 commit comments

Comments
 (0)