@@ -26,6 +26,71 @@ def _log(msg: str) -> None:
2626 print (msg , file = sys .stderr , flush = True )
2727
2828
29+ class _TimingProposer :
30+ """Wraps a proposer, timing each RPC and counting WireTensor payload bytes,
31+ so the run reports a real per-block RTT + bandwidth breakdown."""
32+
33+ def __init__ (self , inner ) -> None :
34+ self .inner = inner
35+ self .t = {"restore" : [], "seed_context" : [], "draft_block" : [], "extend_context" : []}
36+ self .bytes = {"seed_context" : 0 , "extend_context" : 0 , "restore" : 0 }
37+
38+ @staticmethod
39+ def _wbytes (aux ) -> int :
40+ import numpy as np
41+ return int (sum (np .asarray (w .data ).nbytes for w in aux ))
42+
43+ def restore (self , prompt_ids , ** kw ):
44+ import time as _t
45+ t0 = _t .perf_counter ()
46+ r = self .inner .restore (prompt_ids , ** kw )
47+ self .t ["restore" ].append ((_t .perf_counter () - t0 ) * 1000 )
48+ self .bytes ["restore" ] += int (sum (
49+ __import__ ("numpy" ).asarray (k .data ).nbytes + __import__ ("numpy" ).asarray (v .data ).nbytes
50+ for (_ , k , v ) in r .restored ))
51+ return r
52+
53+ def seed_context (self , aux , positions ):
54+ import time as _t
55+ self .bytes ["seed_context" ] += self ._wbytes (aux )
56+ t0 = _t .perf_counter ()
57+ r = self .inner .seed_context (aux , positions )
58+ self .t ["seed_context" ].append ((_t .perf_counter () - t0 ) * 1000 )
59+ return r
60+
61+ def draft_block (self , ** kw ):
62+ import time as _t
63+ t0 = _t .perf_counter ()
64+ r = self .inner .draft_block (** kw )
65+ self .t ["draft_block" ].append ((_t .perf_counter () - t0 ) * 1000 )
66+ return r
67+
68+ def extend_context (self , aux , positions ):
69+ import time as _t
70+ self .bytes ["extend_context" ] += self ._wbytes (aux )
71+ t0 = _t .perf_counter ()
72+ r = self .inner .extend_context (aux , positions )
73+ self .t ["extend_context" ].append ((_t .perf_counter () - t0 ) * 1000 )
74+ return r
75+
76+ def close (self ):
77+ return self .inner .close ()
78+
79+ def report (self ) -> str :
80+ import statistics
81+ out = []
82+ for name in ("restore" , "seed_context" , "draft_block" , "extend_context" ):
83+ v = self .t [name ]
84+ if not v :
85+ continue
86+ mean = statistics .mean (v )
87+ p50 = sorted (v )[len (v ) // 2 ]
88+ b = self .bytes .get (name , 0 )
89+ out .append (f"{ name } : n={ len (v )} mean={ mean :.2f} ms p50={ p50 :.2f} ms"
90+ + (f" bytes={ b / 1e6 :.2f} MB" if b else "" ))
91+ return " | " .join (out )
92+
93+
2994def main () -> int :
3095 ap = argparse .ArgumentParser ()
3196 ap .add_argument ("--verifier-path" , required = True )
@@ -123,13 +188,16 @@ def main() -> int:
123188 proposer , stop = InProcessDFlashProposer (engine , session_id = "dist" ,
124189 sink = args .sink , window = args .window ), (lambda : None )
125190
191+ proposer = _TimingProposer (proposer )
126192 dec = DistributedFusedDecoder (proposer , verifier , block_size = args .block_size ,
127193 sink = args .sink , window = args .window )
128194 t0 = time .perf_counter ()
129195 res = dec .generate (prompt_ids , args .max_new_tokens )
130196 dist_s = time .perf_counter () - t0
197+ rtt_report = proposer .report ()
131198 proposer .close ()
132199 stop ()
200+ _log (f"[e2e] RTT/payload per RPC: { rtt_report } " )
133201
134202 n = len (res .output_token_ids )
135203 _log (f"[e2e] distributed: { n } tok in { dist_s :.2f} s ({ n / dist_s :.2f} tok/s) "
0 commit comments