Skip to content

Commit 64670e6

Browse files
committed
refactor(responses): reuse shared response helpers
Signed-off-by: Major Hayden <major@redhat.com>
1 parent 109274d commit 64670e6

2 files changed

Lines changed: 65 additions & 220 deletions

File tree

src/app/endpoints/responses.py

Lines changed: 61 additions & 220 deletions
Original file line numberDiff line numberDiff line change
@@ -654,23 +654,11 @@ async def handle_streaming_response(
654654
turn_summary.id = context.moderation_result.moderation_id
655655
turn_summary.llm_response = context.moderation_result.message
656656
generator = shield_violation_generator(api_params, context)
657-
if api_params.store:
658-
await append_turn_items_to_conversation(
659-
client=context.client,
660-
conversation_id=api_params.conversation,
661-
user_input=api_params.input,
662-
llm_output=[context.moderation_result.refusal_response],
663-
)
664-
_queue_responses_splunk_event(
665-
background_tasks=context.background_tasks,
666-
input_text=context.input_text,
667-
response_text=context.moderation_result.message,
668-
conversation_id=normalize_conversation_id(api_params.conversation),
669-
model=api_params.model,
670-
rh_identity_context=context.rh_identity_context,
671-
inference_time=(datetime.now(UTC) - context.started_at).total_seconds(),
672-
sourcetype="responses_shield_blocked",
673-
user_agent=context.user_agent,
657+
await _persist_blocked_response_turn(api_params, context)
658+
_queue_blocked_response_event(
659+
api_params,
660+
context,
661+
context.moderation_result.message,
674662
)
675663
else:
676664
try:
@@ -684,58 +672,13 @@ async def handle_streaming_response(
684672
context=context,
685673
turn_summary=turn_summary,
686674
)
687-
except RuntimeError as e: # library mode wraps 413 into runtime error
688-
if is_context_length_error(str(e)):
689-
_queue_responses_splunk_event(
690-
background_tasks=context.background_tasks,
691-
input_text=context.input_text,
692-
response_text=str(e),
693-
conversation_id=normalize_conversation_id(api_params.conversation),
694-
model=api_params.model,
695-
rh_identity_context=context.rh_identity_context,
696-
inference_time=(
697-
datetime.now(UTC) - context.started_at
698-
).total_seconds(),
699-
sourcetype="responses_error",
700-
fire_and_forget=True,
701-
user_agent=context.user_agent,
702-
)
703-
error_response = PromptTooLongResponse(model=api_params.model)
704-
raise HTTPException(**error_response.model_dump()) from e
705-
raise e
706-
except APIConnectionError as e:
707-
_queue_responses_splunk_event(
708-
background_tasks=context.background_tasks,
709-
input_text=context.input_text,
710-
response_text=str(e),
711-
conversation_id=normalize_conversation_id(api_params.conversation),
712-
model=api_params.model,
713-
rh_identity_context=context.rh_identity_context,
714-
inference_time=(datetime.now(UTC) - context.started_at).total_seconds(),
715-
sourcetype="responses_error",
716-
fire_and_forget=True,
717-
user_agent=context.user_agent,
718-
)
719-
error_response = ServiceUnavailableResponse(
720-
backend_name="Llama Stack",
721-
cause=str(e),
722-
)
723-
raise HTTPException(**error_response.model_dump()) from e
724-
except (LLSApiStatusError, OpenAIAPIStatusError) as e:
725-
_queue_responses_splunk_event(
726-
background_tasks=context.background_tasks,
727-
input_text=context.input_text,
728-
response_text=str(e),
729-
conversation_id=normalize_conversation_id(api_params.conversation),
730-
model=api_params.model,
731-
rh_identity_context=context.rh_identity_context,
732-
inference_time=(datetime.now(UTC) - context.started_at).total_seconds(),
733-
sourcetype="responses_error",
734-
fire_and_forget=True,
735-
user_agent=context.user_agent,
736-
)
737-
error_response = handle_known_apistatus_errors(e, api_params.model)
738-
raise HTTPException(**error_response.model_dump()) from e
675+
except (
676+
RuntimeError,
677+
APIConnectionError,
678+
LLSApiStatusError,
679+
OpenAIAPIStatusError,
680+
) as e:
681+
_raise_response_api_http_exception(e, api_params, context)
739682

740683
return StreamingResponse(
741684
generate_response(
@@ -1088,11 +1031,10 @@ async def response_generator(
10881031
)
10891032

10901033
# Explicitly append the turn to conversation if context passed by previous response
1091-
if api_params.store and api_params.previous_response_id and latest_response_object:
1092-
await append_turn_items_to_conversation(
1093-
context.client,
1094-
api_params.conversation,
1095-
api_params.input,
1034+
if latest_response_object:
1035+
await _append_previous_response_turn(
1036+
api_params,
1037+
context,
10961038
latest_response_object.output,
10971039
)
10981040

@@ -1118,45 +1060,25 @@ async def generate_response(
11181060
Yields:
11191061
SSE-formatted strings from the generator
11201062
"""
1121-
user_id, _, skip_userid_check, _ = context.auth
11221063
async for event in generator:
11231064
yield event
11241065

1125-
# Get topic summary for new conversation
1126-
topic_summary = None
1127-
if context.generate_topic_summary:
1128-
logger.debug("Generating topic summary for new conversation")
1129-
topic_summary = await get_topic_summary(
1130-
context.input_text, context.client, api_params.model
1131-
)
1132-
1066+
topic_summary = await _maybe_get_topic_summary(api_params, context)
11331067
completed_at = datetime.now(UTC)
1134-
if api_params.store:
1135-
store_query_results(
1136-
user_id=user_id,
1137-
conversation_id=normalize_conversation_id(api_params.conversation),
1138-
model=api_params.model,
1139-
started_at=context.started_at.strftime("%Y-%m-%dT%H:%M:%SZ"),
1140-
completed_at=completed_at.strftime("%Y-%m-%dT%H:%M:%SZ"),
1141-
summary=turn_summary,
1142-
query=context.input_text,
1143-
attachments=[],
1144-
skip_userid_check=skip_userid_check,
1145-
topic_summary=topic_summary,
1146-
)
1147-
if context.moderation_result.decision == "passed":
1148-
_queue_responses_splunk_event(
1149-
background_tasks=context.background_tasks,
1150-
input_text=context.input_text,
1151-
response_text=turn_summary.llm_response,
1152-
conversation_id=normalize_conversation_id(api_params.conversation),
1153-
model=api_params.model,
1154-
rh_identity_context=context.rh_identity_context,
1155-
inference_time=(completed_at - context.started_at).total_seconds(),
1156-
sourcetype="responses_completed",
1157-
input_tokens=turn_summary.token_usage.input_tokens,
1158-
output_tokens=turn_summary.token_usage.output_tokens,
1159-
)
1068+
_store_response_query_results(
1069+
api_params,
1070+
context,
1071+
turn_summary,
1072+
completed_at,
1073+
topic_summary,
1074+
)
1075+
_queue_completed_response_event(
1076+
api_params,
1077+
context,
1078+
turn_summary,
1079+
completed_at,
1080+
turn_summary.llm_response,
1081+
)
11601082

11611083

11621084
async def handle_non_streaming_response(
@@ -1173,7 +1095,7 @@ async def handle_non_streaming_response(
11731095
Returns:
11741096
ResponsesResponse with the completed response
11751097
"""
1176-
user_id, _, skip_userid_check, _ = context.auth
1098+
user_id = context.auth[0]
11771099

11781100
# Fork: Get response object (blocked vs normal)
11791101
if context.moderation_result.decision == "blocked":
@@ -1186,24 +1108,8 @@ async def handle_non_streaming_response(
11861108
usage=get_zero_usage(),
11871109
**api_params.echoed_params(configuration.rag_id_mapping),
11881110
)
1189-
if api_params.store:
1190-
await append_turn_items_to_conversation(
1191-
client=context.client,
1192-
conversation_id=api_params.conversation,
1193-
user_input=api_params.input,
1194-
llm_output=[context.moderation_result.refusal_response],
1195-
)
1196-
_queue_responses_splunk_event(
1197-
background_tasks=context.background_tasks,
1198-
input_text=context.input_text,
1199-
response_text=output_text,
1200-
conversation_id=normalize_conversation_id(api_params.conversation),
1201-
model=api_params.model,
1202-
rh_identity_context=context.rh_identity_context,
1203-
inference_time=(datetime.now(UTC) - context.started_at).total_seconds(),
1204-
sourcetype="responses_shield_blocked",
1205-
user_agent=context.user_agent,
1206-
)
1111+
await _persist_blocked_response_turn(api_params, context)
1112+
_queue_blocked_response_event(api_params, context, output_text)
12071113
else:
12081114
try:
12091115
api_response = cast(
@@ -1223,79 +1129,26 @@ async def handle_non_streaming_response(
12231129
)
12241130
output_text = extract_text_from_response_items(api_response.output)
12251131
# Explicitly append the turn to conversation if context passed by previous response
1226-
if api_params.store and api_params.previous_response_id:
1227-
await append_turn_items_to_conversation(
1228-
context.client,
1229-
api_params.conversation,
1230-
api_params.input,
1231-
api_response.output,
1232-
)
1233-
1234-
except RuntimeError as e:
1235-
if is_context_length_error(str(e)):
1236-
_queue_responses_splunk_event(
1237-
background_tasks=context.background_tasks,
1238-
input_text=context.input_text,
1239-
response_text=str(e),
1240-
conversation_id=normalize_conversation_id(api_params.conversation),
1241-
model=api_params.model,
1242-
rh_identity_context=context.rh_identity_context,
1243-
inference_time=(
1244-
datetime.now(UTC) - context.started_at
1245-
).total_seconds(),
1246-
sourcetype="responses_error",
1247-
fire_and_forget=True,
1248-
user_agent=context.user_agent,
1249-
)
1250-
error_response = PromptTooLongResponse(model=api_params.model)
1251-
raise HTTPException(**error_response.model_dump()) from e
1252-
raise e
1253-
except APIConnectionError as e:
1254-
_queue_responses_splunk_event(
1255-
background_tasks=context.background_tasks,
1256-
input_text=context.input_text,
1257-
response_text=str(e),
1258-
conversation_id=normalize_conversation_id(api_params.conversation),
1259-
model=api_params.model,
1260-
rh_identity_context=context.rh_identity_context,
1261-
inference_time=(datetime.now(UTC) - context.started_at).total_seconds(),
1262-
sourcetype="responses_error",
1263-
fire_and_forget=True,
1264-
user_agent=context.user_agent,
1132+
await _append_previous_response_turn(
1133+
api_params,
1134+
context,
1135+
api_response.output,
12651136
)
1266-
error_response = ServiceUnavailableResponse(
1267-
backend_name="Llama Stack",
1268-
cause=str(e),
1269-
)
1270-
raise HTTPException(**error_response.model_dump()) from e
1271-
except (LLSApiStatusError, OpenAIAPIStatusError) as e:
1272-
_queue_responses_splunk_event(
1273-
background_tasks=context.background_tasks,
1274-
input_text=context.input_text,
1275-
response_text=str(e),
1276-
conversation_id=normalize_conversation_id(api_params.conversation),
1277-
model=api_params.model,
1278-
rh_identity_context=context.rh_identity_context,
1279-
inference_time=(datetime.now(UTC) - context.started_at).total_seconds(),
1280-
sourcetype="responses_error",
1281-
fire_and_forget=True,
1282-
user_agent=context.user_agent,
1283-
)
1284-
error_response = handle_known_apistatus_errors(e, api_params.model)
1285-
raise HTTPException(**error_response.model_dump()) from e
1137+
1138+
except (
1139+
RuntimeError,
1140+
APIConnectionError,
1141+
LLSApiStatusError,
1142+
OpenAIAPIStatusError,
1143+
) as e:
1144+
_raise_response_api_http_exception(e, api_params, context)
12861145

12871146
# Get available quotas
12881147
logger.info("Getting available quotas")
12891148
available_quotas = get_available_quotas(
12901149
quota_limiters=configuration.quota_limiters, user_id=user_id
12911150
)
1292-
# Get topic summary for new conversation
1293-
topic_summary = None
1294-
if context.generate_topic_summary:
1295-
logger.debug("Generating topic summary for new conversation")
1296-
topic_summary = await get_topic_summary(
1297-
context.input_text, context.client, api_params.model
1298-
)
1151+
topic_summary = await _maybe_get_topic_summary(api_params, context)
12991152

13001153
vector_store_ids = extract_vector_store_ids_from_tools(api_params.tools)
13011154
turn_summary = build_turn_summary(
@@ -1312,32 +1165,20 @@ async def handle_non_streaming_response(
13121165
)
13131166
turn_summary.rag_chunks.extend(context.inline_rag_context.rag_chunks)
13141167
completed_at = datetime.now(UTC)
1315-
if context.moderation_result.decision == "passed":
1316-
_queue_responses_splunk_event(
1317-
background_tasks=context.background_tasks,
1318-
input_text=context.input_text,
1319-
response_text=output_text,
1320-
conversation_id=normalize_conversation_id(api_params.conversation),
1321-
model=api_params.model,
1322-
rh_identity_context=context.rh_identity_context,
1323-
inference_time=(completed_at - context.started_at).total_seconds(),
1324-
sourcetype="responses_completed",
1325-
input_tokens=turn_summary.token_usage.input_tokens,
1326-
output_tokens=turn_summary.token_usage.output_tokens,
1327-
)
1328-
if api_params.store:
1329-
store_query_results(
1330-
user_id=user_id,
1331-
conversation_id=normalize_conversation_id(api_params.conversation),
1332-
model=api_params.model,
1333-
started_at=context.started_at.strftime("%Y-%m-%dT%H:%M:%SZ"),
1334-
completed_at=completed_at.strftime("%Y-%m-%dT%H:%M:%SZ"),
1335-
summary=turn_summary,
1336-
query=context.input_text,
1337-
attachments=[],
1338-
skip_userid_check=skip_userid_check,
1339-
topic_summary=topic_summary,
1340-
)
1168+
_queue_completed_response_event(
1169+
api_params,
1170+
context,
1171+
turn_summary,
1172+
completed_at,
1173+
output_text,
1174+
)
1175+
_store_response_query_results(
1176+
api_params,
1177+
context,
1178+
turn_summary,
1179+
completed_at,
1180+
topic_summary,
1181+
)
13411182
configured_mcp_labels = {s.name for s in configuration.mcp_servers}
13421183
response_dict = api_response.model_dump(exclude_none=True)
13431184
_sanitize_response_dict(

tests/unit/app/endpoints/test_responses_splunk.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ async def test_non_streaming_success(
432432
inline_rag_context=RAGContext(),
433433
background_tasks=mock_background_tasks,
434434
rh_identity_context=("org1", "sys1"),
435+
user_agent="test-agent/1.0",
435436
)
436437
await handle_non_streaming_response(
437438
original_request=request,
@@ -445,6 +446,7 @@ async def test_non_streaming_success(
445446
assert call_kwargs["sourcetype"] == "responses_completed"
446447
assert call_kwargs["input_tokens"] == 100
447448
assert call_kwargs["output_tokens"] == 50
449+
assert call_kwargs["user_agent"] == "test-agent/1.0"
448450

449451
# -- Streaming paths ----------------------------------------------------
450452

@@ -660,6 +662,7 @@ async def mock_stream() -> Any:
660662
inline_rag_context=RAGContext(),
661663
background_tasks=mock_background_tasks,
662664
rh_identity_context=("org1", "sys1"),
665+
user_agent="test-agent/1.0",
663666
)
664667
response = await handle_streaming_response(
665668
original_request=request,
@@ -678,6 +681,7 @@ async def mock_stream() -> Any:
678681
assert call_kwargs["sourcetype"] == "responses_completed"
679682
assert call_kwargs["input_tokens"] == 100
680683
assert call_kwargs["output_tokens"] == 50
684+
assert call_kwargs["user_agent"] == "test-agent/1.0"
681685

682686
# -- Splunk disabled (no BackgroundTasks) --------------------------------
683687

0 commit comments

Comments
 (0)