Skip to content

Commit 0c79c65

Browse files
committed
refactor(responses): reuse shared response helpers
Signed-off-by: Major Hayden <major@redhat.com>
1 parent 6782625 commit 0c79c65

3 files changed

Lines changed: 101 additions & 221 deletions

File tree

src/app/endpoints/responses.py

Lines changed: 65 additions & 221 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,10 @@ async def _maybe_get_topic_summary(
381381
Returns:
382382
Generated topic summary, or None when topic summaries are disabled.
383383
"""
384-
if not context.generate_topic_summary:
384+
if (
385+
not context.generate_topic_summary
386+
or context.moderation_result.decision != "passed"
387+
):
385388
return None
386389
logger.debug("Generating topic summary for new conversation")
387390
return await get_topic_summary(context.input_text, context.client, api_params.model)
@@ -654,23 +657,11 @@ async def handle_streaming_response(
654657
turn_summary.id = context.moderation_result.moderation_id
655658
turn_summary.llm_response = context.moderation_result.message
656659
generator = shield_violation_generator(api_params, context)
657-
if api_params.store:
658-
await append_turn_items_to_conversation(
659-
client=context.client,
660-
conversation_id=api_params.conversation,
661-
user_input=api_params.input,
662-
llm_output=[context.moderation_result.refusal_response],
663-
)
664-
_queue_responses_splunk_event(
665-
background_tasks=context.background_tasks,
666-
input_text=context.input_text,
667-
response_text=context.moderation_result.message,
668-
conversation_id=normalize_conversation_id(api_params.conversation),
669-
model=api_params.model,
670-
rh_identity_context=context.rh_identity_context,
671-
inference_time=(datetime.now(UTC) - context.started_at).total_seconds(),
672-
sourcetype="responses_shield_blocked",
673-
user_agent=context.user_agent,
660+
await _persist_blocked_response_turn(api_params, context)
661+
_queue_blocked_response_event(
662+
api_params,
663+
context,
664+
context.moderation_result.message,
674665
)
675666
else:
676667
try:
@@ -684,58 +675,13 @@ async def handle_streaming_response(
684675
context=context,
685676
turn_summary=turn_summary,
686677
)
687-
except RuntimeError as e: # library mode wraps 413 into runtime error
688-
if is_context_length_error(str(e)):
689-
_queue_responses_splunk_event(
690-
background_tasks=context.background_tasks,
691-
input_text=context.input_text,
692-
response_text=str(e),
693-
conversation_id=normalize_conversation_id(api_params.conversation),
694-
model=api_params.model,
695-
rh_identity_context=context.rh_identity_context,
696-
inference_time=(
697-
datetime.now(UTC) - context.started_at
698-
).total_seconds(),
699-
sourcetype="responses_error",
700-
fire_and_forget=True,
701-
user_agent=context.user_agent,
702-
)
703-
error_response = PromptTooLongResponse(model=api_params.model)
704-
raise HTTPException(**error_response.model_dump()) from e
705-
raise e
706-
except APIConnectionError as e:
707-
_queue_responses_splunk_event(
708-
background_tasks=context.background_tasks,
709-
input_text=context.input_text,
710-
response_text=str(e),
711-
conversation_id=normalize_conversation_id(api_params.conversation),
712-
model=api_params.model,
713-
rh_identity_context=context.rh_identity_context,
714-
inference_time=(datetime.now(UTC) - context.started_at).total_seconds(),
715-
sourcetype="responses_error",
716-
fire_and_forget=True,
717-
user_agent=context.user_agent,
718-
)
719-
error_response = ServiceUnavailableResponse(
720-
backend_name="Llama Stack",
721-
cause=str(e),
722-
)
723-
raise HTTPException(**error_response.model_dump()) from e
724-
except (LLSApiStatusError, OpenAIAPIStatusError) as e:
725-
_queue_responses_splunk_event(
726-
background_tasks=context.background_tasks,
727-
input_text=context.input_text,
728-
response_text=str(e),
729-
conversation_id=normalize_conversation_id(api_params.conversation),
730-
model=api_params.model,
731-
rh_identity_context=context.rh_identity_context,
732-
inference_time=(datetime.now(UTC) - context.started_at).total_seconds(),
733-
sourcetype="responses_error",
734-
fire_and_forget=True,
735-
user_agent=context.user_agent,
736-
)
737-
error_response = handle_known_apistatus_errors(e, api_params.model)
738-
raise HTTPException(**error_response.model_dump()) from e
678+
except (
679+
RuntimeError,
680+
APIConnectionError,
681+
LLSApiStatusError,
682+
OpenAIAPIStatusError,
683+
) as e:
684+
_raise_response_api_http_exception(e, api_params, context)
739685

740686
return StreamingResponse(
741687
generate_response(
@@ -1088,11 +1034,10 @@ async def response_generator(
10881034
)
10891035

10901036
# Explicitly append the turn to conversation if context passed by previous response
1091-
if api_params.store and api_params.previous_response_id and latest_response_object:
1092-
await append_turn_items_to_conversation(
1093-
context.client,
1094-
api_params.conversation,
1095-
api_params.input,
1037+
if latest_response_object:
1038+
await _append_previous_response_turn(
1039+
api_params,
1040+
context,
10961041
latest_response_object.output,
10971042
)
10981043

@@ -1118,45 +1063,25 @@ async def generate_response(
11181063
Yields:
11191064
SSE-formatted strings from the generator
11201065
"""
1121-
user_id, _, skip_userid_check, _ = context.auth
11221066
async for event in generator:
11231067
yield event
11241068

1125-
# Get topic summary for new conversation
1126-
topic_summary = None
1127-
if context.generate_topic_summary:
1128-
logger.debug("Generating topic summary for new conversation")
1129-
topic_summary = await get_topic_summary(
1130-
context.input_text, context.client, api_params.model
1131-
)
1132-
1069+
topic_summary = await _maybe_get_topic_summary(api_params, context)
11331070
completed_at = datetime.now(UTC)
1134-
if api_params.store:
1135-
store_query_results(
1136-
user_id=user_id,
1137-
conversation_id=normalize_conversation_id(api_params.conversation),
1138-
model=api_params.model,
1139-
started_at=context.started_at.strftime("%Y-%m-%dT%H:%M:%SZ"),
1140-
completed_at=completed_at.strftime("%Y-%m-%dT%H:%M:%SZ"),
1141-
summary=turn_summary,
1142-
query=context.input_text,
1143-
attachments=[],
1144-
skip_userid_check=skip_userid_check,
1145-
topic_summary=topic_summary,
1146-
)
1147-
if context.moderation_result.decision == "passed":
1148-
_queue_responses_splunk_event(
1149-
background_tasks=context.background_tasks,
1150-
input_text=context.input_text,
1151-
response_text=turn_summary.llm_response,
1152-
conversation_id=normalize_conversation_id(api_params.conversation),
1153-
model=api_params.model,
1154-
rh_identity_context=context.rh_identity_context,
1155-
inference_time=(completed_at - context.started_at).total_seconds(),
1156-
sourcetype="responses_completed",
1157-
input_tokens=turn_summary.token_usage.input_tokens,
1158-
output_tokens=turn_summary.token_usage.output_tokens,
1159-
)
1071+
_store_response_query_results(
1072+
api_params,
1073+
context,
1074+
turn_summary,
1075+
completed_at,
1076+
topic_summary,
1077+
)
1078+
_queue_completed_response_event(
1079+
api_params,
1080+
context,
1081+
turn_summary,
1082+
completed_at,
1083+
turn_summary.llm_response,
1084+
)
11601085

11611086

11621087
async def handle_non_streaming_response(
@@ -1173,7 +1098,7 @@ async def handle_non_streaming_response(
11731098
Returns:
11741099
ResponsesResponse with the completed response
11751100
"""
1176-
user_id, _, skip_userid_check, _ = context.auth
1101+
user_id = context.auth[0]
11771102

11781103
# Fork: Get response object (blocked vs normal)
11791104
if context.moderation_result.decision == "blocked":
@@ -1186,24 +1111,8 @@ async def handle_non_streaming_response(
11861111
usage=get_zero_usage(),
11871112
**api_params.echoed_params(configuration.rag_id_mapping),
11881113
)
1189-
if api_params.store:
1190-
await append_turn_items_to_conversation(
1191-
client=context.client,
1192-
conversation_id=api_params.conversation,
1193-
user_input=api_params.input,
1194-
llm_output=[context.moderation_result.refusal_response],
1195-
)
1196-
_queue_responses_splunk_event(
1197-
background_tasks=context.background_tasks,
1198-
input_text=context.input_text,
1199-
response_text=output_text,
1200-
conversation_id=normalize_conversation_id(api_params.conversation),
1201-
model=api_params.model,
1202-
rh_identity_context=context.rh_identity_context,
1203-
inference_time=(datetime.now(UTC) - context.started_at).total_seconds(),
1204-
sourcetype="responses_shield_blocked",
1205-
user_agent=context.user_agent,
1206-
)
1114+
await _persist_blocked_response_turn(api_params, context)
1115+
_queue_blocked_response_event(api_params, context, output_text)
12071116
else:
12081117
try:
12091118
api_response = cast(
@@ -1223,79 +1132,26 @@ async def handle_non_streaming_response(
12231132
)
12241133
output_text = extract_text_from_response_items(api_response.output)
12251134
# Explicitly append the turn to conversation if context passed by previous response
1226-
if api_params.store and api_params.previous_response_id:
1227-
await append_turn_items_to_conversation(
1228-
context.client,
1229-
api_params.conversation,
1230-
api_params.input,
1231-
api_response.output,
1232-
)
1233-
1234-
except RuntimeError as e:
1235-
if is_context_length_error(str(e)):
1236-
_queue_responses_splunk_event(
1237-
background_tasks=context.background_tasks,
1238-
input_text=context.input_text,
1239-
response_text=str(e),
1240-
conversation_id=normalize_conversation_id(api_params.conversation),
1241-
model=api_params.model,
1242-
rh_identity_context=context.rh_identity_context,
1243-
inference_time=(
1244-
datetime.now(UTC) - context.started_at
1245-
).total_seconds(),
1246-
sourcetype="responses_error",
1247-
fire_and_forget=True,
1248-
user_agent=context.user_agent,
1249-
)
1250-
error_response = PromptTooLongResponse(model=api_params.model)
1251-
raise HTTPException(**error_response.model_dump()) from e
1252-
raise e
1253-
except APIConnectionError as e:
1254-
_queue_responses_splunk_event(
1255-
background_tasks=context.background_tasks,
1256-
input_text=context.input_text,
1257-
response_text=str(e),
1258-
conversation_id=normalize_conversation_id(api_params.conversation),
1259-
model=api_params.model,
1260-
rh_identity_context=context.rh_identity_context,
1261-
inference_time=(datetime.now(UTC) - context.started_at).total_seconds(),
1262-
sourcetype="responses_error",
1263-
fire_and_forget=True,
1264-
user_agent=context.user_agent,
1135+
await _append_previous_response_turn(
1136+
api_params,
1137+
context,
1138+
api_response.output,
12651139
)
1266-
error_response = ServiceUnavailableResponse(
1267-
backend_name="Llama Stack",
1268-
cause=str(e),
1269-
)
1270-
raise HTTPException(**error_response.model_dump()) from e
1271-
except (LLSApiStatusError, OpenAIAPIStatusError) as e:
1272-
_queue_responses_splunk_event(
1273-
background_tasks=context.background_tasks,
1274-
input_text=context.input_text,
1275-
response_text=str(e),
1276-
conversation_id=normalize_conversation_id(api_params.conversation),
1277-
model=api_params.model,
1278-
rh_identity_context=context.rh_identity_context,
1279-
inference_time=(datetime.now(UTC) - context.started_at).total_seconds(),
1280-
sourcetype="responses_error",
1281-
fire_and_forget=True,
1282-
user_agent=context.user_agent,
1283-
)
1284-
error_response = handle_known_apistatus_errors(e, api_params.model)
1285-
raise HTTPException(**error_response.model_dump()) from e
1140+
1141+
except (
1142+
RuntimeError,
1143+
APIConnectionError,
1144+
LLSApiStatusError,
1145+
OpenAIAPIStatusError,
1146+
) as e:
1147+
_raise_response_api_http_exception(e, api_params, context)
12861148

12871149
# Get available quotas
12881150
logger.info("Getting available quotas")
12891151
available_quotas = get_available_quotas(
12901152
quota_limiters=configuration.quota_limiters, user_id=user_id
12911153
)
1292-
# Get topic summary for new conversation
1293-
topic_summary = None
1294-
if context.generate_topic_summary:
1295-
logger.debug("Generating topic summary for new conversation")
1296-
topic_summary = await get_topic_summary(
1297-
context.input_text, context.client, api_params.model
1298-
)
1154+
topic_summary = await _maybe_get_topic_summary(api_params, context)
12991155

13001156
vector_store_ids = extract_vector_store_ids_from_tools(api_params.tools)
13011157
turn_summary = build_turn_summary(
@@ -1312,32 +1168,20 @@ async def handle_non_streaming_response(
13121168
)
13131169
turn_summary.rag_chunks.extend(context.inline_rag_context.rag_chunks)
13141170
completed_at = datetime.now(UTC)
1315-
if context.moderation_result.decision == "passed":
1316-
_queue_responses_splunk_event(
1317-
background_tasks=context.background_tasks,
1318-
input_text=context.input_text,
1319-
response_text=output_text,
1320-
conversation_id=normalize_conversation_id(api_params.conversation),
1321-
model=api_params.model,
1322-
rh_identity_context=context.rh_identity_context,
1323-
inference_time=(completed_at - context.started_at).total_seconds(),
1324-
sourcetype="responses_completed",
1325-
input_tokens=turn_summary.token_usage.input_tokens,
1326-
output_tokens=turn_summary.token_usage.output_tokens,
1327-
)
1328-
if api_params.store:
1329-
store_query_results(
1330-
user_id=user_id,
1331-
conversation_id=normalize_conversation_id(api_params.conversation),
1332-
model=api_params.model,
1333-
started_at=context.started_at.strftime("%Y-%m-%dT%H:%M:%SZ"),
1334-
completed_at=completed_at.strftime("%Y-%m-%dT%H:%M:%SZ"),
1335-
summary=turn_summary,
1336-
query=context.input_text,
1337-
attachments=[],
1338-
skip_userid_check=skip_userid_check,
1339-
topic_summary=topic_summary,
1340-
)
1171+
_store_response_query_results(
1172+
api_params,
1173+
context,
1174+
turn_summary,
1175+
completed_at,
1176+
topic_summary,
1177+
)
1178+
_queue_completed_response_event(
1179+
api_params,
1180+
context,
1181+
turn_summary,
1182+
completed_at,
1183+
output_text,
1184+
)
13411185
configured_mcp_labels = {s.name for s in configuration.mcp_servers}
13421186
response_dict = api_response.model_dump(exclude_none=True)
13431187
_sanitize_response_dict(

0 commit comments

Comments
 (0)