Skip to content

Commit e89b3ad

Browse files
authored
Merge pull request #1849 from asimurka/refactor_get_topic_summary
LCORE-2310: Refactor get topic summary utility
2 parents 4920d68 + cc63a6f commit e89b3ad

4 files changed

Lines changed: 49 additions & 35 deletions

File tree

src/app/endpoints/responses.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@
9090
extract_text_from_response_items,
9191
extract_token_usage,
9292
extract_vector_store_ids_from_tools,
93-
get_topic_summary,
9493
get_zero_usage,
9594
is_server_deployed_output,
95+
maybe_get_topic_summary,
9696
parse_rag_chunks,
9797
parse_referenced_documents,
9898
resolve_client_tool_choice,
@@ -254,25 +254,6 @@ async def _append_previous_response_turn(
254254
)
255255

256256

257-
async def _maybe_get_topic_summary(
258-
api_params: ResponsesApiParams,
259-
context: ResponsesContext,
260-
) -> Optional[str]:
261-
"""Generate a topic summary when requested for the current response.
262-
263-
Args:
264-
api_params: Responses API parameters containing the selected model.
265-
context: Request-scoped Responses API context.
266-
267-
Returns:
268-
Generated topic summary, or None when topic summaries are disabled.
269-
"""
270-
if not context.generate_topic_summary:
271-
return None
272-
logger.debug("Generating topic summary for new conversation")
273-
return await get_topic_summary(context.input_text, context.client, api_params.model)
274-
275-
276257
def _store_response_query_results(
277258
api_params: ResponsesApiParams,
278259
context: ResponsesContext,
@@ -981,7 +962,12 @@ async def generate_response(
981962
async for event in generator:
982963
yield event
983964

984-
topic_summary = await _maybe_get_topic_summary(api_params, context)
965+
topic_summary = await maybe_get_topic_summary(
966+
generate_topic_summary=context.generate_topic_summary,
967+
input_text=context.input_text,
968+
client=context.client,
969+
model_id=api_params.model,
970+
)
985971
completed_at = datetime.now(UTC)
986972
_store_response_query_results(
987973
api_params,
@@ -1083,7 +1069,12 @@ async def handle_non_streaming_response(
10831069
available_quotas = get_available_quotas(
10841070
quota_limiters=configuration.quota_limiters, user_id=user_id
10851071
)
1086-
topic_summary = await _maybe_get_topic_summary(api_params, context)
1072+
topic_summary = await maybe_get_topic_summary(
1073+
generate_topic_summary=context.generate_topic_summary,
1074+
input_text=context.input_text,
1075+
client=context.client,
1076+
model_id=api_params.model,
1077+
)
10871078

10881079
vector_store_ids = extract_vector_store_ids_from_tools(api_params.tools)
10891080
turn_summary = build_turn_summary(

src/utils/responses.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,29 @@ async def get_topic_summary( # pylint: disable=too-many-nested-blocks
205205
return extract_text_from_response_items(response.output)
206206

207207

208+
async def maybe_get_topic_summary(
209+
generate_topic_summary: bool,
210+
input_text: str,
211+
client: AsyncLlamaStackClient,
212+
model_id: str,
213+
) -> Optional[str]:
214+
"""Generate a topic summary when requested for the current response.
215+
216+
Args:
217+
generate_topic_summary: Whether topic summary generation is enabled.
218+
input_text: User input text to summarize.
219+
client: Llama Stack client for the summary request.
220+
model_id: Model identifier in provider/model format.
221+
222+
Returns:
223+
Generated topic summary, or None when topic summaries are disabled.
224+
"""
225+
if not generate_topic_summary:
226+
return None
227+
logger.debug("Generating topic summary for new conversation")
228+
return await get_topic_summary(input_text, client, model_id)
229+
230+
208231
async def prepare_tools( # pylint: disable=too-many-arguments,too-many-positional-arguments
209232
client: AsyncLlamaStackClient,
210233
vector_store_ids: Optional[list[str]],

tests/unit/app/endpoints/test_responses.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def _patch_handle_non_streaming_common(
233233
mocker.patch(f"{MODULE}.configuration", config)
234234
mocker.patch(f"{MODULE}.get_available_quotas", return_value={})
235235
mocker.patch(
236-
f"{MODULE}.get_topic_summary",
236+
f"{MODULE}.maybe_get_topic_summary",
237237
new=mocker.AsyncMock(return_value=None),
238238
)
239239
mocker.patch(f"{MODULE}.store_query_results")
@@ -1177,7 +1177,7 @@ async def test_handle_streaming_blocked_returns_sse_consumes_shield_generator(
11771177
return_value=VALID_CONV_ID_NORMALIZED,
11781178
)
11791179
mocker.patch(
1180-
f"{MODULE}.get_topic_summary",
1180+
f"{MODULE}.maybe_get_topic_summary",
11811181
new=mocker.AsyncMock(return_value=None),
11821182
)
11831183
mocker.patch(f"{MODULE}.store_query_results")
@@ -1256,7 +1256,7 @@ async def mock_stream() -> Any:
12561256
return_value=TurnSummary(referenced_documents=[]),
12571257
)
12581258
mocker.patch(
1259-
f"{MODULE}.get_topic_summary",
1259+
f"{MODULE}.maybe_get_topic_summary",
12601260
new=mocker.AsyncMock(return_value=None),
12611261
)
12621262
mocker.patch(f"{MODULE}.store_query_results")
@@ -1343,7 +1343,7 @@ async def mock_stream() -> Any:
13431343
return_value=TurnSummary(referenced_documents=[]),
13441344
)
13451345
mocker.patch(
1346-
f"{MODULE}.get_topic_summary",
1346+
f"{MODULE}.maybe_get_topic_summary",
13471347
new=mocker.AsyncMock(return_value=None),
13481348
)
13491349
mocker.patch(f"{MODULE}.store_query_results")
@@ -1427,7 +1427,7 @@ async def mock_stream() -> Any:
14271427
return_value=(mocker.Mock(), mocker.Mock()),
14281428
)
14291429
mocker.patch(
1430-
f"{MODULE}.get_topic_summary",
1430+
f"{MODULE}.maybe_get_topic_summary",
14311431
new=mocker.AsyncMock(return_value=None),
14321432
)
14331433
mocker.patch(f"{MODULE}.store_query_results")
@@ -1509,7 +1509,7 @@ async def mock_stream() -> Any:
15091509
return_value=TurnSummary(referenced_documents=[]),
15101510
)
15111511
mocker.patch(
1512-
f"{MODULE}.get_topic_summary",
1512+
f"{MODULE}.maybe_get_topic_summary",
15131513
new=mocker.AsyncMock(return_value=None),
15141514
)
15151515
mocker.patch(f"{MODULE}.store_query_results")
@@ -2346,7 +2346,7 @@ async def test_non_streaming_sanitizes_mcp_output_and_model(
23462346
return_value=[],
23472347
)
23482348
mocker.patch(
2349-
f"{MODULE}.get_topic_summary",
2349+
f"{MODULE}.maybe_get_topic_summary",
23502350
new=mocker.AsyncMock(return_value=None),
23512351
)
23522352
mocker.patch(f"{MODULE}.store_query_results")
@@ -2461,7 +2461,7 @@ async def mock_stream() -> Any:
24612461
return_value=TurnSummary(referenced_documents=[]),
24622462
)
24632463
mocker.patch(
2464-
f"{MODULE}.get_topic_summary",
2464+
f"{MODULE}.maybe_get_topic_summary",
24652465
new=mocker.AsyncMock(return_value=None),
24662466
)
24672467
mocker.patch(f"{MODULE}.store_query_results")
@@ -2587,7 +2587,7 @@ async def mock_stream() -> Any:
25872587
return_value=TurnSummary(referenced_documents=[]),
25882588
)
25892589
mocker.patch(
2590-
f"{MODULE}.get_topic_summary",
2590+
f"{MODULE}.maybe_get_topic_summary",
25912591
new=mocker.AsyncMock(return_value=None),
25922592
)
25932593
mocker.patch(f"{MODULE}.store_query_results")
@@ -2683,7 +2683,7 @@ async def mock_stream() -> Any:
26832683
return_value=TurnSummary(referenced_documents=[]),
26842684
)
26852685
mocker.patch(
2686-
f"{MODULE}.get_topic_summary",
2686+
f"{MODULE}.maybe_get_topic_summary",
26872687
new=mocker.AsyncMock(return_value=None),
26882688
)
26892689
mocker.patch(f"{MODULE}.store_query_results")

tests/unit/app/endpoints/test_responses_splunk.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def _patch_handle_non_streaming_common(
6262
mocker.patch(f"{MODULE}.configuration", config)
6363
mocker.patch(f"{MODULE}.get_available_quotas", return_value={})
6464
mocker.patch(
65-
f"{MODULE}.get_topic_summary",
65+
f"{MODULE}.maybe_get_topic_summary",
6666
new=mocker.AsyncMock(return_value=None),
6767
)
6868
mocker.patch(f"{MODULE}.store_query_results")
@@ -485,7 +485,7 @@ async def test_streaming_shield_blocked(
485485
return_value=VALID_CONV_ID_NORMALIZED,
486486
)
487487
mocker.patch(
488-
f"{MODULE}.get_topic_summary",
488+
f"{MODULE}.maybe_get_topic_summary",
489489
new=mocker.AsyncMock(return_value=None),
490490
)
491491
mocker.patch(f"{MODULE}.store_query_results")
@@ -646,7 +646,7 @@ async def mock_stream() -> Any:
646646
mock_turn_summary.token_usage = mock_token_usage
647647
mocker.patch(f"{MODULE}.build_turn_summary", return_value=mock_turn_summary)
648648
mocker.patch(
649-
f"{MODULE}.get_topic_summary",
649+
f"{MODULE}.maybe_get_topic_summary",
650650
new=mocker.AsyncMock(return_value=None),
651651
)
652652
mocker.patch(f"{MODULE}.store_query_results")

0 commit comments

Comments
 (0)