Skip to content

Commit 7edc1a0

Browse files
feat(e2e): per-RPC RTT + payload-byte instrumentation for distributed DFlash E2E
_TimingProposer wraps the proposer to report mean/p50 RTT for restore/seed/draft/ extend + WireTensor payload bytes (DraftBlock O(1) vs ExtendContext O(block aux)). Co-authored-by: FluffyAIcode <FluffyAIcode@users.noreply.github.com>
1 parent c360aba commit 7edc1a0

1 file changed

Lines changed: 68 additions & 0 deletions

File tree

scripts/research/k3_distributed_dflash_e2e_mac.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,71 @@ def _log(msg: str) -> None:
2626
print(msg, file=sys.stderr, flush=True)
2727

2828

29+
class _TimingProposer:
30+
"""Wraps a proposer, timing each RPC and counting WireTensor payload bytes,
31+
so the run reports a real per-block RTT + bandwidth breakdown."""
32+
33+
def __init__(self, inner) -> None:
34+
self.inner = inner
35+
self.t = {"restore": [], "seed_context": [], "draft_block": [], "extend_context": []}
36+
self.bytes = {"seed_context": 0, "extend_context": 0, "restore": 0}
37+
38+
@staticmethod
39+
def _wbytes(aux) -> int:
40+
import numpy as np
41+
return int(sum(np.asarray(w.data).nbytes for w in aux))
42+
43+
def restore(self, prompt_ids, **kw):
44+
import time as _t
45+
t0 = _t.perf_counter()
46+
r = self.inner.restore(prompt_ids, **kw)
47+
self.t["restore"].append((_t.perf_counter() - t0) * 1000)
48+
self.bytes["restore"] += int(sum(
49+
__import__("numpy").asarray(k.data).nbytes + __import__("numpy").asarray(v.data).nbytes
50+
for (_, k, v) in r.restored))
51+
return r
52+
53+
def seed_context(self, aux, positions):
54+
import time as _t
55+
self.bytes["seed_context"] += self._wbytes(aux)
56+
t0 = _t.perf_counter()
57+
r = self.inner.seed_context(aux, positions)
58+
self.t["seed_context"].append((_t.perf_counter() - t0) * 1000)
59+
return r
60+
61+
def draft_block(self, **kw):
62+
import time as _t
63+
t0 = _t.perf_counter()
64+
r = self.inner.draft_block(**kw)
65+
self.t["draft_block"].append((_t.perf_counter() - t0) * 1000)
66+
return r
67+
68+
def extend_context(self, aux, positions):
69+
import time as _t
70+
self.bytes["extend_context"] += self._wbytes(aux)
71+
t0 = _t.perf_counter()
72+
r = self.inner.extend_context(aux, positions)
73+
self.t["extend_context"].append((_t.perf_counter() - t0) * 1000)
74+
return r
75+
76+
def close(self):
77+
return self.inner.close()
78+
79+
def report(self) -> str:
80+
import statistics
81+
out = []
82+
for name in ("restore", "seed_context", "draft_block", "extend_context"):
83+
v = self.t[name]
84+
if not v:
85+
continue
86+
mean = statistics.mean(v)
87+
p50 = sorted(v)[len(v) // 2]
88+
b = self.bytes.get(name, 0)
89+
out.append(f"{name}: n={len(v)} mean={mean:.2f}ms p50={p50:.2f}ms"
90+
+ (f" bytes={b/1e6:.2f}MB" if b else ""))
91+
return " | ".join(out)
92+
93+
2994
def main() -> int:
3095
ap = argparse.ArgumentParser()
3196
ap.add_argument("--verifier-path", required=True)
@@ -123,13 +188,16 @@ def main() -> int:
123188
proposer, stop = InProcessDFlashProposer(engine, session_id="dist",
124189
sink=args.sink, window=args.window), (lambda: None)
125190

191+
proposer = _TimingProposer(proposer)
126192
dec = DistributedFusedDecoder(proposer, verifier, block_size=args.block_size,
127193
sink=args.sink, window=args.window)
128194
t0 = time.perf_counter()
129195
res = dec.generate(prompt_ids, args.max_new_tokens)
130196
dist_s = time.perf_counter() - t0
197+
rtt_report = proposer.report()
131198
proposer.close()
132199
stop()
200+
_log(f"[e2e] RTT/payload per RPC: {rtt_report}")
133201

134202
n = len(res.output_token_ids)
135203
_log(f"[e2e] distributed: {n} tok in {dist_s:.2f}s ({n/dist_s:.2f} tok/s) "

0 commit comments

Comments
 (0)