Skip to content

Commit 757290c

Browse files
committed
More thorough handling of empty legacy fields (not just missing keys)
1 parent 3e5e30e commit 757290c

2 files changed

Lines changed: 141 additions & 12 deletions

File tree

src/app/endpoints/tools.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,28 @@ def _input_schema_to_parameters(
6262
]
6363

6464

65+
def _normalize_tool_dict(tool_dict: dict[str, Any], toolgroup: Any) -> None:
66+
"""Normalize a ToolDef dict to the endpoint's response format.
67+
68+
Remaps field names (``name`` -> ``identifier``, ``input_schema`` ->
69+
``parameters``) and propagates ``provider_id``/``type`` from the
70+
parent toolgroup. Handles both missing keys and empty legacy
71+
placeholders.
72+
"""
73+
if "name" in tool_dict and not tool_dict.get("identifier"):
74+
tool_dict["identifier"] = tool_dict["name"]
75+
tool_dict.pop("name", None)
76+
77+
if "input_schema" in tool_dict and not tool_dict.get("parameters"):
78+
tool_dict["parameters"] = _input_schema_to_parameters(tool_dict["input_schema"])
79+
tool_dict.pop("input_schema", None)
80+
81+
if not tool_dict.get("provider_id"):
82+
tool_dict["provider_id"] = toolgroup.provider_id
83+
if not tool_dict.get("type"):
84+
tool_dict["type"] = getattr(toolgroup, "type", None) or "tool"
85+
86+
6587
tools_responses: dict[int | str, dict[str, Any]] = {
6688
200: ToolsResponse.openapi_response(),
6789
401: UnauthorizedResponse.openapi_response(
@@ -153,18 +175,7 @@ async def tools_endpoint_handler( # pylint: disable=too-many-locals,too-many-st
153175
for tool in tools_response:
154176
tool_dict = dict(tool)
155177

156-
# Normalize Llama Stack ToolDef field names to the endpoint's
157-
# response format ('name' -> 'identifier', 'input_schema' -> 'parameters')
158-
if "name" in tool_dict and "identifier" not in tool_dict:
159-
tool_dict["identifier"] = tool_dict.pop("name")
160-
if "input_schema" in tool_dict and "parameters" not in tool_dict:
161-
tool_dict["parameters"] = _input_schema_to_parameters(
162-
tool_dict.pop("input_schema")
163-
)
164-
165-
# Propagate toolgroup-level fields to individual tools
166-
tool_dict.setdefault("provider_id", toolgroup.provider_id)
167-
tool_dict.setdefault("type", getattr(toolgroup, "type", None) or "tool")
178+
_normalize_tool_dict(tool_dict, toolgroup)
168179

169180
# Determine server source based on toolgroup type
170181
if toolgroup.identifier in mcp_server_names:

tests/unit/app/endpoints/test_tools.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,3 +1079,121 @@ async def test_tools_endpoint_rag_builtin_toolgroup(mocker: MockerFixture) -> No
10791079
assert tool["parameters"][0]["name"] == "query"
10801080
assert tool["parameters"][0]["parameter_type"] == "string"
10811081
assert tool["parameters"][0]["required"] is True
1082+
1083+
1084+
@pytest.mark.asyncio
1085+
async def test_tools_endpoint_empty_legacy_fields_overridden(
1086+
mocker: MockerFixture,
1087+
) -> None:
1088+
"""Test that empty legacy fields are overridden by ToolDef fields.
1089+
1090+
Regression variant: when a tool dict contains both new fields (name,
1091+
input_schema) AND empty legacy fields (identifier="", parameters=[],
1092+
provider_id="", type=""), the endpoint must populate from the new sources.
1093+
"""
1094+
mock_config = Configuration(
1095+
name="test",
1096+
service=ServiceConfiguration(
1097+
tls_config=TLSConfiguration(
1098+
tls_certificate_path=Path("tests/configuration/server.crt"),
1099+
tls_key_path=Path("tests/configuration/server.key"),
1100+
tls_key_password=Path("tests/configuration/password"),
1101+
),
1102+
cors=CORSConfiguration(
1103+
allow_origins=["*"],
1104+
allow_credentials=False,
1105+
allow_methods=["*"],
1106+
allow_headers=["*"],
1107+
),
1108+
host="localhost",
1109+
port=8080,
1110+
base_url=".",
1111+
auth_enabled=False,
1112+
workers=1,
1113+
color_log=True,
1114+
access_log=True,
1115+
root_path="/.",
1116+
),
1117+
llama_stack=LlamaStackConfiguration(
1118+
url=AnyHttpUrl("http://localhost:8321"),
1119+
api_key=SecretStr("xyzzy"),
1120+
use_as_library_client=False,
1121+
library_client_config_path=".",
1122+
timeout=10,
1123+
),
1124+
user_data_collection=UserDataCollection(
1125+
transcripts_enabled=False,
1126+
feedback_enabled=False,
1127+
transcripts_storage=".",
1128+
feedback_storage=".",
1129+
),
1130+
mcp_servers=[],
1131+
customization=None,
1132+
authorization=None,
1133+
deployment_environment=".",
1134+
)
1135+
app_config = AppConfig()
1136+
app_config._configuration = mock_config
1137+
mocker.patch("app.endpoints.tools.configuration", app_config)
1138+
mocker.patch("app.endpoints.tools.authorize", lambda _: lambda func: func)
1139+
1140+
mock_client_holder = mocker.patch("app.endpoints.tools.AsyncLlamaStackClientHolder")
1141+
mock_client = mocker.AsyncMock()
1142+
mock_client_holder.return_value.get_client.return_value = mock_client
1143+
1144+
mock_toolgroup = mocker.Mock()
1145+
mock_toolgroup.identifier = "builtin::rag"
1146+
mock_toolgroup.provider_id = "rag-runtime"
1147+
mock_toolgroup.type = "tool_group"
1148+
mock_client.toolgroups.list.return_value = [mock_toolgroup]
1149+
1150+
# Tool with both new fields AND empty legacy fields
1151+
rag_tool = _make_tool_def_mock(
1152+
mocker,
1153+
{
1154+
"name": "knowledge_search",
1155+
"identifier": "",
1156+
"description": "Search for information in a database.",
1157+
"input_schema": {
1158+
"type": "object",
1159+
"properties": {
1160+
"query": {
1161+
"type": "string",
1162+
"description": "The query to search for.",
1163+
}
1164+
},
1165+
"required": ["query"],
1166+
},
1167+
"parameters": [],
1168+
"provider_id": "",
1169+
"type": "",
1170+
"toolgroup_id": "builtin::rag",
1171+
"metadata": None,
1172+
"output_schema": None,
1173+
},
1174+
)
1175+
mock_client.tools.list.return_value = [rag_tool]
1176+
1177+
mock_request = mocker.Mock()
1178+
mock_auth = MOCK_AUTH
1179+
1180+
response = await tools.tools_endpoint_handler.__wrapped__( # pyright: ignore
1181+
mock_request, mock_auth, {}
1182+
)
1183+
1184+
assert isinstance(response, ToolsResponse)
1185+
assert len(response.tools) == 1
1186+
1187+
tool = response.tools[0]
1188+
# Empty legacy fields must be overridden by new sources
1189+
assert tool["identifier"] == "knowledge_search"
1190+
assert tool["provider_id"] == "rag-runtime"
1191+
assert tool["type"] == "tool_group"
1192+
assert tool["server_source"] == "builtin"
1193+
assert tool["toolgroup_id"] == "builtin::rag"
1194+
1195+
# Parameters populated from input_schema, not empty legacy list
1196+
assert len(tool["parameters"]) == 1
1197+
assert tool["parameters"][0]["name"] == "query"
1198+
assert tool["parameters"][0]["parameter_type"] == "string"
1199+
assert tool["parameters"][0]["required"] is True

0 commit comments

Comments
 (0)