Skip to content

Commit 570a4d5

Browse files
giles17CopilotCopilot
authored
Python: Support OpenAI and Gemini allowed_tools tool choice (#5322)
* Support OpenAI allowed_tools in ToolMode (#5309) Add allowed_tools field to ToolMode TypedDict, enabling users to restrict which tools the model may call via the OpenAI allowed_tools tool_choice type. This preserves prompt caching by keeping all tools in the tools list while limiting which ones the model can invoke. - Add allowed_tools: list[str] to ToolMode TypedDict - Add validation in validate_tool_mode() (only valid when mode == "auto") - Convert to OpenAI API format in _prepare_options() - Add tests for validation and API payload generation Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Python: Support OpenAI `allowed_tools` tool choice in Python SDK Fixes #5309 * Fix #5309: Validate allowed_tools shape and add Chat Completions client support - validate_tool_mode now checks allowed_tools is a non-string sequence of strings and normalizes to list[str], raising ContentError for invalid types - Add missing allowed_tools branch in _chat_completion_client._prepare_options so allowed_tools is emitted as the OpenAI allowed_tools wire format instead of being silently dropped - Add tests for invalid allowed_tools types (string, int, mixed), empty list, tuple normalization, and Chat Completions client payload generation Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: support allowed_tools with mode 'required' in addition to 'auto' OpenAI's allowed_tools tool_choice type supports both mode 'auto' and 'required'. Update validation, client conversion, and tests to allow both modes instead of restricting to 'auto' only. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: use Gemini VALIDATED mode for allowed_tools, warn in unsupported providers - Use FunctionCallingConfigMode.VALIDATED instead of ANY when allowed_tools is set with auto mode in Gemini, preserving optional tool-call semantics. - Handle allowed_tools in required mode with required_function_name precedence. - Fix allowed_names guard to use identity check (is not None) so empty lists are preserved. - Bump google-genai minimum to >=1.32.0 (VALIDATED added in that version). - Add warnings in Anthropic and Bedrock when allowed_tools is set but not supported. - Add Gemini unit tests for allowed_tools with auto, required, empty list, and required_function_name precedence scenarios. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: Chat Completions API does not support allowed_tools, add integration tests - Chat Completions API (_chat_completion_client.py) now warns and falls back to plain mode when allowed_tools is set, since the /chat/completions endpoint does not support the allowed_tools type. - Add allowed_tools integration test param to both OpenAIChatClient (Responses API) and OpenAIChatCompletionClient parametrized option tests. - Update Chat Completions unit tests to reflect the warn-and-fallback behavior. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: remove unused walrus operator variable in chat completion client Remove assigned-but-never-used variable 'allowed' flagged by ruff F841. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent f5419b9 commit 570a4d5

11 files changed

Lines changed: 912 additions & 605 deletions

File tree

python/packages/anthropic/agent_framework_anthropic/_chat_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,8 @@ def _prepare_tools_for_anthropic(self, options: Mapping[str, Any]) -> dict[str,
872872
tool_mode = validate_tool_mode(options.get("tool_choice"))
873873
if tool_mode is None:
874874
return result or None
875+
if "allowed_tools" in tool_mode:
876+
logger.warning("allowed_tools is not supported by Anthropic; the setting will be ignored")
875877
allow_multiple = options.get("allow_multiple_tool_calls")
876878
match tool_mode.get("mode"):
877879
case "auto":

python/packages/bedrock/agent_framework_bedrock/_chat_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,8 @@ def _prepare_options(
405405

406406
tool_config = self._prepare_tools(options.get("tools"))
407407
if tool_mode := validate_tool_mode(options.get("tool_choice")):
408+
if "allowed_tools" in tool_mode:
409+
logger.warning("allowed_tools is not supported by Bedrock; the setting will be ignored")
408410
match tool_mode.get("mode"):
409411
case "none":
410412
# Bedrock doesn't support toolChoice "none".

python/packages/core/agent_framework/_types.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3246,10 +3246,12 @@ class ToolMode(TypedDict, total=False):
32463246
Fields:
32473247
mode: One of "auto", "required", or "none".
32483248
required_function_name: Optional function name when `mode == "required"`.
3249+
allowed_tools: Optional list of tool names when `mode` is `"auto"` or `"required"`.
32493250
"""
32503251

32513252
mode: Literal["auto", "required", "none"]
32523253
required_function_name: str
3254+
allowed_tools: list[str]
32533255

32543256

32553257
# region TypedDict-based Chat Options
@@ -3482,7 +3484,7 @@ def validate_tool_mode(
34823484
34833485
Returns:
34843486
A ToolMode dict (contains keys: "mode", and optionally
3485-
"required_function_name"), or ``None`` when not provided.
3487+
"required_function_name" or "allowed_tools"), or ``None`` when not provided.
34863488
34873489
Raises:
34883490
ContentError: If the tool_choice string is invalid.
@@ -3499,6 +3501,17 @@ def validate_tool_mode(
34993501
raise ContentError(f"Invalid tool choice: {tool_choice['mode']}")
35003502
if tool_choice["mode"] != "required" and "required_function_name" in tool_choice:
35013503
raise ContentError("tool_choice with mode other than 'required' cannot have 'required_function_name'")
3504+
if tool_choice["mode"] not in ("auto", "required") and "allowed_tools" in tool_choice:
3505+
raise ContentError("tool_choice 'allowed_tools' is only valid when mode is 'auto' or 'required'")
3506+
if "allowed_tools" in tool_choice:
3507+
allowed_tools = tool_choice["allowed_tools"]
3508+
if isinstance(allowed_tools, str) or not isinstance(allowed_tools, Sequence):
3509+
raise ContentError("tool_choice 'allowed_tools' must be a non-string sequence of strings")
3510+
if not all(isinstance(tool_name, str) for tool_name in allowed_tools):
3511+
raise ContentError("tool_choice 'allowed_tools' must contain only strings")
3512+
normalized_tool_choice = dict(tool_choice)
3513+
normalized_tool_choice["allowed_tools"] = list(allowed_tools)
3514+
return cast(ToolMode, normalized_tool_choice)
35023515
return tool_choice
35033516

35043517

python/packages/core/tests/core/test_types.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,16 +1087,20 @@ def test_chat_tool_mode():
10871087
required_any: ToolMode = {"mode": "required"}
10881088
required_mode: ToolMode = {"mode": "required", "required_function_name": "example_function"}
10891089
none_mode: ToolMode = {"mode": "none"}
1090+
allowed_mode: ToolMode = {"mode": "auto", "allowed_tools": ["get_weather", "search_docs"]}
10901091

10911092
# Check the type and content
10921093
assert auto_mode["mode"] == "auto"
10931094
assert "required_function_name" not in auto_mode
1095+
assert "allowed_tools" not in auto_mode
10941096
assert required_any["mode"] == "required"
10951097
assert "required_function_name" not in required_any
10961098
assert required_mode["mode"] == "required"
10971099
assert required_mode["required_function_name"] == "example_function"
10981100
assert none_mode["mode"] == "none"
10991101
assert "required_function_name" not in none_mode
1102+
assert allowed_mode["mode"] == "auto"
1103+
assert allowed_mode["allowed_tools"] == ["get_weather", "search_docs"]
11001104

11011105
# equality of dicts
11021106
assert {"mode": "required", "required_function_name": "example_function"} == {
@@ -1154,6 +1158,45 @@ def test_chat_options_tool_choice_validation():
11541158
with raises(ContentError):
11551159
validate_tool_mode({"mode": "auto", "required_function_name": "should_not_be_here"})
11561160

1161+
# Valid allowed_tools
1162+
assert validate_tool_mode({"mode": "auto", "allowed_tools": ["get_weather"]}) == {
1163+
"mode": "auto",
1164+
"allowed_tools": ["get_weather"],
1165+
}
1166+
assert validate_tool_mode({"mode": "auto", "allowed_tools": ["get_weather", "search_docs"]}) == {
1167+
"mode": "auto",
1168+
"allowed_tools": ["get_weather", "search_docs"],
1169+
}
1170+
1171+
# allowed_tools valid with required mode
1172+
assert validate_tool_mode({"mode": "required", "allowed_tools": ["get_weather"]}) == {
1173+
"mode": "required",
1174+
"allowed_tools": ["get_weather"],
1175+
}
1176+
1177+
# allowed_tools invalid with none mode
1178+
with raises(ContentError):
1179+
validate_tool_mode({"mode": "none", "allowed_tools": ["get_weather"]})
1180+
1181+
# allowed_tools must be a non-string sequence of strings
1182+
with raises(ContentError):
1183+
validate_tool_mode({"mode": "auto", "allowed_tools": "get_weather"})
1184+
with raises(ContentError):
1185+
validate_tool_mode({"mode": "auto", "allowed_tools": 123})
1186+
with raises(ContentError):
1187+
validate_tool_mode({"mode": "auto", "allowed_tools": ["get_weather", 123]})
1188+
1189+
# Empty list is valid (caller explicitly allows no tools)
1190+
assert validate_tool_mode({"mode": "auto", "allowed_tools": []}) == {
1191+
"mode": "auto",
1192+
"allowed_tools": [],
1193+
}
1194+
1195+
# Tuple is normalized to list
1196+
result = validate_tool_mode({"mode": "auto", "allowed_tools": ("get_weather",)})
1197+
assert result is not None
1198+
assert result["allowed_tools"] == ["get_weather"]
1199+
11571200

11581201
def test_chat_options_merge(tool_tool, ai_tool) -> None:
11591202
"""Test merge_chat_options utility function."""

python/packages/gemini/agent_framework_gemini/_chat_client.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -823,19 +823,28 @@ def _prepare_tool_config(self, tool_choice: Any) -> types.ToolConfig | None:
823823

824824
match tool_mode.get("mode"):
825825
case "auto":
826-
function_calling_mode, allowed_names = types.FunctionCallingConfigMode.AUTO, None
826+
if "allowed_tools" in tool_mode:
827+
function_calling_mode = types.FunctionCallingConfigMode.VALIDATED
828+
allowed_names = list(tool_mode["allowed_tools"])
829+
else:
830+
function_calling_mode, allowed_names = types.FunctionCallingConfigMode.AUTO, None
827831
case "none":
828832
function_calling_mode, allowed_names = types.FunctionCallingConfigMode.NONE, None
829833
case "required":
830834
function_calling_mode = types.FunctionCallingConfigMode.ANY
831835
name = tool_mode.get("required_function_name")
832-
allowed_names = [name] if name else None
836+
if name:
837+
allowed_names = [name]
838+
elif "allowed_tools" in tool_mode:
839+
allowed_names = list(tool_mode["allowed_tools"])
840+
else:
841+
allowed_names = None
833842
case unknown_mode:
834843
logger.warning("Unsupported tool_choice mode for Gemini: %s", unknown_mode)
835844
return None
836845

837846
function_calling_kwargs: dict[str, Any] = {"mode": function_calling_mode}
838-
if allowed_names:
847+
if allowed_names is not None:
839848
function_calling_kwargs["allowed_function_names"] = allowed_names
840849

841850
return types.ToolConfig(function_calling_config=types.FunctionCallingConfig(**function_calling_kwargs))

python/packages/gemini/tests/test_gemini_client.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,6 +1157,86 @@ async def test_unknown_tool_choice_mode_is_ignored() -> None:
11571157
assert not hasattr(config, "tool_config") or config.tool_config is None
11581158

11591159

1160+
async def test_tool_choice_auto_with_allowed_tools_uses_VALIDATED() -> None:
1161+
"""Maps auto + allowed_tools to FunctionCallingConfigMode.VALIDATED with allowed_function_names."""
1162+
tool = _make_dummy_tool()
1163+
client, mock = _make_gemini_client()
1164+
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
1165+
1166+
await client.get_response(
1167+
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
1168+
options={
1169+
"tools": [tool],
1170+
"tool_choice": {"mode": "auto", "allowed_tools": ["dummy", "other"]},
1171+
},
1172+
)
1173+
1174+
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
1175+
function_calling_config = config.tool_config.function_calling_config
1176+
assert function_calling_config.mode == "VALIDATED"
1177+
assert function_calling_config.allowed_function_names == ["dummy", "other"]
1178+
1179+
1180+
async def test_tool_choice_auto_with_empty_allowed_tools_uses_VALIDATED() -> None:
1181+
"""Maps auto + empty allowed_tools to VALIDATED with empty allowed_function_names."""
1182+
tool = _make_dummy_tool()
1183+
client, mock = _make_gemini_client()
1184+
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
1185+
1186+
await client.get_response(
1187+
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
1188+
options={
1189+
"tools": [tool],
1190+
"tool_choice": {"mode": "auto", "allowed_tools": []},
1191+
},
1192+
)
1193+
1194+
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
1195+
function_calling_config = config.tool_config.function_calling_config
1196+
assert function_calling_config.mode == "VALIDATED"
1197+
assert function_calling_config.allowed_function_names == []
1198+
1199+
1200+
async def test_tool_choice_required_with_allowed_tools_uses_ANY() -> None:
1201+
"""Maps required + allowed_tools to ANY with allowed_function_names."""
1202+
tool = _make_dummy_tool()
1203+
client, mock = _make_gemini_client()
1204+
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
1205+
1206+
await client.get_response(
1207+
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
1208+
options={
1209+
"tools": [tool],
1210+
"tool_choice": {"mode": "required", "allowed_tools": ["dummy"]},
1211+
},
1212+
)
1213+
1214+
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
1215+
function_calling_config = config.tool_config.function_calling_config
1216+
assert function_calling_config.mode == "ANY"
1217+
assert function_calling_config.allowed_function_names == ["dummy"]
1218+
1219+
1220+
async def test_tool_choice_required_function_name_takes_precedence_over_allowed_tools() -> None:
1221+
"""When both required_function_name and allowed_tools are present, required_function_name wins."""
1222+
tool = _make_dummy_tool()
1223+
client, mock = _make_gemini_client()
1224+
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))
1225+
1226+
await client.get_response(
1227+
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
1228+
options={
1229+
"tools": [tool],
1230+
"tool_choice": {"mode": "required", "required_function_name": "dummy", "allowed_tools": ["other"]},
1231+
},
1232+
)
1233+
1234+
config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
1235+
function_calling_config = config.tool_config.function_calling_config
1236+
assert function_calling_config.mode == "ANY"
1237+
assert function_calling_config.allowed_function_names == ["dummy"]
1238+
1239+
11601240
# built-in tool factories
11611241

11621242

python/packages/openai/agent_framework_openai/_chat_client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,6 +1296,12 @@ async def _prepare_options(
12961296
"type": "function",
12971297
"name": func_name,
12981298
}
1299+
elif mode == "auto" and (allowed := tool_mode.get("allowed_tools")) is not None:
1300+
run_options["tool_choice"] = {
1301+
"type": "allowed_tools",
1302+
"mode": "auto",
1303+
"tools": [{"type": "function", "name": name} for name in allowed],
1304+
}
12991305
else:
13001306
run_options["tool_choice"] = mode
13011307
else:

python/packages/openai/agent_framework_openai/_chat_completion_client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,12 @@ def _prepare_options(self, messages: Sequence[Message], options: Mapping[str, An
662662
"type": "function",
663663
"function": {"name": func_name},
664664
}
665+
elif mode in ("auto", "required") and tool_mode.get("allowed_tools") is not None:
666+
logger.warning(
667+
"allowed_tools is not supported by the Chat Completions API; "
668+
"the setting will be ignored. Use OpenAIChatClient (Responses API) instead."
669+
)
670+
run_options["tool_choice"] = mode
665671
else:
666672
run_options["tool_choice"] = mode
667673

python/packages/openai/tests/openai/test_openai_chat_client.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4259,6 +4259,12 @@ async def get_api_key() -> str:
42594259
True,
42604260
id="tool_choice_required",
42614261
),
4262+
param(
4263+
"tool_choice",
4264+
{"mode": "auto", "allowed_tools": ["get_weather"]},
4265+
True,
4266+
id="tool_choice_allowed_tools",
4267+
),
42624268
param("response_format", OutputStruct, True, id="response_format_pydantic"),
42634269
param(
42644270
"response_format",
@@ -4813,6 +4819,90 @@ async def test_prepare_options_excludes_continuation_token() -> None:
48134819
assert run_options["background"] is True
48144820

48154821

4822+
async def test_prepare_options_allowed_tools() -> None:
4823+
"""Test that _prepare_options converts allowed_tools to OpenAI API format."""
4824+
client = OpenAIChatClient(model="test-model", api_key="test-key")
4825+
4826+
@tool
4827+
def get_weather(city: str) -> str:
4828+
"""Get the weather for a city."""
4829+
return f"Sunny in {city}"
4830+
4831+
@tool
4832+
def search_docs(query: str) -> str:
4833+
"""Search documentation."""
4834+
return f"Results for {query}"
4835+
4836+
messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
4837+
options: dict[str, Any] = {
4838+
"model": "test-model",
4839+
"tools": [get_weather, search_docs],
4840+
"tool_choice": {"mode": "auto", "allowed_tools": ["get_weather"]},
4841+
}
4842+
4843+
run_options = await client._prepare_options(messages, options)
4844+
4845+
assert run_options["tool_choice"] == {
4846+
"type": "allowed_tools",
4847+
"mode": "auto",
4848+
"tools": [{"type": "function", "name": "get_weather"}],
4849+
}
4850+
4851+
4852+
async def test_prepare_options_allowed_tools_multiple() -> None:
4853+
"""Test that _prepare_options converts multiple allowed_tools correctly."""
4854+
client = OpenAIChatClient(model="test-model", api_key="test-key")
4855+
4856+
@tool
4857+
def get_weather(city: str) -> str:
4858+
"""Get the weather for a city."""
4859+
return f"Sunny in {city}"
4860+
4861+
@tool
4862+
def search_docs(query: str) -> str:
4863+
"""Search documentation."""
4864+
return f"Results for {query}"
4865+
4866+
messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
4867+
options: dict[str, Any] = {
4868+
"model": "test-model",
4869+
"tools": [get_weather, search_docs],
4870+
"tool_choice": {"mode": "auto", "allowed_tools": ["get_weather", "search_docs"]},
4871+
}
4872+
4873+
run_options = await client._prepare_options(messages, options)
4874+
4875+
assert run_options["tool_choice"] == {
4876+
"type": "allowed_tools",
4877+
"mode": "auto",
4878+
"tools": [
4879+
{"type": "function", "name": "get_weather"},
4880+
{"type": "function", "name": "search_docs"},
4881+
],
4882+
}
4883+
4884+
4885+
async def test_prepare_options_auto_without_allowed_tools() -> None:
4886+
"""Test that auto mode without allowed_tools still returns plain 'auto' string."""
4887+
client = OpenAIChatClient(model="test-model", api_key="test-key")
4888+
4889+
@tool
4890+
def get_weather(city: str) -> str:
4891+
"""Get the weather for a city."""
4892+
return f"Sunny in {city}"
4893+
4894+
messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
4895+
options: dict[str, Any] = {
4896+
"model": "test-model",
4897+
"tools": [get_weather],
4898+
"tool_choice": {"mode": "auto"},
4899+
}
4900+
4901+
run_options = await client._prepare_options(messages, options)
4902+
4903+
assert run_options["tool_choice"] == "auto"
4904+
4905+
48164906
# endregion
48174907

48184908

0 commit comments

Comments
 (0)