Skip to content

Commit a5e9982

Browse files
authored
Enable RAG tool when a vector DB is available (#166)
* Enable RAG tool when a vector DB is available * Replace typing.List with list
1 parent b22bfbf commit a5e9982

4 files changed

Lines changed: 125 additions & 0 deletions

File tree

src/app/endpoints/query.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010

1111
from llama_stack_client import LlamaStackClient # type: ignore
1212
from llama_stack_client.types import UserMessage # type: ignore
13+
from llama_stack_client.types.agents.turn_create_params import (
14+
ToolgroupAgentToolGroupWithArgs,
15+
Toolgroup,
16+
)
1317
from llama_stack_client.types.model_list_response import ModelListResponse
1418

1519
from fastapi import APIRouter, HTTPException, status, Depends
@@ -182,11 +186,13 @@ def retrieve_response(
182186
)
183187
session_id = agent.create_session("chat_session")
184188
logger.debug("Session ID: %s", session_id)
189+
vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()]
185190
response = agent.create_turn(
186191
messages=[UserMessage(role="user", content=query_request.query)],
187192
session_id=session_id,
188193
documents=query_request.get_documents(),
189194
stream=False,
195+
toolgroups=get_rag_toolgroups(vector_db_ids),
190196
)
191197
return str(response.output_message.content) # type: ignore[union-attr]
192198

@@ -282,3 +288,21 @@ def store_transcript( # pylint: disable=too-many-arguments,too-many-positional-
282288
json.dump(data_to_store, transcript_file)
283289

284290
logger.info("Transcript successfully stored at: %s", transcript_file_path)
291+
292+
293+
def get_rag_toolgroups(
294+
vector_db_ids: list[str],
295+
) -> list[Toolgroup] | None:
296+
"""Return a list of RAG Tool groups if the given vector DB list is not empty."""
297+
return (
298+
[
299+
ToolgroupAgentToolGroupWithArgs(
300+
name="builtin::rag/knowledge_search",
301+
args={
302+
"vector_db_ids": vector_db_ids,
303+
},
304+
)
305+
]
306+
if vector_db_ids
307+
else None
308+
)

src/app/endpoints/streaming_query.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121

2222
from app.endpoints.query import (
23+
get_rag_toolgroups,
2324
is_transcripts_enabled,
2425
retrieve_conversation_id,
2526
store_transcript,
@@ -202,11 +203,15 @@ async def retrieve_response(
202203
)
203204
session_id = await agent.create_session("chat_session")
204205
logger.debug("Session ID: %s", session_id)
206+
vector_db_ids = [
207+
vector_db.identifier for vector_db in await client.vector_dbs.list()
208+
]
205209
response = await agent.create_turn(
206210
messages=[UserMessage(role="user", content=query_request.query)],
207211
session_id=session_id,
208212
documents=query_request.get_documents(),
209213
stream=True,
214+
toolgroups=get_rag_toolgroups(vector_db_ids),
210215
)
211216

212217
return response

tests/unit/app/endpoints/test_query.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
is_transcripts_enabled,
1313
construct_transcripts_path,
1414
store_transcript,
15+
get_rag_toolgroups,
1516
)
1617
from models.requests import QueryRequest, Attachment
1718
from models.config import ModelContextProtocolServer
@@ -277,12 +278,45 @@ def test_validate_attachments_metadata_invalid_content_type():
277278
)
278279

279280

281+
def test_retrieve_response_vector_db_available(mocker):
282+
"""Test the retrieve_response function."""
283+
mock_agent = mocker.Mock()
284+
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
285+
mock_client = mocker.Mock()
286+
mock_client.shields.list.return_value = []
287+
mock_vector_db = mocker.Mock()
288+
mock_vector_db.identifier = "VectorDB-1"
289+
mock_client.vector_dbs.list.return_value = [mock_vector_db]
290+
291+
# Mock configuration with empty MCP servers
292+
mock_config = mocker.Mock()
293+
mock_config.mcp_servers = []
294+
mocker.patch("app.endpoints.query.configuration", mock_config)
295+
mocker.patch("app.endpoints.query.Agent", return_value=mock_agent)
296+
297+
query_request = QueryRequest(query="What is OpenStack?")
298+
model_id = "fake_model_id"
299+
access_token = "test_token"
300+
301+
response = retrieve_response(mock_client, model_id, query_request, access_token)
302+
303+
assert response == "LLM answer"
304+
mock_agent.create_turn.assert_called_once_with(
305+
messages=[UserMessage(content="What is OpenStack?", role="user", context=None)],
306+
session_id=mocker.ANY,
307+
documents=[],
308+
stream=False,
309+
toolgroups=get_rag_toolgroups(["VectorDB-1"]),
310+
)
311+
312+
280313
def test_retrieve_response_no_available_shields(mocker):
281314
"""Test the retrieve_response function."""
282315
mock_agent = mocker.Mock()
283316
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
284317
mock_client = mocker.Mock()
285318
mock_client.shields.list.return_value = []
319+
mock_client.vector_dbs.list.return_value = []
286320

287321
# Mock configuration with empty MCP servers
288322
mock_config = mocker.Mock()
@@ -302,6 +336,7 @@ def test_retrieve_response_no_available_shields(mocker):
302336
session_id=mocker.ANY,
303337
documents=[],
304338
stream=False,
339+
toolgroups=None,
305340
)
306341

307342

@@ -319,6 +354,7 @@ def identifier(self):
319354
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
320355
mock_client = mocker.Mock()
321356
mock_client.shields.list.return_value = [MockShield("shield1")]
357+
mock_client.vector_dbs.list.return_value = []
322358

323359
# Mock configuration with empty MCP servers
324360
mock_config = mocker.Mock()
@@ -338,6 +374,7 @@ def identifier(self):
338374
session_id=mocker.ANY,
339375
documents=[],
340376
stream=False,
377+
toolgroups=None,
341378
)
342379

343380

@@ -358,6 +395,7 @@ def identifier(self):
358395
MockShield("shield1"),
359396
MockShield("shield2"),
360397
]
398+
mock_client.vector_dbs.list.return_value = []
361399

362400
# Mock configuration with empty MCP servers
363401
mock_config = mocker.Mock()
@@ -377,6 +415,7 @@ def identifier(self):
377415
session_id=mocker.ANY,
378416
documents=[],
379417
stream=False,
418+
toolgroups=None,
380419
)
381420

382421

@@ -386,6 +425,7 @@ def test_retrieve_response_with_one_attachment(mocker):
386425
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
387426
mock_client = mocker.Mock()
388427
mock_client.shields.list.return_value = []
428+
mock_client.vector_dbs.list.return_value = []
389429

390430
# Mock configuration with empty MCP servers
391431
mock_config = mocker.Mock()
@@ -418,6 +458,7 @@ def test_retrieve_response_with_one_attachment(mocker):
418458
"mime_type": "text/plain",
419459
},
420460
],
461+
toolgroups=None,
421462
)
422463

423464

@@ -427,6 +468,7 @@ def test_retrieve_response_with_two_attachments(mocker):
427468
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
428469
mock_client = mocker.Mock()
429470
mock_client.shields.list.return_value = []
471+
mock_client.vector_dbs.list.return_value = []
430472

431473
# Mock configuration with empty MCP servers
432474
mock_config = mocker.Mock()
@@ -468,6 +510,7 @@ def test_retrieve_response_with_two_attachments(mocker):
468510
"mime_type": "application/yaml",
469511
},
470512
],
513+
toolgroups=None,
471514
)
472515

473516

@@ -477,6 +520,7 @@ def test_retrieve_response_with_mcp_servers(mocker):
477520
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
478521
mock_client = mocker.Mock()
479522
mock_client.shields.list.return_value = []
523+
mock_client.vector_dbs.list.return_value = []
480524

481525
# Mock configuration with MCP servers
482526
mcp_servers = [
@@ -536,6 +580,7 @@ def test_retrieve_response_with_mcp_servers_empty_token(mocker):
536580
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
537581
mock_client = mocker.Mock()
538582
mock_client.shields.list.return_value = []
583+
mock_client.vector_dbs.list.return_value = []
539584

540585
# Mock configuration with MCP servers
541586
mcp_servers = [
@@ -646,3 +691,16 @@ def test_store_transcript(mocker):
646691
},
647692
mocker.ANY,
648693
)
694+
695+
696+
def test_get_rag_toolgroups(mocker):
697+
"""Test get_rag_toolgroups function."""
698+
vector_db_ids = []
699+
result = get_rag_toolgroups(vector_db_ids)
700+
assert result is None
701+
702+
vector_db_ids = ["Vector-DB-1", "Vector-DB-2"]
703+
result = get_rag_toolgroups(vector_db_ids)
704+
assert len(result) == 1
705+
assert result[0]["name"] == "builtin::rag/knowledge_search"
706+
assert result[0]["args"]["vector_db_ids"] == vector_db_ids

tests/unit/app/endpoints/test_streaming_query.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22

3+
from app.endpoints.query import get_rag_toolgroups
34
from app.endpoints.streaming_query import (
45
streaming_query_endpoint_handler,
56
retrieve_response,
@@ -114,12 +115,41 @@ async def test_streaming_query_endpoint_handler_store_transcript(mocker):
114115
await _test_streaming_query_endpoint_handler(mocker, store_transcript=True)
115116

116117

118+
async def test_retrieve_response_vector_db_available(mocker):
119+
"""Test the retrieve_response function."""
120+
mock_agent = mocker.AsyncMock()
121+
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
122+
mock_client = mocker.AsyncMock()
123+
mock_client.shields.list.return_value = []
124+
mock_vector_db = mocker.Mock()
125+
mock_vector_db.identifier = "VectorDB-1"
126+
mock_client.vector_dbs.list.return_value = [mock_vector_db]
127+
128+
mocker.patch("app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent)
129+
130+
query_request = QueryRequest(query="What is OpenStack?")
131+
model_id = "fake_model_id"
132+
133+
response = await retrieve_response(mock_client, model_id, query_request)
134+
135+
# For streaming, the response should be the streaming object
136+
assert response is not None
137+
mock_agent.create_turn.assert_called_once_with(
138+
messages=[UserMessage(content="What is OpenStack?", role="user", context=None)],
139+
session_id=mocker.ANY,
140+
documents=[],
141+
stream=True, # Should be True for streaming endpoint
142+
toolgroups=get_rag_toolgroups(["VectorDB-1"]),
143+
)
144+
145+
117146
async def test_retrieve_response_no_available_shields(mocker):
118147
"""Test the retrieve_response function."""
119148
mock_agent = mocker.AsyncMock()
120149
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
121150
mock_client = mocker.AsyncMock()
122151
mock_client.shields.list.return_value = []
152+
mock_client.vector_dbs.list.return_value = []
123153

124154
mocker.patch("app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent)
125155

@@ -135,6 +165,7 @@ async def test_retrieve_response_no_available_shields(mocker):
135165
session_id=mocker.ANY,
136166
documents=[],
137167
stream=True, # Should be True for streaming endpoint
168+
toolgroups=None,
138169
)
139170

140171

@@ -166,6 +197,7 @@ def identifier(self):
166197
session_id=mocker.ANY,
167198
documents=[],
168199
stream=True, # Should be True for streaming endpoint
200+
toolgroups=None,
169201
)
170202

171203

@@ -186,6 +218,7 @@ def identifier(self):
186218
MockShield("shield1"),
187219
MockShield("shield2"),
188220
]
221+
mock_client.vector_dbs.list.return_value = []
189222

190223
mocker.patch("app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent)
191224

@@ -200,6 +233,7 @@ def identifier(self):
200233
session_id=mocker.ANY,
201234
documents=[],
202235
stream=True, # Should be True for streaming endpoint
236+
toolgroups=None,
203237
)
204238

205239

@@ -209,6 +243,7 @@ async def test_retrieve_response_with_one_attachment(mocker):
209243
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
210244
mock_client = mocker.AsyncMock()
211245
mock_client.shields.list.return_value = []
246+
mock_client.vector_dbs.list.return_value = []
212247

213248
attachments = [
214249
Attachment(
@@ -235,6 +270,7 @@ async def test_retrieve_response_with_one_attachment(mocker):
235270
"mime_type": "text/plain",
236271
},
237272
],
273+
toolgroups=None,
238274
)
239275

240276

@@ -244,6 +280,7 @@ async def test_retrieve_response_with_two_attachments(mocker):
244280
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
245281
mock_client = mocker.AsyncMock()
246282
mock_client.shields.list.return_value = []
283+
mock_client.vector_dbs.list.return_value = []
247284

248285
attachments = [
249286
Attachment(
@@ -279,6 +316,7 @@ async def test_retrieve_response_with_two_attachments(mocker):
279316
"mime_type": "application/yaml",
280317
},
281318
],
319+
toolgroups=None,
282320
)
283321

284322

0 commit comments

Comments
 (0)