Skip to content

Commit 5e35cbc

Browse files
authored
Merge pull request #1236 from samdoran/rlsapi-metrics
Properly increment metrics for /v1/infer
2 parents d3ba681 + efb62c2 commit 5e35cbc

4 files changed

Lines changed: 55 additions & 11 deletions

File tree

src/app/endpoints/rlsapi_v1.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,13 @@
3333
from models.rlsapi.requests import RlsapiV1InferRequest, RlsapiV1SystemInfo
3434
from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse
3535
from observability import InferenceEventData, build_inference_event, send_splunk_event
36-
from utils.query import handle_known_apistatus_errors
36+
from utils.query import (
37+
extract_provider_and_model_from_model_id,
38+
handle_known_apistatus_errors,
39+
)
3740
from utils.responses import (
3841
extract_text_from_response_items,
42+
extract_token_usage,
3943
get_mcp_tools,
4044
)
4145
from utils.suid import get_suid
@@ -191,6 +195,7 @@ async def retrieve_simple_response(
191195
store=False,
192196
)
193197
response = cast(OpenAIResponseObject, response)
198+
extract_token_usage(response.usage, model_id)
194199

195200
return extract_text_from_response_items(response.output)
196201

@@ -242,6 +247,8 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
242247
request_id: str,
243248
error: Exception,
244249
start_time: float,
250+
model: str,
251+
provider: str,
245252
) -> float:
246253
"""Record metrics and queue Splunk event for an inference failure.
247254
@@ -257,7 +264,7 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
257264
The total inference time in seconds.
258265
"""
259266
inference_time = time.monotonic() - start_time
260-
metrics.llm_calls_failures_total.inc()
267+
metrics.llm_calls_failures_total.labels(provider, model).inc()
261268
_queue_splunk_event(
262269
background_tasks,
263270
infer_request,
@@ -272,7 +279,7 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
272279

273280
@router.post("/infer", responses=infer_responses)
274281
@authorize(Action.RLSAPI_V1_INFER)
275-
async def infer_endpoint(
282+
async def infer_endpoint( # pylint: disable=R0914
276283
infer_request: RlsapiV1InferRequest,
277284
request: Request,
278285
background_tasks: BackgroundTasks,
@@ -307,6 +314,7 @@ async def infer_endpoint(
307314
input_source = infer_request.get_input_source()
308315
instructions = _build_instructions(infer_request.context.systeminfo)
309316
model_id = _get_default_model_id()
317+
provider, model = extract_provider_and_model_from_model_id(model_id)
310318
mcp_tools = await get_mcp_tools(request_headers=request.headers)
311319
logger.debug(
312320
"Request %s: Combined input source length: %d", request_id, len(input_source)
@@ -321,19 +329,40 @@ async def infer_endpoint(
321329
except RuntimeError as e:
322330
if "context_length" in str(e).lower():
323331
_record_inference_failure(
324-
background_tasks, infer_request, request, request_id, e, start_time
332+
background_tasks,
333+
infer_request,
334+
request,
335+
request_id,
336+
e,
337+
start_time,
338+
model,
339+
provider,
325340
)
326341
logger.error("Prompt too long for request %s: %s", request_id, e)
327342
error_response = PromptTooLongResponse(model=model_id)
328343
raise HTTPException(**error_response.model_dump()) from e
329344
_record_inference_failure(
330-
background_tasks, infer_request, request, request_id, e, start_time
345+
background_tasks,
346+
infer_request,
347+
request,
348+
request_id,
349+
e,
350+
start_time,
351+
model,
352+
provider,
331353
)
332354
logger.error("Unexpected RuntimeError for request %s: %s", request_id, e)
333355
raise
334356
except APIConnectionError as e:
335357
_record_inference_failure(
336-
background_tasks, infer_request, request, request_id, e, start_time
358+
background_tasks,
359+
infer_request,
360+
request,
361+
request_id,
362+
e,
363+
start_time,
364+
model,
365+
provider,
337366
)
338367
logger.error(
339368
"Unable to connect to Llama Stack for request %s: %s", request_id, e
@@ -345,7 +374,14 @@ async def infer_endpoint(
345374
raise HTTPException(**error_response.model_dump()) from e
346375
except RateLimitError as e:
347376
_record_inference_failure(
348-
background_tasks, infer_request, request, request_id, e, start_time
377+
background_tasks,
378+
infer_request,
379+
request,
380+
request_id,
381+
e,
382+
start_time,
383+
model,
384+
provider,
349385
)
350386
logger.error("Rate limit exceeded for request %s: %s", request_id, e)
351387
error_response = QuotaExceededResponse(
@@ -355,7 +391,14 @@ async def infer_endpoint(
355391
raise HTTPException(**error_response.model_dump()) from e
356392
except (APIStatusError, OpenAIAPIStatusError) as e:
357393
_record_inference_failure(
358-
background_tasks, infer_request, request, request_id, e, start_time
394+
background_tasks,
395+
infer_request,
396+
request,
397+
request_id,
398+
e,
399+
start_time,
400+
model,
401+
provider,
359402
)
360403
logger.exception("API error for request %s: %s", request_id, e)
361404
error_response = handle_known_apistatus_errors(e, model_id)

src/metrics/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
)
3434

3535
# Metric that counts how many LLM calls failed
36-
llm_calls_failures_total = Counter("ls_llm_calls_failures_total", "LLM calls failures")
36+
llm_calls_failures_total = Counter(
37+
"ls_llm_calls_failures_total", "LLM calls failures", ["provider", "model"]
38+
)
3739

3840
# Metric that counts how many LLM calls had validation errors
3941
llm_calls_validation_errors_total = Counter(

src/utils/query.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def extract_provider_and_model_from_model_id(model_id: str) -> tuple[str, str]:
484484
model_id: The model ID to extract from.
485485
486486
Returns:
487-
tuple[str, str]: The model and provider.
487+
tuple[str, str]: The provider and model.
488488
"""
489489
split = model_id.split("/", 1)
490490
if len(split) == 2:

tests/unit/app/endpoints/test_metrics.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ async def test_metrics_endpoint(mocker: MockerFixture) -> None:
4242
assert "# TYPE ls_provider_model_configuration gauge" in response_body
4343
assert "# TYPE ls_llm_calls_total counter" in response_body
4444
assert "# TYPE ls_llm_calls_failures_total counter" in response_body
45-
assert "# TYPE ls_llm_calls_failures_created gauge" in response_body
4645
assert "# TYPE ls_llm_validation_errors_total counter" in response_body
4746
assert "# TYPE ls_llm_validation_errors_created gauge" in response_body
4847
assert "# TYPE ls_llm_token_sent_total counter" in response_body

0 commit comments

Comments
 (0)