Skip to content

Commit 2b79abf

Browse files
committed
feat: add Responses API inference metrics
Signed-off-by: Major Hayden <major@redhat.com>
1 parent 43c8f4c commit 2b79abf

8 files changed

Lines changed: 291 additions & 66 deletions

File tree

docs/demos/lcore/weak_points_for_ai/ex3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
):
1010
self._task_state = TaskState.auth_required
1111
self._task_status_message = event.status.message
12-
elif (
13-
event.status.state == TaskState.input_required
14-
and self._task_state not in (TaskState.failed, TaskState.auth_required)
12+
elif event.status.state == TaskState.input_required and self._task_state not in (
13+
TaskState.failed,
14+
TaskState.auth_required,
1515
):
1616
self._task_state = TaskState.input_required
1717
self._task_status_message = event.status.message

docs/demos/lcore/weak_points_for_ai/ex9.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Pydantic model utilization
22

3+
34
class ShieldModerationBlocked(BaseModel):
45
"""Shield moderation blocked the content; refusal details are present."""
56

docs/demos/lcore/weak_points_for_ai/exA.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Dynamic dispatch: functional style
22

3+
34
@singledispatch
45
def function(arg: Any) -> None:
56
print("Original function with argument", arg, "that has type", type(arg))
@@ -26,4 +27,3 @@ def _(arg: None) -> None:
2627
function(("foo", "bar", "baz"))
2728
function(1.4142)
2829
function(None)
29-

docs/demos/lcore/weak_points_for_ai/exB.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Pydantic model utilization
22

3+
34
class TranscriptMetadata(BaseModel):
45
"""Metadata for a transcript entry."""
56

@@ -31,4 +32,3 @@ def create_transcript_metadata(
3132
conversation_id=conversation_id,
3233
timestamp=datetime.now(UTC).isoformat(),
3334
)
34-

src/app/endpoints/responses.py

Lines changed: 174 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import asyncio
66
import json
7+
import time
78
from collections.abc import AsyncIterator
89
from datetime import UTC, datetime
910
from typing import Annotated, Any, Final, Optional, cast
@@ -38,6 +39,7 @@
3839
from configuration import configuration
3940
from constants import ENDPOINT_PATH_RESPONSES, SUBSTITUTED_INSTRUCTIONS_PLACEHOLDER
4041
from log import get_logger
42+
from metrics import recording
4143
from models.api.responses import (
4244
UNAUTHORIZED_OPENAPI_EXAMPLES_WITH_MCP_OAUTH,
4345
ConflictResponse,
@@ -404,6 +406,34 @@ async def responses_endpoint_handler(
404406
)
405407

406408

409+
def _record_response_inference_result(
410+
model_id: str,
411+
endpoint_path: str,
412+
result: str,
413+
duration: float,
414+
record_failure: bool = False,
415+
) -> None:
416+
"""Record inference result metrics for a Responses API call.
417+
418+
Extracts the provider and model from the composite model identifier and
419+
records the inference duration histogram. Optionally records a failure
420+
counter increment.
421+
422+
Args:
423+
model_id: Composite model identifier in ``provider/model`` format.
424+
endpoint_path: API endpoint path for metric labeling.
425+
result: Result label such as ``success`` or ``failure``.
426+
duration: Inference call duration in seconds.
427+
record_failure: When True, also increment the LLM failure counter.
428+
"""
429+
provider, model = extract_provider_and_model_from_model_id(model_id)
430+
if record_failure:
431+
recording.record_llm_failure(provider, model, endpoint_path)
432+
recording.record_llm_inference_duration(
433+
provider, model, endpoint_path, result, duration
434+
)
435+
436+
407437
async def handle_streaming_response(
408438
original_request: ResponsesRequest,
409439
api_params: ResponsesApiParams,
@@ -444,6 +474,7 @@ async def handle_streaming_response(
444474
user_agent=context.user_agent,
445475
)
446476
else:
477+
inference_start_time = time.monotonic()
447478
try:
448479
response = await context.client.responses.create(
449480
**api_params.model_dump(exclude_none=True)
@@ -454,9 +485,17 @@ async def handle_streaming_response(
454485
api_params=api_params,
455486
context=context,
456487
turn_summary=turn_summary,
488+
inference_start_time=inference_start_time,
457489
)
458490
except RuntimeError as e: # library mode wraps 413 into runtime error
459491
if is_context_length_error(str(e)):
492+
_record_response_inference_result(
493+
api_params.model,
494+
context.endpoint_path,
495+
"failure",
496+
time.monotonic() - inference_start_time,
497+
record_failure=True,
498+
)
460499
_queue_responses_splunk_event(
461500
background_tasks=context.background_tasks,
462501
input_text=context.input_text,
@@ -475,6 +514,13 @@ async def handle_streaming_response(
475514
raise HTTPException(**error_response.model_dump()) from e
476515
raise e
477516
except APIConnectionError as e:
517+
_record_response_inference_result(
518+
api_params.model,
519+
context.endpoint_path,
520+
"failure",
521+
time.monotonic() - inference_start_time,
522+
record_failure=True,
523+
)
478524
_queue_responses_splunk_event(
479525
background_tasks=context.background_tasks,
480526
input_text=context.input_text,
@@ -493,6 +539,13 @@ async def handle_streaming_response(
493539
)
494540
raise HTTPException(**error_response.model_dump()) from e
495541
except (LLSApiStatusError, OpenAIAPIStatusError) as e:
542+
_record_response_inference_result(
543+
api_params.model,
544+
context.endpoint_path,
545+
"failure",
546+
time.monotonic() - inference_start_time,
547+
record_failure=True,
548+
)
496549
_queue_responses_splunk_event(
497550
background_tasks=context.background_tasks,
498551
input_text=context.input_text,
@@ -759,6 +812,7 @@ async def response_generator(
759812
api_params: ResponsesApiParams,
760813
context: ResponsesContext,
761814
turn_summary: TurnSummary,
815+
inference_start_time: float,
762816
) -> AsyncIterator[str]:
763817
"""Generate SSE-formatted streaming response with LCORE-enriched events.
764818
@@ -768,6 +822,7 @@ async def response_generator(
768822
api_params: ResponsesApiParams
769823
context: Responses context
770824
turn_summary: TurnSummary to populate during streaming
825+
inference_start_time: Monotonic timestamp taken before the inference call.
771826
Yields:
772827
SSE-formatted strings for streaming events, ending with [DONE]
773828
"""
@@ -778,76 +833,102 @@ async def response_generator(
778833
configured_mcp_labels = {s.name for s in configuration.mcp_servers}
779834
# Track output indices of server-deployed MCP calls to filter their events
780835
server_mcp_output_indices: set[int] = set()
836+
inference_metric_recorded = False
781837

782-
async for chunk in stream:
783-
logger.debug("Processing streaming chunk, type: %s", chunk.type)
838+
try:
839+
async for chunk in stream:
840+
logger.debug("Processing streaming chunk, type: %s", chunk.type)
784841

785-
# Filter out streaming events for server-deployed MCP tools.
786-
# These are handled internally by LCS and should not be forwarded
787-
# to clients that don't understand the mcp_call item type.
788-
if _should_filter_mcp_chunk(
789-
chunk, configured_mcp_labels, server_mcp_output_indices
790-
):
791-
continue
842+
# Filter out streaming events for server-deployed MCP tools.
843+
# These are handled internally by LCS and should not be forwarded
844+
# to clients that don't understand the mcp_call item type.
845+
if _should_filter_mcp_chunk(
846+
chunk, configured_mcp_labels, server_mcp_output_indices
847+
):
848+
continue
792849

793-
chunk_dict = chunk.model_dump(exclude_none=True, by_alias=True)
850+
chunk_dict = chunk.model_dump(exclude_none=True, by_alias=True)
794851

795-
# Create own sequence number for chunks to maintain order
796-
chunk_dict["sequence_number"] = sequence_number
797-
sequence_number += 1
852+
# Create own sequence number for chunks to maintain order
853+
chunk_dict["sequence_number"] = sequence_number
854+
sequence_number += 1
798855

799-
if "response" in chunk_dict:
800-
chunk_dict["response"]["conversation"] = normalize_conversation_id(
801-
api_params.conversation
802-
)
803-
_sanitize_response_dict(
804-
chunk_dict["response"],
805-
configured_mcp_labels,
806-
original_request,
807-
)
808-
tools = chunk_dict["response"].get("tools")
809-
if tools is not None:
810-
chunk_dict["response"]["tools"] = (
811-
translate_vector_store_ids_to_user_facing(
812-
tools,
813-
configuration.rag_id_mapping,
856+
if "response" in chunk_dict:
857+
chunk_dict["response"]["conversation"] = normalize_conversation_id(
858+
api_params.conversation
859+
)
860+
_sanitize_response_dict(
861+
chunk_dict["response"],
862+
configured_mcp_labels,
863+
original_request,
864+
)
865+
tools = chunk_dict["response"].get("tools")
866+
if tools is not None:
867+
chunk_dict["response"]["tools"] = (
868+
translate_vector_store_ids_to_user_facing(
869+
tools,
870+
configuration.rag_id_mapping,
871+
)
814872
)
873+
# Intermediate response - no quota consumption and text yet
874+
if chunk.type == "response.in_progress":
875+
chunk_dict["response"]["available_quotas"] = {}
876+
chunk_dict["response"]["output_text"] = ""
877+
878+
# Handle completion, incomplete, and failed events
879+
if chunk.type in (
880+
"response.completed",
881+
"response.incomplete",
882+
"response.failed",
883+
):
884+
latest_response_object = cast(
885+
OpenAIResponseObject, cast(Any, chunk).response
815886
)
816-
# Intermediate response - no quota consumption and text yet
817-
if chunk.type == "response.in_progress":
818-
chunk_dict["response"]["available_quotas"] = {}
819-
chunk_dict["response"]["output_text"] = ""
820-
821-
# Handle completion, incomplete, and failed events - only quota handling here
822-
if chunk.type in (
823-
"response.completed",
824-
"response.incomplete",
825-
"response.failed",
826-
):
827-
latest_response_object = cast(
828-
OpenAIResponseObject, cast(Any, chunk).response
829-
)
830887

831-
# Extract and consume tokens if any were used
832-
turn_summary.token_usage = extract_token_usage(
833-
latest_response_object.usage, api_params.model, context.endpoint_path
834-
)
835-
consume_query_tokens(
836-
user_id=context.auth[0],
837-
model_id=api_params.model,
838-
token_usage=turn_summary.token_usage,
839-
)
888+
# Extract and consume tokens if any were used
889+
turn_summary.token_usage = extract_token_usage(
890+
latest_response_object.usage,
891+
api_params.model,
892+
context.endpoint_path,
893+
)
894+
consume_query_tokens(
895+
user_id=context.auth[0],
896+
model_id=api_params.model,
897+
token_usage=turn_summary.token_usage,
898+
)
840899

841-
# Get available quotas after token consumption
842-
chunk_dict["response"]["available_quotas"] = get_available_quotas(
843-
quota_limiters=configuration.quota_limiters, user_id=context.auth[0]
844-
)
845-
turn_summary.llm_response = extract_text_from_response_items(
846-
latest_response_object.output
900+
# Get available quotas after token consumption
901+
chunk_dict["response"]["available_quotas"] = get_available_quotas(
902+
quota_limiters=configuration.quota_limiters,
903+
user_id=context.auth[0],
904+
)
905+
turn_summary.llm_response = extract_text_from_response_items(
906+
latest_response_object.output
907+
)
908+
chunk_dict["response"]["output_text"] = turn_summary.llm_response
909+
910+
# Record inference duration metric for terminal events
911+
result = "failure" if chunk.type == "response.failed" else "success"
912+
_record_response_inference_result(
913+
api_params.model,
914+
context.endpoint_path,
915+
result,
916+
time.monotonic() - inference_start_time,
917+
record_failure=(result == "failure"),
918+
)
919+
inference_metric_recorded = True
920+
921+
yield f"event: {chunk.type or 'error'}\ndata: {json.dumps(chunk_dict)}\n\n"
922+
except Exception:
923+
if not inference_metric_recorded:
924+
_record_response_inference_result(
925+
api_params.model,
926+
context.endpoint_path,
927+
"failure",
928+
time.monotonic() - inference_start_time,
929+
record_failure=True,
847930
)
848-
chunk_dict["response"]["output_text"] = turn_summary.llm_response
849-
850-
yield f"event: {chunk.type or 'error'}\ndata: {json.dumps(chunk_dict)}\n\n"
931+
raise
851932

852933
# Extract response metadata from final response object
853934
if latest_response_object:
@@ -976,13 +1057,22 @@ async def handle_non_streaming_response(
9761057
user_agent=context.user_agent,
9771058
)
9781059
else:
1060+
inference_start_time = time.monotonic()
1061+
inference_metric_recorded = False
9791062
try:
9801063
api_response = cast(
9811064
OpenAIResponseObject,
9821065
await context.client.responses.create(
9831066
**api_params.model_dump(exclude_none=True)
9841067
),
9851068
)
1069+
_record_response_inference_result(
1070+
api_params.model,
1071+
context.endpoint_path,
1072+
"success",
1073+
time.monotonic() - inference_start_time,
1074+
)
1075+
inference_metric_recorded = True
9861076
token_usage = extract_token_usage(
9871077
api_response.usage, api_params.model, context.endpoint_path
9881078
)
@@ -1004,6 +1094,14 @@ async def handle_non_streaming_response(
10041094

10051095
except RuntimeError as e:
10061096
if is_context_length_error(str(e)):
1097+
if not inference_metric_recorded:
1098+
_record_response_inference_result(
1099+
api_params.model,
1100+
context.endpoint_path,
1101+
"failure",
1102+
time.monotonic() - inference_start_time,
1103+
record_failure=True,
1104+
)
10071105
_queue_responses_splunk_event(
10081106
background_tasks=context.background_tasks,
10091107
input_text=context.input_text,
@@ -1022,6 +1120,14 @@ async def handle_non_streaming_response(
10221120
raise HTTPException(**error_response.model_dump()) from e
10231121
raise e
10241122
except APIConnectionError as e:
1123+
if not inference_metric_recorded:
1124+
_record_response_inference_result(
1125+
api_params.model,
1126+
context.endpoint_path,
1127+
"failure",
1128+
time.monotonic() - inference_start_time,
1129+
record_failure=True,
1130+
)
10251131
_queue_responses_splunk_event(
10261132
background_tasks=context.background_tasks,
10271133
input_text=context.input_text,
@@ -1040,6 +1146,14 @@ async def handle_non_streaming_response(
10401146
)
10411147
raise HTTPException(**error_response.model_dump()) from e
10421148
except (LLSApiStatusError, OpenAIAPIStatusError) as e:
1149+
if not inference_metric_recorded:
1150+
_record_response_inference_result(
1151+
api_params.model,
1152+
context.endpoint_path,
1153+
"failure",
1154+
time.monotonic() - inference_start_time,
1155+
record_failure=True,
1156+
)
10431157
_queue_responses_splunk_event(
10441158
background_tasks=context.background_tasks,
10451159
input_text=context.input_text,

0 commit comments

Comments
 (0)