Skip to content

Commit 6c8c606

Browse files
committed
Fix Tool RAG always configured
1 parent 5375422 commit 6c8c606

7 files changed

Lines changed: 176 additions & 40 deletions

File tree

docs/byok_guide.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ rag:
307307
- okp # include OKP context inline
308308
309309
# Tool RAG: the LLM can call file_search to retrieve context on demand
310-
# Omit to use all registered BYOK stores (backward compatibility)
310+
# If omitted, tool RAG is disabled. If both tool and inline are omitted, all registered stores are used as fallback
311311
tool:
312312
- my-docs # expose this BYOK store as the file_search tool
313313
- okp # expose OKP as the file_search tool

docs/openapi.json

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5500,7 +5500,6 @@
55005500
},
55015501
"db_path": {
55025502
"type": "string",
5503-
"format": "file-path",
55045503
"title": "DB path",
55055504
"description": "Path to RAG database."
55065505
},
@@ -7312,6 +7311,21 @@
73127311
"system",
73137312
"developer"
73147313
]
7314+
},
7315+
"referenced_documents": {
7316+
"anyOf": [
7317+
{
7318+
"items": {
7319+
"$ref": "#/components/schemas/ReferencedDocument"
7320+
},
7321+
"type": "array"
7322+
},
7323+
{
7324+
"type": "null"
7325+
}
7326+
],
7327+
"title": "Referenced Documents",
7328+
"description": "List of documents referenced in the response (assistant messages only)"
73157329
}
73167330
},
73177331
"type": "object",
@@ -7320,7 +7334,7 @@
73207334
"type"
73217335
],
73227336
"title": "Message",
7323-
"description": "Model representing a message in a conversation turn.\n\nAttributes:\n content: The message content.\n type: The type of message."
7337+
"description": "Model representing a message in a conversation turn.\n\nAttributes:\n content: The message content.\n type: The type of message.\n referenced_documents: Optional list of documents referenced in an assistant response."
73247338
},
73257339
"ModelContextProtocolServer": {
73267340
"properties": {
@@ -8140,7 +8154,7 @@
81408154
}
81418155
],
81428156
"title": "Shield Ids",
8143-
"description": "Optional list of safety shield IDs to apply. If None, all configured shields are used. If provided, must contain at least one valid shield ID (empty list raises 422 error).",
8157+
"description": "Optional list of safety shield IDs to apply. If None, all configured shields are used. ",
81448158
"examples": [
81458159
"llama-guard",
81468160
"custom-shield"
@@ -8785,25 +8799,18 @@
87858799
"description": "RAG IDs whose sources are injected as context before the LLM call. Use 'okp' to enable OKP inline RAG. Empty by default (no inline RAG)."
87868800
},
87878801
"tool": {
8788-
"anyOf": [
8789-
{
8790-
"items": {
8791-
"type": "string"
8792-
},
8793-
"type": "array"
8794-
},
8795-
{
8796-
"type": "null"
8797-
}
8798-
],
8802+
"items": {
8803+
"type": "string"
8804+
},
8805+
"type": "array",
87998806
"title": "Tool RAG IDs",
88008807
"description": "RAG IDs made available to the LLM as a file_search tool. Use 'okp' to include the OKP vector store. When omitted, all registered BYOK vector stores are used (backward compatibility)."
88018808
}
88028809
},
88038810
"additionalProperties": false,
88048811
"type": "object",
88058812
"title": "RagConfiguration",
8806-
"description": "RAG strategy configuration.\n\nControls which RAG sources are used for inline and tool-based retrieval.\n\nEach strategy lists RAG IDs to include. The special ID ``\"okp\"`` defined in constants,\nactivates the OKP provider; all other IDs refer to entries in ``byok_rag``.\n\nBackward compatibility:\n - ``inline`` defaults to ``[]`` (no inline RAG).\n - ``tool`` defaults to ``None`` which means all registered vector stores\n are used (identical to the previous ``tool.byok.enabled = True`` default)."
8813+
"description": "RAG strategy configuration.\n\nControls which RAG sources are used for inline and tool-based retrieval.\n\nEach strategy lists RAG IDs to include. The special ID ``\"okp\"`` defined in constants,\nactivates the OKP provider; all other IDs refer to entries in ``byok_rag``.\n\nBackward compatibility:\n - ``inline`` defaults to ``[]`` (no inline RAG).\n - ``tool`` defaults to ``[]`` (no tool RAG).\n\nIf no RAG strategy is defined (inline and tool are empty),\nthe RAG tool will register all stores available to llama-stack."
88078814
},
88088815
"ReadinessResponse": {
88098816
"properties": {
@@ -9854,4 +9861,4 @@
98549861
}
98559862
}
98569863
}
9857-
}
9864+
}

src/models/config.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from enum import Enum
77
from functools import cached_property
88
from pathlib import Path
9-
from typing import Any, Optional, Literal, Self
109
from re import Pattern
10+
from typing import Any, Literal, Optional, Self
1111

1212
import jsonpath_ng
1313
import yaml
@@ -1704,8 +1704,10 @@ class RagConfiguration(ConfigurationBase):
17041704
17051705
Backward compatibility:
17061706
- ``inline`` defaults to ``[]`` (no inline RAG).
1707-
- ``tool`` defaults to ``None`` which means all registered vector stores
1708-
are used (identical to the previous ``tool.byok.enabled = True`` default).
1707+
- ``tool`` defaults to ``[]`` (no tool RAG).
1708+
1709+
If no RAG strategy is defined (inline and tool are empty),
1710+
the RAG tool will register all stores available to llama-stack.
17091711
"""
17101712

17111713
inline: list[str] = Field(
@@ -1715,8 +1717,8 @@ class RagConfiguration(ConfigurationBase):
17151717
f"Use '{constants.OKP_RAG_ID}' to enable OKP inline RAG. Empty by default (no inline RAG).",
17161718
)
17171719

1718-
tool: Optional[list[str]] = Field(
1719-
default=None,
1720+
tool: list[str] = Field(
1721+
default_factory=list,
17201722
title="Tool RAG IDs",
17211723
description="RAG IDs made available to the LLM as a file_search tool. "
17221724
f"Use '{constants.OKP_RAG_ID}' to include the OKP vector store. "

src/utils/responses.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,18 +166,26 @@ async def prepare_tools( # pylint: disable=too-many-arguments,too-many-position
166166
return None
167167

168168
toolgroups: list[InputTool] = []
169-
170-
# Priority: per-request IDs > rag.tool config > all registered stores.
171-
# In all cases, customer-facing rag_ids are translated to internal vector_db_ids.
172-
# IDs fetched from llama-stack are already internal and need no translation.
169+
effective_ids: list[str] = []
170+
171+
# Vector store ID resolution priority:
172+
# 1. Per-request IDs: highest prio; customer-facing rag_ids are translated to vector_db_ids.
173+
# 2. rag.tool config IDs: used when no per-request IDs provided, and rag.tool is configured.
174+
# If rag.inline is configured, but not rag.tool, tool RAG is disabled.
175+
# 3. All registered vector DBs: fallback when neither rag.tool nor rag.inline are configured.
176+
# IDs fetched from llama-stack are already internal and need no translation.
173177
byok_rags = configuration.configuration.byok_rag
178+
179+
is_tool_rag_enabled = len(configuration.configuration.rag.tool) > 0
180+
is_inline_rag_enabled = len(configuration.configuration.rag.inline) > 0
181+
174182
if vector_store_ids is not None:
175-
effective_ids: list[str] = resolve_vector_store_ids(vector_store_ids, byok_rags)
176-
elif configuration.configuration.rag.tool is not None:
183+
effective_ids = resolve_vector_store_ids(vector_store_ids, byok_rags)
184+
elif is_tool_rag_enabled:
177185
effective_ids = resolve_vector_store_ids(
178186
configuration.configuration.rag.tool, byok_rags
179187
)
180-
else:
188+
elif not is_inline_rag_enabled:
181189
effective_ids = await get_vector_store_ids(client, None)
182190

183191
# Add RAG tools if vector stores are available

tests/unit/models/config/test_dump_configuration.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def test_dump_configuration(tmp_path: Path) -> None:
208208
"azure_entra_id": None,
209209
"rag": {
210210
"inline": [],
211-
"tool": None,
211+
"tool": [],
212212
},
213213
"okp": {
214214
"offline": True,
@@ -559,7 +559,7 @@ def test_dump_configuration_with_quota_limiters(tmp_path: Path) -> None:
559559
"azure_entra_id": None,
560560
"rag": {
561561
"inline": [],
562-
"tool": None,
562+
"tool": [],
563563
},
564564
"okp": {
565565
"offline": True,
@@ -788,7 +788,7 @@ def test_dump_configuration_with_quota_limiters_different_values(
788788
"azure_entra_id": None,
789789
"rag": {
790790
"inline": [],
791-
"tool": None,
791+
"tool": [],
792792
},
793793
"okp": {
794794
"offline": True,
@@ -992,7 +992,7 @@ def test_dump_configuration_byok(tmp_path: Path) -> None:
992992
"azure_entra_id": None,
993993
"rag": {
994994
"inline": [],
995-
"tool": None,
995+
"tool": [],
996996
},
997997
"okp": {
998998
"offline": True,
@@ -1181,7 +1181,7 @@ def test_dump_configuration_pg_namespace(tmp_path: Path) -> None:
11811181
"azure_entra_id": None,
11821182
"rag": {
11831183
"inline": [],
1184-
"tool": None,
1184+
"tool": [],
11851185
},
11861186
"okp": {
11871187
"offline": True,

tests/unit/models/config/test_rag_configuration.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ def test_default_values(self) -> None:
1717
"""Test that RagConfiguration has correct default values."""
1818
config = RagConfiguration()
1919
assert config.inline == []
20-
assert config.tool is None
20+
assert config.tool == []
2121

2222
def test_inline_with_byok_ids(self) -> None:
2323
"""Test inline list with BYOK rag IDs."""
2424
config = RagConfiguration(inline=["store-1", "store-2"])
2525
assert config.inline == ["store-1", "store-2"]
26-
assert config.tool is None
26+
assert config.tool == []
2727

2828
def test_inline_with_okp_rag(self) -> None:
2929
"""Test inline list including the special OKP ID."""
@@ -45,10 +45,10 @@ def test_tool_empty_list(self) -> None:
4545
config = RagConfiguration(tool=[])
4646
assert config.tool == []
4747

48-
def test_tool_none_means_all_stores(self) -> None:
49-
"""Test that tool=None (default) means all registered stores are used."""
48+
def test_tool_default_is_empty_list(self) -> None:
49+
"""Test that tool defaults to an empty list."""
5050
config = RagConfiguration()
51-
assert config.tool is None
51+
assert config.tool == []
5252

5353
def test_no_unknown_fields_allowed(self) -> None:
5454
"""Test that RagConfiguration rejects unknown fields."""

tests/unit/utils/test_responses.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1031,6 +1031,8 @@ async def test_translates_byok_ids_in_prepare_tools(
10311031
mock_byok_rag.vector_db_id = "vs-001"
10321032
mock_config = mocker.Mock()
10331033
mock_config.configuration.byok_rag = [mock_byok_rag]
1034+
mock_config.configuration.rag.tool = []
1035+
mock_config.configuration.rag.inline = []
10341036
mocker.patch("utils.responses.configuration", mock_config)
10351037

10361038
result = await prepare_tools(mock_client, ["ocp_docs"], False, "token")
@@ -1050,6 +1052,8 @@ async def test_passes_through_unknown_ids_in_prepare_tools(
10501052
# Configure empty BYOK RAG
10511053
mock_config = mocker.Mock()
10521054
mock_config.configuration.byok_rag = []
1055+
mock_config.configuration.rag.tool = []
1056+
mock_config.configuration.rag.inline = []
10531057
mocker.patch("utils.responses.configuration", mock_config)
10541058

10551059
result = await prepare_tools(mock_client, ["raw-internal-id"], False, "token")
@@ -1078,7 +1082,8 @@ async def test_does_not_translate_when_ids_fetched_from_llama_stack(
10781082
mock_byok_rag.vector_db_id = "vs-translated"
10791083
mock_config = mocker.Mock()
10801084
mock_config.configuration.byok_rag = [mock_byok_rag]
1081-
mock_config.configuration.rag.tool = None
1085+
mock_config.configuration.rag.tool = []
1086+
mock_config.configuration.rag.inline = []
10821087
mocker.patch("utils.responses.configuration", mock_config)
10831088

10841089
result = await prepare_tools(mock_client, None, False, "token")
@@ -1087,6 +1092,120 @@ async def test_does_not_translate_when_ids_fetched_from_llama_stack(
10871092
assert result[0].vector_store_ids == ["vs-internal"]
10881093

10891094

1095+
class TestPrepareToolsVectorStoreResolution:
1096+
"""Tests for vector store ID resolution priority in prepare_tools."""
1097+
1098+
@pytest.mark.asyncio
1099+
async def test_uses_rag_tool_config_when_no_per_request_ids(
1100+
self, mocker: MockerFixture
1101+
) -> None:
1102+
"""Test that rag.tool config IDs are used when no per-request IDs are provided."""
1103+
mock_client = mocker.AsyncMock()
1104+
mocker.patch("utils.responses.get_mcp_tools", return_value=None)
1105+
1106+
mock_config = mocker.Mock()
1107+
mock_config.configuration.byok_rag = []
1108+
mock_config.configuration.rag.tool = ["rag-tool-id-1", "rag-tool-id-2"]
1109+
mock_config.configuration.rag.inline = []
1110+
mocker.patch("utils.responses.configuration", mock_config)
1111+
1112+
result = await prepare_tools(mock_client, None, False, "token")
1113+
1114+
assert result is not None
1115+
assert len(result) == 1
1116+
assert result[0].type == "file_search"
1117+
assert result[0].vector_store_ids == ["rag-tool-id-1", "rag-tool-id-2"]
1118+
mock_client.vector_stores.list.assert_not_called()
1119+
1120+
@pytest.mark.asyncio
1121+
async def test_rag_tool_config_ids_are_translated(
1122+
self, mocker: MockerFixture
1123+
) -> None:
1124+
"""Test that rag.tool config IDs are translated from rag_ids to vector_db_ids."""
1125+
mock_client = mocker.AsyncMock()
1126+
mocker.patch("utils.responses.get_mcp_tools", return_value=None)
1127+
1128+
mock_byok_rag = mocker.Mock()
1129+
mock_byok_rag.rag_id = "ocp_docs"
1130+
mock_byok_rag.vector_db_id = "vs-001"
1131+
mock_config = mocker.Mock()
1132+
mock_config.configuration.byok_rag = [mock_byok_rag]
1133+
mock_config.configuration.rag.tool = ["ocp_docs"]
1134+
mock_config.configuration.rag.inline = []
1135+
mocker.patch("utils.responses.configuration", mock_config)
1136+
1137+
result = await prepare_tools(mock_client, None, False, "token")
1138+
1139+
assert result is not None
1140+
assert result[0].vector_store_ids == ["vs-001"]
1141+
mock_client.vector_stores.list.assert_not_called()
1142+
1143+
@pytest.mark.asyncio
1144+
async def test_inline_rag_disables_tool_rag(self, mocker: MockerFixture) -> None:
1145+
"""Test that configuring rag.inline without rag.tool disables tool RAG."""
1146+
mock_client = mocker.AsyncMock()
1147+
mocker.patch("utils.responses.get_mcp_tools", return_value=None)
1148+
1149+
mock_config = mocker.Mock()
1150+
mock_config.configuration.byok_rag = []
1151+
mock_config.configuration.rag.tool = []
1152+
mock_config.configuration.rag.inline = [
1153+
"inline-store-id"
1154+
] # inline is configured
1155+
mocker.patch("utils.responses.configuration", mock_config)
1156+
1157+
result = await prepare_tools(mock_client, None, False, "token")
1158+
1159+
# Tool RAG should be disabled — no RAG tool in result, no llama-stack fetch
1160+
assert result is None
1161+
mock_client.vector_stores.list.assert_not_called()
1162+
1163+
@pytest.mark.asyncio
1164+
async def test_per_request_ids_override_rag_tool_config(
1165+
self, mocker: MockerFixture
1166+
) -> None:
1167+
"""Test that per-request vector_store_ids take priority over rag.tool config."""
1168+
mock_client = mocker.AsyncMock()
1169+
mocker.patch("utils.responses.get_mcp_tools", return_value=None)
1170+
1171+
mock_config = mocker.Mock()
1172+
mock_config.configuration.byok_rag = []
1173+
mock_config.configuration.rag.tool = ["config-id-1"]
1174+
mock_config.configuration.rag.inline = []
1175+
mocker.patch("utils.responses.configuration", mock_config)
1176+
1177+
result = await prepare_tools(mock_client, ["request-id-1"], False, "token")
1178+
1179+
assert result is not None
1180+
assert result[0].vector_store_ids == ["request-id-1"]
1181+
mock_client.vector_stores.list.assert_not_called()
1182+
1183+
@pytest.mark.asyncio
1184+
async def test_all_registered_dbs_used_when_neither_tool_nor_inline_configured(
1185+
self, mocker: MockerFixture
1186+
) -> None:
1187+
"""Test fallback to all registered vector DBs when neither rag.tool nor rag.inline are set."""
1188+
mock_client = mocker.AsyncMock()
1189+
mock_vs = mocker.Mock()
1190+
mock_vs.id = "vs-registered"
1191+
mock_list = mocker.Mock()
1192+
mock_list.data = [mock_vs]
1193+
mock_client.vector_stores.list = mocker.AsyncMock(return_value=mock_list)
1194+
mocker.patch("utils.responses.get_mcp_tools", return_value=None)
1195+
1196+
mock_config = mocker.Mock()
1197+
mock_config.configuration.byok_rag = []
1198+
mock_config.configuration.rag.tool = []
1199+
mock_config.configuration.rag.inline = []
1200+
mocker.patch("utils.responses.configuration", mock_config)
1201+
1202+
result = await prepare_tools(mock_client, None, False, "token")
1203+
1204+
assert result is not None
1205+
assert result[0].vector_store_ids == ["vs-registered"]
1206+
mock_client.vector_stores.list.assert_called_once()
1207+
1208+
10901209
class TestPrepareResponsesParams:
10911210
"""Tests for prepare_responses_params function."""
10921211

0 commit comments

Comments
 (0)