Skip to content

Commit 26ce9b9

Browse files
committed
refactor(responses): reuse shared response helpers
Signed-off-by: Major Hayden <major@redhat.com>
1 parent 6782625 commit 26ce9b9

2 files changed

Lines changed: 67 additions & 221 deletions

File tree

src/app/endpoints/responses.py

Lines changed: 63 additions & 221 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from llama_stack_api import (
1414
OpenAIResponseObject,
1515
OpenAIResponseObjectStream,
16+
OpenAIResponseOutput,
1617
)
1718
from llama_stack_api import (
1819
OpenAIResponseObjectStreamResponseOutputItemAdded as OutputItemAddedChunk,
@@ -350,7 +351,7 @@ def _queue_blocked_response_event(
350351
async def _append_previous_response_turn(
351352
api_params: ResponsesApiParams,
352353
context: ResponsesContext,
353-
output: Sequence[Any],
354+
output: Sequence[OpenAIResponseOutput],
354355
) -> None:
355356
"""Append response output when continuing from a previous response id.
356357
@@ -654,23 +655,11 @@ async def handle_streaming_response(
654655
turn_summary.id = context.moderation_result.moderation_id
655656
turn_summary.llm_response = context.moderation_result.message
656657
generator = shield_violation_generator(api_params, context)
657-
if api_params.store:
658-
await append_turn_items_to_conversation(
659-
client=context.client,
660-
conversation_id=api_params.conversation,
661-
user_input=api_params.input,
662-
llm_output=[context.moderation_result.refusal_response],
663-
)
664-
_queue_responses_splunk_event(
665-
background_tasks=context.background_tasks,
666-
input_text=context.input_text,
667-
response_text=context.moderation_result.message,
668-
conversation_id=normalize_conversation_id(api_params.conversation),
669-
model=api_params.model,
670-
rh_identity_context=context.rh_identity_context,
671-
inference_time=(datetime.now(UTC) - context.started_at).total_seconds(),
672-
sourcetype="responses_shield_blocked",
673-
user_agent=context.user_agent,
658+
await _persist_blocked_response_turn(api_params, context)
659+
_queue_blocked_response_event(
660+
api_params,
661+
context,
662+
context.moderation_result.message,
674663
)
675664
else:
676665
try:
@@ -684,58 +673,13 @@ async def handle_streaming_response(
684673
context=context,
685674
turn_summary=turn_summary,
686675
)
687-
except RuntimeError as e: # library mode wraps 413 into runtime error
688-
if is_context_length_error(str(e)):
689-
_queue_responses_splunk_event(
690-
background_tasks=context.background_tasks,
691-
input_text=context.input_text,
692-
response_text=str(e),
693-
conversation_id=normalize_conversation_id(api_params.conversation),
694-
model=api_params.model,
695-
rh_identity_context=context.rh_identity_context,
696-
inference_time=(
697-
datetime.now(UTC) - context.started_at
698-
).total_seconds(),
699-
sourcetype="responses_error",
700-
fire_and_forget=True,
701-
user_agent=context.user_agent,
702-
)
703-
error_response = PromptTooLongResponse(model=api_params.model)
704-
raise HTTPException(**error_response.model_dump()) from e
705-
raise e
706-
except APIConnectionError as e:
707-
_queue_responses_splunk_event(
708-
background_tasks=context.background_tasks,
709-
input_text=context.input_text,
710-
response_text=str(e),
711-
conversation_id=normalize_conversation_id(api_params.conversation),
712-
model=api_params.model,
713-
rh_identity_context=context.rh_identity_context,
714-
inference_time=(datetime.now(UTC) - context.started_at).total_seconds(),
715-
sourcetype="responses_error",
716-
fire_and_forget=True,
717-
user_agent=context.user_agent,
718-
)
719-
error_response = ServiceUnavailableResponse(
720-
backend_name="Llama Stack",
721-
cause=str(e),
722-
)
723-
raise HTTPException(**error_response.model_dump()) from e
724-
except (LLSApiStatusError, OpenAIAPIStatusError) as e:
725-
_queue_responses_splunk_event(
726-
background_tasks=context.background_tasks,
727-
input_text=context.input_text,
728-
response_text=str(e),
729-
conversation_id=normalize_conversation_id(api_params.conversation),
730-
model=api_params.model,
731-
rh_identity_context=context.rh_identity_context,
732-
inference_time=(datetime.now(UTC) - context.started_at).total_seconds(),
733-
sourcetype="responses_error",
734-
fire_and_forget=True,
735-
user_agent=context.user_agent,
736-
)
737-
error_response = handle_known_apistatus_errors(e, api_params.model)
738-
raise HTTPException(**error_response.model_dump()) from e
676+
except (
677+
RuntimeError,
678+
APIConnectionError,
679+
LLSApiStatusError,
680+
OpenAIAPIStatusError,
681+
) as e:
682+
_raise_response_api_http_exception(e, api_params, context)
739683

740684
return StreamingResponse(
741685
generate_response(
@@ -1088,11 +1032,10 @@ async def response_generator(
10881032
)
10891033

10901034
# Explicitly append the turn to conversation if context passed by previous response
1091-
if api_params.store and api_params.previous_response_id and latest_response_object:
1092-
await append_turn_items_to_conversation(
1093-
context.client,
1094-
api_params.conversation,
1095-
api_params.input,
1035+
if latest_response_object:
1036+
await _append_previous_response_turn(
1037+
api_params,
1038+
context,
10961039
latest_response_object.output,
10971040
)
10981041

@@ -1118,45 +1061,25 @@ async def generate_response(
11181061
Yields:
11191062
SSE-formatted strings from the generator
11201063
"""
1121-
user_id, _, skip_userid_check, _ = context.auth
11221064
async for event in generator:
11231065
yield event
11241066

1125-
# Get topic summary for new conversation
1126-
topic_summary = None
1127-
if context.generate_topic_summary:
1128-
logger.debug("Generating topic summary for new conversation")
1129-
topic_summary = await get_topic_summary(
1130-
context.input_text, context.client, api_params.model
1131-
)
1132-
1067+
topic_summary = await _maybe_get_topic_summary(api_params, context)
11331068
completed_at = datetime.now(UTC)
1134-
if api_params.store:
1135-
store_query_results(
1136-
user_id=user_id,
1137-
conversation_id=normalize_conversation_id(api_params.conversation),
1138-
model=api_params.model,
1139-
started_at=context.started_at.strftime("%Y-%m-%dT%H:%M:%SZ"),
1140-
completed_at=completed_at.strftime("%Y-%m-%dT%H:%M:%SZ"),
1141-
summary=turn_summary,
1142-
query=context.input_text,
1143-
attachments=[],
1144-
skip_userid_check=skip_userid_check,
1145-
topic_summary=topic_summary,
1146-
)
1147-
if context.moderation_result.decision == "passed":
1148-
_queue_responses_splunk_event(
1149-
background_tasks=context.background_tasks,
1150-
input_text=context.input_text,
1151-
response_text=turn_summary.llm_response,
1152-
conversation_id=normalize_conversation_id(api_params.conversation),
1153-
model=api_params.model,
1154-
rh_identity_context=context.rh_identity_context,
1155-
inference_time=(completed_at - context.started_at).total_seconds(),
1156-
sourcetype="responses_completed",
1157-
input_tokens=turn_summary.token_usage.input_tokens,
1158-
output_tokens=turn_summary.token_usage.output_tokens,
1159-
)
1069+
_store_response_query_results(
1070+
api_params,
1071+
context,
1072+
turn_summary,
1073+
completed_at,
1074+
topic_summary,
1075+
)
1076+
_queue_completed_response_event(
1077+
api_params,
1078+
context,
1079+
turn_summary,
1080+
completed_at,
1081+
turn_summary.llm_response,
1082+
)
11601083

11611084

11621085
async def handle_non_streaming_response(
@@ -1173,7 +1096,7 @@ async def handle_non_streaming_response(
11731096
Returns:
11741097
ResponsesResponse with the completed response
11751098
"""
1176-
user_id, _, skip_userid_check, _ = context.auth
1099+
user_id = context.auth[0]
11771100

11781101
# Fork: Get response object (blocked vs normal)
11791102
if context.moderation_result.decision == "blocked":
@@ -1186,24 +1109,8 @@ async def handle_non_streaming_response(
11861109
usage=get_zero_usage(),
11871110
**api_params.echoed_params(configuration.rag_id_mapping),
11881111
)
1189-
if api_params.store:
1190-
await append_turn_items_to_conversation(
1191-
client=context.client,
1192-
conversation_id=api_params.conversation,
1193-
user_input=api_params.input,
1194-
llm_output=[context.moderation_result.refusal_response],
1195-
)
1196-
_queue_responses_splunk_event(
1197-
background_tasks=context.background_tasks,
1198-
input_text=context.input_text,
1199-
response_text=output_text,
1200-
conversation_id=normalize_conversation_id(api_params.conversation),
1201-
model=api_params.model,
1202-
rh_identity_context=context.rh_identity_context,
1203-
inference_time=(datetime.now(UTC) - context.started_at).total_seconds(),
1204-
sourcetype="responses_shield_blocked",
1205-
user_agent=context.user_agent,
1206-
)
1112+
await _persist_blocked_response_turn(api_params, context)
1113+
_queue_blocked_response_event(api_params, context, output_text)
12071114
else:
12081115
try:
12091116
api_response = cast(
@@ -1223,79 +1130,26 @@ async def handle_non_streaming_response(
12231130
)
12241131
output_text = extract_text_from_response_items(api_response.output)
12251132
# Explicitly append the turn to conversation if context passed by previous response
1226-
if api_params.store and api_params.previous_response_id:
1227-
await append_turn_items_to_conversation(
1228-
context.client,
1229-
api_params.conversation,
1230-
api_params.input,
1231-
api_response.output,
1232-
)
1233-
1234-
except RuntimeError as e:
1235-
if is_context_length_error(str(e)):
1236-
_queue_responses_splunk_event(
1237-
background_tasks=context.background_tasks,
1238-
input_text=context.input_text,
1239-
response_text=str(e),
1240-
conversation_id=normalize_conversation_id(api_params.conversation),
1241-
model=api_params.model,
1242-
rh_identity_context=context.rh_identity_context,
1243-
inference_time=(
1244-
datetime.now(UTC) - context.started_at
1245-
).total_seconds(),
1246-
sourcetype="responses_error",
1247-
fire_and_forget=True,
1248-
user_agent=context.user_agent,
1249-
)
1250-
error_response = PromptTooLongResponse(model=api_params.model)
1251-
raise HTTPException(**error_response.model_dump()) from e
1252-
raise e
1253-
except APIConnectionError as e:
1254-
_queue_responses_splunk_event(
1255-
background_tasks=context.background_tasks,
1256-
input_text=context.input_text,
1257-
response_text=str(e),
1258-
conversation_id=normalize_conversation_id(api_params.conversation),
1259-
model=api_params.model,
1260-
rh_identity_context=context.rh_identity_context,
1261-
inference_time=(datetime.now(UTC) - context.started_at).total_seconds(),
1262-
sourcetype="responses_error",
1263-
fire_and_forget=True,
1264-
user_agent=context.user_agent,
1133+
await _append_previous_response_turn(
1134+
api_params,
1135+
context,
1136+
api_response.output,
12651137
)
1266-
error_response = ServiceUnavailableResponse(
1267-
backend_name="Llama Stack",
1268-
cause=str(e),
1269-
)
1270-
raise HTTPException(**error_response.model_dump()) from e
1271-
except (LLSApiStatusError, OpenAIAPIStatusError) as e:
1272-
_queue_responses_splunk_event(
1273-
background_tasks=context.background_tasks,
1274-
input_text=context.input_text,
1275-
response_text=str(e),
1276-
conversation_id=normalize_conversation_id(api_params.conversation),
1277-
model=api_params.model,
1278-
rh_identity_context=context.rh_identity_context,
1279-
inference_time=(datetime.now(UTC) - context.started_at).total_seconds(),
1280-
sourcetype="responses_error",
1281-
fire_and_forget=True,
1282-
user_agent=context.user_agent,
1283-
)
1284-
error_response = handle_known_apistatus_errors(e, api_params.model)
1285-
raise HTTPException(**error_response.model_dump()) from e
1138+
1139+
except (
1140+
RuntimeError,
1141+
APIConnectionError,
1142+
LLSApiStatusError,
1143+
OpenAIAPIStatusError,
1144+
) as e:
1145+
_raise_response_api_http_exception(e, api_params, context)
12861146

12871147
# Get available quotas
12881148
logger.info("Getting available quotas")
12891149
available_quotas = get_available_quotas(
12901150
quota_limiters=configuration.quota_limiters, user_id=user_id
12911151
)
1292-
# Get topic summary for new conversation
1293-
topic_summary = None
1294-
if context.generate_topic_summary:
1295-
logger.debug("Generating topic summary for new conversation")
1296-
topic_summary = await get_topic_summary(
1297-
context.input_text, context.client, api_params.model
1298-
)
1152+
topic_summary = await _maybe_get_topic_summary(api_params, context)
12991153

13001154
vector_store_ids = extract_vector_store_ids_from_tools(api_params.tools)
13011155
turn_summary = build_turn_summary(
@@ -1312,32 +1166,20 @@ async def handle_non_streaming_response(
13121166
)
13131167
turn_summary.rag_chunks.extend(context.inline_rag_context.rag_chunks)
13141168
completed_at = datetime.now(UTC)
1315-
if context.moderation_result.decision == "passed":
1316-
_queue_responses_splunk_event(
1317-
background_tasks=context.background_tasks,
1318-
input_text=context.input_text,
1319-
response_text=output_text,
1320-
conversation_id=normalize_conversation_id(api_params.conversation),
1321-
model=api_params.model,
1322-
rh_identity_context=context.rh_identity_context,
1323-
inference_time=(completed_at - context.started_at).total_seconds(),
1324-
sourcetype="responses_completed",
1325-
input_tokens=turn_summary.token_usage.input_tokens,
1326-
output_tokens=turn_summary.token_usage.output_tokens,
1327-
)
1328-
if api_params.store:
1329-
store_query_results(
1330-
user_id=user_id,
1331-
conversation_id=normalize_conversation_id(api_params.conversation),
1332-
model=api_params.model,
1333-
started_at=context.started_at.strftime("%Y-%m-%dT%H:%M:%SZ"),
1334-
completed_at=completed_at.strftime("%Y-%m-%dT%H:%M:%SZ"),
1335-
summary=turn_summary,
1336-
query=context.input_text,
1337-
attachments=[],
1338-
skip_userid_check=skip_userid_check,
1339-
topic_summary=topic_summary,
1340-
)
1169+
_store_response_query_results(
1170+
api_params,
1171+
context,
1172+
turn_summary,
1173+
completed_at,
1174+
topic_summary,
1175+
)
1176+
_queue_completed_response_event(
1177+
api_params,
1178+
context,
1179+
turn_summary,
1180+
completed_at,
1181+
output_text,
1182+
)
13411183
configured_mcp_labels = {s.name for s in configuration.mcp_servers}
13421184
response_dict = api_response.model_dump(exclude_none=True)
13431185
_sanitize_response_dict(

tests/unit/app/endpoints/test_responses_splunk.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ async def test_non_streaming_success(
432432
inline_rag_context=RAGContext(),
433433
background_tasks=mock_background_tasks,
434434
rh_identity_context=("org1", "sys1"),
435+
user_agent="test-agent/1.0",
435436
)
436437
await handle_non_streaming_response(
437438
original_request=request,
@@ -445,6 +446,7 @@ async def test_non_streaming_success(
445446
assert call_kwargs["sourcetype"] == "responses_completed"
446447
assert call_kwargs["input_tokens"] == 100
447448
assert call_kwargs["output_tokens"] == 50
449+
assert call_kwargs["user_agent"] == "test-agent/1.0"
448450

449451
# -- Streaming paths ----------------------------------------------------
450452

@@ -660,6 +662,7 @@ async def mock_stream() -> Any:
660662
inline_rag_context=RAGContext(),
661663
background_tasks=mock_background_tasks,
662664
rh_identity_context=("org1", "sys1"),
665+
user_agent="test-agent/1.0",
663666
)
664667
response = await handle_streaming_response(
665668
original_request=request,
@@ -678,6 +681,7 @@ async def mock_stream() -> Any:
678681
assert call_kwargs["sourcetype"] == "responses_completed"
679682
assert call_kwargs["input_tokens"] == 100
680683
assert call_kwargs["output_tokens"] == 50
684+
assert call_kwargs["user_agent"] == "test-agent/1.0"
681685

682686
# -- Splunk disabled (no BackgroundTasks) --------------------------------
683687

0 commit comments

Comments
 (0)