Skip to content

Commit 47e1652

Browse files
perf(distributed): add bench (throughput / bounded-KV / gRPC RTT) + mlx-distributed-spec-decode-bench preset
scripts/bench_distributed_spec_decode.py measures the three axes the distributed spec-decode path is judged on; run_distributed_bench.sh starts a local proposer and benches against it; bridge preset runs it on-device. Used to produce the GPU-host / Mac / cross-host comparison. Co-authored-by: FluffyAIcode <FluffyAIcode@users.noreply.github.com>
1 parent 372c30f commit 47e1652

4 files changed

Lines changed: 177 additions & 0 deletions

File tree

inference_engine/bridge/manifest.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,28 @@ def _harness_preset(
121121
params={"max_new_tokens": ("int:max_new_tokens", "48")},
122122
validate_reports=False,
123123
),
124+
Preset(
125+
name="mlx-distributed-spec-decode-bench",
126+
description="ADR 0009 distributed spec-decode perf bench, on-device: "
127+
"token throughput (local greedy vs distributed), bounded-KV "
128+
"footprint (constant in context length), and gRPC RTT "
129+
"(localhost ProposeBlock round-trip). Starts a local "
130+
"ProposerService and benches against it.",
131+
command_templates=(
132+
(
133+
"bash", "scripts/run_distributed_bench.sh",
134+
"--label", "Mac-localhost",
135+
"--max-new-tokens", "{max_new_tokens}",
136+
"--rtt-samples", "{rtt_samples}",
137+
),
138+
),
139+
timeout_minutes=45,
140+
params={
141+
"max_new_tokens": ("int:max_new_tokens", "48"),
142+
"rtt_samples": ("int:rtt_samples", "300"),
143+
},
144+
validate_reports=False,
145+
),
124146
Preset(
125147
name="mlx-env-probe",
126148
description="Probe Metal/MLX + mlx.distributed availability.",
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""Benchmark the ADR 0009 distributed spec-decode path on three axes:
2+
3+
1. token throughput — local greedy baseline vs distributed spec-decode (tok/s);
4+
2. bounded-KV footprint — the sink+window verifier's resident K/V bytes, which
5+
stay CONSTANT in context length, vs the equivalent full-attention K/V;
6+
3. gRPC RTT — per-block ProposeBlock round-trip latency to the remote proposer
7+
(localhost vs cross-host shows the network cost of remote drafts).
8+
9+
Point ``--peer`` at a running ProposerService (see
10+
scripts/demo_distributed_spec_decode.py --role proposer-node, or
11+
scripts/run_distributed_bench.sh which starts one locally).
12+
13+
CLI plumbing around tested library code; exempt from unit-test coverage by the
14+
same convention as start_grpc_runtime_server.py / demo_distributed_spec_decode.py.
15+
"""
16+
from __future__ import annotations
17+
import argparse, time, statistics, json
18+
import torch
19+
20+
from inference_engine.distributed.proposer_service import RemoteProposer
21+
from inference_engine.distributed.capability import NGRAM_MODEL_ID
22+
from inference_engine.distributed.spec_decode import DistributedSpeculativeDecoder
23+
from kv_cache_proposer.verifier import SinkWindowVerifier, VerifierConfig
24+
25+
PROMPT = ("List the numbers from 1 to 30, separated by commas, then repeat "
26+
"the same list again:")
27+
28+
def pctl(xs, p):
29+
xs = sorted(xs); k = (len(xs) - 1) * p / 100.0; f = int(k)
30+
return xs[f] if f + 1 >= len(xs) else xs[f] + (xs[f + 1] - xs[f]) * (k - f)
31+
32+
def greedy(verifier, prompt_ids, n):
33+
verifier.reset(); verifier.prefill(prompt_ids)
34+
out = [int(torch.argmax(verifier.next_token_logits).item())]
35+
while len(out) < n:
36+
verifier.append_token(out[-1])
37+
out.append(int(torch.argmax(verifier.next_token_logits).item()))
38+
return out
39+
40+
def main():
41+
ap = argparse.ArgumentParser()
42+
ap.add_argument("--peer", required=True)
43+
ap.add_argument("--label", default="run")
44+
ap.add_argument("--verifier-id", default="Qwen/Qwen3-0.6B")
45+
ap.add_argument("--rtt-samples", type=int, default=300)
46+
ap.add_argument("--max-new-tokens", type=int, default=48)
47+
ap.add_argument("--block-size", type=int, default=4)
48+
ap.add_argument("--long-tokens", type=int, default=1024)
49+
args = ap.parse_args()
50+
51+
print(f"\n================ {args.label} (peer={args.peer}) ================")
52+
53+
# ---------- 1. gRPC RTT (single ProposeBlock round-trip) ----------
54+
rp = RemoteProposer(args.peer, model_id=NGRAM_MODEL_ID)
55+
ctx = [1, 2, 3, 4, 5, 6, 7, 8] * 16 # 128-token repetitive context
56+
for _ in range(15): # warm up channel
57+
rp.propose_block(ctx, args.block_size, 1)
58+
lat = []
59+
for _ in range(args.rtt_samples):
60+
t = time.perf_counter(); rp.propose_block(ctx, args.block_size, 1)
61+
lat.append((time.perf_counter() - t) * 1000.0)
62+
rp.close()
63+
print(f"[RTT] ProposeBlock n={len(lat)} "
64+
f"mean={statistics.mean(lat):.3f}ms p50={pctl(lat,50):.3f}ms "
65+
f"p90={pctl(lat,90):.3f}ms p99={pctl(lat,99):.3f}ms "
66+
f"min={min(lat):.3f}ms max={max(lat):.3f}ms")
67+
68+
# ---------- 2. token throughput (baseline vs distributed) ----------
69+
verifier = SinkWindowVerifier(VerifierConfig(
70+
model_id=args.verifier_id, dtype=torch.bfloat16, device="cpu",
71+
sink_size=4, window_size=64))
72+
prompt = verifier.tokenizer.apply_chat_template(
73+
[{"role": "user", "content": PROMPT}],
74+
add_generation_prompt=True, tokenize=True, return_dict=False)
75+
76+
t = time.perf_counter(); greedy(verifier, prompt, args.max_new_tokens)
77+
bt = time.perf_counter() - t
78+
verifier.reset()
79+
dec = DistributedSpeculativeDecoder(
80+
RemoteProposer(args.peer, model_id=NGRAM_MODEL_ID), verifier,
81+
block_size=args.block_size, num_diffusion_steps=1)
82+
t = time.perf_counter()
83+
res = dec.generate(prompt, max_new_tokens=args.max_new_tokens)
84+
dt = time.perf_counter() - t
85+
n = len(res.output_token_ids)
86+
dec.proposer.close()
87+
print(f"[THRUPUT] baseline(local greedy)={args.max_new_tokens/bt:6.2f} tok/s "
88+
f"({bt:.2f}s) distributed={n/dt:6.2f} tok/s ({dt:.2f}s) "
89+
f"acceptance={res.acceptance_rate:.3f} ({res.total_accepted}/{res.total_proposed})")
90+
91+
# ---------- 3. bounded-KV footprint (constant vs context length) ----------
92+
bpt = verifier._bytes_per_kv_token
93+
bound = verifier.config.sink_size + verifier.config.window_size
94+
long_prompt = (prompt * (args.long_tokens // len(prompt) + 1))[: args.long_tokens]
95+
verifier.reset(); verifier.prefill(long_prompt)
96+
nxt = int(torch.argmax(verifier.next_token_logits).item())
97+
for _ in range(args.max_new_tokens): # decode further; cache must stay bounded
98+
verifier.append_token(nxt); nxt = int(torch.argmax(verifier.next_token_logits).item())
99+
total_ctx = len(long_prompt) + args.max_new_tokens
100+
live = verifier.cache_logical_size * bpt
101+
unbounded = total_ctx * bpt
102+
print(f"[BOUNDED-KV] ctx={total_ctx} tok sink+window={bound} "
103+
f"cache_logical_size={verifier.cache_logical_size} slots "
104+
f"bytes/kv-token={bpt}")
105+
print(f"[BOUNDED-KV] resident KV={live/1e6:.3f} MB (peak={verifier.stats.peak_kv_bytes/1e6:.3f} MB) "
106+
f"vs full-attention KV={unbounded/1e6:.3f} MB "
107+
f"=> {unbounded/live:.1f}x smaller, CONSTANT in context length")
108+
109+
if __name__ == "__main__":
110+
main()

scripts/run_distributed_bench.sh

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#!/usr/bin/env bash
2+
# Start a local ProposerService and benchmark the distributed spec-decode path
3+
# (token throughput, bounded-KV footprint, gRPC RTT) against it. On-device perf
4+
# validation for the Mac bridge; runnable anywhere with the deps.
5+
set -euo pipefail
6+
repo_root="$(cd "$(dirname "$0")/.." && pwd)"
7+
cd "$repo_root"
8+
9+
VERIFIER_ID="Qwen/Qwen3-0.6B"
10+
MAXNEW="48"; RTT="300"; LABEL="local"
11+
while [[ $# -gt 0 ]]; do
12+
case "$1" in
13+
--verifier-id) shift; VERIFIER_ID="${1:?}" ;;
14+
--max-new-tokens) shift; MAXNEW="${1:?}" ;;
15+
--rtt-samples) shift; RTT="${1:?}" ;;
16+
--label) shift; LABEL="${1:?}" ;;
17+
*) echo "[dist-bench] ignoring arg: $1" >&2 ;;
18+
esac
19+
shift
20+
done
21+
22+
_can() { [ -n "${1:-}" ] && "$1" -c 'import grpc, torch, transformers' >/dev/null 2>&1; }
23+
PYBIN=""
24+
for c in "${KAKEYA_MAC_PYTHON:-}" "$repo_root/.venv-mac/bin/python3.13" \
25+
"$repo_root/.venv-mac/bin/python" "$(command -v python3 2>/dev/null || true)"; do
26+
if _can "$c"; then PYBIN="$c"; break; fi
27+
done
28+
[[ -z "$PYBIN" ]] && { echo "[dist-bench] no Python with grpc+torch+transformers" >&2; exit 2; }
29+
echo "[dist-bench] python=$PYBIN label=$LABEL" >&2
30+
31+
export PYTHONPATH="$repo_root:$repo_root/sdks/python"
32+
export HF_HUB_DISABLE_PROGRESS_BARS=1
33+
34+
for p in $(pgrep -f demo_distributed_spec_decode 2>/dev/null || true); do kill "$p" 2>/dev/null || true; done
35+
sleep 1
36+
"$PYBIN" scripts/demo_distributed_spec_decode.py \
37+
--role proposer-node --bind 127.0.0.1:50061 --node-id bench-proposer \
38+
> /tmp/kakeya_bench_proposer.log 2>&1 &
39+
PP=$!
40+
trap 'kill "$PP" 2>/dev/null || true' EXIT
41+
sleep 6
42+
"$PYBIN" scripts/bench_distributed_spec_decode.py \
43+
--peer 127.0.0.1:50061 --label "$LABEL" \
44+
--verifier-id "$VERIFIER_ID" --max-new-tokens "$MAXNEW" --rtt-samples "$RTT"

tests/inference_engine/bridge/test_manifest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def test_allowlist_contains_exactly_the_documented_presets():
7979
"mlx-batched-manual-sdpa",
8080
"mlx-batched-multitenant",
8181
"mlx-batched-pad-decode",
82+
"mlx-distributed-spec-decode-bench",
8283
"mlx-distributed-spec-decode-demo",
8384
"mlx-env-probe",
8485
"mlx-kakeya-chat-smoke",

0 commit comments

Comments
 (0)