Skip to content

Commit 5e31ae4

Browse files
committed
test: add more unit tests for uncovered patterns
1 parent ceb238f commit 5e31ae4

6 files changed

Lines changed: 538 additions & 2 deletions

File tree

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
from agents.extensions.experimental.codex.items import AgentMessageItem, TodoItem, TodoListItem
6+
7+
8+
def test_dict_like_supports_mapping_access_for_dataclass_fields() -> None:
9+
item = AgentMessageItem(id="item-1", text="hello")
10+
11+
assert item["id"] == "item-1"
12+
assert item["text"] == "hello"
13+
assert item["type"] == "agent_message"
14+
assert item.get("text") == "hello"
15+
assert item.get("missing", "fallback") == "fallback"
16+
assert "id" in item
17+
assert "missing" not in item
18+
assert object() not in item
19+
assert list(item.keys()) == ["id", "text", "type"]
20+
21+
22+
def test_dict_like_raises_key_error_for_unknown_fields() -> None:
23+
item = AgentMessageItem(id="item-1", text="hello")
24+
25+
with pytest.raises(KeyError, match="missing"):
26+
_ = item["missing"]
27+
28+
29+
def test_dict_like_as_dict_recursively_converts_nested_dataclasses() -> None:
30+
item = TodoListItem(
31+
id="todo-list-1",
32+
items=[
33+
TodoItem(text="write tests", completed=True),
34+
TodoItem(text="run tests", completed=False),
35+
],
36+
)
37+
38+
assert item.as_dict() == {
39+
"id": "todo-list-1",
40+
"items": [
41+
{"text": "write tests", "completed": True},
42+
{"text": "run tests", "completed": False},
43+
],
44+
"type": "todo_list",
45+
}

tests/sandbox/test_session_state_roundtrip.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from pathlib import Path
1313
from typing import Literal
1414

15+
import pytest
16+
from pydantic import ValidationError
17+
1518
from agents.sandbox import Manifest
1619
from agents.sandbox.session import SandboxSessionState
1720
from agents.sandbox.snapshot import LocalSnapshot
@@ -27,6 +30,21 @@ class _StubSessionState(SandboxSessionState):
2730
custom_field: str
2831

2932

33+
class _PlainTypeSessionState(SandboxSessionState):
34+
__test__ = False
35+
type: str = "plain-type"
36+
37+
38+
class _EmptyDefaultSessionState(SandboxSessionState):
39+
__test__ = False
40+
type: Literal[""] = ""
41+
42+
43+
class _SimpleSessionState(SandboxSessionState):
44+
__test__ = False
45+
type: Literal["simple-roundtrip"] = "simple-roundtrip"
46+
47+
3048
# ---------------------------------------------------------------------------
3149
# Helpers
3250
# ---------------------------------------------------------------------------
@@ -93,3 +111,80 @@ def test_model_dump_preserves_snapshot_subclass_fields(self) -> None:
93111
dumped = state.model_dump()
94112

95113
assert "base_path" in dumped["snapshot"]
114+
115+
def test_parse_returns_subclass_instances_as_is(self) -> None:
116+
state = _make_session_state()
117+
118+
assert SandboxSessionState.parse(state) is state
119+
120+
def test_parse_upgrades_base_instance_through_registry(self) -> None:
121+
state = _SimpleSessionState(
122+
session_id=uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"),
123+
snapshot=LocalSnapshot(id="snap-1", base_path=Path("/tmp/snapshots")),
124+
manifest=Manifest(),
125+
)
126+
base_instance = SandboxSessionState.model_validate(state.model_dump())
127+
128+
reconstructed = SandboxSessionState.parse(base_instance)
129+
130+
assert type(reconstructed) is _SimpleSessionState
131+
assert reconstructed.session_id == uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb")
132+
133+
@pytest.mark.parametrize(
134+
("payload", "error_type", "message"),
135+
[
136+
({}, ValueError, "must include a string `type`"),
137+
({"type": "missing"}, ValueError, "unknown sandbox session state type `missing`"),
138+
("not-a-state", TypeError, "session state payload must be"),
139+
],
140+
)
141+
def test_parse_rejects_invalid_payloads(
142+
self,
143+
payload: object,
144+
error_type: type[Exception],
145+
message: str,
146+
) -> None:
147+
with pytest.raises(error_type, match=message):
148+
SandboxSessionState.parse(payload)
149+
150+
def test_subclass_registration_skips_non_literal_or_empty_type_defaults(self) -> None:
151+
assert "plain-type" not in SandboxSessionState._subclass_registry
152+
assert "" not in SandboxSessionState._subclass_registry
153+
154+
@pytest.mark.parametrize(
155+
("raw_ports", "expected"),
156+
[
157+
(None, ()),
158+
(8080, (8080,)),
159+
([8080, 9000, 8080], (8080, 9000)),
160+
],
161+
)
162+
def test_exposed_ports_are_normalized(
163+
self, raw_ports: object, expected: tuple[int, ...]
164+
) -> None:
165+
state = _StubSessionState(
166+
snapshot=LocalSnapshot(id="snap-1", base_path=Path("/tmp/snapshots")),
167+
manifest=Manifest(),
168+
custom_field="my-value",
169+
exposed_ports=raw_ports, # type: ignore[arg-type]
170+
)
171+
172+
assert state.exposed_ports == expected
173+
174+
@pytest.mark.parametrize(
175+
("raw_ports", "message"),
176+
[
177+
("8080", "exposed_ports must be an iterable"),
178+
([8080, "9000"], "exposed_ports must contain integers"),
179+
([0], "exposed_ports entries must be between 1 and 65535"),
180+
([65536], "exposed_ports entries must be between 1 and 65535"),
181+
],
182+
)
183+
def test_exposed_ports_reject_invalid_values(self, raw_ports: object, message: str) -> None:
184+
with pytest.raises((TypeError, ValidationError), match=message):
185+
_StubSessionState(
186+
snapshot=LocalSnapshot(id="snap-1", base_path=Path("/tmp/snapshots")),
187+
manifest=Manifest(),
188+
custom_field="my-value",
189+
exposed_ports=raw_ports, # type: ignore[arg-type]
190+
)
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from __future__ import annotations
2+
3+
from agents.sandbox.util.token_truncation import (
4+
TruncationPolicy,
5+
approx_bytes_for_tokens,
6+
approx_token_count,
7+
approx_tokens_from_byte_count,
8+
format_truncation_marker,
9+
formatted_truncate_text,
10+
formatted_truncate_text_with_token_count,
11+
removed_units_for_source,
12+
split_budget,
13+
split_string,
14+
truncate_text,
15+
truncate_with_byte_estimate,
16+
truncate_with_token_budget,
17+
)
18+
19+
20+
def test_truncation_policy_clamps_negative_limits_and_converts_budgets() -> None:
21+
byte_policy = TruncationPolicy.bytes(-10)
22+
token_policy = TruncationPolicy.tokens(-2)
23+
24+
assert byte_policy.limit == 0
25+
assert byte_policy.token_budget() == 0
26+
assert byte_policy.byte_budget() == 0
27+
assert token_policy.limit == 0
28+
assert token_policy.token_budget() == 0
29+
assert token_policy.byte_budget() == 0
30+
31+
32+
def test_formatted_truncate_text_returns_short_content_unchanged() -> None:
33+
assert formatted_truncate_text("short", TruncationPolicy.bytes(20)) == "short"
34+
35+
36+
def test_formatted_truncate_text_adds_line_count_when_truncated() -> None:
37+
result = formatted_truncate_text("alpha\nbeta\ngamma", TruncationPolicy.bytes(8))
38+
39+
assert result.startswith("Total output lines: 3\n\n")
40+
assert "chars truncated" in result
41+
42+
43+
def test_formatted_truncate_text_with_token_count_handles_none_and_short_content() -> None:
44+
assert formatted_truncate_text_with_token_count("short", None) == ("short", None)
45+
assert formatted_truncate_text_with_token_count("short", 10) == ("short", None)
46+
47+
48+
def test_formatted_truncate_text_with_token_count_reports_original_count() -> None:
49+
result, original_token_count = formatted_truncate_text_with_token_count("abcdefghi", 1)
50+
51+
assert result.startswith("Total output lines: 1\n\n")
52+
assert "tokens truncated" in result
53+
assert original_token_count == approx_token_count("abcdefghi")
54+
55+
56+
def test_truncate_text_dispatches_byte_and_token_modes() -> None:
57+
assert truncate_text("abcdef", TruncationPolicy.bytes(4)).startswith("a")
58+
assert "tokens truncated" in truncate_text("abcdefghi", TruncationPolicy.tokens(1))
59+
60+
61+
def test_truncate_with_token_budget_handles_empty_and_short_content() -> None:
62+
assert truncate_with_token_budget("", TruncationPolicy.tokens(1)) == ("", None)
63+
assert truncate_with_token_budget("abc", TruncationPolicy.tokens(1)) == ("abc", None)
64+
65+
66+
def test_truncate_with_byte_estimate_handles_empty_zero_and_short_content() -> None:
67+
assert truncate_with_byte_estimate("", TruncationPolicy.bytes(0)) == ""
68+
assert "chars truncated" in truncate_with_byte_estimate("abc", TruncationPolicy.bytes(0))
69+
assert truncate_with_byte_estimate("abc", TruncationPolicy.bytes(10)) == "abc"
70+
71+
72+
def test_split_string_preserves_utf8_boundaries() -> None:
73+
removed_chars, prefix, suffix = split_string("aあbいc", 2, 4)
74+
75+
assert prefix == "a"
76+
assert suffix == "いc"
77+
assert removed_chars == 2
78+
79+
80+
def test_split_string_handles_empty_content() -> None:
81+
assert split_string("", 10, 10) == (0, "", "")
82+
83+
84+
def test_formatting_and_estimate_helpers() -> None:
85+
byte_policy = TruncationPolicy.bytes(8)
86+
token_policy = TruncationPolicy.tokens(2)
87+
88+
assert "chars truncated" in format_truncation_marker(byte_policy, 3)
89+
assert "tokens truncated" in format_truncation_marker(token_policy, 2)
90+
assert split_budget(5) == (2, 3)
91+
assert removed_units_for_source(byte_policy, removed_bytes=10, removed_chars=4) == 4
92+
assert removed_units_for_source(token_policy, removed_bytes=9, removed_chars=4) == 3
93+
assert approx_token_count("abcde") == 2
94+
assert approx_bytes_for_tokens(-1) == 0
95+
assert approx_tokens_from_byte_count(0) == 0
96+
assert approx_tokens_from_byte_count(5) == 2
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from __future__ import annotations
2+
3+
import io
4+
from pathlib import Path
5+
from typing import Any, cast
6+
7+
import pytest
8+
9+
from agents.sandbox.errors import ErrorCode, WorkspaceWriteTypeError
10+
from agents.sandbox.session.workspace_payloads import coerce_write_payload
11+
12+
13+
class _Headers:
14+
def __init__(self, value: str | None) -> None:
15+
self._value = value
16+
17+
def get(self, name: str) -> str | None:
18+
assert name == "Content-Length"
19+
return self._value
20+
21+
22+
class _HeaderStream(io.BytesIO):
23+
def __init__(self, data: bytes, content_length: str | None) -> None:
24+
super().__init__(data)
25+
self.headers = _Headers(content_length)
26+
27+
28+
class _LengthStream(io.BytesIO):
29+
def __init__(self, data: bytes, length: int) -> None:
30+
super().__init__(data)
31+
self.length = length
32+
33+
34+
class _NoneReadStream:
35+
def read(self, size: int = -1) -> Any:
36+
_ = size
37+
return None
38+
39+
40+
class _BytearrayReadStream:
41+
def read(self, size: int = -1) -> Any:
42+
_ = size
43+
return bytearray(b"abc")
44+
45+
46+
class _TextReadStream:
47+
def read(self, size: int = -1) -> Any:
48+
_ = size
49+
return "not-bytes"
50+
51+
52+
class _UnseekableStream(io.BytesIO):
53+
def tell(self) -> int:
54+
raise OSError("not seekable")
55+
56+
57+
def test_coerce_write_payload_adapts_binary_reads() -> None:
58+
payload = coerce_write_payload(path=Path("/workspace/file.bin"), data=io.BytesIO(b"abc"))
59+
60+
assert payload.content_length == 3
61+
assert payload.stream.readable() is True
62+
assert payload.stream.read(1) == b"a"
63+
assert payload.stream.read() == b"bc"
64+
65+
66+
def test_coerce_write_payload_adapts_bytearray_and_none_reads() -> None:
67+
bytearray_payload = coerce_write_payload(
68+
path=Path("/workspace/file.bin"),
69+
data=cast(io.IOBase, _BytearrayReadStream()),
70+
)
71+
none_payload = coerce_write_payload(
72+
path=Path("/workspace/empty.bin"),
73+
data=cast(io.IOBase, _NoneReadStream()),
74+
)
75+
76+
assert bytearray_payload.stream.read() == b"abc"
77+
assert none_payload.stream.read() == b""
78+
79+
80+
def test_coerce_write_payload_supports_readinto_seek_and_tell() -> None:
81+
payload = coerce_write_payload(path=Path("/workspace/file.bin"), data=io.BytesIO(b"abcdef"))
82+
buffer = bytearray(3)
83+
84+
assert cast(Any, payload.stream).readinto(buffer) == 3
85+
assert bytes(buffer) == b"abc"
86+
assert payload.stream.tell() == 3
87+
assert payload.stream.seek(1) == 1
88+
assert payload.stream.read(2) == b"bc"
89+
90+
91+
def test_coerce_write_payload_rejects_text_chunks() -> None:
92+
payload = coerce_write_payload(
93+
path=Path("/workspace/file.txt"),
94+
data=cast(io.IOBase, _TextReadStream()),
95+
)
96+
97+
with pytest.raises(WorkspaceWriteTypeError) as exc_info:
98+
payload.stream.read()
99+
100+
assert exc_info.value.error_code is ErrorCode.WORKSPACE_WRITE_TYPE_ERROR
101+
assert exc_info.value.context == {
102+
"path": "/workspace/file.txt",
103+
"actual_type": "str",
104+
}
105+
106+
107+
@pytest.mark.parametrize(
108+
("stream", "expected"),
109+
[
110+
(_LengthStream(b"abc", 5), 5),
111+
(_HeaderStream(b"abc", "7"), 7),
112+
(_HeaderStream(b"abc", "-1"), 3),
113+
(_HeaderStream(b"abc", "invalid"), 3),
114+
(_UnseekableStream(b"abc"), None),
115+
],
116+
)
117+
def test_coerce_write_payload_uses_best_effort_content_length(
118+
stream: io.IOBase,
119+
expected: int | None,
120+
) -> None:
121+
payload = coerce_write_payload(path=Path("/workspace/file.bin"), data=stream)
122+
123+
assert payload.content_length == expected

0 commit comments

Comments
 (0)