Skip to content

Commit 2cb9840

Browse files
committed
feat: add RLS API inference metrics
Signed-off-by: Major Hayden <major@redhat.com>
1 parent ca125c4 commit 2cb9840

4 files changed

Lines changed: 110 additions & 4 deletions

File tree

src/app/endpoints/rlsapi_v1.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,9 @@ def _record_inference_failure( # pylint: disable=too-many-arguments,too-many-po
455455
"""
456456
inference_time = time.monotonic() - start_time
457457
recording.record_llm_failure(provider, model, endpoint_path)
458+
recording.record_llm_inference_duration(
459+
provider, model, endpoint_path, "failure", inference_time
460+
)
458461
_queue_splunk_event(
459462
background_tasks,
460463
infer_request,
@@ -669,13 +672,14 @@ async def infer_endpoint( # pylint: disable=R0914
669672
"""
670673
# Authentication enforced by get_auth_dependency(), authorization by @authorize decorator.
671674
check_configuration_loaded(configuration)
672-
673675
# Quota enforcement: resolve subject and check availability before any work.
674676
# No-op when quota_subject is not configured or no quota limiters exist.
675677
quota_id = _resolve_quota_subject(request, auth)
676678
if quota_id is not None:
677679
check_tokens_available(configuration.quota_limiters, quota_id)
678680

681+
endpoint_path = "/v1/infer"
682+
679683
request_id = get_suid()
680684

681685
logger.info("Processing rlsapi v1 /infer request %s", request_id)
@@ -685,8 +689,6 @@ async def infer_endpoint( # pylint: disable=R0914
685689
"Request %s: Combined input source length: %d", request_id, len(input_source)
686690
)
687691

688-
endpoint_path = "/v1/infer"
689-
690692
# Run shield moderation on user input before inference.
691693
# Uses all configured shields; no-op when no shields are registered.
692694
# Runs before model/tool discovery so blocked requests short-circuit
@@ -721,6 +723,9 @@ async def infer_endpoint( # pylint: disable=R0914
721723
response_text = extract_text_from_response_items(response.output)
722724
token_usage = extract_token_usage(response.usage, model_id, endpoint_path)
723725
inference_time = time.monotonic() - start_time
726+
recording.record_llm_inference_duration(
727+
provider, model, endpoint_path, "success", inference_time
728+
)
724729
except _INFER_HANDLED_EXCEPTIONS as error:
725730
if response is not None:
726731
extract_token_usage(response.usage, model_id, endpoint_path) # type: ignore[arg-type]

src/metrics/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,27 @@
11
"""Metrics module for Lightspeed Core Stack."""
22

3+
from typing import Final
4+
35
from prometheus_client import (
46
Counter,
57
Gauge,
68
Histogram,
79
)
810

11+
LLM_INFERENCE_DURATION_BUCKETS: Final[tuple[float, ...]] = (
12+
0.1,
13+
0.5,
14+
1.0,
15+
2.5,
16+
5.0,
17+
10.0,
18+
20.0,
19+
30.0,
20+
60.0,
21+
120.0,
22+
float("inf"),
23+
)
24+
925
# Counter to track REST API calls
1026
# This will be used to count how many times each API endpoint is called
1127
# and the status code of the response
@@ -55,3 +71,11 @@
5571
"LLM tokens received",
5672
["provider", "model", "endpoint"],
5773
)
74+
75+
# Histogram to measure the latency of direct LLM inference backend calls.
76+
llm_inference_duration_seconds = Histogram(
77+
"ls_llm_inference_duration_seconds",
78+
"LLM inference call duration",
79+
["provider", "model", "endpoint", "result"],
80+
buckets=LLM_INFERENCE_DURATION_BUCKETS,
81+
)

src/metrics/recording.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,23 @@ def record_llm_token_usage(
109109
)
110110
except (AttributeError, TypeError, ValueError):
111111
logger.warning("Failed to update token metrics", exc_info=True)
112+
113+
114+
def record_llm_inference_duration(
115+
provider: str, model: str, endpoint_path: str, result: str, duration: float
116+
) -> None:
117+
"""Record the latency of a direct LLM inference backend call.
118+
119+
Args:
120+
provider: LLM provider identifier.
121+
model: LLM model identifier without the provider prefix.
122+
endpoint_path: API endpoint path for metric labeling.
123+
result: Bounded result label, such as ``success`` or ``failure``.
124+
duration: Inference call duration in seconds.
125+
"""
126+
try:
127+
metrics.llm_inference_duration_seconds.labels(
128+
provider, model, endpoint_path, result
129+
).observe(duration)
130+
except (AttributeError, TypeError, ValueError):
131+
logger.warning("Failed to update LLM inference duration metric", exc_info=True)

tests/unit/metrics/test_recording.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,26 @@
11
"""Unit tests for Prometheus metric recording helpers."""
22

3-
from pytest_mock import MockerFixture
3+
from collections.abc import Callable
4+
from dataclasses import dataclass
5+
6+
import pytest
7+
from pytest_mock import MockerFixture, MockType
48

59
from metrics import recording
610

711

12+
@dataclass(frozen=True)
13+
class HistogramRecorderCase:
14+
"""Expected behavior for a histogram-style metric recorder."""
15+
16+
metric_path: str
17+
recorder: Callable[..., None]
18+
args: tuple[object, ...]
19+
labels: tuple[object, ...]
20+
duration: float
21+
warning_message: str
22+
23+
824
def test_measure_response_duration_records_timer(mocker: MockerFixture) -> None:
925
"""Test that response duration measurement uses the path label timer."""
1026
mock_timer = mocker.MagicMock()
@@ -159,3 +175,44 @@ def test_record_llm_token_usage_logs_metric_errors(mocker: MockerFixture) -> Non
159175
mock_logger.warning.assert_called_once_with(
160176
"Failed to update token metrics", exc_info=True
161177
)
178+
179+
180+
@pytest.fixture(name="recording_logger")
181+
def recording_logger_fixture(mocker: MockerFixture) -> MockType:
182+
"""Patch the metric recording logger for failure assertions."""
183+
return mocker.patch("metrics.recording.logger")
184+
185+
186+
@pytest.mark.parametrize(
187+
"case",
188+
[
189+
HistogramRecorderCase(
190+
metric_path="metrics.recording.metrics.llm_inference_duration_seconds",
191+
recorder=recording.record_llm_inference_duration,
192+
args=("vertexai", "gemini", "/v1/responses", "success", 1.5),
193+
labels=("vertexai", "gemini", "/v1/responses", "success"),
194+
duration=1.5,
195+
warning_message="Failed to update LLM inference duration metric",
196+
),
197+
],
198+
)
199+
def test_histogram_recorders_observe_metrics_and_log_errors(
200+
mocker: MockerFixture,
201+
recording_logger: MockType,
202+
case: HistogramRecorderCase,
203+
) -> None:
204+
"""Test new histogram helpers with shared success and failure coverage."""
205+
mock_metric = mocker.patch(case.metric_path)
206+
207+
case.recorder(*case.args)
208+
209+
mock_metric.labels.assert_called_once_with(*case.labels)
210+
mock_metric.labels.return_value.observe.assert_called_once_with(case.duration)
211+
212+
mock_metric.reset_mock()
213+
mock_metric.labels.return_value.observe.side_effect = TypeError("bad")
214+
case.recorder(*case.args)
215+
216+
recording_logger.warning.assert_called_once_with(
217+
case.warning_message, exc_info=True
218+
)

0 commit comments

Comments
 (0)