Skip to content

Commit b7af86f

Browse files
committed
feat: add Responses API inference metrics
Signed-off-by: Major Hayden <major@redhat.com>
1 parent 35216ab commit b7af86f

4 files changed

Lines changed: 279 additions & 61 deletions

File tree

src/app/endpoints/responses.py

Lines changed: 169 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import asyncio
66
import json
7+
import time
78
from collections.abc import AsyncIterator
89
from datetime import UTC, datetime
910
from typing import Annotated, Any, Final, Optional, cast
@@ -38,6 +39,7 @@
3839
from configuration import configuration
3940
from constants import SUBSTITUTED_INSTRUCTIONS_PLACEHOLDER
4041
from log import get_logger
42+
from metrics import recording
4143
from models.common.responses.responses_api_params import ResponsesApiParams
4244
from models.common.responses.responses_context import ResponsesContext
4345
from models.config import Action
@@ -402,6 +404,34 @@ async def responses_endpoint_handler(
402404
)
403405

404406

407+
def _record_response_inference_result(
408+
model_id: str,
409+
endpoint_path: str,
410+
result: str,
411+
duration: float,
412+
record_failure: bool = False,
413+
) -> None:
414+
"""Record inference result metrics for a Responses API call.
415+
416+
Extracts the provider and model from the composite model identifier and
417+
records the inference duration histogram. Optionally records a failure
418+
counter increment.
419+
420+
Args:
421+
model_id: Composite model identifier in ``provider/model`` format.
422+
endpoint_path: API endpoint path for metric labeling.
423+
result: Result label such as ``success`` or ``failure``.
424+
duration: Inference call duration in seconds.
425+
record_failure: When True, also increment the LLM failure counter.
426+
"""
427+
provider, model = extract_provider_and_model_from_model_id(model_id)
428+
if record_failure:
429+
recording.record_llm_failure(provider, model, endpoint_path)
430+
recording.record_llm_inference_duration(
431+
provider, model, endpoint_path, result, duration
432+
)
433+
434+
405435
async def handle_streaming_response(
406436
original_request: ResponsesRequest,
407437
api_params: ResponsesApiParams,
@@ -442,6 +472,7 @@ async def handle_streaming_response(
442472
user_agent=context.user_agent,
443473
)
444474
else:
475+
inference_start_time = time.monotonic()
445476
try:
446477
response = await context.client.responses.create(
447478
**api_params.model_dump(exclude_none=True)
@@ -452,9 +483,17 @@ async def handle_streaming_response(
452483
api_params=api_params,
453484
context=context,
454485
turn_summary=turn_summary,
486+
inference_start_time=inference_start_time,
455487
)
456488
except RuntimeError as e: # library mode wraps 413 into runtime error
457489
if is_context_length_error(str(e)):
490+
_record_response_inference_result(
491+
api_params.model,
492+
context.endpoint_path,
493+
"failure",
494+
time.monotonic() - inference_start_time,
495+
record_failure=True,
496+
)
458497
_queue_responses_splunk_event(
459498
background_tasks=context.background_tasks,
460499
input_text=context.input_text,
@@ -473,6 +512,13 @@ async def handle_streaming_response(
473512
raise HTTPException(**error_response.model_dump()) from e
474513
raise e
475514
except APIConnectionError as e:
515+
_record_response_inference_result(
516+
api_params.model,
517+
context.endpoint_path,
518+
"failure",
519+
time.monotonic() - inference_start_time,
520+
record_failure=True,
521+
)
476522
_queue_responses_splunk_event(
477523
background_tasks=context.background_tasks,
478524
input_text=context.input_text,
@@ -491,6 +537,13 @@ async def handle_streaming_response(
491537
)
492538
raise HTTPException(**error_response.model_dump()) from e
493539
except (LLSApiStatusError, OpenAIAPIStatusError) as e:
540+
_record_response_inference_result(
541+
api_params.model,
542+
context.endpoint_path,
543+
"failure",
544+
time.monotonic() - inference_start_time,
545+
record_failure=True,
546+
)
494547
_queue_responses_splunk_event(
495548
background_tasks=context.background_tasks,
496549
input_text=context.input_text,
@@ -757,6 +810,7 @@ async def response_generator(
757810
api_params: ResponsesApiParams,
758811
context: ResponsesContext,
759812
turn_summary: TurnSummary,
813+
inference_start_time: float,
760814
) -> AsyncIterator[str]:
761815
"""Generate SSE-formatted streaming response with LCORE-enriched events.
762816
@@ -766,6 +820,7 @@ async def response_generator(
766820
api_params: ResponsesApiParams
767821
context: Responses context
768822
turn_summary: TurnSummary to populate during streaming
823+
inference_start_time: Monotonic timestamp taken before the inference call.
769824
Yields:
770825
SSE-formatted strings for streaming events, ending with [DONE]
771826
"""
@@ -776,76 +831,102 @@ async def response_generator(
776831
configured_mcp_labels = {s.name for s in configuration.mcp_servers}
777832
# Track output indices of server-deployed MCP calls to filter their events
778833
server_mcp_output_indices: set[int] = set()
834+
inference_metric_recorded = False
779835

780-
async for chunk in stream:
781-
logger.debug("Processing streaming chunk, type: %s", chunk.type)
836+
try:
837+
async for chunk in stream:
838+
logger.debug("Processing streaming chunk, type: %s", chunk.type)
782839

783-
# Filter out streaming events for server-deployed MCP tools.
784-
# These are handled internally by LCS and should not be forwarded
785-
# to clients that don't understand the mcp_call item type.
786-
if _should_filter_mcp_chunk(
787-
chunk, configured_mcp_labels, server_mcp_output_indices
788-
):
789-
continue
840+
# Filter out streaming events for server-deployed MCP tools.
841+
# These are handled internally by LCS and should not be forwarded
842+
# to clients that don't understand the mcp_call item type.
843+
if _should_filter_mcp_chunk(
844+
chunk, configured_mcp_labels, server_mcp_output_indices
845+
):
846+
continue
790847

791-
chunk_dict = chunk.model_dump(exclude_none=True, by_alias=True)
848+
chunk_dict = chunk.model_dump(exclude_none=True, by_alias=True)
792849

793-
# Create own sequence number for chunks to maintain order
794-
chunk_dict["sequence_number"] = sequence_number
795-
sequence_number += 1
850+
# Create own sequence number for chunks to maintain order
851+
chunk_dict["sequence_number"] = sequence_number
852+
sequence_number += 1
796853

797-
if "response" in chunk_dict:
798-
chunk_dict["response"]["conversation"] = normalize_conversation_id(
799-
api_params.conversation
800-
)
801-
_sanitize_response_dict(
802-
chunk_dict["response"],
803-
configured_mcp_labels,
804-
original_request,
805-
)
806-
tools = chunk_dict["response"].get("tools")
807-
if tools is not None:
808-
chunk_dict["response"]["tools"] = (
809-
translate_vector_store_ids_to_user_facing(
810-
tools,
811-
configuration.rag_id_mapping,
854+
if "response" in chunk_dict:
855+
chunk_dict["response"]["conversation"] = normalize_conversation_id(
856+
api_params.conversation
857+
)
858+
_sanitize_response_dict(
859+
chunk_dict["response"],
860+
configured_mcp_labels,
861+
original_request,
862+
)
863+
tools = chunk_dict["response"].get("tools")
864+
if tools is not None:
865+
chunk_dict["response"]["tools"] = (
866+
translate_vector_store_ids_to_user_facing(
867+
tools,
868+
configuration.rag_id_mapping,
869+
)
812870
)
871+
# Intermediate response - no quota consumption and text yet
872+
if chunk.type == "response.in_progress":
873+
chunk_dict["response"]["available_quotas"] = {}
874+
chunk_dict["response"]["output_text"] = ""
875+
876+
# Handle completion, incomplete, and failed events
877+
if chunk.type in (
878+
"response.completed",
879+
"response.incomplete",
880+
"response.failed",
881+
):
882+
latest_response_object = cast(
883+
OpenAIResponseObject, cast(Any, chunk).response
813884
)
814-
# Intermediate response - no quota consumption and text yet
815-
if chunk.type == "response.in_progress":
816-
chunk_dict["response"]["available_quotas"] = {}
817-
chunk_dict["response"]["output_text"] = ""
818-
819-
# Handle completion, incomplete, and failed events - only quota handling here
820-
if chunk.type in (
821-
"response.completed",
822-
"response.incomplete",
823-
"response.failed",
824-
):
825-
latest_response_object = cast(
826-
OpenAIResponseObject, cast(Any, chunk).response
827-
)
828885

829-
# Extract and consume tokens if any were used
830-
turn_summary.token_usage = extract_token_usage(
831-
latest_response_object.usage, api_params.model, context.endpoint_path
832-
)
833-
consume_query_tokens(
834-
user_id=context.auth[0],
835-
model_id=api_params.model,
836-
token_usage=turn_summary.token_usage,
837-
)
886+
# Extract and consume tokens if any were used
887+
turn_summary.token_usage = extract_token_usage(
888+
latest_response_object.usage,
889+
api_params.model,
890+
context.endpoint_path,
891+
)
892+
consume_query_tokens(
893+
user_id=context.auth[0],
894+
model_id=api_params.model,
895+
token_usage=turn_summary.token_usage,
896+
)
838897

839-
# Get available quotas after token consumption
840-
chunk_dict["response"]["available_quotas"] = get_available_quotas(
841-
quota_limiters=configuration.quota_limiters, user_id=context.auth[0]
842-
)
843-
turn_summary.llm_response = extract_text_from_response_items(
844-
latest_response_object.output
898+
# Get available quotas after token consumption
899+
chunk_dict["response"]["available_quotas"] = get_available_quotas(
900+
quota_limiters=configuration.quota_limiters,
901+
user_id=context.auth[0],
902+
)
903+
turn_summary.llm_response = extract_text_from_response_items(
904+
latest_response_object.output
905+
)
906+
chunk_dict["response"]["output_text"] = turn_summary.llm_response
907+
908+
# Record inference duration metric for terminal events
909+
result = "failure" if chunk.type == "response.failed" else "success"
910+
_record_response_inference_result(
911+
api_params.model,
912+
context.endpoint_path,
913+
result,
914+
time.monotonic() - inference_start_time,
915+
record_failure=(result == "failure"),
916+
)
917+
inference_metric_recorded = True
918+
919+
yield f"event: {chunk.type or 'error'}\ndata: {json.dumps(chunk_dict)}\n\n"
920+
except Exception:
921+
if not inference_metric_recorded:
922+
_record_response_inference_result(
923+
api_params.model,
924+
context.endpoint_path,
925+
"failure",
926+
time.monotonic() - inference_start_time,
927+
record_failure=True,
845928
)
846-
chunk_dict["response"]["output_text"] = turn_summary.llm_response
847-
848-
yield f"event: {chunk.type or 'error'}\ndata: {json.dumps(chunk_dict)}\n\n"
929+
raise
849930

850931
# Extract response metadata from final response object
851932
if latest_response_object:
@@ -974,13 +1055,20 @@ async def handle_non_streaming_response(
9741055
user_agent=context.user_agent,
9751056
)
9761057
else:
1058+
inference_start_time = time.monotonic()
9771059
try:
9781060
api_response = cast(
9791061
OpenAIResponseObject,
9801062
await context.client.responses.create(
9811063
**api_params.model_dump(exclude_none=True)
9821064
),
9831065
)
1066+
_record_response_inference_result(
1067+
api_params.model,
1068+
context.endpoint_path,
1069+
"success",
1070+
time.monotonic() - inference_start_time,
1071+
)
9841072
token_usage = extract_token_usage(
9851073
api_response.usage, api_params.model, context.endpoint_path
9861074
)
@@ -1002,6 +1090,13 @@ async def handle_non_streaming_response(
10021090

10031091
except RuntimeError as e:
10041092
if is_context_length_error(str(e)):
1093+
_record_response_inference_result(
1094+
api_params.model,
1095+
context.endpoint_path,
1096+
"failure",
1097+
time.monotonic() - inference_start_time,
1098+
record_failure=True,
1099+
)
10051100
_queue_responses_splunk_event(
10061101
background_tasks=context.background_tasks,
10071102
input_text=context.input_text,
@@ -1020,6 +1115,13 @@ async def handle_non_streaming_response(
10201115
raise HTTPException(**error_response.model_dump()) from e
10211116
raise e
10221117
except APIConnectionError as e:
1118+
_record_response_inference_result(
1119+
api_params.model,
1120+
context.endpoint_path,
1121+
"failure",
1122+
time.monotonic() - inference_start_time,
1123+
record_failure=True,
1124+
)
10231125
_queue_responses_splunk_event(
10241126
background_tasks=context.background_tasks,
10251127
input_text=context.input_text,
@@ -1038,6 +1140,13 @@ async def handle_non_streaming_response(
10381140
)
10391141
raise HTTPException(**error_response.model_dump()) from e
10401142
except (LLSApiStatusError, OpenAIAPIStatusError) as e:
1143+
_record_response_inference_result(
1144+
api_params.model,
1145+
context.endpoint_path,
1146+
"failure",
1147+
time.monotonic() - inference_start_time,
1148+
record_failure=True,
1149+
)
10411150
_queue_responses_splunk_event(
10421151
background_tasks=context.background_tasks,
10431152
input_text=context.input_text,

src/metrics/recording.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from collections.abc import Iterator
99
from contextlib import contextmanager
10+
from typing import Final
1011

1112
import metrics
1213
from log import get_logger
@@ -111,6 +112,32 @@ def record_llm_token_usage(
111112
logger.warning("Failed to update token metrics", exc_info=True)
112113

113114

115+
LLM_INFERENCE_RESULT_SUCCESS: Final[str] = "success"
116+
LLM_INFERENCE_RESULT_FAILURE: Final[str] = "failure"
117+
ALLOWED_LLM_INFERENCE_RESULTS: Final[frozenset[str]] = frozenset(
118+
{LLM_INFERENCE_RESULT_SUCCESS, LLM_INFERENCE_RESULT_FAILURE}
119+
)
120+
121+
122+
def normalize_llm_inference_result(result: str) -> str:
123+
"""Clamp an inference result string to the bounded label set.
124+
125+
Unknown or unexpected values are mapped to ``failure`` so that the
126+
Prometheus label cardinality stays bounded.
127+
128+
Args:
129+
result: Raw result label from the caller.
130+
131+
Returns:
132+
A value guaranteed to be in ``ALLOWED_LLM_INFERENCE_RESULTS``.
133+
"""
134+
return (
135+
result
136+
if result in ALLOWED_LLM_INFERENCE_RESULTS
137+
else LLM_INFERENCE_RESULT_FAILURE
138+
)
139+
140+
114141
def record_llm_inference_duration(
115142
provider: str, model: str, endpoint_path: str, result: str, duration: float
116143
) -> None:
@@ -123,9 +150,10 @@ def record_llm_inference_duration(
123150
result: Bounded result label, such as ``success`` or ``failure``.
124151
duration: Inference call duration in seconds.
125152
"""
153+
bounded_result = normalize_llm_inference_result(result)
126154
try:
127155
metrics.llm_inference_duration_seconds.labels(
128-
provider, model, endpoint_path, result
156+
provider, model, endpoint_path, bounded_result
129157
).observe(duration)
130158
except (AttributeError, TypeError, ValueError):
131159
logger.warning("Failed to update LLM inference duration metric", exc_info=True)

0 commit comments

Comments
 (0)