Skip to content

Commit ecd4ca6

Browse files
author
Radovan Fuchs
committed
add option to disable topic summary
1 parent db45d8c commit ecd4ca6

7 files changed

Lines changed: 283 additions & 64 deletions

File tree

src/app/endpoints/query.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,19 @@ async def query_endpoint_handler_base( # pylint: disable=R0914
318318
session.query(UserConversation).filter_by(id=conversation_id).first()
319319
)
320320
if not existing_conversation:
321-
topic_summary = await get_topic_summary_func(
322-
query_request.query, client, llama_stack_model_id
323-
)
321+
# Check if topic summary should be generated (default: True)
322+
should_generate = query_request.generate_topic_summary
323+
324+
if should_generate:
325+
logger.debug("Generating topic summary for new conversation")
326+
topic_summary = await get_topic_summary_func(
327+
query_request.query, client, llama_stack_model_id
328+
)
329+
else:
330+
logger.debug(
331+
"Topic summary generation disabled by request parameter"
332+
)
333+
topic_summary = None
324334
# Convert RAG chunks to dictionary format once for reuse
325335
logger.info("Processing RAG chunks...")
326336
rag_chunks_dict = [chunk.model_dump() for chunk in summary.rag_chunks]

src/app/endpoints/streaming_query.py

Lines changed: 40 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -783,63 +783,47 @@ async def response_generator(
783783

784784
yield stream_end_event(context.metadata_map, summary, token_usage, media_type)
785785

786-
# Perform cleanup tasks (database and cache operations)
787-
await cleanup_after_streaming(
788-
user_id=context.user_id,
789-
conversation_id=context.conversation_id,
790-
model_id=context.model_id,
791-
provider_id=context.provider_id,
792-
llama_stack_model_id=context.llama_stack_model_id,
793-
query_request=context.query_request,
794-
summary=summary,
795-
metadata_map=context.metadata_map,
796-
started_at=context.started_at,
797-
client=context.client,
798-
config=configuration,
799-
skip_userid_check=context.skip_userid_check,
800-
get_topic_summary_func=get_topic_summary,
801-
is_transcripts_enabled_func=is_transcripts_enabled,
802-
store_transcript_func=store_transcript,
803-
persist_user_conversation_details_func=persist_user_conversation_details,
804-
rag_chunks=create_rag_chunks_dict(summary),
805-
)
806-
807-
return response_generator
808-
809-
810-
async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-locals,too-many-statements,too-many-arguments,too-many-positional-arguments
811-
request: Request,
812-
query_request: QueryRequest,
813-
auth: AuthTuple,
814-
mcp_headers: dict[str, dict[str, str]],
815-
retrieve_response_func: Callable[..., Any],
816-
create_response_generator_func: Callable[..., Any],
817-
) -> StreamingResponse:
818-
"""
819-
Handle streaming query endpoints with common logic.
820-
821-
This base handler contains all the common logic for streaming query endpoints
822-
and accepts functions for API-specific behavior (Agent API vs Responses API).
823-
824-
Args:
825-
request: The FastAPI request object
826-
query_request: The query request from the user
827-
auth: Authentication tuple (user_id, username, skip_check, token)
828-
mcp_headers: MCP headers for tool integrations
829-
retrieve_response_func: Function to retrieve the streaming response
830-
create_response_generator_func: Function factory that creates the response generator
831-
832-
Returns:
833-
StreamingResponse: An HTTP streaming response yielding SSE-formatted events
834-
835-
Raises:
836-
HTTPException: Returns HTTP 500 if unable to connect to Llama Stack
837-
"""
838-
# Nothing interesting in the request
839-
_ = request
786+
if not is_transcripts_enabled():
787+
logger.debug("Transcript collection is disabled in the configuration")
788+
else:
789+
store_transcript(
790+
user_id=user_id,
791+
conversation_id=conversation_id,
792+
model_id=model_id,
793+
provider_id=provider_id,
794+
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
795+
query=query_request.query,
796+
query_request=query_request,
797+
summary=summary,
798+
rag_chunks=create_rag_chunks_dict(summary),
799+
truncated=False, # TODO(lucasagomes): implement truncation as part
800+
# of quota work
801+
attachments=query_request.attachments or [],
802+
)
840803

841-
check_configuration_loaded(configuration)
842-
started_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
804+
# Get the initial topic summary for the conversation
805+
topic_summary = None
806+
with get_session() as session:
807+
existing_conversation = (
808+
session.query(UserConversation)
809+
.filter_by(id=conversation_id)
810+
.first()
811+
)
812+
if not existing_conversation:
813+
# Check if topic summary should be generated (default: True)
814+
should_generate = query_request.generate_topic_summary
815+
if should_generate:
816+
logger.debug("Generating topic summary for new conversation")
817+
topic_summary = await get_topic_summary(
818+
query_request.query, client, model_id
819+
)
820+
else:
821+
logger.debug(
822+
"Topic summary generation disabled by request parameter"
823+
)
824+
topic_summary = None
825+
826+
completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
843827

844828
# Enforce RBAC: optionally disallow overriding model/provider in requests
845829
validate_model_provider_override(query_request, request.state.authorized_actions)

src/models/requests.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class QueryRequest(BaseModel):
8181
system_prompt: The optional system prompt.
8282
attachments: The optional attachments.
8383
no_tools: Whether to bypass all tools and MCP servers (default: False).
84+
generate_topic_summary: Whether to generate topic summary for new conversations.
8485
media_type: The optional media type for response format (application/json or text/plain).
8586
8687
Example:
@@ -146,6 +147,12 @@ class QueryRequest(BaseModel):
146147
examples=[True, False],
147148
)
148149

150+
generate_topic_summary: Optional[bool] = Field(
151+
True,
152+
description="Whether to generate topic summary for new conversations",
153+
examples=[True, False],
154+
)
155+
149156
media_type: Optional[str] = Field(
150157
None,
151158
description="Media type for the response format",
@@ -164,6 +171,7 @@ class QueryRequest(BaseModel):
164171
"model": "model-name",
165172
"system_prompt": "You are a helpful assistant",
166173
"no_tools": False,
174+
"generate_topic_summary": True,
167175
"attachments": [
168176
{
169177
"attachment_type": "log",

src/utils/endpoints.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -671,11 +671,17 @@ async def cleanup_after_streaming(
671671
session.query(UserConversation).filter_by(id=conversation_id).first()
672672
)
673673
if not existing_conversation:
674-
topic_summary = await get_topic_summary_func(
675-
query_request.query,
676-
client,
677-
llama_stack_model_id,
678-
)
674+
# Check if topic summary should be generated (default: True)
675+
should_generate = query_request.generate_topic_summary
676+
677+
if should_generate:
678+
logger.debug("Generating topic summary for new conversation")
679+
topic_summary = await get_topic_summary_func(
680+
query_request.query, client, llama_stack_model_id
681+
)
682+
else:
683+
logger.debug("Topic summary generation disabled by request parameter")
684+
topic_summary = None
679685

680686
completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
681687

tests/unit/app/endpoints/test_query.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2265,6 +2265,7 @@ async def test_get_topic_summary_create_turn_parameters(mocker: MockerFixture) -
22652265

22662266

22672267
@pytest.mark.asyncio
2268+
<<<<<<< HEAD
22682269
async def test_query_endpoint_quota_exceeded(
22692270
mocker: MockerFixture, dummy_request: Request
22702271
) -> None:
@@ -2305,3 +2306,99 @@ async def test_query_endpoint_quota_exceeded(
23052306
assert isinstance(detail, dict)
23062307
assert detail["response"] == "Model quota exceeded" # type: ignore
23072308
assert "gpt-4-turbo" in detail["cause"] # type: ignore
2309+
=======
2310+
async def test_query_endpoint_generate_topic_summary_default_true(
2311+
mocker: MockerFixture, dummy_request: Request
2312+
) -> None:
2313+
"""Test that topic summary is generated by default for new conversations."""
2314+
mock_client = mocker.AsyncMock()
2315+
mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client")
2316+
mock_lsc.return_value = mock_client
2317+
mock_client.models.list.return_value = [
2318+
mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"),
2319+
]
2320+
2321+
mock_config = mocker.Mock()
2322+
mock_config.quota_limiters = []
2323+
mocker.patch("app.endpoints.query.configuration", mock_config)
2324+
2325+
summary = TurnSummary(llm_response="Test response", tool_calls=[])
2326+
mocker.patch(
2327+
"app.endpoints.query.retrieve_response",
2328+
return_value=(
2329+
summary,
2330+
"00000000-0000-0000-0000-000000000000",
2331+
[],
2332+
TokenCounter(),
2333+
),
2334+
)
2335+
2336+
mocker.patch(
2337+
"app.endpoints.query.select_model_and_provider_id",
2338+
return_value=("test_model", "test_model", "test_provider"),
2339+
)
2340+
mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False)
2341+
2342+
mock_get_topic_summary = mocker.patch(
2343+
"app.endpoints.query.get_topic_summary", return_value="Generated topic"
2344+
)
2345+
mock_database_operations(mocker)
2346+
2347+
await query_endpoint_handler(
2348+
request=dummy_request,
2349+
query_request=QueryRequest(query="test query"),
2350+
auth=("user123", "username", False, "auth_token_123"),
2351+
mcp_headers={},
2352+
)
2353+
2354+
mock_get_topic_summary.assert_called_once()
2355+
2356+
2357+
@pytest.mark.asyncio
2358+
async def test_query_endpoint_generate_topic_summary_explicit_false(
2359+
mocker: MockerFixture, dummy_request: Request
2360+
) -> None:
2361+
"""Test that topic summary is NOT generated when explicitly set to False."""
2362+
mock_client = mocker.AsyncMock()
2363+
mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client")
2364+
mock_lsc.return_value = mock_client
2365+
mock_client.models.list.return_value = [
2366+
mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"),
2367+
]
2368+
2369+
mock_config = mocker.Mock()
2370+
mock_config.quota_limiters = []
2371+
mocker.patch("app.endpoints.query.configuration", mock_config)
2372+
2373+
summary = TurnSummary(llm_response="Test response", tool_calls=[])
2374+
mocker.patch(
2375+
"app.endpoints.query.retrieve_response",
2376+
return_value=(
2377+
summary,
2378+
"00000000-0000-0000-0000-000000000000",
2379+
[],
2380+
TokenCounter(),
2381+
),
2382+
)
2383+
2384+
mocker.patch(
2385+
"app.endpoints.query.select_model_and_provider_id",
2386+
return_value=("test_model", "test_model", "test_provider"),
2387+
)
2388+
mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False)
2389+
2390+
mock_get_topic_summary = mocker.patch(
2391+
"app.endpoints.query.get_topic_summary", return_value="Generated topic"
2392+
)
2393+
2394+
mock_database_operations(mocker)
2395+
2396+
await query_endpoint_handler(
2397+
request=dummy_request,
2398+
query_request=QueryRequest(query="test query", generate_topic_summary=False),
2399+
auth=("user123", "username", False, "auth_token_123"),
2400+
mcp_headers={},
2401+
)
2402+
2403+
mock_get_topic_summary.assert_not_called()
2404+
>>>>>>> 81b4b90 (added unit tests for the extra logic)

tests/unit/models/requests/test_query_request.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,15 @@ def test_validate_media_type(self, mocker: MockerFixture) -> None:
154154

155155
# Media type is now fully supported, no warning expected
156156
mock_logger.warning.assert_not_called()
157+
158+
def test_generate_topic_summary_explicit_false(self) -> None:
159+
"""Test that generate_topic_summary can be explicitly set to False."""
160+
qr = QueryRequest(
161+
query="Tell me about Kubernetes", generate_topic_summary=False
162+
)
163+
assert qr.generate_topic_summary is False
164+
165+
def test_generate_topic_summary_explicit_true(self) -> None:
166+
"""Test that generate_topic_summary can be explicitly set to True."""
167+
qr = QueryRequest(query="Tell me about Kubernetes", generate_topic_summary=True)
168+
assert qr.generate_topic_summary is True

0 commit comments

Comments
 (0)