Skip to content

Commit 5aba85d

Browse files
unamedkrclaude
andcommitted
bench(niah): R1 grid — 36/36 PASS, +0.0 pp delta vs FP32
Karpathy-loop NIAH validation of v0.12 KV compression. Runs the 6.4× turbo_q4_w128 (-k turbo_kv_4b -v q4 --k-window 128) against an fp32 KV baseline on Llama-3.2-3B-Instruct-Q8_0, with the wikitext-2 haystack and 3 common-English-word needles inserted at depths 0.10/0.50/0.90 in 512-token and 1024-token contexts. Result: 18/18 PASS for fp32, 18/18 PASS for the compressed variant — exact match, +0.0 pp delta. The 6.4× compression preserves needle retrieval bit-for-bit in the regime where the 3B Q4 model can actually retrieve. The findings.md file is the honest part: it documents that the Karpathy loop also surfaced a real model-layer ceiling. Above ~1500 tokens of input, the 3B Q4 build (running through quant.cpp's default-Q4 weight conversion) loses the chat template anchor and just continues the haystack instead of answering the question. That limit is the model's, not the KV cache's — the failure reproduces under -k fp32 — but it does mean the Beyond RAG framing in the v0.12 manifesto is honest only for documents that fit in the model's *effective* working memory, not its nominal 128K context. R1 round-by-round log, prompt-format iterations (including a real "I'm trapped in an infinite loop of repetition" generation when fed repetitive synthetic filler), and the methodology that got us to 36/36 are all in bench/results/niah/findings.md. Files: - bench/niah_test.sh: parameterised harness, GRID=quick|default|full - bench/results/niah/findings.md: methodology + Karpathy log + caveats - bench/results/niah/aggregate.py: CSV → markdown summary table - bench/results/niah/results_20260411T024534.{csv,md}: R1 data - bench/results/niah/raw_20260411T024534.log: per-run CLI output Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent a6475d0 commit 5aba85d

6 files changed

Lines changed: 2240 additions & 0 deletions

File tree

bench/niah_test.sh

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
#!/usr/bin/env bash
2+
# Needle-in-a-Haystack benchmark for quant.cpp KV cache compression.
3+
#
4+
# Compares FP32 KV (baseline) vs turbo_kv_4b -v q4 --k-window 128 (6.4× compression).
5+
# Uses common-English-word needles that survive Q4 weight visual jitter.
6+
# Scoring: case-insensitive grep for distinctive keywords from the needle.
7+
#
8+
# Usage:
9+
# bash bench/niah_test.sh # default grid
10+
# GRID=quick bash bench/niah_test.sh # smaller grid for fast iteration
11+
# GRID=full bash bench/niah_test.sh # full grid (slow)
12+
13+
set -e
14+
15+
TQ=${TQ:-./build_metal/quant}
16+
MODEL=${MODEL:-models/Llama-3.2-3B-Instruct-Q8_0.gguf}
17+
THREADS=${THREADS:-8}
18+
GRID=${GRID:-default}
19+
OUT_DIR=${OUT_DIR:-bench/results/niah}
20+
RUN_ID=$(date -u +%Y%m%dT%H%M%S)
21+
RAW_LOG="$OUT_DIR/raw_${RUN_ID}.log"
22+
RESULT_CSV="$OUT_DIR/results_${RUN_ID}.csv"
23+
24+
mkdir -p "$OUT_DIR"
25+
26+
if [ ! -x "$TQ" ]; then
27+
echo "ERROR: $TQ not built. Run: cmake --build build_metal -j8" >&2
28+
exit 1
29+
fi
30+
if [ ! -f "$MODEL" ]; then
31+
echo "ERROR: $MODEL missing." >&2
32+
exit 1
33+
fi
34+
35+
# ----------------------------------------------------------------------------
36+
# Grid configuration
37+
#
38+
# IMPORTANT: contexts here are TOKEN counts, not chars. Llama-3.2-3B-Instruct-Q8_0
39+
# runs from this CLI default-converts weights to Q4 on the fly. Empirically the
40+
# effective working memory of that build is ~1500 tokens — beyond that the
41+
# chat template gets overpowered by the document continuation prior and the
42+
# model fails to answer the question (just continues the haystack text).
43+
# Grid sizes therefore stay within the regime where the model can actually
44+
# retrieve, so we measure compression-vs-baseline cleanly.
45+
# ----------------------------------------------------------------------------
46+
case "$GRID" in
47+
quick)
48+
CONTEXTS=(512 1024)
49+
DEPTHS=(0.1 0.5 0.9)
50+
;;
51+
default)
52+
CONTEXTS=(512 1024 1536)
53+
DEPTHS=(0.1 0.5 0.9)
54+
;;
55+
full)
56+
CONTEXTS=(512 1024 1536)
57+
DEPTHS=(0.1 0.25 0.5 0.75 0.9)
58+
;;
59+
*)
60+
echo "Unknown GRID: $GRID" >&2; exit 1 ;;
61+
esac
62+
63+
# Three needles, all common-English-word so the answer survives Q4 jitter.
64+
NEEDLE_0="The chief financial officer of Northwind Logistics is Sarah Chen, hired in 2023."
65+
QUESTION_0="Who is the chief financial officer of Northwind Logistics? Answer with the full name."
66+
KEYWORD_0="Sarah\|Chen"
67+
68+
NEEDLE_1="The launch date for Project Aurora is November 14th in San Francisco."
69+
QUESTION_1="When and where will Project Aurora launch? Answer in one sentence."
70+
KEYWORD_1="November\|San Francisco"
71+
72+
NEEDLE_2="The reactor cooling tank at the Helios facility holds exactly eight thousand liters of distilled water."
73+
QUESTION_2="How much distilled water does the reactor cooling tank at Helios hold?"
74+
KEYWORD_2="eight thousand\|8000\|8,000"
75+
76+
NEEDLES=("$NEEDLE_0" "$NEEDLE_1" "$NEEDLE_2")
77+
QUESTIONS=("$QUESTION_0" "$QUESTION_1" "$QUESTION_2")
78+
KEYWORDS=("$KEYWORD_0" "$KEYWORD_1" "$KEYWORD_2")
79+
80+
# Methods: name|kv-flag|v-flag|extra
81+
METHOD_NAMES=("fp32" "turbo_q4_w128")
82+
METHOD_FLAGS=("-k fp32" "-k turbo_kv_4b -v q4 --k-window 128")
83+
84+
# ----------------------------------------------------------------------------
85+
# Helpers
86+
# ----------------------------------------------------------------------------
87+
# build_prompt CTX_TOKENS DEPTH NEEDLE QUESTION → echoes the prompt
88+
#
89+
# Uses real wikitext-2 text as varied haystack (synthetic repetitive filler
90+
# triggers a "stuck in repetition loop" failure mode in 3B Q4: the model
91+
# generates meta-text like "I'm trapped in an infinite loop of repetition"
92+
# instead of answering the question — see bench/results/niah/findings.md).
93+
build_prompt() {
94+
python3 - "$1" "$2" "$3" "$4" <<'PYEOF'
95+
import sys
96+
ctx_tokens = int(sys.argv[1])
97+
depth = float(sys.argv[2])
98+
needle = sys.argv[3]
99+
question = sys.argv[4]
100+
101+
with open("bench/data/wikitext2_test.txt") as f:
102+
raw = f.read()
103+
104+
# ~4 chars per token for English wikitext, sized below ctx to leave room
105+
# for the question + chat template + answer headroom.
106+
target_chars = int(ctx_tokens * 3.6)
107+
hay = raw[:target_chars]
108+
# Trim to last full sentence so the model isn't fed a partial word.
109+
end = hay.rfind(". ")
110+
if end > 0:
111+
hay = hay[:end + 1]
112+
113+
# Insert needle at sentence boundary nearest the requested depth.
114+
desired = int(len(hay) * depth)
115+
sb = hay.rfind(". ", 0, max(desired, 2))
116+
if sb < 0:
117+
sb = 0
118+
else:
119+
sb += 2
120+
hay2 = hay[:sb] + needle + " " + hay[sb:]
121+
122+
# Simple format that works with --chat at sub-1500-token contexts.
123+
# The structured "Based on this document..." prefix overpowers the
124+
# chat template at this scale and causes the model to continue the
125+
# haystack — keep it minimal.
126+
prompt = hay2 + "\n\nQuestion: " + question
127+
sys.stdout.write(prompt)
128+
PYEOF
129+
}
130+
131+
# score_response RESPONSE KEYWORD → echoes 1 (pass) or 0 (fail)
132+
score_response() {
133+
local resp="$1"
134+
local kw="$2"
135+
if echo "$resp" | grep -qiE "$(echo "$kw" | sed 's/\\|/|/g')"; then
136+
echo 1
137+
else
138+
echo 0
139+
fi
140+
}
141+
142+
# ----------------------------------------------------------------------------
143+
# Header
144+
# ----------------------------------------------------------------------------
145+
echo "method,context,depth,needle_idx,pass,response" > "$RESULT_CSV"
146+
echo "==> NIAH Benchmark"
147+
echo " binary: $TQ"
148+
echo " model: $MODEL"
149+
echo " grid: $GRID contexts=${CONTEXTS[*]} depths=${DEPTHS[*]}"
150+
echo " needles: ${#NEEDLES[@]}"
151+
echo " methods: ${METHOD_NAMES[*]}"
152+
echo " raw: $RAW_LOG"
153+
echo " results: $RESULT_CSV"
154+
echo ""
155+
156+
total_runs=$(( ${#METHOD_NAMES[@]} * ${#CONTEXTS[@]} * ${#DEPTHS[@]} * ${#NEEDLES[@]} ))
157+
run_idx=0
158+
159+
for mi in "${!METHOD_NAMES[@]}"; do
160+
mname="${METHOD_NAMES[$mi]}"
161+
mflags="${METHOD_FLAGS[$mi]}"
162+
for ctx in "${CONTEXTS[@]}"; do
163+
# Need ctx + question + answer headroom; round up to power of 2 + slack
164+
cli_ctx=$(( ctx + 256 ))
165+
for depth in "${DEPTHS[@]}"; do
166+
for ni in "${!NEEDLES[@]}"; do
167+
run_idx=$(( run_idx + 1 ))
168+
needle="${NEEDLES[$ni]}"
169+
question="${QUESTIONS[$ni]}"
170+
keyword="${KEYWORDS[$ni]}"
171+
172+
prompt=$(build_prompt "$ctx" "$depth" "$needle" "$question")
173+
174+
printf "[%3d/%d] %-14s ctx=%-5d depth=%.2f needle=%d " \
175+
"$run_idx" "$total_runs" "$mname" "$ctx" "$depth" "$ni"
176+
177+
# Run inference
178+
out=$( "$TQ" "$MODEL" -p "$prompt" -n 32 -T 0.0 -j "$THREADS" \
179+
--chat --ctx "$cli_ctx" $mflags 2>&1 || true )
180+
181+
# Extract response — between 1st and 2nd '---' delimiters,
182+
# skipping the [tokenizer] line that the CLI prints first.
183+
resp=$(echo "$out" | awk '
184+
/^---$/ { n++; next }
185+
n==1 && /^\[tokenizer\]/ { next }
186+
n==1 { print }
187+
')
188+
if [ -z "$resp" ]; then
189+
resp=$(echo "$out" | tail -3 | head -1)
190+
fi
191+
# Strip newlines for CSV
192+
resp_csv=$(echo "$resp" | tr '\n' ' ' | sed 's/"/""/g')
193+
194+
pass=$(score_response "$resp" "$keyword")
195+
if [ "$pass" = "1" ]; then echo "PASS"; else echo "FAIL: ${resp:0:60}"; fi
196+
197+
echo "$mname,$ctx,$depth,$ni,$pass,\"$resp_csv\"" >> "$RESULT_CSV"
198+
echo "===== $mname ctx=$ctx depth=$depth needle=$ni =====" >> "$RAW_LOG"
199+
echo "$out" >> "$RAW_LOG"
200+
echo "" >> "$RAW_LOG"
201+
done
202+
done
203+
done
204+
done
205+
206+
# ----------------------------------------------------------------------------
207+
# Summary
208+
# ----------------------------------------------------------------------------
209+
echo ""
210+
echo "==> Results CSV: $RESULT_CSV"
211+
echo ""
212+
echo "==> Summary by method:"
213+
for mname in "${METHOD_NAMES[@]}"; do
214+
pass=$(awk -F, -v m="$mname" 'NR>1 && $1==m {p+=$5; t++} END{printf "%d/%d", p, t}' "$RESULT_CSV")
215+
pct=$(awk -F, -v m="$mname" 'NR>1 && $1==m {p+=$5; t++} END{if(t>0)printf "%.1f%%", 100*p/t; else print "n/a"}' "$RESULT_CSV")
216+
printf " %-16s %s (%s)\n" "$mname" "$pass" "$pct"
217+
done
218+
219+
echo ""
220+
echo "==> Summary by (method × context):"
221+
printf " %-16s" "method"
222+
for ctx in "${CONTEXTS[@]}"; do printf " %7d" "$ctx"; done
223+
echo ""
224+
for mname in "${METHOD_NAMES[@]}"; do
225+
printf " %-16s" "$mname"
226+
for ctx in "${CONTEXTS[@]}"; do
227+
pct=$(awk -F, -v m="$mname" -v c="$ctx" 'NR>1 && $1==m && $2==c {p+=$5; t++} END{if(t>0)printf "%5.0f%%", 100*p/t; else print " n/a"}' "$RESULT_CSV")
228+
printf " %7s" "$pct"
229+
done
230+
echo ""
231+
done

bench/results/niah/aggregate.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
#!/usr/bin/env python3
2+
"""Aggregate NIAH CSV results into a markdown table.
3+
4+
Usage:
5+
python bench/results/niah/aggregate.py bench/results/niah/results_*.csv
6+
"""
7+
import csv
8+
import sys
9+
from collections import defaultdict
10+
from pathlib import Path
11+
12+
13+
def load(csv_path):
14+
rows = []
15+
with open(csv_path) as f:
16+
reader = csv.DictReader(f)
17+
for r in reader:
18+
rows.append({
19+
"method": r["method"],
20+
"context": int(r["context"]),
21+
"depth": float(r["depth"]),
22+
"needle": int(r["needle_idx"]),
23+
"pass": int(r["pass"]),
24+
})
25+
return rows
26+
27+
28+
def by_method(rows):
29+
agg = defaultdict(lambda: {"p": 0, "t": 0})
30+
for r in rows:
31+
agg[r["method"]]["p"] += r["pass"]
32+
agg[r["method"]]["t"] += 1
33+
return agg
34+
35+
36+
def by_method_ctx(rows):
37+
agg = defaultdict(lambda: {"p": 0, "t": 0})
38+
for r in rows:
39+
key = (r["method"], r["context"])
40+
agg[key]["p"] += r["pass"]
41+
agg[key]["t"] += 1
42+
return agg
43+
44+
45+
def by_method_depth(rows):
46+
agg = defaultdict(lambda: {"p": 0, "t": 0})
47+
for r in rows:
48+
key = (r["method"], r["depth"])
49+
agg[key]["p"] += r["pass"]
50+
agg[key]["t"] += 1
51+
return agg
52+
53+
54+
def fmt(p, t):
55+
if t == 0:
56+
return "n/a"
57+
return f"{p}/{t} ({100*p/t:.0f}%)"
58+
59+
60+
def main():
61+
if len(sys.argv) < 2:
62+
print(__doc__); sys.exit(1)
63+
csv_path = sys.argv[1]
64+
rows = load(csv_path)
65+
if not rows:
66+
print("No rows."); sys.exit(1)
67+
68+
methods = sorted({r["method"] for r in rows})
69+
contexts = sorted({r["context"] for r in rows})
70+
depths = sorted({r["depth"] for r in rows})
71+
72+
print(f"# NIAH Results — `{Path(csv_path).name}`\n")
73+
print(f"- Methods: {', '.join(methods)}")
74+
print(f"- Contexts: {contexts}")
75+
print(f"- Depths: {depths}")
76+
print(f"- Total runs: {len(rows)}\n")
77+
78+
# Overall
79+
print("## Overall accuracy\n")
80+
print("| Method | Score |")
81+
print("|---|---|")
82+
bym = by_method(rows)
83+
for m in methods:
84+
s = bym[m]
85+
print(f"| `{m}` | {fmt(s['p'], s['t'])} |")
86+
print()
87+
88+
# Method × context
89+
print("## Accuracy by context length\n")
90+
header = "| Method | " + " | ".join(f"{c}" for c in contexts) + " |"
91+
sep = "|" + "---|" * (len(contexts) + 1)
92+
print(header)
93+
print(sep)
94+
bymc = by_method_ctx(rows)
95+
for m in methods:
96+
cells = [f"`{m}`"]
97+
for c in contexts:
98+
s = bymc[(m, c)]
99+
cells.append(fmt(s["p"], s["t"]))
100+
print("| " + " | ".join(cells) + " |")
101+
print()
102+
103+
# Method × depth
104+
print("## Accuracy by needle depth\n")
105+
header = "| Method | " + " | ".join(f"{d:.2f}" for d in depths) + " |"
106+
sep = "|" + "---|" * (len(depths) + 1)
107+
print(header)
108+
print(sep)
109+
bymd = by_method_depth(rows)
110+
for m in methods:
111+
cells = [f"`{m}`"]
112+
for d in depths:
113+
s = bymd[(m, d)]
114+
cells.append(fmt(s["p"], s["t"]))
115+
print("| " + " | ".join(cells) + " |")
116+
print()
117+
118+
# Delta vs first method (baseline)
119+
if len(methods) >= 2:
120+
baseline = methods[0]
121+
print(f"## Delta vs `{baseline}` baseline\n")
122+
bym = by_method(rows)
123+
b_acc = bym[baseline]["p"] / max(bym[baseline]["t"], 1)
124+
for m in methods:
125+
if m == baseline:
126+
continue
127+
acc = bym[m]["p"] / max(bym[m]["t"], 1)
128+
delta = (acc - b_acc) * 100
129+
sign = "+" if delta >= 0 else ""
130+
print(f"- `{m}`: **{sign}{delta:.1f} pp** vs baseline ({100*acc:.1f}% vs {100*b_acc:.1f}%)")
131+
print()
132+
133+
134+
if __name__ == "__main__":
135+
main()

0 commit comments

Comments
 (0)