Skip to content

Commit 8353fe1

Browse files
NayukiChibaSoulter
andauthored
fix(anthropic): Anthropic API tool_choice schema conversion (#8328)
* fix(anthropic): 修复 Anthropic API tool_choice 格式转换及参数支持 - 将 tool_choice 从简单的 auto/required 逻辑改为遵循 Anthropic API 规范,支持 auto/any/none/tool 四种原生值 - 兼容 OpenAI 风格的 tool_choice="required",自动映射为 {"type": "any"} - 允许直接传入 dict 类型的 tool_choice 以实现指定工具调用 - 更新 text_chat 和 stream_chat 入口的参数类型标注,扩大可接收的 tool_choice 类型 - 新增 tool_choice 格式转换的单元测试,覆盖各类输入场景 Closes #8319 * Clean up test cases and remove unused mocks Removed unused mock classes and tests for tool_choice conversion. * fix(anthropic): 修复 Anthropic API tool_choice="tool" 参数处理及重构格式转换逻辑 - 提取静态方法 _normalize_tool_choice 统一处理 tool_choice 格式转换,消除重复代码 - 处理字符串 "tool" 值时,因无法指定具体工具名而回退为 auto 并记录警告,避免无效请求 - 在 _query 和 _stream_query 中采用默认值 auto 并应用规范化逻辑,确保一致性 * test(anthropic): 添加空工具集时跳过工具参数设置的测试 - 新增 _EmptyToolSet 模拟类,模拟无工具场景 - 新增测试用例 test_tool_choice_empty_tool_list_skips_tool_choice - 验证当 ToolSet 存在但工具列表为空时,请求不包含 tools 和 tool_choice 参数 - 完善边缘情况测试覆盖,确保与现有逻辑一致 * style: ruff 格式化一下 --------- Co-authored-by: Weilong Liao <37870767+Soulter@users.noreply.github.com>
1 parent 01a47b8 commit 8353fe1

2 files changed

Lines changed: 238 additions & 16 deletions

File tree

astrbot/core/provider/sources/anthropic_source.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -318,15 +318,44 @@ def _update_usage(self, token_usage: TokenUsage, usage: MessageDeltaUsage) -> No
318318
if usage.output_tokens is not None:
319319
token_usage.output = usage.output_tokens
320320

321+
@staticmethod
322+
def _normalize_tool_choice(tool_choice) -> dict:
323+
"""将 tool_choice 转换为 Anthropic API 要求的格式
324+
325+
参考: https://platform.claude.com/docs/en/agents-and-tools/tool-use/define-tools#controlling-claudes-output
326+
327+
Args:
328+
tool_choice: 原始 tool_choice 值,支持 str 或 dict
329+
330+
Returns:
331+
Anthropic API 格式的 tool_choice 字典
332+
"""
333+
if isinstance(tool_choice, dict):
334+
return tool_choice
335+
336+
if tool_choice == "required":
337+
# 兼容 OpenAI 命名:required → any
338+
return {"type": "any"}
339+
340+
if tool_choice in ("auto", "any", "none"):
341+
return {"type": tool_choice}
342+
343+
if tool_choice == "tool":
344+
# {"type": "tool"} 必须配合 name 字段指定具体工具
345+
# 纯字符串 "tool" 无法指定工具名,回退为 auto
346+
logger.warning("tool_choice='tool' 无法指定工具名,已回退为 'auto'")
347+
return {"type": "auto"}
348+
349+
logger.warning(f"未知的 tool_choice 值: {tool_choice},已回退为 'auto'")
350+
return {"type": "auto"}
351+
321352
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
322353
if tools:
323354
if tool_list := tools.get_func_desc_anthropic_style():
324355
payloads["tools"] = tool_list
325-
payloads["tool_choice"] = {
326-
"type": "any"
327-
if payloads.get("tool_choice") == "required"
328-
else "auto"
329-
}
356+
payloads["tool_choice"] = self._normalize_tool_choice(
357+
payloads.get("tool_choice", "auto")
358+
)
330359

331360
extra_body = self.provider_config.get("custom_extra_body", {})
332361

@@ -409,11 +438,9 @@ async def _query_stream(
409438
if tools:
410439
if tool_list := tools.get_func_desc_anthropic_style():
411440
payloads["tools"] = tool_list
412-
payloads["tool_choice"] = {
413-
"type": "any"
414-
if payloads.get("tool_choice") == "required"
415-
else "auto"
416-
}
441+
payloads["tool_choice"] = self._normalize_tool_choice(
442+
payloads.get("tool_choice", "auto")
443+
)
417444

418445
# 用于累积工具调用信息
419446
tool_use_buffer = {}
@@ -569,7 +596,7 @@ async def text_chat(
569596
tool_calls_result=None,
570597
model=None,
571598
extra_user_content_parts=None,
572-
tool_choice: Literal["auto", "required"] = "auto",
599+
tool_choice: Literal["auto", "any", "tool", "none"] | dict[str, str] = "auto",
573600
**kwargs,
574601
) -> LLMResponse:
575602
if contexts is None:
@@ -598,8 +625,8 @@ async def text_chat(
598625
if not isinstance(tool_calls_result, list):
599626
context_query.extend(tool_calls_result.to_openai_messages())
600627
else:
601-
for tcr in tool_calls_result:
602-
context_query.extend(tcr.to_openai_messages())
628+
for tool_call_result in tool_calls_result:
629+
context_query.extend(tool_call_result.to_openai_messages())
603630

604631
system_prompt, new_messages = self._prepare_payload(context_query)
605632

@@ -637,7 +664,7 @@ async def text_chat_stream(
637664
tool_calls_result=None,
638665
model=None,
639666
extra_user_content_parts=None,
640-
tool_choice: Literal["auto", "required"] = "auto",
667+
tool_choice: Literal["auto", "any", "tool", "none"] | dict[str, str] = "auto",
641668
**kwargs,
642669
):
643670
if contexts is None:
@@ -665,8 +692,8 @@ async def text_chat_stream(
665692
if not isinstance(tool_calls_result, list):
666693
context_query.extend(tool_calls_result.to_openai_messages())
667694
else:
668-
for tcr in tool_calls_result:
669-
context_query.extend(tcr.to_openai_messages())
695+
for tool_call_result in tool_calls_result:
696+
context_query.extend(tool_call_result.to_openai_messages())
670697

671698
system_prompt, new_messages = self._prepare_payload(context_query)
672699

tests/test_anthropic_kimi_code_provider.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,198 @@ def test_prepare_payload_does_not_merge_non_consecutive_tool_results():
416416
],
417417
},
418418
]
419+
420+
421+
# ---- tool_choice 转换测试 ----
422+
423+
424+
class _FakeToolSet:
425+
"""模拟包含工具的 ToolSet"""
426+
427+
def get_func_desc_anthropic_style(self):
428+
return [{"name": "get_weather", "description": "Get weather"}]
429+
430+
def empty(self):
431+
return False
432+
433+
434+
class _EmptyToolSet:
435+
"""模拟空工具列表的 ToolSet,用于验证无工具时不设置 tool_choice"""
436+
437+
def get_func_desc_anthropic_style(self):
438+
return []
439+
440+
def empty(self):
441+
return True
442+
443+
444+
class _FakeMessages:
445+
"""模拟 AsyncAnthropic.messages 命名空间"""
446+
447+
448+
async def _capture_payloads_create(**kwargs):
449+
"""捕获 payloads 并返回一个真实的 Message 实例"""
450+
from anthropic.types import Message, TextBlock, Usage
451+
452+
_capture_payloads_create.last_kwargs = kwargs
453+
return Message(
454+
id="msg_fake",
455+
content=[TextBlock(type="text", text="Hello")],
456+
model="claude-test",
457+
role="assistant",
458+
stop_reason=None,
459+
stop_sequence=None,
460+
type="message",
461+
usage=Usage(input_tokens=10, output_tokens=5),
462+
)
463+
464+
465+
def _setup_provider_with_mock_client(monkeypatch) -> anthropic_source.ProviderAnthropic:
466+
"""创建 provider 并 mock 底层 API 调用"""
467+
monkeypatch.setattr(anthropic_source, "AsyncAnthropic", _FakeAsyncAnthropic)
468+
469+
provider = anthropic_source.ProviderAnthropic(
470+
provider_config={
471+
"id": "anthropic-test",
472+
"type": "anthropic_chat_completion",
473+
"model": "claude-test",
474+
"key": ["test-key"],
475+
},
476+
provider_settings={},
477+
)
478+
479+
fakeMessages = _FakeMessages()
480+
fakeMessages.create = _capture_payloads_create
481+
provider.client.messages = fakeMessages
482+
483+
return provider
484+
485+
486+
@pytest.mark.asyncio
487+
async def test_tool_choice_auto_converts_to_dict(monkeypatch):
488+
"""tool_choice='auto' 应转换为 {'type': 'auto'}"""
489+
provider = _setup_provider_with_mock_client(monkeypatch)
490+
491+
await provider.text_chat(
492+
prompt="hello",
493+
func_tool=_FakeToolSet(),
494+
tool_choice="auto",
495+
)
496+
497+
assert _capture_payloads_create.last_kwargs["tool_choice"] == {"type": "auto"}
498+
499+
500+
@pytest.mark.asyncio
501+
async def test_tool_choice_any_converts_to_dict(monkeypatch):
502+
"""tool_choice='any' 应转换为 {'type': 'any'}"""
503+
provider = _setup_provider_with_mock_client(monkeypatch)
504+
505+
await provider.text_chat(
506+
prompt="hello",
507+
func_tool=_FakeToolSet(),
508+
tool_choice="any",
509+
)
510+
511+
assert _capture_payloads_create.last_kwargs["tool_choice"] == {"type": "any"}
512+
513+
514+
@pytest.mark.asyncio
515+
async def test_tool_choice_none_converts_to_dict(monkeypatch):
516+
"""tool_choice='none' 应转换为 {'type': 'none'}"""
517+
provider = _setup_provider_with_mock_client(monkeypatch)
518+
519+
await provider.text_chat(
520+
prompt="hello",
521+
func_tool=_FakeToolSet(),
522+
tool_choice="none",
523+
)
524+
525+
assert _capture_payloads_create.last_kwargs["tool_choice"] == {"type": "none"}
526+
527+
528+
@pytest.mark.asyncio
529+
async def test_tool_choice_required_legacy_compat(monkeypatch):
530+
"""tool_choice='required'(OpenAI 命名) 应兼容转换为 {'type': 'any'}"""
531+
provider = _setup_provider_with_mock_client(monkeypatch)
532+
533+
await provider.text_chat(
534+
prompt="hello",
535+
func_tool=_FakeToolSet(),
536+
tool_choice="required",
537+
)
538+
539+
assert _capture_payloads_create.last_kwargs["tool_choice"] == {"type": "any"}
540+
541+
542+
@pytest.mark.asyncio
543+
async def test_tool_choice_dict_passthrough(monkeypatch):
544+
"""tool_choice 为 dict 时应直接透传"""
545+
provider = _setup_provider_with_mock_client(monkeypatch)
546+
547+
await provider.text_chat(
548+
prompt="hello",
549+
func_tool=_FakeToolSet(),
550+
tool_choice={"type": "tool", "name": "get_weather"},
551+
)
552+
553+
assert _capture_payloads_create.last_kwargs["tool_choice"] == {
554+
"type": "tool",
555+
"name": "get_weather",
556+
}
557+
558+
559+
@pytest.mark.asyncio
560+
async def test_tool_choice_default_when_not_set(monkeypatch):
561+
"""未传 tool_choice 时,默认应为 {'type': 'auto'}"""
562+
provider = _setup_provider_with_mock_client(monkeypatch)
563+
564+
await provider.text_chat(
565+
prompt="hello",
566+
func_tool=_FakeToolSet(),
567+
)
568+
569+
assert _capture_payloads_create.last_kwargs["tool_choice"] == {"type": "auto"}
570+
571+
572+
@pytest.mark.asyncio
573+
async def test_tool_choice_invalid_string_falls_back_to_auto(monkeypatch):
574+
"""无效的 tool_choice 字符串应回退为 {'type': 'auto'}"""
575+
provider = _setup_provider_with_mock_client(monkeypatch)
576+
577+
await provider.text_chat(
578+
prompt="hello",
579+
func_tool=_FakeToolSet(),
580+
tool_choice="invalid_value",
581+
)
582+
583+
assert _capture_payloads_create.last_kwargs["tool_choice"] == {"type": "auto"}
584+
585+
586+
@pytest.mark.asyncio
587+
async def test_tool_choice_no_tools_skips_tool_choice(monkeypatch):
588+
"""无工具时不应设置 tool_choice"""
589+
provider = _setup_provider_with_mock_client(monkeypatch)
590+
591+
await provider.text_chat(
592+
prompt="hello",
593+
func_tool=None,
594+
tool_choice="any",
595+
)
596+
597+
assert "tool_choice" not in _capture_payloads_create.last_kwargs
598+
599+
600+
@pytest.mark.asyncio
601+
async def test_tool_choice_empty_tool_list_skips_tool_choice(monkeypatch):
602+
"""ToolSet 存在但工具列表为空时,不应设置 tools 和 tool_choice"""
603+
provider = _setup_provider_with_mock_client(monkeypatch)
604+
605+
await provider.text_chat(
606+
prompt="hello",
607+
func_tool=_EmptyToolSet(),
608+
tool_choice="any",
609+
)
610+
611+
kwargs = _capture_payloads_create.last_kwargs
612+
assert "tools" not in kwargs
613+
assert "tool_choice" not in kwargs

0 commit comments

Comments
 (0)