|
17 | 17 | from models.responses import QueryResponse |
18 | 18 | from utils.token_counter import TokenCounter |
19 | 19 | from utils.types import ( |
| 20 | + RAGChunk, |
| 21 | + RAGContext, |
| 22 | + ReferencedDocument, |
20 | 23 | ResponsesApiParams, |
21 | 24 | ShieldModerationPassed, |
22 | 25 | ToolCallSummary, |
@@ -174,6 +177,93 @@ async def mock_retrieve_response(*_args: Any, **_kwargs: Any) -> TurnSummary: |
174 | 177 | assert response.conversation_id == "123" |
175 | 178 | assert response.response == "Kubernetes is a container orchestration platform" |
176 | 179 |
|
| 180 | + @pytest.mark.asyncio |
| 181 | + async def test_query_merges_inline_and_tool_rag_chunks_and_documents( |
| 182 | + self, |
| 183 | + dummy_request: Request, |
| 184 | + setup_configuration: AppConfig, |
| 185 | + mocker: MockerFixture, |
| 186 | + ) -> None: |
| 187 | + """Test that inline RAG and tool-based RAG chunks/docs are correctly merged.""" |
| 188 | + query_request = QueryRequest( |
| 189 | + query="What is Kubernetes?" |
| 190 | + ) # pyright: ignore[reportCallIssue] |
| 191 | + |
| 192 | + mocker.patch("app.endpoints.query.configuration", setup_configuration) |
| 193 | + mocker.patch("app.endpoints.query.check_configuration_loaded") |
| 194 | + mocker.patch("app.endpoints.query.check_tokens_available") |
| 195 | + mocker.patch("app.endpoints.query.validate_model_provider_override") |
| 196 | + |
| 197 | + mock_client = mocker.AsyncMock(spec=AsyncLlamaStackClient) |
| 198 | + mock_response_obj = mocker.Mock() |
| 199 | + mock_response_obj.output = [] |
| 200 | + mock_client.responses = mocker.Mock() |
| 201 | + mock_client.responses.create = mocker.AsyncMock(return_value=mock_response_obj) |
| 202 | + mock_client_holder = mocker.Mock() |
| 203 | + mock_client_holder.get_client.return_value = mock_client |
| 204 | + mocker.patch( |
| 205 | + "app.endpoints.query.AsyncLlamaStackClientHolder", |
| 206 | + return_value=mock_client_holder, |
| 207 | + ) |
| 208 | + mocker.patch( |
| 209 | + "app.endpoints.query.run_shield_moderation", |
| 210 | + new=mocker.AsyncMock(return_value=ShieldModerationPassed()), |
| 211 | + ) |
| 212 | + |
| 213 | + inline_chunk = RAGChunk(content="inline chunk content", source="byok") |
| 214 | + inline_doc = ReferencedDocument(doc_title="Inline Doc") |
| 215 | + inline_rag = RAGContext( |
| 216 | + context_text="", |
| 217 | + rag_chunks=[inline_chunk], |
| 218 | + referenced_documents=[inline_doc], |
| 219 | + ) |
| 220 | + mocker.patch( |
| 221 | + "app.endpoints.query.build_rag_context", |
| 222 | + new=mocker.AsyncMock(return_value=inline_rag), |
| 223 | + ) |
| 224 | + |
| 225 | + mock_responses_params = mocker.Mock(spec=ResponsesApiParams) |
| 226 | + mock_responses_params.model = "provider1/model1" |
| 227 | + mock_responses_params.conversation = "conv_123" |
| 228 | + mock_responses_params.tools = None |
| 229 | + mock_responses_params.model_dump.return_value = { |
| 230 | + "input": "test", |
| 231 | + "model": "provider1/model1", |
| 232 | + } |
| 233 | + mocker.patch( |
| 234 | + "app.endpoints.query.prepare_responses_params", |
| 235 | + new=mocker.AsyncMock(return_value=mock_responses_params), |
| 236 | + ) |
| 237 | + |
| 238 | + tool_chunk = RAGChunk(content="tool chunk content", source="vs-1") |
| 239 | + tool_doc = ReferencedDocument(doc_title="Tool Doc") |
| 240 | + mock_turn_summary = TurnSummary() |
| 241 | + mock_turn_summary.rag_chunks = [tool_chunk] |
| 242 | + mock_turn_summary.referenced_documents = [tool_doc] |
| 243 | + |
| 244 | + mocker.patch( |
| 245 | + "app.endpoints.query.retrieve_response", |
| 246 | + new=mocker.AsyncMock(return_value=mock_turn_summary), |
| 247 | + ) |
| 248 | + mocker.patch("app.endpoints.query.store_query_results") |
| 249 | + mocker.patch("app.endpoints.query.consume_query_tokens") |
| 250 | + mocker.patch("app.endpoints.query.get_available_quotas", return_value={}) |
| 251 | + |
| 252 | + response = await query_endpoint_handler( |
| 253 | + request=dummy_request, |
| 254 | + query_request=query_request, |
| 255 | + auth=MOCK_AUTH, |
| 256 | + mcp_headers={}, |
| 257 | + ) |
| 258 | + |
| 259 | + assert isinstance(response, QueryResponse) |
| 260 | + assert len(response.rag_chunks) == 2 |
| 261 | + assert response.rag_chunks[0].content == "inline chunk content" |
| 262 | + assert response.rag_chunks[1].content == "tool chunk content" |
| 263 | + assert len(response.referenced_documents) == 2 |
| 264 | + assert response.referenced_documents[0].doc_title == "Inline Doc" |
| 265 | + assert response.referenced_documents[1].doc_title == "Tool Doc" |
| 266 | + |
177 | 267 | @pytest.mark.asyncio |
178 | 268 | async def test_successful_query_with_conversation( |
179 | 269 | self, |
|
0 commit comments