Skip to content

Commit 656baf8

Browse files
authored
fix: #3357 output schema names for Literal types (#3358)
1 parent eca794c commit 656baf8

2 files changed

Lines changed: 17 additions & 4 deletions

File tree

src/agents/agent_output.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,15 +180,16 @@ def _is_subclass_of_base_model_or_dict(t: Any) -> bool:
180180
return issubclass(t, BaseModel | dict)
181181

182182

183-
def _type_to_str(t: type[Any]) -> str:
183+
def _type_to_str(t: Any) -> str:
184184
origin = get_origin(t)
185185
args = get_args(t)
186186

187187
if origin is None:
188188
# It's a simple type like `str`, `int`, etc.
189-
return t.__name__
189+
return getattr(t, "__name__", repr(t))
190190
elif args:
191191
args_str = ", ".join(_type_to_str(arg) for arg in args)
192-
return f"{origin.__name__}[{args_str}]"
192+
origin_name = getattr(origin, "__name__", str(origin))
193+
return f"{origin_name}[{args_str}]"
193194
else:
194195
return str(t)

tests/test_output_tool.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Any
2+
from typing import Any, Literal, cast
33

44
import pytest
55
from pydantic import BaseModel
@@ -77,6 +77,18 @@ def test_structured_output_list():
7777
assert validated == ["foo", "bar"]
7878

7979

80+
def test_structured_output_literal_name_handles_literal_values():
81+
output_schema = AgentOutputSchema(output_type=cast(type[Any], Literal["ok"]))
82+
83+
assert output_schema.name() == "Literal['ok']"
84+
85+
86+
def test_structured_output_nested_literal_name_handles_literal_values():
87+
output_schema = AgentOutputSchema(output_type=list[Literal["ok", "done"]])
88+
89+
assert output_schema.name() == "list[Literal['ok', 'done']]"
90+
91+
8092
def test_structured_output_generic_dict_is_not_wrapped():
8193
output_schema = AgentOutputSchema(output_type=dict[str, int], strict_json_schema=False)
8294
assert output_schema.output_type == dict[str, int]

0 commit comments

Comments
 (0)