Skip to content

Commit a0c6f00

Browse files
Extend conformance harness for tool_choice fixtures
Adds assert_tool_choice_absent matcher mirroring the existing response_format_absent pattern, _build_tool_choice parser handling both YAML shapes (string for the three modes, dict for the ForceTool record form), and tool_choice passthrough on both call sites (raises path and success path). The expected_wire_request_checks dispatcher gains a tool_choice_absent key for fixture 029's default case, where the wire body MUST omit the field entirely (preserves pre-0025 caller behavior — the OpenAI provider's own default applies). LlmCallSpec uses _AllowExtras so tool_choice parses without an explicit pydantic field; no fixture-parsing model changes needed.
1 parent ea2e89e commit a0c6f00

3 files changed

Lines changed: 62 additions & 2 deletions

File tree

tests/conformance/harness/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
assert_error_carries,
2929
assert_response_format_absent,
3030
assert_system_references_schema,
31+
assert_tool_choice_absent,
3132
match_wire_body,
3233
request_body,
3334
)
@@ -41,6 +42,7 @@
4142
"assert_error_carries",
4243
"assert_response_format_absent",
4344
"assert_system_references_schema",
45+
"assert_tool_choice_absent",
4446
"discover_fixtures",
4547
"load_fixture",
4648
"match_wire_body",

tests/conformance/harness/wire.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,20 @@ def assert_response_format_absent(body: Mapping[str, Any]) -> None:
9797
)
9898

9999

100+
def assert_tool_choice_absent(body: Mapping[str, Any]) -> None:
101+
"""Assert the wire body has no ``tool_choice`` key.
102+
103+
Per spec §8.1.1 (proposal 0025): when the caller omits
104+
``tool_choice`` from the ``complete()`` call, the wire body MUST
105+
omit the field entirely so the OpenAI provider's own default
106+
applies. Mirrors :func:`assert_response_format_absent`'s pattern.
107+
"""
108+
if "tool_choice" in body:
109+
raise AssertionError(
110+
f"wire check failed: tool_choice present (value={body['tool_choice']!r}), expected absent"
111+
)
112+
113+
100114
def assert_system_references_schema(body: Mapping[str, Any], schema: Mapping[str, Any]) -> None:
101115
"""Assert the first wire message is a system message whose content
102116
references the supplied JSON Schema (via substring match of the

tests/conformance/test_llm_provider.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from openarmature.llm import (
3232
TRANSIENT_CATEGORIES,
3333
AssistantMessage,
34+
ForceTool,
3435
LlmProviderError,
3536
Message,
3637
OpenAIProvider,
@@ -41,6 +42,7 @@
4142
SystemMessage,
4243
Tool,
4344
ToolCall,
45+
ToolChoice,
4446
ToolMessage,
4547
UserMessage,
4648
)
@@ -49,6 +51,7 @@
4951
assert_error_carries,
5052
assert_response_format_absent,
5153
assert_system_references_schema,
54+
assert_tool_choice_absent,
5255
match_wire_body,
5356
request_body,
5457
)
@@ -200,6 +203,30 @@ def _build_tools(raw_list: list[Mapping[str, Any]] | None) -> list[Tool] | None:
200203
]
201204

202205

206+
def _build_tool_choice(raw: Any) -> ToolChoice | None:
207+
"""Translate a fixture's ``tool_choice:`` value into the
208+
:class:`ToolChoice` discriminated-union value.
209+
210+
Two YAML shapes per spec proposal 0025:
211+
212+
- String: ``auto`` / ``required`` / ``none`` — passes through
213+
verbatim.
214+
- Dict: ``{type: tool, name: X}`` — constructed into a
215+
:class:`ForceTool` record. The wire-side rename (``tool`` →
216+
``function``) happens inside the provider, not at parse time.
217+
218+
Returns ``None`` when the fixture omits ``tool_choice``; the
219+
provider's own default applies on the wire.
220+
"""
221+
if raw is None:
222+
return None
223+
if isinstance(raw, str):
224+
return cast("ToolChoice", raw)
225+
if isinstance(raw, dict):
226+
return ForceTool.model_validate(raw)
227+
raise AssertionError(f"unrecognized tool_choice shape in fixture: {raw!r}")
228+
229+
203230
# ---------------------------------------------------------------------------
204231
# Assertion helpers
205232
# ---------------------------------------------------------------------------
@@ -265,6 +292,9 @@ def _assert_wire_expectations(
265292
if key == "response_format_absent":
266293
if value is True:
267294
assert_response_format_absent(body)
295+
elif key == "tool_choice_absent":
296+
if value is True:
297+
assert_tool_choice_absent(body)
268298
elif key == "system_message_content_references_schema":
269299
if value is True:
270300
if not isinstance(response_schema, dict):
@@ -468,10 +498,17 @@ async def _run_one_call(
468498
_build_message(m) for m in cast("list[Mapping[str, Any]]", call_spec["messages"])
469499
]
470500
tools = _build_tools(cast("list[Mapping[str, Any]] | None", call_spec.get("tools")))
501+
tool_choice = _build_tool_choice(call_spec.get("tool_choice"))
471502
except ValidationError as ve:
472503
raise ProviderInvalidRequest(str(ve)) from ve
473504
await _maybe_with_retry(
474-
lambda: provider.complete(messages, tools, config, response_schema=response_schema),
505+
lambda: provider.complete(
506+
messages,
507+
tools,
508+
config,
509+
response_schema=response_schema,
510+
tool_choice=tool_choice,
511+
),
475512
retry_mw_cfg,
476513
)
477514
_assert_raises_matches(excinfo, expected["raises"])
@@ -485,8 +522,15 @@ async def _run_one_call(
485522
messages = [_build_message(m) for m in cast("list[Mapping[str, Any]]", call_spec["messages"])]
486523
messages_snapshot = [m.model_dump(mode="json") for m in messages]
487524
tools = _build_tools(cast("list[Mapping[str, Any]] | None", call_spec.get("tools")))
525+
tool_choice = _build_tool_choice(call_spec.get("tool_choice"))
488526
response = await _maybe_with_retry(
489-
lambda: provider.complete(messages, tools, config, response_schema=response_schema),
527+
lambda: provider.complete(
528+
messages,
529+
tools,
530+
config,
531+
response_schema=response_schema,
532+
tool_choice=tool_choice,
533+
),
490534
retry_mw_cfg,
491535
)
492536
_assert_response_matches(response, cast("Mapping[str, Any]", expected.get("response") or {}))

0 commit comments

Comments
 (0)