|
| 1 | +"""Streaming chat REPL over the Kakeya gRPC runtime (v0.3). |
| 2 | +
|
| 3 | +A multi-turn chat client that uses the Python SDK to talk to a |
| 4 | +running ``RuntimeService``. Demonstrates the session-bound |
| 5 | +architecture's killer feature: the server keeps the running KV |
| 6 | +cache, so every turn after the first appends only the new user |
| 7 | +message — independent of conversation length. |
| 8 | +
|
| 9 | +Compare to ``scripts/chat.py`` (v0.2): that REPL re-prefilled the |
| 10 | +full conversation on every turn against an in-process |
| 11 | +``SpeculativeEngine``. This REPL holds one ``Session`` open across |
| 12 | +turns and the server keeps O(history) cache; per-turn prefill is |
| 13 | +O(new_user_message). |
| 14 | +
|
| 15 | +Usage:: |
| 16 | +
|
| 17 | + # 1. In one terminal, start the runtime |
| 18 | + PYTHONPATH=.:sdks/python python3 scripts/start_grpc_runtime_server.py \\ |
| 19 | + --backend cpu --verifier-id Qwen/Qwen3-0.6B \\ |
| 20 | + --bind 127.0.0.1:50051 |
| 21 | +
|
| 22 | + # 2. In another terminal, chat |
| 23 | + PYTHONPATH=.:sdks/python python3 scripts/chat_grpc.py |
| 24 | + # Or, with options: |
| 25 | + PYTHONPATH=.:sdks/python python3 scripts/chat_grpc.py \\ |
| 26 | + --address 127.0.0.1:50051 \\ |
| 27 | + --tokenizer-id Qwen/Qwen3-0.6B \\ |
| 28 | + --max-tokens 64 |
| 29 | +
|
| 30 | +REPL controls |
| 31 | +------------- |
| 32 | +
|
| 33 | + Type your message + Enter to send. |
| 34 | + Ctrl-D or empty line to exit. |
| 35 | + ``/reset`` on its own line: close current session, open new one |
| 36 | + (clear context). |
| 37 | + ``/info`` on its own line: print server-side session state |
| 38 | + (history length, KV bytes, idle time). |
| 39 | + ``/help``: this list. |
| 40 | +
|
| 41 | +Per the project's CLI-plumbing convention this script is exempt |
| 42 | +from the unit-test coverage gate. End-to-end behavior is exercised |
| 43 | +by the SDK integration tests at |
| 44 | +``tests/integration/test_sdk_real.py`` which drive the same SDK |
| 45 | +methods this REPL drives. |
| 46 | +""" |
| 47 | + |
| 48 | +from __future__ import annotations |
| 49 | + |
| 50 | +import argparse |
| 51 | +import sys |
| 52 | +from typing import List, Optional |
| 53 | + |
| 54 | + |
| 55 | +_HELP = """ |
| 56 | +Commands: |
| 57 | + /help show this help |
| 58 | + /reset close current session, start a fresh one |
| 59 | + /info show server-side session state |
| 60 | + /exit quit (or Ctrl-D / empty line) |
| 61 | +""".strip() |
| 62 | + |
| 63 | + |
| 64 | +def _print_banner(address: str, tokenizer_id: str) -> None: |
| 65 | + print( |
| 66 | + f"Kakeya v0.3 chat — {address} ({tokenizer_id})\n" |
| 67 | + f"Session-bound runtime: server keeps history, you only send " |
| 68 | + f"new tokens per turn.\n" |
| 69 | + f"Type /help for commands; Ctrl-D or empty line to quit.\n", |
| 70 | + file=sys.stderr, flush=True, |
| 71 | + ) |
| 72 | + |
| 73 | + |
| 74 | +def _read_user_input(prompt: str = "you> ") -> Optional[str]: |
| 75 | + """Read a single user line from stdin. |
| 76 | +
|
| 77 | + Returns ``None`` on EOF (Ctrl-D) or empty input. Empty input is |
| 78 | + a terminate signal — the user can use ``/reset`` to clear context |
| 79 | + without exiting the REPL. |
| 80 | + """ |
| 81 | + try: |
| 82 | + line = input(prompt) |
| 83 | + except EOFError: |
| 84 | + return None |
| 85 | + if not line.strip(): |
| 86 | + return None |
| 87 | + return line |
| 88 | + |
| 89 | + |
| 90 | +def _generate_and_print( |
| 91 | + session, |
| 92 | + tokenizer, |
| 93 | + new_tokens: List[int], |
| 94 | + max_tokens: int, |
| 95 | +) -> int: |
| 96 | + """Drive one append + generate cycle. Streams tokens to stdout |
| 97 | + as they arrive, returns the count emitted. The generator's |
| 98 | + metadata (stop reason, durations) is read after iteration via |
| 99 | + ``session.last_*`` properties. |
| 100 | + """ |
| 101 | + session.append(new_tokens) |
| 102 | + |
| 103 | + print("kakeya> ", end="", flush=True) |
| 104 | + n = 0 |
| 105 | + accumulated = [] |
| 106 | + try: |
| 107 | + for token_id in session.generate(max_tokens=max_tokens): |
| 108 | + n += 1 |
| 109 | + accumulated.append(token_id) |
| 110 | + # Decode incrementally — tokenizer.decode on the running |
| 111 | + # buffer gives the right text including BPE merges that |
| 112 | + # span multiple tokens. We re-decode the full buffer |
| 113 | + # each time (Qwen3-family tokenizers re-decode in <1ms |
| 114 | + # for a 64-token buffer; per-token decoding loses some |
| 115 | + # whitespace correctness on the tokenizer level). |
| 116 | + text_so_far = tokenizer.decode( |
| 117 | + accumulated, skip_special_tokens=True, |
| 118 | + ) |
| 119 | + # Print only the suffix that's new since last frame. |
| 120 | + if hasattr(_generate_and_print, "_last_text"): |
| 121 | + last = _generate_and_print._last_text |
| 122 | + else: |
| 123 | + last = "" |
| 124 | + new_text = text_so_far[len(last):] |
| 125 | + print(new_text, end="", flush=True) |
| 126 | + _generate_and_print._last_text = text_so_far |
| 127 | + except KeyboardInterrupt: |
| 128 | + print("\n[interrupted]", file=sys.stderr) |
| 129 | + finally: |
| 130 | + # Reset the per-call decoder state so the next turn starts |
| 131 | + # fresh. |
| 132 | + if hasattr(_generate_and_print, "_last_text"): |
| 133 | + del _generate_and_print._last_text |
| 134 | + |
| 135 | + print() # final newline |
| 136 | + return n |
| 137 | + |
| 138 | + |
| 139 | +def _print_session_info(session) -> None: |
| 140 | + info = session.info() |
| 141 | + print( |
| 142 | + f" history_length = {info.history_length}\n" |
| 143 | + f" kv_live_bytes = {info.kv_live_bytes:,}\n" |
| 144 | + f" idle_seconds = {info.idle_seconds:.3f}\n" |
| 145 | + f" inv1_violations= {info.cache_invariant_inv1_violations}\n" |
| 146 | + f" inv2_violations= {info.cache_invariant_inv2_violations}", |
| 147 | + file=sys.stderr, flush=True, |
| 148 | + ) |
| 149 | + |
| 150 | + |
| 151 | +def main() -> int: |
| 152 | + ap = argparse.ArgumentParser(description=__doc__) |
| 153 | + ap.add_argument( |
| 154 | + "--address", default="127.0.0.1:50051", |
| 155 | + help="host:port of a running kakeya gRPC RuntimeService", |
| 156 | + ) |
| 157 | + ap.add_argument( |
| 158 | + "--tokenizer-id", default="Qwen/Qwen3-0.6B", |
| 159 | + help="HF model id for the tokenizer. MUST match the verifier " |
| 160 | + "the server is running.", |
| 161 | + ) |
| 162 | + ap.add_argument( |
| 163 | + "--max-tokens", type=int, default=64, |
| 164 | + help="max_tokens per turn", |
| 165 | + ) |
| 166 | + ap.add_argument( |
| 167 | + "--system-prompt", default="You are a helpful assistant.", |
| 168 | + help="System prompt prepended on the first turn (Qwen3 chat " |
| 169 | + "template). Pass empty string to skip.", |
| 170 | + ) |
| 171 | + args = ap.parse_args() |
| 172 | + |
| 173 | + # Lazy imports keep --help fast. |
| 174 | + from kakeya import Client |
| 175 | + from kakeya.errors import KakeyaError |
| 176 | + from transformers import AutoTokenizer |
| 177 | + |
| 178 | + print(f"[chat] loading tokenizer {args.tokenizer_id} ...", |
| 179 | + file=sys.stderr, flush=True) |
| 180 | + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id) |
| 181 | + eos = tokenizer.eos_token_id |
| 182 | + eos_ids: List[int] = [int(eos)] if eos is not None else [] |
| 183 | + |
| 184 | + _print_banner(args.address, args.tokenizer_id) |
| 185 | + |
| 186 | + def _make_session(client): |
| 187 | + s = client.create_session(eos_token_ids=eos_ids) |
| 188 | + # Seed with the system prompt on turn 0 (no generation yet). |
| 189 | + if args.system_prompt: |
| 190 | + seed_ids = tokenizer.apply_chat_template( |
| 191 | + [{"role": "system", "content": args.system_prompt}], |
| 192 | + add_generation_prompt=False, |
| 193 | + tokenize=True, |
| 194 | + return_dict=False, |
| 195 | + enable_thinking=False, |
| 196 | + ) |
| 197 | + if seed_ids: |
| 198 | + s.append(seed_ids) |
| 199 | + return s |
| 200 | + |
| 201 | + with Client(args.address) as client: |
| 202 | + session = _make_session(client) |
| 203 | + try: |
| 204 | + while True: |
| 205 | + user_line = _read_user_input() |
| 206 | + if user_line is None: |
| 207 | + print("[bye]", file=sys.stderr) |
| 208 | + break |
| 209 | + |
| 210 | + # Slash commands |
| 211 | + if user_line.startswith("/"): |
| 212 | + cmd = user_line.strip().lower() |
| 213 | + if cmd in ("/exit", "/quit"): |
| 214 | + print("[bye]", file=sys.stderr) |
| 215 | + break |
| 216 | + if cmd == "/help": |
| 217 | + print(_HELP, file=sys.stderr) |
| 218 | + continue |
| 219 | + if cmd == "/reset": |
| 220 | + try: |
| 221 | + session.close() |
| 222 | + except KakeyaError: |
| 223 | + pass |
| 224 | + session = _make_session(client) |
| 225 | + print("[session reset]", file=sys.stderr) |
| 226 | + continue |
| 227 | + if cmd == "/info": |
| 228 | + try: |
| 229 | + _print_session_info(session) |
| 230 | + except KakeyaError as exc: |
| 231 | + print(f"[info error: {exc}]", file=sys.stderr) |
| 232 | + continue |
| 233 | + print(f"[unknown command: {cmd}; try /help]", |
| 234 | + file=sys.stderr) |
| 235 | + continue |
| 236 | + |
| 237 | + # Tokenize the user message via the chat template — this |
| 238 | + # gives Qwen3 the role marker tokens, not raw text. |
| 239 | + new_tokens = tokenizer.apply_chat_template( |
| 240 | + [{"role": "user", "content": user_line}], |
| 241 | + add_generation_prompt=True, |
| 242 | + tokenize=True, |
| 243 | + return_dict=False, |
| 244 | + enable_thinking=False, |
| 245 | + ) |
| 246 | + |
| 247 | + try: |
| 248 | + _generate_and_print( |
| 249 | + session=session, |
| 250 | + tokenizer=tokenizer, |
| 251 | + new_tokens=new_tokens, |
| 252 | + max_tokens=args.max_tokens, |
| 253 | + ) |
| 254 | + except KakeyaError as exc: |
| 255 | + print(f"[runtime error: {exc}]", file=sys.stderr) |
| 256 | + # Try to recover by resetting the session — the |
| 257 | + # server may have evicted it. |
| 258 | + try: |
| 259 | + session.close() |
| 260 | + except KakeyaError: |
| 261 | + pass |
| 262 | + session = _make_session(client) |
| 263 | + print("[session re-created after error]", |
| 264 | + file=sys.stderr) |
| 265 | + finally: |
| 266 | + try: |
| 267 | + session.close() |
| 268 | + except KakeyaError: |
| 269 | + pass |
| 270 | + |
| 271 | + return 0 |
| 272 | + |
| 273 | + |
| 274 | +if __name__ == "__main__": |
| 275 | + sys.exit(main()) |
0 commit comments