diff --git a/src/kimi_cli/config.py b/src/kimi_cli/config.py index 78019555a..6046aebe5 100644 --- a/src/kimi_cli/config.py +++ b/src/kimi_cli/config.py @@ -197,6 +197,10 @@ class Config(BaseModel): default="dark", description="Terminal color theme. Use 'light' for light terminal backgrounds.", ) + show_tps_meter: bool = Field( + default=False, + description="Show tokens-per-second (TPS) meter in the status bar", + ) models: dict[str, LLMModel] = Field(default_factory=dict, description="List of LLM models") providers: dict[str, LLMProvider] = Field( default_factory=dict, description="List of LLM providers" diff --git a/src/kimi_cli/soul/__init__.py b/src/kimi_cli/soul/__init__.py index 1ca049a8a..8ac837c74 100644 --- a/src/kimi_cli/soul/__init__.py +++ b/src/kimi_cli/soul/__init__.py @@ -96,6 +96,8 @@ class StatusSnapshot: """The maximum number of tokens the context can hold.""" mcp_status: MCPStatusSnapshot | None = None """The current MCP startup snapshot, if MCP is configured.""" + tps: float = 0.0 + """Current tokens-per-second rate during streaming. 0 when not streaming.""" @runtime_checkable diff --git a/src/kimi_cli/soul/kimisoul.py b/src/kimi_cli/soul/kimisoul.py index dfe99eac4..2a44bffe5 100644 --- a/src/kimi_cli/soul/kimisoul.py +++ b/src/kimi_cli/soul/kimisoul.py @@ -1,7 +1,9 @@ from __future__ import annotations import asyncio +import time import uuid +from collections import deque from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass from functools import partial @@ -17,6 +19,7 @@ APIStatusError, APITimeoutError, RetryableChatProvider, + StreamedMessagePart, ) from kosong.message import Message from tenacity import RetryCallState, retry_if_exception, stop_after_attempt, wait_exponential_jitter @@ -79,6 +82,7 @@ StepBegin, StepInterrupted, TextPart, + ThinkPart, ToolResult, TurnBegin, TurnEnd, @@ -148,6 +152,10 @@ def __init__( self._steer_queue: asyncio.Queue[str | list[ContentPart]] = asyncio.Queue() self._plan_mode: bool = self._runtime.session.state.plan_mode self._plan_session_id: str | None = self._runtime.session.state.plan_session_id + # TPS tracking for streaming tokens + self._streaming_token_timestamps: deque[tuple[float, float]] = deque() + self._streaming_token_count: float = 0.0 + self._tps_window_seconds: float = 3.0 # Pre-warm slug cache so the persisted slug survives process restarts if self._plan_session_id is not None and self._runtime.session.state.plan_slug is not None: from kimi_cli.tools.plan.heroes import seed_slug_cache @@ -380,6 +388,7 @@ def status(self) -> StatusSnapshot: context_tokens=token_count, max_context_tokens=max_size, mcp_status=self._mcp_status_snapshot(), + tps=self._calculate_tps(), ) @property @@ -428,6 +437,62 @@ def steer(self, content: str | list[ContentPart]) -> None: """Queue a steer message for injection into the current turn.""" self._steer_queue.put_nowait(content) + def _track_streaming_tokens(self, token_count: float) -> None: + """Track tokens received during streaming for TPS calculation.""" + now = time.monotonic() + self._streaming_token_count += token_count + self._streaming_token_timestamps.append((now, self._streaming_token_count)) + # Prune old entries outside the rolling window + cutoff = now - self._tps_window_seconds + while self._streaming_token_timestamps and self._streaming_token_timestamps[0][0] < cutoff: + self._streaming_token_timestamps.popleft() + + def _reset_streaming_tps(self) -> None: + """Reset TPS tracking when streaming ends or a new step begins.""" + self._streaming_token_timestamps.clear() + self._streaming_token_count = 0.0 + + def _calculate_tps(self) -> float: + """Calculate current tokens-per-second over the rolling window.""" + if len(self._streaming_token_timestamps) < 2: + return 0.0 + first_time, first_tokens = self._streaming_token_timestamps[0] + last_time, last_tokens = self._streaming_token_timestamps[-1] + duration = last_time - first_time + if duration <= 0: + return 0.0 + tokens = last_tokens - first_tokens + return tokens / duration + + @staticmethod + def _estimate_tokens_for_tps(text: str) -> float: + """Estimate token count for TPS calculation. + + Uses simple heuristics for mixed CJK/Latin text: + - CJK characters: ~1.5 tokens each + - Other characters: ~1 token per 4 characters + """ + cjk_count = 0 + other_count = 0 + for ch in text: + cp = ord(ch) + if ( + 0x4E00 <= cp <= 0x9FFF # CJK Unified Ideographs + or 0x3400 <= cp <= 0x4DBF # CJK Extension A + or 0xF900 <= cp <= 0xFAFF # CJK Compatibility + or 0x3000 <= cp <= 0x303F # CJK Symbols + or 0xFF00 <= cp <= 0xFFEF # Fullwidth Forms + or 0x3040 <= cp <= 0x309F # Hiragana + or 0x30A0 <= cp <= 0x30FF # Katakana + or 0xAC00 <= cp <= 0xD7AF # Hangul Syllables + or 0x1100 <= cp <= 0x11FF # Hangul Jamo + or 0x3130 <= cp <= 0x318F # Hangul Compatibility Jamo + ): + cjk_count += 1 + else: + other_count += 1 + return cjk_count * 1.5 + other_count / 4 + async def _consume_pending_steers(self) -> bool: """Drain the steer queue and inject as follow-up user messages. @@ -691,6 +756,8 @@ async def _agent_loop(self) -> TurnOutcome: raise MaxStepsReached(self._loop_control.max_steps_per_turn) wire_send(StepBegin(n=step_no)) + # Reset TPS tracking at the start of each step + self._reset_streaming_tps() back_to_the_future: BackToTheFuture | None = None step_outcome: StepOutcome | None = None try: @@ -806,6 +873,18 @@ async def _append_notification(view: NotificationView) -> None: # Normalize: merge adjacent user messages for clean API input effective_history = normalize_history(self._context.history) + # Create a wrapped callback to track streaming tokens for TPS calculation + def _track_and_wire_send(part: StreamedMessagePart) -> None: + """Track tokens from streaming content and send to wire.""" + match part: + case TextPart(text=text) | ThinkPart(think=text): + if text: + # Estimate tokens for TPS calculation + self._track_streaming_tokens(self._estimate_tokens_for_tps(text)) + case _: + pass # Other parts don't contain tokens to track + wire_send(part) + async def _run_step_once() -> StepResult: # run an LLM step (may be interrupted) return await kosong.step( @@ -813,7 +892,7 @@ async def _run_step_once() -> StepResult: self._agent.system_prompt, self._agent.toolset, effective_history, - on_message_part=wire_send, + on_message_part=_track_and_wire_send, on_tool_result=wire_send, ) @@ -843,6 +922,7 @@ async def _kosong_step_with_retry() -> StepResult: status_update.context_usage = snap.context_usage status_update.context_tokens = snap.context_tokens status_update.max_context_tokens = snap.max_context_tokens + status_update.tps = snap.tps wire_send(status_update) # wait for all tool results (may be interrupted) diff --git a/src/kimi_cli/ui/shell/__init__.py b/src/kimi_cli/ui/shell/__init__.py index 6c72a907b..18db3352e 100644 --- a/src/kimi_cli/ui/shell/__init__.py +++ b/src/kimi_cli/ui/shell/__init__.py @@ -298,8 +298,10 @@ async def run(self, command: str | None = None) -> bool: # Initialize theme from config if isinstance(self.soul, KimiSoul): from kimi_cli.ui.theme import set_active_theme + from kimi_cli.ui.tps_meter import set_show_tps_meter set_active_theme(self.soul.runtime.config.theme) + set_show_tps_meter(self.soul.runtime.config.show_tps_meter) if command is not None: # run single command and exit @@ -984,9 +986,11 @@ def _activate_prompt_approval_modal(self) -> None: current_request, on_response=self._handle_prompt_approval_response, buffer_text_provider=( - lambda: self._prompt_session._session.default_buffer.text # pyright: ignore[reportPrivateUsage] - if self._prompt_session is not None - else "" + lambda: ( + self._prompt_session._session.default_buffer.text # pyright: ignore[reportPrivateUsage] + if self._prompt_session is not None + else "" + ) ), text_expander=self._prompt_session._get_placeholder_manager().serialize_for_history, # pyright: ignore[reportPrivateUsage] ) diff --git a/src/kimi_cli/ui/shell/prompt.py b/src/kimi_cli/ui/shell/prompt.py index 0b7870ccc..94b7a1a5d 100644 --- a/src/kimi_cli/ui/shell/prompt.py +++ b/src/kimi_cli/ui/shell/prompt.py @@ -63,6 +63,7 @@ sanitize_surrogates, ) from kimi_cli.ui.theme import get_prompt_style, get_toolbar_colors +from kimi_cli.ui.tps_meter import get_show_tps_meter from kimi_cli.utils.clipboard import ( grab_media_from_clipboard, is_clipboard_available, @@ -2152,9 +2153,12 @@ def _get_one_rotating_tip(self) -> str | None: def _render_right_span(status: StatusSnapshot) -> str: current_toast = _current_toast("right") if current_toast is None: - return format_context_status( + context_str = format_context_status( status.context_usage, status.context_tokens, status.max_context_tokens, ) + if get_show_tps_meter() and status.tps > 0: + return f"{context_str} · {status.tps:.1f} tok/s" + return context_str return current_toast.message diff --git a/src/kimi_cli/ui/shell/slash.py b/src/kimi_cli/ui/shell/slash.py index 671304e04..21687f2dc 100644 --- a/src/kimi_cli/ui/shell/slash.py +++ b/src/kimi_cli/ui/shell/slash.py @@ -645,6 +645,57 @@ def theme(app: Shell, args: str): raise Reload(session_id=soul.runtime.session.id) +@registry.command +@shell_mode_registry.command +def tps(app: Shell, args: str): + """Toggle TPS (tokens-per-second) meter display in status bar""" + from kimi_cli.ui.tps_meter import get_show_tps_meter, set_show_tps_meter + + soul = ensure_kimi_soul(app) + if soul is None: + return + + current = get_show_tps_meter() + arg = args.strip().lower() + + if not arg: + status = "on" if current else "off" + console.print(f"TPS meter: [bold]{status}[/bold]") + console.print("[grey50]Usage: /tps on | /tps off[/grey50]") + return + + if arg not in ("on", "off"): + console.print(f"[red]Invalid argument: {arg}. Use 'on' or 'off'.[/red]") + return + + new_value = arg == "on" + + if new_value == current: + console.print(f"[yellow]TPS meter is already {arg}.[/yellow]") + return + + config_file = soul.runtime.config.source_file + if config_file is None: + console.print( + "[yellow]TPS toggle requires a config file; " + "restart without --config to persist this setting.[/yellow]" + ) + return + + # Persist to disk first — only update in-memory state after success + try: + config_for_save = load_config(config_file) + config_for_save.show_tps_meter = new_value + save_config(config_for_save, config_file) + except (ConfigError, OSError) as exc: + console.print(f"[red]Failed to save config: {exc}[/red]") + return + + # Update in-memory state immediately (no reload needed for TPS) + set_show_tps_meter(new_value) + console.print(f"[green]TPS meter {'enabled' if new_value else 'disabled'}.[/green]") + + @registry.command def web(app: Shell, args: str): """Open Kimi Code Web UI in browser""" diff --git a/src/kimi_cli/ui/shell/visualize.py b/src/kimi_cli/ui/shell/visualize.py index f63fea1f5..7f9a6d009 100644 --- a/src/kimi_cli/ui/shell/visualize.py +++ b/src/kimi_cli/ui/shell/visualize.py @@ -50,6 +50,7 @@ prompt_other_input, show_question_body_in_pager, ) +from kimi_cli.ui.tps_meter import get_show_tps_meter from kimi_cli.utils.aioqueue import Queue, QueueShutDown from kimi_cli.utils.logging import logger from kimi_cli.utils.rich.columns import BulletColumns @@ -634,6 +635,7 @@ def __init__(self, initial: StatusUpdate) -> None: self._context_usage: float = 0.0 self._context_tokens: int = 0 self._max_context_tokens: int = 0 + self._tps: float = 0.0 self.update(initial) def render(self) -> RenderableType: @@ -646,12 +648,22 @@ def update(self, status: StatusUpdate) -> None: self._context_tokens = status.context_tokens if status.max_context_tokens is not None: self._max_context_tokens = status.max_context_tokens - if status.context_usage is not None: - self.text.plain = format_context_status( - self._context_usage, - self._context_tokens, - self._max_context_tokens, - ) + if status.tps is not None: + self._tps = status.tps + # Only refresh if context_usage or tps is provided (fields that affect display) + if status.context_usage is not None or status.tps is not None: + self._refresh_text() + + def _refresh_text(self) -> None: + context_str = format_context_status( + self._context_usage, + self._context_tokens, + self._max_context_tokens, + ) + if get_show_tps_meter() and self._tps > 0: + self.text.plain = f"{context_str} · {self._tps:.1f} tok/s" + else: + self.text.plain = context_str @asynccontextmanager diff --git a/src/kimi_cli/ui/tps_meter.py b/src/kimi_cli/ui/tps_meter.py new file mode 100644 index 000000000..2b9520c42 --- /dev/null +++ b/src/kimi_cli/ui/tps_meter.py @@ -0,0 +1,27 @@ +"""TPS meter display preference - mirrors the theme pattern. + +This module provides a global state for the TPS meter display setting, +similar to how theme.py manages the active color theme. +""" + +# Module-level private state +_show_tps_meter: bool = False + + +def set_show_tps_meter(enabled: bool) -> None: + """Set whether the TPS meter should be displayed in the status bar. + + Args: + enabled: True to show the TPS meter, False to hide it. + """ + global _show_tps_meter + _show_tps_meter = enabled + + +def get_show_tps_meter() -> bool: + """Get whether the TPS meter should be displayed. + + Returns: + True if the TPS meter should be shown, False otherwise. + """ + return _show_tps_meter diff --git a/src/kimi_cli/wire/types.py b/src/kimi_cli/wire/types.py index 33eb5098d..0f53c7d64 100644 --- a/src/kimi_cli/wire/types.py +++ b/src/kimi_cli/wire/types.py @@ -176,6 +176,8 @@ class StatusUpdate(BaseModel): """Whether plan mode (read-only) is active. None means no change.""" mcp_status: MCPStatusSnapshot | None = None """The current MCP startup snapshot. None means no change.""" + tps: float | None = None + """Current tokens-per-second rate during streaming. None when not streaming.""" class Notification(BaseModel): diff --git a/tests/core/test_config.py b/tests/core/test_config.py index 4f522a0be..104c03f7b 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -26,8 +26,7 @@ def test_default_config_dump(): "default_yolo": False, "default_plan_mode": False, "default_editor": "", - "theme": "dark", - "models": {}, + "theme": "dark", "show_tps_meter": False, "models": {}, "providers": {}, "loop_control": { "max_steps_per_turn": 100, diff --git a/tests/core/test_kimisoul_tps.py b/tests/core/test_kimisoul_tps.py new file mode 100644 index 000000000..75763a136 --- /dev/null +++ b/tests/core/test_kimisoul_tps.py @@ -0,0 +1,116 @@ +"""Tests for TPS (tokens-per-second) meter functionality in KimiSoul.""" + +import time +from collections import deque +from pathlib import Path + +import pytest +from kosong.tooling.empty import EmptyToolset + +from kimi_cli.soul.agent import Agent, Runtime +from kimi_cli.soul.context import Context +from kimi_cli.soul.kimisoul import KimiSoul + + +class TestKimiSoulTPSTracking: + def test_estimate_tokens_for_tps(self): + """Token estimation uses correct heuristics for CJK and ASCII text.""" + # Empty string + assert KimiSoul._estimate_tokens_for_tps("") == 0.0 + # ASCII: 40 chars / 4 = 10 tokens + assert KimiSoul._estimate_tokens_for_tps("abcd" * 10) == pytest.approx(10.0, abs=0.1) + # CJK: 40 chars * 1.5 = 60 tokens + assert KimiSoul._estimate_tokens_for_tps("中文测试" * 10) == pytest.approx(60.0, abs=0.1) + # Mixed: 20 CJK (30) + 20 ASCII (5) = 35 + assert KimiSoul._estimate_tokens_for_tps("中文测试" * 5 + "abcd" * 5) == pytest.approx( + 35.0, abs=0.1 + ) + + def test_calculate_tps(self, runtime: Runtime, tmp_path: Path): + """TPS calculation handles edge cases and normal flow.""" + soul = self._make_soul(runtime, tmp_path) + now = time.monotonic() + + # Empty: need at least 2 timestamps + assert soul._calculate_tps() == 0.0 + + # Single timestamp: need delta + soul._streaming_token_timestamps = deque([(now, 100.0)]) + assert soul._calculate_tps() == 0.0 + + # Zero duration: same timestamp + soul._streaming_token_timestamps = deque([(now, 0.0), (now, 100.0)]) + assert soul._calculate_tps() == 0.0 + + # Normal: 300 tokens over 3 seconds = 100 tps + soul._streaming_token_timestamps = deque( + [ + (now + t, count) + for t, count in [(0.0, 0.0), (1.0, 100.0), (2.0, 200.0), (3.0, 300.0)] + ] + ) + assert soul._calculate_tps() == pytest.approx(100.0, rel=0.01) + + def test_track_streaming_tokens_and_pruning(self, runtime: Runtime, tmp_path: Path): + """Tracking accumulates tokens and prunes old entries.""" + soul = self._make_soul(runtime, tmp_path) + now = time.monotonic() + + # Add entries: one old (outside 3s window), one recent + soul._streaming_token_timestamps.append((now - 4.0, 0.0)) + soul._streaming_token_timestamps.append((now - 1.0, 100.0)) + soul._streaming_token_count = 100.0 + + # Track new tokens - should trigger pruning of old entry + soul._track_streaming_tokens(50.0) + + # Should have 2 entries (recent + new), old pruned + assert len(soul._streaming_token_timestamps) == 2 + assert soul._streaming_token_timestamps[0][0] > now - 3.5 + assert soul._streaming_token_count == 150.0 + + def test_reset_streaming_tps(self, runtime: Runtime, tmp_path: Path): + """Reset clears timestamps and token count.""" + soul = self._make_soul(runtime, tmp_path) + soul._streaming_token_timestamps = deque( + [ + (time.monotonic(), 100.0), + (time.monotonic(), 200.0), + ] + ) + soul._streaming_token_count = 200.0 + + soul._reset_streaming_tps() + + assert len(soul._streaming_token_timestamps) == 0 + assert soul._streaming_token_count == 0.0 + + def test_status_tps(self, runtime: Runtime, tmp_path: Path): + """Status snapshot includes TPS when streaming, zero otherwise.""" + soul = self._make_soul(runtime, tmp_path) + + # Not streaming: TPS should be 0.0 + status = soul.status + assert status.tps == 0.0 + + # Streaming: TPS calculated from timestamps + now = time.monotonic() + soul._streaming_token_timestamps = deque( + [ + (now - 2.0, 0.0), + (now - 1.0, 50.0), + (now, 100.0), + ] + ) + status = soul.status + assert status.tps == pytest.approx(50.0, rel=0.1) # 100 tokens / 2 seconds + + @staticmethod + def _make_soul(runtime: Runtime, tmp_path: Path) -> KimiSoul: + agent = Agent( + name="TPS Test Agent", + system_prompt="Test prompt.", + toolset=EmptyToolset(), + runtime=runtime, + ) + return KimiSoul(agent, context=Context(file_backend=tmp_path / "history.jsonl")) diff --git a/tests/core/test_wire_message.py b/tests/core/test_wire_message.py index b7b7674ac..0193766e1 100644 --- a/tests/core/test_wire_message.py +++ b/tests/core/test_wire_message.py @@ -162,8 +162,7 @@ async def test_wire_message_serde(): "tools": [], } ], - }, - }, + }, "tps": None}, } ) _test_serde(msg) diff --git a/tests/ui_and_conv/test_tps_display.py b/tests/ui_and_conv/test_tps_display.py new file mode 100644 index 000000000..30e28e24d --- /dev/null +++ b/tests/ui_and_conv/test_tps_display.py @@ -0,0 +1,91 @@ +"""Tests for TPS meter display conditionals.""" + +from pathlib import Path +from types import SimpleNamespace + +import pytest +from kosong.tooling.empty import EmptyToolset + +from kimi_cli.soul.agent import Agent, Runtime +from kimi_cli.soul.context import Context +from kimi_cli.soul.kimisoul import KimiSoul +from kimi_cli.ui.shell.prompt import CustomPromptSession +from kimi_cli.ui.shell.visualize import _StatusBlock +from kimi_cli.ui.tps_meter import set_show_tps_meter +from kimi_cli.wire.types import StatusUpdate + + +@pytest.fixture(autouse=True) +def _reset_tps_meter(): + set_show_tps_meter(False) + yield + set_show_tps_meter(False) + + +def _make_status_snapshot(tps: float = 0.0) -> SimpleNamespace: + return SimpleNamespace( + context_usage=0.5, + context_tokens=5000, + max_context_tokens=10000, + tps=tps, + ) + + +def test_render_right_span_shows_tps_when_enabled(): + """_render_right_span includes TPS when enabled and TPS > 0.""" + set_show_tps_meter(True) + status = _make_status_snapshot(tps=12.3) + + result = CustomPromptSession._render_right_span(status) + + assert "12.3" in result or "tok/s" in result + + +def test_render_right_span_hides_tps_when_not_shown(): + """_render_right_span hides TPS when disabled or TPS is 0.""" + # When disabled (even with TPS > 0) + set_show_tps_meter(False) + status = _make_status_snapshot(tps=12.3) + result = CustomPromptSession._render_right_span(status) + assert "tok/s" not in result + + # When enabled but TPS is 0 + set_show_tps_meter(True) + status = _make_status_snapshot(tps=0.0) + result = CustomPromptSession._render_right_span(status) + assert "tok/s" not in result + + +def test_status_block_shows_tps_when_enabled(): + """_StatusBlock includes TPS when enabled and TPS > 0.""" + set_show_tps_meter(True) + status_update = StatusUpdate( + context_usage=0.5, + context_tokens=5000, + max_context_tokens=10000, + tps=15.5, + ) + + block = _StatusBlock(status_update) + + assert "15.5" in block.text.plain or "tok/s" in block.text.plain + + +def test_status_block_hides_tps_when_not_shown(): + """_StatusBlock hides TPS when disabled or TPS is 0.""" + # When disabled (even with TPS > 0) + set_show_tps_meter(False) + status_update = StatusUpdate( + context_usage=0.5, + context_tokens=5000, + max_context_tokens=10000, + tps=15.5, + ) + block = _StatusBlock(status_update) + assert "tok/s" not in block.text.plain + + # When enabled but TPS is 0 + set_show_tps_meter(True) + status_update = StatusUpdate(context_usage=0.5, tps=0.0) + block = _StatusBlock(status_update) + assert "tok/s" not in block.text.plain diff --git a/tests/ui_and_conv/test_tps_meter.py b/tests/ui_and_conv/test_tps_meter.py new file mode 100644 index 000000000..81bd18b24 --- /dev/null +++ b/tests/ui_and_conv/test_tps_meter.py @@ -0,0 +1,21 @@ +"""Tests for TPS meter UI state module.""" + +import pytest + +from kimi_cli.ui.tps_meter import get_show_tps_meter, set_show_tps_meter + + +@pytest.fixture(autouse=True) +def _reset_tps_meter(): + set_show_tps_meter(False) + yield + set_show_tps_meter(False) + + +class TestTpsMeterState: + def test_get_and_set_show_tps_meter(self): + assert get_show_tps_meter() is False + set_show_tps_meter(True) + assert get_show_tps_meter() is True + set_show_tps_meter(False) + assert get_show_tps_meter() is False diff --git a/tests/ui_and_conv/test_tps_slash.py b/tests/ui_and_conv/test_tps_slash.py new file mode 100644 index 000000000..ed80342b2 --- /dev/null +++ b/tests/ui_and_conv/test_tps_slash.py @@ -0,0 +1,187 @@ +"""Tests for /tps slash command.""" + +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace +from typing import cast +from unittest.mock import Mock + +import pytest +from kosong.tooling.empty import EmptyToolset + +from kimi_cli.config import get_default_config +from kimi_cli.exception import ConfigError +from kimi_cli.soul.agent import Agent, Runtime +from kimi_cli.soul.context import Context +from kimi_cli.soul.kimisoul import KimiSoul +from kimi_cli.ui.shell import Shell +from kimi_cli.ui.shell import slash as shell_slash +from kimi_cli.ui.tps_meter import get_show_tps_meter, set_show_tps_meter + + +@pytest.fixture(autouse=True) +def _reset_tps_meter(): + set_show_tps_meter(False) + yield + set_show_tps_meter(False) + + +def _make_shell_app(runtime: Runtime, tmp_path: Path) -> SimpleNamespace: + agent = Agent( + name="Test Agent", + system_prompt="Test system prompt.", + toolset=EmptyToolset(), + runtime=runtime, + ) + soul = KimiSoul(agent, context=Context(file_backend=tmp_path / "history.jsonl")) + return SimpleNamespace(soul=soul) + + +def test_tps_command_registered_in_both_registries(): + """/tps should be available in both agent and shell registries.""" + from kimi_cli.ui.shell.slash import registry, shell_mode_registry + + agent_cmds = {c.name for c in registry.list_commands()} + shell_cmds = {c.name for c in shell_mode_registry.list_commands()} + assert "tps" in agent_cmds + assert "tps" in shell_cmds + + +def test_tps_no_args_shows_current(runtime: Runtime, tmp_path: Path, monkeypatch): + """/tps with no args should show current status.""" + app = _make_shell_app(runtime, tmp_path) + print_mock = Mock() + monkeypatch.setattr(shell_slash.console, "print", print_mock) + + set_show_tps_meter(False) + shell_slash.tps(cast(Shell, app), "") + + assert print_mock.call_count == 2 + assert "off" in str(print_mock.call_args_list[0].args[0]).lower() + + +def test_tps_invalid_arg(runtime: Runtime, tmp_path: Path, monkeypatch): + """/tps with invalid arg should show error.""" + app = _make_shell_app(runtime, tmp_path) + print_mock = Mock() + monkeypatch.setattr(shell_slash.console, "print", print_mock) + + shell_slash.tps(cast(Shell, app), "invalid") + + assert "Invalid argument" in str(print_mock.call_args.args[0]) + + +def test_tps_same_as_current(runtime: Runtime, tmp_path: Path, monkeypatch): + """/tps with same value should show 'already' message.""" + app = _make_shell_app(runtime, tmp_path) + print_mock = Mock() + monkeypatch.setattr(shell_slash.console, "print", print_mock) + + set_show_tps_meter(False) + shell_slash.tps(cast(Shell, app), "off") + + assert "already" in str(print_mock.call_args.args[0]).lower() + + +def test_tps_on_enables_and_persists(runtime: Runtime, tmp_path: Path, monkeypatch): + """/tps on should enable meter and persist to config.""" + config_path = (tmp_path / "config.toml").resolve() + runtime.config.source_file = config_path + app = _make_shell_app(runtime, tmp_path) + + config_for_save = get_default_config() + load_mock = Mock(return_value=config_for_save) + save_mock = Mock() + monkeypatch.setattr(shell_slash, "load_config", load_mock) + monkeypatch.setattr(shell_slash, "save_config", save_mock) + monkeypatch.setattr(shell_slash.console, "print", Mock()) + + set_show_tps_meter(False) + shell_slash.tps(cast(Shell, app), "on") + + load_mock.assert_called_once_with(config_path) + save_mock.assert_called_once() + assert config_for_save.show_tps_meter is True + assert get_show_tps_meter() is True # In-memory state updated + + +def test_tps_off_disables_and_persists(runtime: Runtime, tmp_path: Path, monkeypatch): + """/tps off should disable meter and persist to config.""" + config_path = (tmp_path / "config.toml").resolve() + runtime.config.source_file = config_path + app = _make_shell_app(runtime, tmp_path) + + config_for_save = get_default_config() + config_for_save.show_tps_meter = True + load_mock = Mock(return_value=config_for_save) + save_mock = Mock() + monkeypatch.setattr(shell_slash, "load_config", load_mock) + monkeypatch.setattr(shell_slash, "save_config", save_mock) + monkeypatch.setattr(shell_slash.console, "print", Mock()) + + set_show_tps_meter(True) + shell_slash.tps(cast(Shell, app), "off") + + assert config_for_save.show_tps_meter is False + assert get_show_tps_meter() is False + + +def test_tps_save_failure_no_state_change(runtime: Runtime, tmp_path: Path, monkeypatch): + """If save fails, in-memory state should not change.""" + config_path = (tmp_path / "config.toml").resolve() + runtime.config.source_file = config_path + app = _make_shell_app(runtime, tmp_path) + + set_show_tps_meter(False) + + load_mock = Mock(side_effect=ConfigError("Disk full")) + monkeypatch.setattr(shell_slash, "load_config", load_mock) + save_mock = Mock() + monkeypatch.setattr(shell_slash, "save_config", save_mock) + print_mock = Mock() + monkeypatch.setattr(shell_slash.console, "print", print_mock) + + shell_slash.tps(cast(Shell, app), "on") + + assert get_show_tps_meter() is False # Unchanged + save_mock.assert_not_called() + assert "Failed" in str(print_mock.call_args.args[0]) + + +def test_tps_rejects_inline_config(runtime: Runtime, tmp_path: Path, monkeypatch): + """/tps should warn when config file is None (inline config).""" + runtime.config.source_file = None + app = _make_shell_app(runtime, tmp_path) + print_mock = Mock() + monkeypatch.setattr(shell_slash.console, "print", print_mock) + + shell_slash.tps(cast(Shell, app), "on") + + assert "config file" in str(print_mock.call_args.args[0]).lower() + + +def test_tps_whitespace_and_case_handling(runtime: Runtime, tmp_path: Path, monkeypatch): + """Arguments are stripped and lowercased: ' ON ' should work.""" + config_path = (tmp_path / "config.toml").resolve() + runtime.config.source_file = config_path + app = _make_shell_app(runtime, tmp_path) + + config_for_save = get_default_config() + load_mock = Mock(return_value=config_for_save) + save_mock = Mock() + monkeypatch.setattr(shell_slash, "load_config", load_mock) + monkeypatch.setattr(shell_slash, "save_config", save_mock) + monkeypatch.setattr(shell_slash.console, "print", Mock()) + + set_show_tps_meter(False) + # Test with extra whitespace and uppercase + shell_slash.tps(cast(Shell, app), " ON ") + + assert config_for_save.show_tps_meter is True + assert get_show_tps_meter() is True + + # Test OFF with mixed case + shell_slash.tps(cast(Shell, app), "Off") + assert config_for_save.show_tps_meter is False + assert get_show_tps_meter() is False diff --git a/tests_e2e/test_wire_protocol.py b/tests_e2e/test_wire_protocol.py index b1e042e87..1163cf587 100644 --- a/tests_e2e/test_wire_protocol.py +++ b/tests_e2e/test_wire_protocol.py @@ -325,8 +325,7 @@ def handle_request(msg: dict[str, Any]) -> dict[str, Any]: "token_usage": None, "message_id": None, "plan_mode": False, - "mcp_status": None, - }, + "mcp_status": None, "tps": None}, }, { "method": "request", @@ -367,8 +366,7 @@ def handle_request(msg: dict[str, Any]) -> dict[str, Any]: "token_usage": None, "message_id": None, "plan_mode": False, - "mcp_status": None, - }, + "mcp_status": None, "tps": None}, }, {"method": "event", "type": "TurnEnd", "payload": {}}, ] @@ -419,8 +417,7 @@ def test_prompt_without_initialize(tmp_path) -> None: "token_usage": None, "message_id": None, "plan_mode": False, - "mcp_status": None, - }, + "mcp_status": None, "tps": None}, }, {"method": "event", "type": "TurnEnd", "payload": {}}, ]