Skip to content

Commit ea2e89e

Browse files
Add tool_choice + ForceTool (proposal 0025)
Provider.complete() gains an optional tool_choice parameter — one of "auto", "required", "none", or a ForceTool record — constraining the model's tool-calling behavior. Pre-send validation routes the three §5 failure modes through ProviderInvalidRequest (§7's existing category; no new category per the proposal's framing). ForceTool is a frozen Pydantic model with type: Literal["tool"] matching the spec discriminator. The OpenAI wire mapping in _build_request_body translates the spec shape to OpenAI's body per §8.1.1: string literals pass through verbatim; ForceTool renames type to "function" and nests the name under a function sub-object. None / omit preserves pre-0025 behavior — the field is absent on the wire and the provider's own default applies. 15 unit tests cover the three validation rules, ForceTool shape constraints (frozen, extras-forbid, Literal type), and the wire mapping rows from §8.1.1.
1 parent a3a22c6 commit ea2e89e

5 files changed

Lines changed: 336 additions & 6 deletions

File tree

src/openarmature/llm/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from .messages import (
4848
AssistantMessage,
4949
ContentBlock,
50+
ForceTool,
5051
ImageBlock,
5152
ImageSource,
5253
ImageSourceInline,
@@ -56,6 +57,7 @@
5657
TextBlock,
5758
Tool,
5859
ToolCall,
60+
ToolChoice,
5961
ToolMessage,
6062
UserMessage,
6163
)
@@ -64,6 +66,7 @@
6466
strict_mode_supported,
6567
validate_message_list,
6668
validate_response_schema,
69+
validate_tool_choice,
6770
validate_tools,
6871
)
6972
from .providers import OpenAIProvider, classify_http_error, parse_retry_after
@@ -83,6 +86,7 @@
8386
"AssistantMessage",
8487
"ContentBlock",
8588
"FinishReason",
89+
"ForceTool",
8690
"ImageBlock",
8791
"ImageSource",
8892
"ImageSourceInline",
@@ -107,6 +111,7 @@
107111
"TextBlock",
108112
"Tool",
109113
"ToolCall",
114+
"ToolChoice",
110115
"ToolMessage",
111116
"Usage",
112117
"UserMessage",
@@ -115,5 +120,6 @@
115120
"strict_mode_supported",
116121
"validate_message_list",
117122
"validate_response_schema",
123+
"validate_tool_choice",
118124
"validate_tools",
119125
]

src/openarmature/llm/messages.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,46 @@ class Tool(BaseModel):
6868
parameters: dict[str, Any]
6969

7070

71+
# Spec: realizes llm-provider §5 `tool_choice` discriminated-union
72+
# (proposal 0025). The string-literal modes (`"auto"`, `"required"`,
73+
# `"none"`) and the `ForceTool` record share the `ToolChoice` alias.
74+
# Implementations validate `tool_choice` against `tools` before send
75+
# (see ``validate_tool_choice`` in :mod:`provider`); violations raise
76+
# ``ProviderInvalidRequest`` per §7.
77+
class ForceTool(BaseModel):
78+
"""Force the model to call exactly the named tool.
79+
80+
Use the record form of the §5 `tool_choice` discriminated union
81+
when you need the model to call a specific tool by name. ``type``
82+
is the spec-level discriminator (``"tool"``); the wire mapping
83+
(§8.1.1) renames it to ``"function"`` for the OpenAI body. The
84+
``name`` MUST match a ``Tool.name`` in the supplied ``tools``
85+
list; ``validate_tool_choice`` enforces this at pre-send time and
86+
raises ``ProviderInvalidRequest`` on violation.
87+
"""
88+
89+
model_config = ConfigDict(extra="forbid", frozen=True)
90+
91+
# Frozen + extras-forbidden so a ``ForceTool`` instance is safely
92+
# hashable and structurally pinned. The ``Literal["tool"]`` default
93+
# makes ``ForceTool(name="search")`` ergonomic at the call site
94+
# while preserving the spec-level discriminator on the type.
95+
type: Literal["tool"] = "tool"
96+
name: str
97+
98+
99+
# Per spec §5: `tool_choice` is one of:
100+
# - ``"auto"`` — the model decides.
101+
# - ``"required"`` — the model MUST call at least one tool.
102+
# - ``"none"`` — the model MUST NOT call tools.
103+
# - ``ForceTool(name=X)`` — the model MUST call the named tool.
104+
# A union of the three string literals plus the record form.
105+
# Callers pass ``tool_choice=None`` (the default) to omit the field
106+
# from the wire — the provider's own default applies, preserving
107+
# pre-0025 behavior.
108+
ToolChoice = Literal["auto", "required", "none"] | ForceTool
109+
110+
71111
# ---------------------------------------------------------------------------
72112
# Per-role message classes
73113
# ---------------------------------------------------------------------------
@@ -274,6 +314,7 @@ class ToolMessage(_MessageBase):
274314
__all__ = [
275315
"AssistantMessage",
276316
"ContentBlock",
317+
"ForceTool",
277318
"ImageBlock",
278319
"ImageSource",
279320
"ImageSourceInline",
@@ -283,6 +324,7 @@ class ToolMessage(_MessageBase):
283324
"TextBlock",
284325
"Tool",
285326
"ToolCall",
327+
"ToolChoice",
286328
"ToolMessage",
287329
"UserMessage",
288330
]

src/openarmature/llm/provider.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,11 @@
4848
from .errors import ProviderInvalidRequest
4949
from .messages import (
5050
AssistantMessage,
51+
ForceTool,
5152
Message,
5253
SystemMessage,
5354
Tool,
55+
ToolChoice,
5456
ToolMessage,
5557
UserMessage,
5658
)
@@ -75,6 +77,7 @@ async def complete(
7577
tools: Sequence[Tool] | None = None,
7678
config: RuntimeConfig | None = None,
7779
response_schema: dict[str, Any] | type[BaseModel] | None = None,
80+
tool_choice: ToolChoice | None = None,
7881
) -> Response:
7982
"""Perform a single completion call.
8083
@@ -93,6 +96,12 @@ async def complete(
9396
supplied, the implementation constrains the model's
9497
output to the schema and populates ``Response.parsed``
9598
with the validated value.
99+
tool_choice: Optional tool-choice constraint (spec §5). One
100+
of ``"auto"``, ``"required"``, ``"none"``, or a
101+
:class:`ForceTool` record. When ``None`` (the default)
102+
the wire ``tool_choice`` field is omitted and the
103+
provider's own default applies. Pre-send validation
104+
routes through ``provider_invalid_request``.
96105
"""
97106
...
98107

@@ -174,6 +183,53 @@ def validate_tools(tools: Sequence[Tool] | None) -> None:
174183
seen.add(t.name)
175184

176185

186+
# Spec: realizes llm-provider §5 `tool_choice` pre-send validation
187+
# rules (proposal 0025). The three failure modes route through the
188+
# existing §7 ``provider_invalid_request`` category; no new error
189+
# categories per the spec's "no new category" framing. Validation
190+
# fires BEFORE any HTTP request is sent (fixture 031's mock_provider
191+
# returns an empty response list on these cases to fail the test
192+
# if a request escapes the validation gate).
193+
def validate_tool_choice(
194+
tool_choice: ToolChoice | None,
195+
tools: Sequence[Tool] | None,
196+
) -> None:
197+
"""Validate ``tool_choice`` against ``tools`` per spec §5.
198+
199+
Raises :class:`ProviderInvalidRequest` (the §7
200+
``provider_invalid_request`` category) on:
201+
202+
- ``tool_choice="required"`` supplied with empty / absent
203+
``tools``.
204+
- ``tool_choice=ForceTool(name=X)`` supplied with empty / absent
205+
``tools``.
206+
- ``tool_choice=ForceTool(name=X)`` supplied with ``X`` not in the
207+
supplied tools list.
208+
209+
No-op when ``tool_choice`` is ``None`` (the default — preserves
210+
pre-0025 behavior; the wire field is omitted and the provider's
211+
own default applies). ``tool_choice="auto"`` and
212+
``tool_choice="none"`` have no ``tools``-related preconditions.
213+
"""
214+
if tool_choice is None:
215+
return
216+
has_tools = bool(tools)
217+
if tool_choice == "required" and not has_tools:
218+
raise ProviderInvalidRequest('tool_choice="required" requires non-empty tools')
219+
if isinstance(tool_choice, ForceTool):
220+
if not has_tools:
221+
raise ProviderInvalidRequest(
222+
f"tool_choice ForceTool(name={tool_choice.name!r}) requires non-empty tools"
223+
)
224+
# ``tools`` is non-empty here per the preceding guard. The list
225+
# is also guaranteed non-None inside this branch.
226+
names = {t.name for t in tools or ()}
227+
if tool_choice.name not in names:
228+
raise ProviderInvalidRequest(
229+
f"tool_choice name {tool_choice.name!r} not in tools (declared: {sorted(names)})"
230+
)
231+
232+
177233
# ---------------------------------------------------------------------------
178234
# Schema helpers — used by structured-output Provider implementations
179235
# ---------------------------------------------------------------------------
@@ -485,5 +541,6 @@ def _resolve_ref(ref: str, root: dict[str, Any]) -> Any:
485541
"strict_mode_supported",
486542
"validate_message_list",
487543
"validate_response_schema",
544+
"validate_tool_choice",
488545
"validate_tools",
489546
]

src/openarmature/llm/providers/openai.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,19 +76,22 @@
7676
from ..messages import (
7777
AssistantMessage,
7878
ContentBlock,
79+
ForceTool,
7980
ImageBlock,
8081
ImageSourceInline,
8182
Message,
8283
SystemMessage,
8384
TextBlock,
8485
Tool,
8586
ToolCall,
87+
ToolChoice,
8688
UserMessage,
8789
)
8890
from ..provider import (
8991
strict_mode_supported,
9092
validate_message_list,
9193
validate_response_schema,
94+
validate_tool_choice,
9295
validate_tools,
9396
)
9497
from ..response import FinishReason, ParsedValue, Response, RuntimeConfig, Usage
@@ -232,24 +235,35 @@ async def complete(
232235
tools: Sequence[Tool] | None = None,
233236
config: RuntimeConfig | None = None,
234237
response_schema: dict[str, Any] | type[BaseModel] | None = None,
238+
tool_choice: ToolChoice | None = None,
235239
) -> Response:
236240
"""Single completion call.
237241
238242
Pre-send validation runs first (per-message Pydantic +
239-
list-level invariants + response_schema shape check). HTTP
240-
errors map to canonical provider-error categories. The
241-
successful 200 body is parsed into a :class:`Response`;
242-
failure to parse raises ``provider_invalid_response``; failure
243-
to validate the response content against ``response_schema``
244-
raises ``structured_output_invalid``.
243+
list-level invariants + response_schema shape check +
244+
``tool_choice`` validation). HTTP errors map to canonical
245+
provider-error categories. The successful 200 body is parsed
246+
into a :class:`Response`; failure to parse raises
247+
``provider_invalid_response``; failure to validate the response
248+
content against ``response_schema`` raises
249+
``structured_output_invalid``.
245250
246251
When ``response_schema`` is supplied as a Pydantic BaseModel
247252
subclass, ``Response.parsed`` is a validated instance of that
248253
class; when supplied as a JSON Schema dict,
249254
``Response.parsed`` is the deserialized dict.
255+
256+
``tool_choice`` is validated against ``tools`` per spec §5:
257+
``"required"`` and the ``ForceTool`` record both demand
258+
non-empty ``tools``, and ``ForceTool.name`` must appear in the
259+
supplied list. Violations raise ``provider_invalid_request``
260+
BEFORE any HTTP request is sent.
250261
"""
251262
validate_message_list(messages)
252263
validate_tools(tools)
264+
# ``validate_tool_choice`` runs after ``validate_tools`` so the
265+
# name-membership check sees a structurally valid tools list.
266+
validate_tool_choice(tool_choice, tools)
253267
schema_dict, schema_class = _normalize_response_schema(response_schema)
254268
# On the fallback path, the wire-side messages list is an
255269
# augmented COPY of the caller's messages — original messages
@@ -268,6 +282,7 @@ async def complete(
268282
# form calls (schema_dict is None) must preserve any
269283
# caller-supplied response_format from RuntimeConfig extras.
270284
include_response_format=(schema_dict is None or not self._force_prompt_augmentation_fallback),
285+
tool_choice=tool_choice,
271286
)
272287

273288
# Spec observability §5.5 LLM provider span: when an
@@ -399,6 +414,7 @@ def _build_request_body(
399414
config: RuntimeConfig | None,
400415
schema_dict: dict[str, Any] | None,
401416
include_response_format: bool = True,
417+
tool_choice: ToolChoice | None = None,
402418
) -> dict[str, Any]:
403419
body: dict[str, Any] = {
404420
"model": self.model,
@@ -439,6 +455,22 @@ def _build_request_body(
439455
# loop above; strip it here so the fallback contract holds
440456
# regardless of caller-supplied extras.
441457
body.pop("response_format", None)
458+
# Per §8.1.1 (proposal 0025): map the spec-level `tool_choice`
459+
# shape onto the OpenAI wire shape. ``None`` omits the field
460+
# entirely so the OpenAI provider's own default applies —
461+
# load-bearing for backward compat with pre-0025 callers. The
462+
# string-literal modes pass through verbatim; the ``ForceTool``
463+
# record renames ``type: "tool"`` → ``type: "function"`` and
464+
# nests the name under a ``function`` sub-object per OpenAI's
465+
# request shape.
466+
if tool_choice is not None:
467+
if isinstance(tool_choice, ForceTool):
468+
body["tool_choice"] = {
469+
"type": "function",
470+
"function": {"name": tool_choice.name},
471+
}
472+
else:
473+
body["tool_choice"] = tool_choice
442474
return body
443475

444476
# ------------------------------------------------------------------

0 commit comments

Comments
 (0)