diff --git a/src/claude_agent_sdk/types.py b/src/claude_agent_sdk/types.py index 1be8cff8..1340638f 100644 --- a/src/claude_agent_sdk/types.py +++ b/src/claude_agent_sdk/types.py @@ -731,22 +731,37 @@ class SandboxSettings(TypedDict, total=False): # Content block types -@dataclass + + +def _truncate(text: str, max_length: int = 80) -> str: + """Truncate text for repr display.""" + if len(text) <= max_length: + return text + return text[: max_length - 3] + "..." + + +@dataclass(repr=False) class TextBlock: """Text content block.""" text: str + def __repr__(self) -> str: + return f"TextBlock(text={_truncate(self.text)!r})" -@dataclass + +@dataclass(repr=False) class ThinkingBlock: """Thinking content block.""" thinking: str signature: str + def __repr__(self) -> str: + return f"ThinkingBlock(thinking={_truncate(self.thinking)!r})" -@dataclass + +@dataclass(repr=False) class ToolUseBlock: """Tool use content block.""" @@ -754,8 +769,11 @@ class ToolUseBlock: name: str input: dict[str, Any] + def __repr__(self) -> str: + return f"ToolUseBlock(id={self.id!r}, name={self.name!r})" -@dataclass + +@dataclass(repr=False) class ToolResultBlock: """Tool result content block.""" @@ -763,6 +781,17 @@ class ToolResultBlock: content: str | list[dict[str, Any]] | None = None is_error: bool | None = None + def __repr__(self) -> str: + parts = [f"tool_use_id={self.tool_use_id!r}"] + if self.content is not None: + if isinstance(self.content, str): + parts.append(f"content={_truncate(self.content)!r}") + else: + parts.append(f"content={self.content!r}") + if self.is_error: + parts.append("is_error=True") + return f"ToolResultBlock({', '.join(parts)})" + ContentBlock = TextBlock | ThinkingBlock | ToolUseBlock | ToolResultBlock @@ -778,7 +807,7 @@ class ToolResultBlock: ] -@dataclass +@dataclass(repr=False) class UserMessage: """User message.""" @@ -787,8 +816,13 @@ class UserMessage: parent_tool_use_id: str | None = None tool_use_result: dict[str, Any] | None = None + def __repr__(self) -> str: + if isinstance(self.content, str): + return f"UserMessage(content={_truncate(self.content)!r})" + return f"UserMessage(content={self.content!r})" -@dataclass + +@dataclass(repr=False) class AssistantMessage: """Assistant message with content blocks.""" @@ -798,14 +832,23 @@ class AssistantMessage: error: AssistantMessageError | None = None usage: dict[str, Any] | None = None + def __repr__(self) -> str: + parts = [f"model={self.model!r}", f"content={self.content!r}"] + if self.error is not None: + parts.append(f"error={self.error!r}") + return f"AssistantMessage({', '.join(parts)})" -@dataclass + +@dataclass(repr=False) class SystemMessage: """System message with metadata.""" subtype: str data: dict[str, Any] + def __repr__(self) -> str: + return f"SystemMessage(subtype={self.subtype!r})" + class TaskUsage(TypedDict): """Usage statistics reported in task_progress and task_notification messages.""" @@ -819,7 +862,7 @@ class TaskUsage(TypedDict): TaskNotificationStatus = Literal["completed", "failed", "stopped"] -@dataclass +@dataclass(repr=False) class TaskStartedMessage(SystemMessage): """System message emitted when a task starts. @@ -835,8 +878,11 @@ class TaskStartedMessage(SystemMessage): tool_use_id: str | None = None task_type: str | None = None + def __repr__(self) -> str: + return f"TaskStartedMessage(task_id={self.task_id!r}, description={_truncate(self.description)!r})" -@dataclass + +@dataclass(repr=False) class TaskProgressMessage(SystemMessage): """System message emitted while a task is in progress. @@ -853,8 +899,11 @@ class TaskProgressMessage(SystemMessage): tool_use_id: str | None = None last_tool_name: str | None = None + def __repr__(self) -> str: + return f"TaskProgressMessage(task_id={self.task_id!r}, description={_truncate(self.description)!r})" -@dataclass + +@dataclass(repr=False) class TaskNotificationMessage(SystemMessage): """System message emitted when a task completes, fails, or is stopped. @@ -872,8 +921,13 @@ class TaskNotificationMessage(SystemMessage): tool_use_id: str | None = None usage: TaskUsage | None = None + def __repr__(self) -> str: + return ( + f"TaskNotificationMessage(task_id={self.task_id!r}, status={self.status!r})" + ) -@dataclass + +@dataclass(repr=False) class ResultMessage: """Result message with cost and usage information.""" @@ -889,8 +943,18 @@ class ResultMessage: result: str | None = None structured_output: Any = None + def __repr__(self) -> str: + parts = [f"num_turns={self.num_turns}"] + if self.is_error: + parts.append("is_error=True") + if self.total_cost_usd is not None: + parts.append(f"total_cost_usd={self.total_cost_usd}") + if self.stop_reason is not None: + parts.append(f"stop_reason={self.stop_reason!r}") + return f"ResultMessage({', '.join(parts)})" -@dataclass + +@dataclass(repr=False) class StreamEvent: """Stream event for partial message updates during streaming.""" @@ -899,6 +963,9 @@ class StreamEvent: event: dict[str, Any] # The raw Anthropic API stream event parent_tool_use_id: str | None = None + def __repr__(self) -> str: + return f"StreamEvent(session_id={self.session_id!r})" + # Rate limit types — see https://docs.claude.com/en/docs/claude-code/rate-limits RateLimitStatus = Literal["allowed", "allowed_warning", "rejected"] @@ -907,7 +974,7 @@ class StreamEvent: ] -@dataclass +@dataclass(repr=False) class RateLimitInfo: """Rate limit status emitted by the CLI when rate limit state changes. @@ -932,8 +999,16 @@ class RateLimitInfo: overage_disabled_reason: str | None = None raw: dict[str, Any] = field(default_factory=dict) + def __repr__(self) -> str: + parts = [f"status={self.status!r}"] + if self.utilization is not None: + parts.append(f"utilization={self.utilization}") + if self.rate_limit_type is not None: + parts.append(f"rate_limit_type={self.rate_limit_type!r}") + return f"RateLimitInfo({', '.join(parts)})" -@dataclass + +@dataclass(repr=False) class RateLimitEvent: """Rate limit event emitted when rate limit info changes. @@ -946,6 +1021,9 @@ class RateLimitEvent: uuid: str session_id: str + def __repr__(self) -> str: + return f"RateLimitEvent(status={self.rate_limit_info.status!r})" + Message = ( UserMessage diff --git a/tests/test_types.py b/tests/test_types.py index 3b1bfb67..aed2f798 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -483,3 +483,258 @@ def test_mcp_servers_serializes_as_camelcase(self): assert "mcp_servers" not in payload assert payload["mcpServers"][0] == "slack" assert payload["mcpServers"][1]["local"]["command"] == "python" + + +class TestTruncateHelper: + def test_short_text_unchanged(self): + from claude_agent_sdk.types import _truncate + + assert _truncate("hello") == "hello" + + def test_exact_limit_unchanged(self): + from claude_agent_sdk.types import _truncate + + assert _truncate("a" * 80) == "a" * 80 + + def test_long_text_truncated(self): + from claude_agent_sdk.types import _truncate + + result = _truncate("a" * 100) + assert len(result) == 80 + assert result.endswith("...") + assert result == "a" * 77 + "..." + + def test_empty_string(self): + from claude_agent_sdk.types import _truncate + + assert _truncate("") == "" + + def test_custom_max_length(self): + from claude_agent_sdk.types import _truncate + + result = _truncate("a" * 20, max_length=10) + assert result == "a" * 7 + "..." + + +class TestContentBlockRepr: + def test_text_block_short(self): + block = TextBlock(text="Hello world") + assert repr(block) == "TextBlock(text='Hello world')" + + def test_text_block_truncated(self): + block = TextBlock(text="a" * 100) + r = repr(block) + assert r.startswith("TextBlock(text='") + assert "..." in r + assert r.endswith("')") + + def test_thinking_block(self): + block = ThinkingBlock(thinking="deep thoughts", signature="sig123") + r = repr(block) + assert "ThinkingBlock(thinking=" in r + assert "signature" not in r + + def test_tool_use_block(self): + block = ToolUseBlock( + id="toolu_abc", name="Read", input={"file_path": "/test.py"} + ) + r = repr(block) + assert r == "ToolUseBlock(id='toolu_abc', name='Read')" + + def test_tool_result_block(self): + block = ToolResultBlock(tool_use_id="toolu_abc", content="result text") + r = repr(block) + assert "ToolResultBlock(tool_use_id='toolu_abc'" in r + assert "content='result text'" in r + + def test_tool_result_block_error(self): + block = ToolResultBlock(tool_use_id="toolu_abc", content="err", is_error=True) + r = repr(block) + assert "is_error=True" in r + + def test_tool_result_block_no_content(self): + block = ToolResultBlock(tool_use_id="toolu_abc") + r = repr(block) + assert "None" not in r + + def test_tool_result_block_list_content(self): + block = ToolResultBlock( + tool_use_id="toolu_abc", + content=[{"type": "text", "text": "hi"}], + ) + r = repr(block) + assert "ToolResultBlock(tool_use_id='toolu_abc'" in r + assert "content=[" in r + + def test_text_block_with_quotes(self): + """Ensure repr properly escapes quotes in text.""" + block = TextBlock(text='it\'s a "test"') + r = repr(block) + assert "TextBlock(text=" in r + # !r formatting should produce a valid Python string literal + assert r.count("TextBlock") == 1 + + def test_text_block_with_newlines(self): + """Ensure repr properly handles newlines.""" + block = TextBlock(text="line1\nline2") + r = repr(block) + assert "\\n" in r + + def test_text_block_with_backslashes(self): + """Ensure repr properly handles backslashes.""" + block = TextBlock(text="path\\to\\file") + r = repr(block) + assert "TextBlock(text=" in r + + +class TestMessageRepr: + def test_user_message_string_content(self): + msg = UserMessage(content="Hello") + r = repr(msg) + assert r == "UserMessage(content='Hello')" + + def test_user_message_blocks(self): + msg = UserMessage(content=[TextBlock(text="Hi")]) + r = repr(msg) + assert r == "UserMessage(content=[TextBlock(text='Hi')])" + + def test_assistant_message(self): + msg = AssistantMessage( + content=[TextBlock(text="Hello!")], + model="claude-opus-4-6", + ) + r = repr(msg) + assert "AssistantMessage(model='claude-opus-4-6'" in r + assert "TextBlock(text='Hello!')" in r + + def test_assistant_message_with_error(self): + msg = AssistantMessage( + content=[TextBlock(text="err")], + model="claude-opus-4-6", + error="server_error", + ) + r = repr(msg) + assert "error='server_error'" in r + + def test_assistant_message_omits_none_fields(self): + msg = AssistantMessage( + content=[TextBlock(text="Hi")], + model="claude-opus-4-6", + ) + r = repr(msg) + assert "parent_tool_use_id" not in r + assert "error" not in r + assert "usage" not in r + + +class TestSystemMessageRepr: + def test_system_message(self): + from claude_agent_sdk.types import SystemMessage + + msg = SystemMessage(subtype="init", data={"key": "value"}) + assert repr(msg) == "SystemMessage(subtype='init')" + + def test_task_started(self): + from claude_agent_sdk.types import TaskStartedMessage + + msg = TaskStartedMessage( + subtype="task_started", + data={}, + task_id="t1", + description="Running tests", + uuid="u1", + session_id="s1", + ) + r = repr(msg) + assert "TaskStartedMessage(task_id='t1'" in r + assert "description='Running tests'" in r + + def test_task_progress(self): + from claude_agent_sdk.types import TaskProgressMessage + + msg = TaskProgressMessage( + subtype="task_progress", + data={}, + task_id="t1", + description="In progress", + usage={"total_tokens": 100, "tool_uses": 1, "duration_ms": 500}, + uuid="u1", + session_id="s1", + ) + r = repr(msg) + assert "TaskProgressMessage(task_id='t1'" in r + + def test_task_notification(self): + from claude_agent_sdk.types import TaskNotificationMessage + + msg = TaskNotificationMessage( + subtype="task_notification", + data={}, + task_id="t1", + status="completed", + output_file="/tmp/out", + summary="Done", + uuid="u1", + session_id="s1", + ) + r = repr(msg) + assert "TaskNotificationMessage(task_id='t1'" in r + assert "status='completed'" in r + + +class TestResultAndEventRepr: + def test_result_message(self): + msg = ResultMessage( + subtype="result", + duration_ms=1234, + duration_api_ms=1000, + is_error=False, + num_turns=3, + session_id="s1", + total_cost_usd=0.05, + stop_reason="end_turn", + ) + r = repr(msg) + assert "ResultMessage(" in r + assert "num_turns=3" in r + assert "total_cost_usd=0.05" in r + assert "stop_reason='end_turn'" in r + + def test_result_message_error(self): + msg = ResultMessage( + subtype="result", + duration_ms=100, + duration_api_ms=50, + is_error=True, + num_turns=1, + session_id="s1", + ) + r = repr(msg) + assert "is_error=True" in r + + def test_stream_event(self): + from claude_agent_sdk.types import StreamEvent + + msg = StreamEvent( + uuid="u1", + session_id="s1", + event={"type": "content_block_delta"}, + ) + r = repr(msg) + assert r == "StreamEvent(session_id='s1')" + + def test_rate_limit_info(self): + from claude_agent_sdk.types import RateLimitInfo + + info = RateLimitInfo(status="allowed", utilization=0.5) + r = repr(info) + assert "RateLimitInfo(status='allowed'" in r + assert "utilization=0.5" in r + + def test_rate_limit_event(self): + from claude_agent_sdk.types import RateLimitEvent, RateLimitInfo + + info = RateLimitInfo(status="rejected") + event = RateLimitEvent(rate_limit_info=info, uuid="u1", session_id="s1") + r = repr(event) + assert r == "RateLimitEvent(status='rejected')"