|
| 1 | +"""Force the CPU KV-cache offload->restore path and check correctness. |
| 2 | +
|
| 3 | +GSM8K can't exercise the CPU cache (one shared hot prefix, sub-page tails). |
| 4 | +This driver builds N distinct, page-aligned, long prompts that overflow the |
| 5 | +GPU KV budget so their KV is offloaded to CPU, then re-requests them so they |
| 6 | +are restored from CPU. With greedy decoding the round-2 (CPU-restored) output |
| 7 | +MUST be token-identical to round-1 (freshly computed). For the MTP build it |
| 8 | +also tracks accept-rate (mtp_avg_token_per_step) which would degrade if the |
| 9 | +draft full-attn slots were not persisted/restored correctly. |
| 10 | +""" |
| 11 | +import argparse |
| 12 | +import sys |
| 13 | +import requests |
| 14 | +from concurrent.futures import ThreadPoolExecutor |
| 15 | + |
| 16 | + |
| 17 | +def make_prompts(n, words_per_prompt): |
| 18 | + prompts = [] |
| 19 | + for i in range(n): |
| 20 | + # Distinct, deterministic filler so each prompt is its own radix branch |
| 21 | + # and long enough to span several 256-token pages. |
| 22 | + filler = " ".join(f"item{i}-{j}" for j in range(words_per_prompt)) |
| 23 | + prompts.append( |
| 24 | + f"You are given list number {i}. The list is: {filler}. " |
| 25 | + f"Question: briefly summarize what list number {i} contains. Answer:" |
| 26 | + ) |
| 27 | + return prompts |
| 28 | + |
| 29 | + |
| 30 | +def gen(url, prompt, max_tokens): |
| 31 | + data = { |
| 32 | + "inputs": prompt, |
| 33 | + "parameters": { |
| 34 | + "temperature": 0.0, |
| 35 | + "max_new_tokens": max_tokens, |
| 36 | + "stop_sequences": None, |
| 37 | + "repetition_penalty": 1.0, |
| 38 | + "top_p": 1.0, |
| 39 | + "top_k": 1, |
| 40 | + }, |
| 41 | + } |
| 42 | + r = requests.post(url, json=data, timeout=120) |
| 43 | + assert r.status_code == 200, f"{r.status_code}: {r.text}" |
| 44 | + return r.json()["generated_text"][0] |
| 45 | + |
| 46 | + |
| 47 | +def run_round(url, prompts, max_tokens, parallel): |
| 48 | + out = [None] * len(prompts) |
| 49 | + with ThreadPoolExecutor(max_workers=parallel) as ex: |
| 50 | + futs = {ex.submit(gen, url, p, max_tokens): k for k, p in enumerate(prompts)} |
| 51 | + for f in futs: |
| 52 | + k = futs[f] |
| 53 | + out[k] = f.result() |
| 54 | + return out |
| 55 | + |
| 56 | + |
| 57 | +def main(): |
| 58 | + ap = argparse.ArgumentParser() |
| 59 | + ap.add_argument("--host", default="http://127.0.0.1") |
| 60 | + ap.add_argument("--port", type=int, default=8088) |
| 61 | + ap.add_argument("--num-prompts", type=int, default=24) |
| 62 | + ap.add_argument("--words-per-prompt", type=int, default=400) |
| 63 | + ap.add_argument("--max-tokens", type=int, default=32) |
| 64 | + ap.add_argument("--parallel", type=int, default=8) |
| 65 | + args = ap.parse_args() |
| 66 | + |
| 67 | + url = f"{args.host}:{args.port}/generate" |
| 68 | + prompts = make_prompts(args.num_prompts, args.words_per_prompt) |
| 69 | + |
| 70 | + print(f"Round 1 (cold compute): {len(prompts)} distinct prompts", flush=True) |
| 71 | + r1 = run_round(url, prompts, args.max_tokens, args.parallel) |
| 72 | + print("Round 2 (CPU-restored):", flush=True) |
| 73 | + r2 = run_round(url, prompts, args.max_tokens, args.parallel) |
| 74 | + |
| 75 | + mismatches = [i for i in range(len(prompts)) if r1[i] != r2[i]] |
| 76 | + print(f"\n=== RESULT ===") |
| 77 | + print(f"prompts: {len(prompts)} identical: {len(prompts) - len(mismatches)} mismatches: {len(mismatches)}") |
| 78 | + if mismatches: |
| 79 | + for i in mismatches[:5]: |
| 80 | + print(f" [#{i}] R1={r1[i]!r}\n R2={r2[i]!r}") |
| 81 | + sys.exit(1) |
| 82 | + print("PASS: round-2 (CPU-restored) output is token-identical to round-1.") |
| 83 | + |
| 84 | + |
| 85 | +if __name__ == "__main__": |
| 86 | + main() |
0 commit comments