Skip to content

Commit ca125c4

Browse files
authored
Merge pull request lightspeed-core#1624 from Lifto/feat/rspeed-2849-llm-metrics
RSPEED-2957: add endpoint label to LLM Prometheus metrics
2 parents 4c6177e + ee9570a commit ca125c4

13 files changed

Lines changed: 191 additions & 96 deletions

File tree

src/app/endpoints/query.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,10 @@ async def query_endpoint_handler(
170170

171171
# Moderation input is the raw user content (query + attachments) without injected RAG
172172
# context, to avoid false positives from retrieved document content.
173+
endpoint_path = "/v1/query"
173174
moderation_input = prepare_input(query_request)
174175
moderation_result = await run_shield_moderation(
175-
client, moderation_input, query_request.shield_ids
176+
client, moderation_input, endpoint_path, query_request.shield_ids
176177
)
177178

178179
# Build RAG context from Inline RAG sources
@@ -207,7 +208,9 @@ async def query_endpoint_handler(
207208
client = await update_azure_token(client)
208209

209210
# Retrieve response using Responses API
210-
turn_summary = await retrieve_response(client, responses_params, moderation_result)
211+
turn_summary = await retrieve_response(
212+
client, responses_params, moderation_result, endpoint_path
213+
)
211214

212215
if moderation_result.decision == "passed":
213216
# Combine inline RAG results (BYOK + Solr) with tool-based RAG results for the transcript
@@ -280,6 +283,7 @@ async def retrieve_response(
280283
client: AsyncLlamaStackClient,
281284
responses_params: ResponsesApiParams,
282285
moderation_result: ShieldModerationResult,
286+
endpoint_path: str = "",
283287
) -> TurnSummary:
284288
"""
285289
Retrieve response from LLMs and agents.
@@ -332,5 +336,9 @@ async def retrieve_response(
332336
vector_store_ids = extract_vector_store_ids_from_tools(responses_params.tools)
333337
rag_id_mapping = configuration.rag_id_mapping
334338
return build_turn_summary(
335-
response, responses_params.model, vector_store_ids, rag_id_mapping
339+
response,
340+
responses_params.model,
341+
endpoint_path,
342+
vector_store_ids,
343+
rag_id_mapping,
336344
)

src/app/endpoints/responses.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,9 +331,11 @@ async def responses_endpoint_handler(
331331
)
332332
attachments_text = extract_attachments_text(original_request.input)
333333

334+
endpoint_path = "/v1/responses"
334335
moderation_result = await run_shield_moderation(
335336
client,
336337
input_text + "\n\n" + attachments_text,
338+
endpoint_path,
337339
original_request.shield_ids,
338340
)
339341

@@ -388,6 +390,7 @@ async def responses_endpoint_handler(
388390
background_tasks=background_tasks,
389391
rh_identity_context=rh_identity_context,
390392
user_agent=_get_user_agent(request),
393+
endpoint_path=endpoint_path,
391394
)
392395

393396

@@ -404,6 +407,7 @@ async def handle_streaming_response(
404407
background_tasks: Optional[BackgroundTasks] = None,
405408
rh_identity_context: tuple[str, str] = ("", ""),
406409
user_agent: Optional[str] = None,
410+
endpoint_path: str = "",
407411
) -> StreamingResponse:
408412
"""Handle streaming response from Responses API.
409413
@@ -470,6 +474,7 @@ async def handle_streaming_response(
470474
turn_summary=turn_summary,
471475
inline_rag_context=inline_rag_context,
472476
filter_server_tools=filter_server_tools,
477+
endpoint_path=endpoint_path,
473478
)
474479
except RuntimeError as e: # library mode wraps 413 into runtime error
475480
if is_context_length_error(str(e)):
@@ -798,6 +803,7 @@ async def response_generator(
798803
turn_summary: TurnSummary,
799804
inline_rag_context: RAGContext,
800805
filter_server_tools: bool = False,
806+
endpoint_path: str = "",
801807
) -> AsyncIterator[str]:
802808
"""Generate SSE-formatted streaming response with LCORE-enriched events.
803809
@@ -810,6 +816,7 @@ async def response_generator(
810816
turn_summary: TurnSummary to populate during streaming
811817
inline_rag_context: Inline RAG context to be used for the response
812818
filter_server_tools: Whether to filter server-deployed MCP tool events from the stream
819+
endpoint_path: API endpoint path used for metric labeling.
813820
Yields:
814821
SSE-formatted strings for streaming events, ending with [DONE]
815822
"""
@@ -873,7 +880,7 @@ async def response_generator(
873880

874881
# Extract and consume tokens if any were used
875882
turn_summary.token_usage = extract_token_usage(
876-
latest_response_object.usage, api_params.model
883+
latest_response_object.usage, api_params.model, endpoint_path
877884
)
878885
consume_query_tokens(
879886
user_id=user_id,
@@ -1010,6 +1017,7 @@ async def handle_non_streaming_response(
10101017
background_tasks: Optional[BackgroundTasks] = None,
10111018
rh_identity_context: tuple[str, str] = ("", ""),
10121019
user_agent: Optional[str] = None,
1020+
endpoint_path: str = "",
10131021
) -> ResponsesResponse:
10141022
"""Handle non-streaming response from Responses API.
10151023
@@ -1069,7 +1077,9 @@ async def handle_non_streaming_response(
10691077
**api_params.model_dump(exclude_none=True)
10701078
),
10711079
)
1072-
token_usage = extract_token_usage(api_response.usage, api_params.model)
1080+
token_usage = extract_token_usage(
1081+
api_response.usage, api_params.model, endpoint_path
1082+
)
10731083
logger.info("Consuming tokens")
10741084
consume_query_tokens(
10751085
user_id=user_id,
@@ -1152,6 +1162,7 @@ async def handle_non_streaming_response(
11521162
turn_summary = build_turn_summary(
11531163
api_response,
11541164
api_params.model,
1165+
endpoint_path,
11551166
vector_store_ids,
11561167
configuration.rag_id_mapping,
11571168
filter_server_tools=filter_server_tools,

src/app/endpoints/rlsapi_v1.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ async def retrieve_simple_response(
241241
instructions: str,
242242
tools: Optional[list[Any]] = None,
243243
model_id: Optional[str] = None,
244+
endpoint_path: str = "/v1/infer",
244245
) -> str:
245246
"""Retrieve a simple response from the LLM for a stateless query.
246247
@@ -263,7 +264,7 @@ async def retrieve_simple_response(
263264
"""
264265
resolved_model_id = model_id or await _get_default_model_id()
265266
response = await _call_llm(question, instructions, tools, resolved_model_id)
266-
extract_token_usage(response.usage, resolved_model_id)
267+
extract_token_usage(response.usage, resolved_model_id, endpoint_path)
267268
return extract_text_from_response_items(response.output)
268269

269270

@@ -366,12 +367,13 @@ def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-position
366367
background_tasks.add_task(send_splunk_event, event, sourcetype)
367368

368369

369-
async def _check_shield_moderation(
370+
async def _check_shield_moderation( # pylint: disable=too-many-arguments,too-many-positional-arguments
370371
input_text: str,
371372
request_id: str,
372373
background_tasks: BackgroundTasks,
373374
infer_request: RlsapiV1InferRequest,
374375
request: Request,
376+
endpoint_path: str,
375377
) -> Optional[RlsapiV1InferResponse]:
376378
"""Run shield moderation and return a refusal response if blocked.
377379
@@ -384,13 +386,14 @@ async def _check_shield_moderation(
384386
background_tasks: FastAPI background tasks for async Splunk event sending.
385387
infer_request: The original inference request (for Splunk event context).
386388
request: The FastAPI request object (for Splunk event context).
389+
endpoint_path: The API endpoint path for metric labeling.
387390
388391
Returns:
389392
An RlsapiV1InferResponse containing the refusal message if the input
390393
was blocked, or None if moderation passed.
391394
"""
392395
client = AsyncLlamaStackClientHolder().get_client()
393-
moderation_result = await run_shield_moderation(client, input_text)
396+
moderation_result = await run_shield_moderation(client, input_text, endpoint_path)
394397

395398
if moderation_result.decision != "blocked":
396399
return None
@@ -432,6 +435,7 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
432435
start_time: float,
433436
model: str,
434437
provider: str,
438+
endpoint_path: str,
435439
) -> float:
436440
"""Record metrics and queue Splunk event for an inference failure.
437441
@@ -442,12 +446,15 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
442446
request_id: Unique identifier for the request.
443447
error: The exception that caused the failure.
444448
start_time: Monotonic clock time when inference started.
449+
model: The model name.
450+
provider: The provider name.
451+
endpoint_path: The API endpoint path for metric labeling.
445452
446453
Returns:
447454
The total inference time in seconds.
448455
"""
449456
inference_time = time.monotonic() - start_time
450-
recording.record_llm_failure(provider, model)
457+
recording.record_llm_failure(provider, model, endpoint_path)
451458
_queue_splunk_event(
452459
background_tasks,
453460
infer_request,
@@ -530,6 +537,7 @@ def _build_infer_response(
530537
request_id: str,
531538
response: Optional[OpenAIResponseObject],
532539
model_id: str,
540+
endpoint_path: str,
533541
) -> RlsapiV1InferResponse:
534542
"""Build the final inference response, with optional verbose metadata.
535543
@@ -549,7 +557,11 @@ def _build_infer_response(
549557
"""
550558
if response is not None:
551559
turn_summary = build_turn_summary(
552-
response, model_id, vector_store_ids=None, rag_id_mapping=None
560+
response,
561+
model_id,
562+
endpoint_path,
563+
vector_store_ids=None,
564+
rag_id_mapping=None,
553565
)
554566
return RlsapiV1InferResponse(
555567
data=RlsapiV1InferData(
@@ -673,12 +685,19 @@ async def infer_endpoint( # pylint: disable=R0914
673685
"Request %s: Combined input source length: %d", request_id, len(input_source)
674686
)
675687

688+
endpoint_path = "/v1/infer"
689+
676690
# Run shield moderation on user input before inference.
677691
# Uses all configured shields; no-op when no shields are registered.
678692
# Runs before model/tool discovery so blocked requests short-circuit
679693
# without incurring external I/O.
680694
blocked_response = await _check_shield_moderation(
681-
input_source, request_id, background_tasks, infer_request, request
695+
input_source,
696+
request_id,
697+
background_tasks,
698+
infer_request,
699+
request,
700+
endpoint_path,
682701
)
683702
if blocked_response is not None:
684703
return blocked_response
@@ -700,11 +719,11 @@ async def infer_endpoint( # pylint: disable=R0914
700719
model_id=model_id,
701720
)
702721
response_text = extract_text_from_response_items(response.output)
703-
token_usage = extract_token_usage(response.usage, model_id)
722+
token_usage = extract_token_usage(response.usage, model_id, endpoint_path)
704723
inference_time = time.monotonic() - start_time
705724
except _INFER_HANDLED_EXCEPTIONS as error:
706725
if response is not None:
707-
extract_token_usage(response.usage, model_id) # type: ignore[arg-type]
726+
extract_token_usage(response.usage, model_id, endpoint_path) # type: ignore[arg-type]
708727
_record_inference_failure(
709728
background_tasks,
710729
infer_request,
@@ -714,6 +733,7 @@ async def infer_endpoint( # pylint: disable=R0914
714733
start_time,
715734
model,
716735
provider,
736+
endpoint_path,
717737
)
718738
mapped_error = _map_inference_error_to_http_exception(
719739
error,
@@ -755,4 +775,5 @@ async def infer_endpoint( # pylint: disable=R0914
755775
request_id,
756776
response if verbose_enabled else None,
757777
model_id,
778+
endpoint_path,
758779
)

src/app/endpoints/streaming_query.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,9 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
226226
# Moderation input is the raw user content (query + attachments) without injected RAG
227227
# context, to avoid false positives from retrieved document content.
228228
moderation_input = prepare_input(query_request)
229+
endpoint_path = "/v1/streaming_query"
229230
moderation_result = await run_shield_moderation(
230-
client, moderation_input, query_request.shield_ids
231+
client, moderation_input, endpoint_path, query_request.shield_ids
231232
)
232233

233234
# Build RAG context from Inline RAG sources
@@ -283,11 +284,12 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
283284
provider_id, model_id = extract_provider_and_model_from_model_id(
284285
responses_params.model
285286
)
286-
recording.record_llm_call(provider_id, model_id)
287+
recording.record_llm_call(provider_id, model_id, endpoint_path)
287288

288289
generator, turn_summary = await retrieve_response_generator(
289290
responses_params=responses_params,
290291
context=context,
292+
endpoint_path=endpoint_path,
291293
)
292294

293295
# Combine inline RAG results (BYOK + Solr) with tool-based results
@@ -316,6 +318,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
316318
async def retrieve_response_generator(
317319
responses_params: ResponsesApiParams,
318320
context: ResponseGeneratorContext,
321+
endpoint_path: str,
319322
) -> tuple[AsyncIterator[str], TurnSummary]:
320323
"""
321324
Retrieve the appropriate response generator.
@@ -327,6 +330,7 @@ async def retrieve_response_generator(
327330
Args:
328331
responses_params: The Responses API parameters
329332
context: The response generator context
333+
endpoint_path: API endpoint path used for metric labeling.
330334
Returns:
331335
tuple[AsyncIterator[str], TurnSummary]: The response generator and turn summary
332336
@@ -360,6 +364,7 @@ async def retrieve_response_generator(
360364
response,
361365
context,
362366
turn_summary,
367+
endpoint_path,
363368
),
364369
turn_summary,
365370
)
@@ -685,6 +690,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
685690
turn_response: AsyncIterator[OpenAIResponseObjectStream],
686691
context: ResponseGeneratorContext,
687692
turn_summary: TurnSummary,
693+
endpoint_path: str,
688694
) -> AsyncIterator[str]:
689695
"""Generate SSE formatted streaming response.
690696
@@ -696,6 +702,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
696702
turn_response: The streaming response from Llama Stack
697703
context: The response generator context
698704
turn_summary: TurnSummary to populate during streaming
705+
endpoint_path: API endpoint path used for metric labeling.
699706
700707
Yields:
701708
SSE-formatted strings for tokens, tool calls, tool results,
@@ -862,7 +869,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
862869
return
863870

864871
turn_summary.token_usage = extract_token_usage(
865-
latest_response_object.usage, context.model_id
872+
latest_response_object.usage, context.model_id, endpoint_path
866873
)
867874
# Parse tool-based referenced documents from the final response object
868875
tool_rag_docs = parse_referenced_documents(

src/metrics/__init__.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,29 @@
2929

3030
# Metric that counts how many LLM calls were made for each provider + model
3131
llm_calls_total = Counter(
32-
"ls_llm_calls_total", "LLM calls counter", ["provider", "model"]
32+
"ls_llm_calls_total", "LLM calls counter", ["provider", "model", "endpoint"]
3333
)
3434

3535
# Metric that counts how many LLM calls failed
3636
llm_calls_failures_total = Counter(
37-
"ls_llm_calls_failures_total", "LLM calls failures", ["provider", "model"]
37+
"ls_llm_calls_failures_total",
38+
"LLM calls failures",
39+
["provider", "model", "endpoint"],
3840
)
3941

4042
# Metric that counts how many LLM calls had validation errors
4143
llm_calls_validation_errors_total = Counter(
42-
"ls_llm_validation_errors_total", "LLM validation errors"
44+
"ls_llm_validation_errors_total", "LLM validation errors", ["endpoint"]
4345
)
4446

4547
# Metric that counts how many tokens were sent to LLMs
4648
llm_token_sent_total = Counter(
47-
"ls_llm_token_sent_total", "LLM tokens sent", ["provider", "model"]
49+
"ls_llm_token_sent_total", "LLM tokens sent", ["provider", "model", "endpoint"]
4850
)
4951

5052
# Metric that counts how many tokens were received from LLMs
5153
llm_token_received_total = Counter(
52-
"ls_llm_token_received_total", "LLM tokens received", ["provider", "model"]
54+
"ls_llm_token_received_total",
55+
"LLM tokens received",
56+
["provider", "model", "endpoint"],
5357
)

0 commit comments

Comments
 (0)