Skip to content

Commit aa2b77f

Browse files
authored
Merge pull request #99 from Serverless-Devs/fix/sanitize-non-identifier-tool-field-names
fix: sanitize non-identifier field names in MCP/OpenAPI tool schemas
2 parents 1435314 + e97f1ba commit aa2b77f

2 files changed

Lines changed: 265 additions & 8 deletions

File tree

agentrun/integration/utils/tool.py

Lines changed: 96 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from pydantic import (
4242
AliasChoices,
4343
BaseModel,
44+
ConfigDict,
4445
create_model,
4546
Field,
4647
ValidationError,
@@ -1396,6 +1397,14 @@ def _create_function_with_signature(
13961397
args_schema, "__agentrun_argument_aliases__", {}
13971398
)
13981399
if alias_map:
1400+
existing_param_names = {p.name for p in parameters}
1401+
# 防御性 sanitize: alias 要落到 inspect.Parameter 上, 非法字符
1402+
# (如 ``x-access-id``)会触发 ValueError。当前 alias 仅由
1403+
# ``_maybe_add_body_alias`` 写入 "query", 但未来可能扩展。
1404+
# 若 alias 被 sanitize, 同时把 sanitized 名字加进 alias_map 指向同一
1405+
# canonical, 以便 _normalize_tool_arguments 在调用方使用签名暴露的
1406+
# sanitized 名字时也能正确翻译。
1407+
extra_alias_entries: Dict[str, str] = {}
13991408
for alias, canonical in alias_map.items():
14001409
canonical_field = args_schema.model_fields.get(canonical)
14011410
alias_annotation = (
@@ -1408,14 +1417,29 @@ def _create_function_with_signature(
14081417
and alias_annotation is not None
14091418
):
14101419
alias_annotation = Optional[alias_annotation]
1420+
alias_name = (
1421+
alias
1422+
if alias.isidentifier()
1423+
else _sanitize_python_identifier(alias)
1424+
)
1425+
if alias_name != alias and alias_name not in alias_map:
1426+
extra_alias_entries[alias_name] = canonical
1427+
if alias_name in existing_param_names:
1428+
continue
1429+
existing_param_names.add(alias_name)
14111430
parameters.append(
14121431
inspect.Parameter(
1413-
alias,
1432+
alias_name,
14141433
inspect.Parameter.KEYWORD_ONLY,
14151434
default=None,
14161435
annotation=alias_annotation,
14171436
)
14181437
)
1438+
if extra_alias_entries:
1439+
# 合并到 args_schema 的 alias map (避免就地改动原 dict)
1440+
merged = dict(alias_map)
1441+
merged.update(extra_alias_entries)
1442+
setattr(args_schema, "__agentrun_argument_aliases__", merged)
14191443

14201444
# 创建实际执行函数
14211445
def impl(**kwargs):
@@ -1425,7 +1449,9 @@ def impl(**kwargs):
14251449
if args_schema is not None:
14261450
try:
14271451
parsed = args_schema(**normalized_kwargs)
1428-
payload = parsed.model_dump(mode="python", exclude_unset=True)
1452+
payload = parsed.model_dump(
1453+
mode="python", exclude_unset=True, by_alias=True
1454+
)
14291455
except ValidationError as exc:
14301456
raise ValueError(
14311457
f"Invalid arguments for tool '{tool_name}': {exc}"
@@ -1674,6 +1700,34 @@ def _build_openapi_schema(
16741700
return schema, tuple(body_field_names), alias_map
16751701

16761702

1703+
_PY_KEYWORDS: Set[str] = set()
1704+
1705+
1706+
def _sanitize_python_identifier(name: str) -> str:
1707+
"""将任意字符串转换为合法的 Python 标识符
1708+
1709+
用于把 JSON Schema 中含 ``-`` / ``.`` 等字符的字段名(例如 ``x-access-id``)
1710+
映射成 Pydantic / ``inspect.Parameter`` 都能接受的字段名。原始名通过 alias
1711+
继续保留在 JSON Schema 和实际调用中。
1712+
"""
1713+
import keyword
1714+
1715+
if not _PY_KEYWORDS:
1716+
_PY_KEYWORDS.update(keyword.kwlist)
1717+
1718+
sanitized = re.sub(r"[^0-9a-zA-Z_]", "_", name)
1719+
sanitized = sanitized.lstrip("_")
1720+
if not sanitized:
1721+
sanitized = "field"
1722+
if sanitized[0].isdigit():
1723+
# 数字开头不是合法 Python 标识符; 又因为 Pydantic 不允许字段名以
1724+
# 下划线开头, 这里只能加字母前缀 "field_" 而不是直接补 "_".
1725+
sanitized = "field_" + sanitized
1726+
if sanitized in _PY_KEYWORDS:
1727+
sanitized = sanitized + "_"
1728+
return sanitized
1729+
1730+
16771731
def _json_schema_to_pydantic(
16781732
name: str,
16791733
schema: Optional[Dict[str, Any]],
@@ -1688,40 +1742,74 @@ def _json_schema_to_pydantic(
16881742

16891743
required_fields = set(schema.get("required", []))
16901744
fields = {}
1745+
needs_populate_by_name = False
1746+
used_py_names: Set[str] = set()
16911747

16921748
for field_name, field_schema in properties.items():
16931749
if not isinstance(field_schema, dict):
16941750
continue
16951751

1752+
# 把含非法字符(如 ``x-access-id``)或保留字(``class``)的字段名映射到
1753+
# 合法的 Python 标识符, 通过 alias 保留原名以便 JSON Schema 输出和
1754+
# 调用真实 MCP 工具时使用。
1755+
import keyword as _kw
1756+
1757+
if field_name.isidentifier() and not _kw.iskeyword(field_name):
1758+
py_name = field_name
1759+
else:
1760+
py_name = _sanitize_python_identifier(field_name)
1761+
if py_name in used_py_names:
1762+
suffix = 2
1763+
while f"{py_name}_{suffix}" in used_py_names:
1764+
suffix += 1
1765+
py_name = f"{py_name}_{suffix}"
1766+
used_py_names.add(py_name)
1767+
if py_name != field_name:
1768+
needs_populate_by_name = True
1769+
16961770
# 映射类型
16971771
field_type = _json_type_to_python(field_schema)
16981772
description = field_schema.get("description", "")
16991773
default = field_schema.get("default")
17001774
aliases = field_schema.get("x-aliases")
17011775
field_kwargs: Dict[str, Any] = {"description": description}
1776+
1777+
# 用 ``alias`` 同时作用于 JSON Schema 输出和 by_alias dump,
1778+
# 让 LLM/调用端看到的字段名仍是原始名(如 ``x-access-id``)。
1779+
if py_name != field_name:
1780+
field_kwargs["alias"] = field_name
17021781
if aliases:
17031782
if not isinstance(aliases, (list, tuple)):
17041783
aliases = [aliases]
1705-
field_kwargs["validation_alias"] = AliasChoices(
1706-
field_name, *aliases
1707-
)
1784+
alias_choices: List[str] = [field_name]
1785+
if py_name != field_name:
1786+
alias_choices.append(py_name)
1787+
for alias in aliases:
1788+
if alias and alias not in alias_choices:
1789+
alias_choices.append(alias)
1790+
field_kwargs["validation_alias"] = AliasChoices(*alias_choices)
17081791

17091792
# 构建字段定义
17101793
if field_name in required_fields:
17111794
# 必填字段
1712-
fields[field_name] = (field_type, Field(**field_kwargs))
1795+
fields[py_name] = (field_type, Field(**field_kwargs))
17131796
else:
17141797
# 可选字段
17151798
from typing import Optional as TypingOptional
17161799

1717-
fields[field_name] = (
1800+
fields[py_name] = (
17181801
TypingOptional[field_type],
17191802
Field(default=default, **field_kwargs),
17201803
)
17211804

17221805
# 创建模型,清理名称
17231806
model_name = re.sub(r"[^0-9a-zA-Z]", "", name.title())
1724-
return create_model(model_name or "Args", **fields) # type: ignore
1807+
model_kwargs: Dict[str, Any] = {}
1808+
if needs_populate_by_name:
1809+
model_kwargs["__config__"] = ConfigDict(populate_by_name=True)
1810+
return create_model( # type: ignore
1811+
model_name or "Args", **model_kwargs, **fields
1812+
)
17251813

17261814

17271815
def _json_type_to_python(field_schema: Dict[str, Any]) -> type:

tests/unittests/integration/test_tool_utils.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,14 @@
1010
import pytest
1111

1212
from agentrun.integration.utils.tool import (
13+
_build_tool_from_meta,
14+
_create_function_with_signature,
1315
_extract_core_schema,
16+
_json_schema_to_pydantic,
1417
_load_json,
1518
_merge_schema_dicts,
1619
_normalize_tool_arguments,
20+
_sanitize_python_identifier,
1721
_to_dict,
1822
CommonToolSet,
1923
from_pydantic,
@@ -702,3 +706,168 @@ def test_get_schema_from_parameters(self):
702706
assert "name" in schema["properties"]
703707
assert "age" in schema["properties"]
704708
assert "name" in schema.get("required", [])
709+
710+
711+
class TestSanitizePythonIdentifier:
712+
"""测试字段名 sanitizer"""
713+
714+
def test_valid_identifier_unchanged(self):
715+
assert _sanitize_python_identifier("normal_name") == "normal_name"
716+
717+
def test_hyphenated_name(self):
718+
assert _sanitize_python_identifier("x-access-id") == "x_access_id"
719+
720+
def test_dotted_name(self):
721+
assert _sanitize_python_identifier("a.b.c") == "a_b_c"
722+
723+
def test_leading_digit_prefixed(self):
724+
assert _sanitize_python_identifier("123abc") == "field_123abc"
725+
726+
def test_keyword_suffixed(self):
727+
assert _sanitize_python_identifier("class") == "class_"
728+
729+
def test_empty_string(self):
730+
assert _sanitize_python_identifier("") == "field"
731+
732+
def test_only_invalid_chars(self):
733+
assert _sanitize_python_identifier("---") == "field"
734+
735+
736+
class TestJsonSchemaToPydanticInvalidFieldNames:
737+
"""覆盖 _json_schema_to_pydantic 对非法 Python 标识符字段名的处理"""
738+
739+
def test_hyphenated_field_name_builds_model(self):
740+
"""字段名含 '-' 时不应抛错, 且 JSON Schema 仍以原名暴露"""
741+
schema = {
742+
"type": "object",
743+
"properties": {
744+
"x-access-id": {"type": "string", "description": "id"},
745+
},
746+
"required": ["x-access-id"],
747+
}
748+
749+
model = _json_schema_to_pydantic("Args", schema)
750+
751+
assert model is not None
752+
assert "x_access_id" in model.model_fields
753+
json_schema = model.model_json_schema()
754+
assert "x-access-id" in json_schema["properties"]
755+
assert "x-access-id" in json_schema["required"]
756+
757+
def test_keyword_field_name_sanitized(self):
758+
schema = {
759+
"type": "object",
760+
"properties": {
761+
"class": {"type": "string", "description": "py keyword"},
762+
},
763+
}
764+
765+
model = _json_schema_to_pydantic("Args", schema)
766+
767+
assert model is not None
768+
assert "class_" in model.model_fields
769+
assert "class" in model.model_json_schema()["properties"]
770+
771+
def test_accepts_both_original_and_sanitized_name(self):
772+
schema = {
773+
"type": "object",
774+
"properties": {
775+
"x-access-id": {"type": "string"},
776+
},
777+
"required": ["x-access-id"],
778+
}
779+
780+
model = _json_schema_to_pydantic("Args", schema)
781+
782+
# 原名: 通过 alias
783+
m1 = model(**{"x-access-id": "v1"})
784+
assert m1.model_dump(by_alias=True) == {"x-access-id": "v1"}
785+
# 沙化名: 通过 populate_by_name
786+
m2 = model(x_access_id="v2")
787+
assert m2.model_dump(by_alias=True) == {"x-access-id": "v2"}
788+
789+
790+
class TestCreateFunctionWithSignatureAliasSanitization:
791+
"""覆盖 _create_function_with_signature 对非法 alias 名的防御处理"""
792+
793+
def test_alias_with_hyphen_sanitized(self):
794+
"""`__agentrun_argument_aliases__` 含非法标识符 alias 时不应崩溃"""
795+
from pydantic import BaseModel as _BM
796+
797+
class _Args(_BM):
798+
query: str
799+
800+
setattr(_Args, "__agentrun_argument_aliases__", {"x-alias": "query"})
801+
802+
toolset = MagicMock()
803+
func = _create_function_with_signature("demo", _Args, toolset, None)
804+
805+
import inspect as _inspect
806+
807+
sig = _inspect.signature(func)
808+
# 主字段保留, alias 被 sanitize
809+
assert "query" in sig.parameters
810+
assert "x_alias" in sig.parameters
811+
812+
def test_call_via_sanitized_alias_name_routes_to_canonical(self):
813+
"""用签名暴露的 sanitized alias 名调用时也应翻译到 canonical 字段
814+
815+
回归 Copilot review 提出的: 仅 sanitize 签名不够, 还要让
816+
_normalize_tool_arguments 认识 sanitized alias.
817+
"""
818+
from pydantic import BaseModel as _BM
819+
from pydantic import Field as _Field
820+
821+
class _Args(_BM):
822+
query: str = _Field()
823+
824+
setattr(_Args, "__agentrun_argument_aliases__", {"x-alias": "query"})
825+
826+
toolset = MagicMock()
827+
toolset.call_tool = MagicMock(return_value={"ok": True})
828+
func = _create_function_with_signature("demo", _Args, toolset, None)
829+
830+
# 用沙化后的 alias 名 (签名暴露的形式) 调用
831+
func(x_alias="hello")
832+
833+
call_kwargs = toolset.call_tool.call_args.kwargs
834+
assert call_kwargs["arguments"] == {"query": "hello"}
835+
# 同时验证: alias_map 已被扩展, 包含 sanitized 形式
836+
merged_map = getattr(_Args, "__agentrun_argument_aliases__")
837+
assert merged_map.get("x_alias") == "query"
838+
assert merged_map.get("x-alias") == "query"
839+
840+
841+
class TestBuildToolFromMetaInvalidFieldNames:
842+
"""覆盖 _build_tool_from_meta 完整链路 (回归 'x-access-id' 加载失败)"""
843+
844+
def test_mcp_input_schema_with_hyphen_field(self):
845+
"""模拟 MCP 工具元数据包含 'x-access-id' 入参时仍可成功构造 Tool"""
846+
toolset = MagicMock()
847+
toolset.call_tool = MagicMock(return_value={"status": "ok"})
848+
849+
meta = {
850+
"name": "demo-tool",
851+
"description": "demo",
852+
"input_schema": {
853+
"type": "object",
854+
"properties": {
855+
"x-access-id": {
856+
"type": "string",
857+
"description": "id",
858+
},
859+
"value": {"type": "integer"},
860+
},
861+
"required": ["x-access-id"],
862+
},
863+
}
864+
865+
tool_obj = _build_tool_from_meta(toolset, meta, None)
866+
867+
assert tool_obj is not None
868+
# 调用工具时, MCP 应收到原始字段名 'x-access-id'
869+
tool_obj.func(**{"x-access-id": "abc", "value": 1})
870+
toolset.call_tool.assert_called_once()
871+
call_kwargs = toolset.call_tool.call_args.kwargs
872+
assert call_kwargs["arguments"]["x-access-id"] == "abc"
873+
assert call_kwargs["arguments"]["value"] == 1

0 commit comments

Comments
 (0)