diff --git a/helm/templates/deployment-router.yaml b/helm/templates/deployment-router.yaml index 1648af943..3dd7d89bd 100644 --- a/helm/templates/deployment-router.yaml +++ b/helm/templates/deployment-router.yaml @@ -74,7 +74,7 @@ spec: name: {{ .Values.routerSpec.vllmApiKey.secretName }} key: {{ .Values.routerSpec.vllmApiKey.secretKey }} {{- else if and .Values.servingEngineSpec.enableEngine }} - {{- if kindIs "string" .Values.servingEngineSpec.vllmApiKey }} + {{- if or (kindIs "string" .Values.servingEngineSpec.vllmApiKey) (kindIs "slice" .Values.servingEngineSpec.vllmApiKey) }} - name: VLLM_API_KEY valueFrom: secretKeyRef: diff --git a/helm/templates/deployment-vllm-multi.yaml b/helm/templates/deployment-vllm-multi.yaml index e91fd5455..e75f1f4ef 100644 --- a/helm/templates/deployment-vllm-multi.yaml +++ b/helm/templates/deployment-vllm-multi.yaml @@ -251,7 +251,7 @@ spec: {{- end }} {{- $vllmApiKey := $.Values.servingEngineSpec.vllmApiKey }} {{- if $vllmApiKey }} - {{- if kindIs "string" $vllmApiKey }} + {{- if or (kindIs "string" $vllmApiKey) (kindIs "slice" $vllmApiKey) }} - name: VLLM_API_KEY valueFrom: secretKeyRef: diff --git a/helm/templates/secrets.yaml b/helm/templates/secrets.yaml index 458e542cb..2ce8f185c 100644 --- a/helm/templates/secrets.yaml +++ b/helm/templates/secrets.yaml @@ -7,9 +7,13 @@ metadata: type: Opaque data: {{- $vllmApiKey := $.Values.servingEngineSpec.vllmApiKey }} - {{- if and $vllmApiKey (kindIs "string" $vllmApiKey) }} + {{- if $vllmApiKey }} + {{- if kindIs "slice" $vllmApiKey }} + vllmApiKey: {{ join "," $vllmApiKey | b64enc | quote }} + {{- else if kindIs "string" $vllmApiKey }} vllmApiKey: {{ $vllmApiKey | b64enc | quote }} {{- end }} + {{- end }} {{- range $modelSpec := .Values.servingEngineSpec.modelSpec }} {{- with $ -}} diff --git a/helm/values.yaml b/helm/values.yaml index 0a18a910c..b98836659 100644 --- a/helm/values.yaml +++ b/helm/values.yaml @@ -13,8 +13,17 @@ servingEngineSpec: # -- Extra service ports for models extraPorts: [] - # -- API key for securing the vLLM models. Can be either a string that will be stored in a generated secret or an object referencing an existing secret. - vllmApiKey: # @schema type:[string, object] + # -- API key for securing the vLLM models. Accepts a single string, a + # -- comma-separated string, a list of strings (all stored in a generated + # -- secret joined as comma-separated), or an object referencing an existing + # -- Kubernetes secret. Multiple keys enable multi-tenant access so different + # -- teams can authenticate with independent tokens. + # -- Examples: + # -- vllmApiKey: "single-key" + # -- vllmApiKey: "key1,key2,key3" + # -- vllmApiKey: ["key1", "key2", "key3"] + # -- vllmApiKey: {secretName: my-secret, secretKey: api-key} + vllmApiKey: # @schema type:[string, array, object] # -- Name of the existing Kubernetes secret that contains the vLLM API key secretName: "" # -- Key within the secret that contains the vLLM API key diff --git a/src/tests/test_multi_api_key_auth.py b/src/tests/test_multi_api_key_auth.py new file mode 100644 index 000000000..459d15733 --- /dev/null +++ b/src/tests/test_multi_api_key_auth.py @@ -0,0 +1,130 @@ +from unittest.mock import MagicMock + +import pytest +from fastapi import HTTPException + +from vllm_router.auth import _parse_api_keys, get_allowed_api_keys, verify_api_key + +# --------------------------------------------------------------------------- +# _parse_api_keys +# --------------------------------------------------------------------------- + + +def test_parse_single_key(): + assert _parse_api_keys("key1") == {"key1"} + + +def test_parse_comma_separated_keys(): + assert _parse_api_keys("key1,key2,key3") == {"key1", "key2", "key3"} + + +def test_parse_strips_whitespace(): + assert _parse_api_keys("key1, key2 , key3") == {"key1", "key2", "key3"} + + +def test_parse_ignores_empty_segments(): + assert _parse_api_keys("key1,,key2") == {"key1", "key2"} + assert _parse_api_keys(",key1,") == {"key1"} + + +def test_parse_empty_string_returns_empty_set(): + assert _parse_api_keys("") == frozenset() + + +def test_parse_whitespace_only_returns_empty_set(): + assert _parse_api_keys(" , ") == frozenset() + + +# --------------------------------------------------------------------------- +# get_allowed_api_keys — reads from environment +# --------------------------------------------------------------------------- + + +def test_get_allowed_api_keys_no_env_var(monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + assert get_allowed_api_keys() == frozenset() + + +def test_get_allowed_api_keys_single(monkeypatch): + monkeypatch.setenv("VLLM_API_KEY", "secret") + assert get_allowed_api_keys() == {"secret"} + + +def test_get_allowed_api_keys_multiple(monkeypatch): + monkeypatch.setenv("VLLM_API_KEY", "key1,key2,key3") + assert get_allowed_api_keys() == {"key1", "key2", "key3"} + + +def test_get_allowed_api_keys_trims_spaces(monkeypatch): + monkeypatch.setenv("VLLM_API_KEY", " key1 , key2 ") + assert get_allowed_api_keys() == {"key1", "key2"} + + +# --------------------------------------------------------------------------- +# verify_api_key dependency +# --------------------------------------------------------------------------- + + +def _make_request(auth_header: str | None = None) -> MagicMock: + request = MagicMock() + headers = {} + if auth_header is not None: + headers["Authorization"] = auth_header + request.headers = headers + return request + + +@pytest.mark.anyio +async def test_verify_no_keys_configured_allows_all(monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + request = _make_request() + await verify_api_key(request) # must not raise + + +@pytest.mark.anyio +async def test_verify_valid_single_key(monkeypatch): + monkeypatch.setenv("VLLM_API_KEY", "secret") + request = _make_request("Bearer secret") + await verify_api_key(request) # must not raise + + +@pytest.mark.anyio +async def test_verify_valid_key_among_multiple(monkeypatch): + monkeypatch.setenv("VLLM_API_KEY", "key1,key2,key3") + for key in ("key1", "key2", "key3"): + request = _make_request(f"Bearer {key}") + await verify_api_key(request) # must not raise + + +@pytest.mark.anyio +async def test_verify_invalid_key_raises_401(monkeypatch): + monkeypatch.setenv("VLLM_API_KEY", "key1,key2") + request = _make_request("Bearer wrong-key") + with pytest.raises(HTTPException) as exc_info: + await verify_api_key(request) + assert exc_info.value.status_code == 401 + + +@pytest.mark.anyio +async def test_verify_missing_auth_header_raises_401(monkeypatch): + monkeypatch.setenv("VLLM_API_KEY", "secret") + request = _make_request() # no Authorization header + with pytest.raises(HTTPException) as exc_info: + await verify_api_key(request) + assert exc_info.value.status_code == 401 + + +@pytest.mark.anyio +async def test_verify_non_bearer_scheme_raises_401(monkeypatch): + monkeypatch.setenv("VLLM_API_KEY", "secret") + request = _make_request("Basic secret") + with pytest.raises(HTTPException) as exc_info: + await verify_api_key(request) + assert exc_info.value.status_code == 401 + + +@pytest.mark.anyio +async def test_verify_extra_whitespace_in_env_key(monkeypatch): + monkeypatch.setenv("VLLM_API_KEY", " key1 , key2 ") + request = _make_request("Bearer key1") + await verify_api_key(request) # must not raise diff --git a/src/tests/test_stale_metrics.py b/src/tests/test_stale_metrics.py new file mode 100644 index 000000000..d7786863b --- /dev/null +++ b/src/tests/test_stale_metrics.py @@ -0,0 +1,126 @@ +import pytest +from prometheus_client import REGISTRY, generate_latest + +from vllm_router.routers.metrics_router import _LABEL_GAUGES, _clear_label_gauges +from vllm_router.service_discovery import EndpointInfo +from vllm_router.services.metrics_service import ( + current_qps, + gpu_prefix_cache_hit_rate, + gpu_prefix_cache_hits_total, + gpu_prefix_cache_queries_total, + healthy_pods_total, + num_requests_running, +) + +# --------------------------------------------------------------------------- +# EndpointInfo.healthy field +# --------------------------------------------------------------------------- + + +def test_endpoint_info_healthy_defaults_to_true(): + ep = EndpointInfo( + url="http://ep1:8000", + model_names=["llama"], + Id="id1", + added_timestamp=0, + model_label="default", + sleep=False, + ) + assert ep.healthy is True + + +def test_endpoint_info_healthy_can_be_set_false(): + ep = EndpointInfo( + url="http://ep1:8000", + model_names=["llama"], + Id="id1", + added_timestamp=0, + model_label="default", + sleep=False, + healthy=False, + ) + assert ep.healthy is False + + +# --------------------------------------------------------------------------- +# _clear_label_gauges behaviour +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _reset_gauges(): + """Clear every label-based gauge before and after each test.""" + _clear_label_gauges() + yield + _clear_label_gauges() + + +def test_cleared_gauge_removes_stale_labels(): + """Stale server labels must disappear after _clear_label_gauges().""" + # Simulate two active endpoints + healthy_pods_total.labels(server="http://ep1:8000").set(1) + healthy_pods_total.labels(server="http://ep2:8000").set(1) + + output = generate_latest(REGISTRY).decode() + assert "ep1" in output + assert "ep2" in output + + # Clear all labels (simulates start of /metrics handler) + _clear_label_gauges() + + # Re-populate only ep1 (ep2 was removed from service discovery) + healthy_pods_total.labels(server="http://ep1:8000").set(1) + + output = generate_latest(REGISTRY).decode() + assert "ep1" in output + assert "ep2" not in output + + +def test_clear_removes_all_labels_when_all_endpoints_gone(): + """When every endpoint is removed, no server labels remain.""" + current_qps.labels(server="http://ep1:8000").set(42) + num_requests_running.labels(server="http://ep1:8000").set(3) + + _clear_label_gauges() + + output = generate_latest(REGISTRY).decode() + assert "ep1" not in output + + +def test_clear_does_not_affect_unlabeled_gauges(): + """System gauges (CPU, memory, disk) have no labels and are unaffected.""" + from vllm_router.routers.metrics_router import router_cpu_usage_percent + + router_cpu_usage_percent.set(42.0) + healthy_pods_total.labels(server="http://ep1:8000").set(1) + + _clear_label_gauges() + + output = generate_latest(REGISTRY).decode() + assert "router_cpu_usage_percent" in output + assert "42.0" in output + + +def test_label_gauges_list_contains_all_expected_gauges(): + """Ensure every gauge we export with a server label is in _LABEL_GAUGES.""" + expected = { + current_qps, + gpu_prefix_cache_hit_rate, + gpu_prefix_cache_hits_total, + gpu_prefix_cache_queries_total, + healthy_pods_total, + num_requests_running, + } + assert expected.issubset(set(_LABEL_GAUGES)) + + +def test_repopulate_after_clear_shows_correct_values(): + """Values set after clear must reflect in Prometheus output.""" + _clear_label_gauges() + + healthy_pods_total.labels(server="http://new-ep:8000").set(1) + current_qps.labels(server="http://new-ep:8000").set(99.5) + + output = generate_latest(REGISTRY).decode() + assert "new-ep" in output + assert "99.5" in output diff --git a/src/vllm_router/app.py b/src/vllm_router/app.py index 4bb8823e2..6e3e341d4 100644 --- a/src/vllm_router/app.py +++ b/src/vllm_router/app.py @@ -17,9 +17,10 @@ import sentry_sdk import uvicorn -from fastapi import FastAPI +from fastapi import Depends, FastAPI from vllm_router.aiohttp_client import AiohttpClientWrapper +from vllm_router.auth import verify_api_key from vllm_router.dynamic_config import ( DynamicRouterConfig, get_dynamic_config_watcher, @@ -45,9 +46,7 @@ from vllm_router.services.batch_service import initialize_batch_processor from vllm_router.services.callbacks_service.callbacks import configure_custom_callbacks from vllm_router.services.files_service import initialize_storage -from vllm_router.services.request_service.rewriter import ( - get_request_rewriter, -) +from vllm_router.services.request_service.rewriter import get_request_rewriter from vllm_router.stats.engine_stats import ( get_engine_stats_scraper, initialize_engine_stats_scraper, @@ -80,9 +79,7 @@ initialize_semantic_cache, is_semantic_cache_enabled, ) - from vllm_router.experimental.semantic_cache_integration import ( - semantic_cache_size, - ) + from vllm_router.experimental.semantic_cache_integration import semantic_cache_size semantic_cache_available = True except ImportError: @@ -365,7 +362,7 @@ def initialize_all(app: FastAPI, args): app = FastAPI(lifespan=lifespan) -app.include_router(main_router) +app.include_router(main_router, dependencies=[Depends(verify_api_key)]) app.include_router(files_router) app.include_router(batches_router) app.include_router(metrics_router) diff --git a/src/vllm_router/auth.py b/src/vllm_router/auth.py new file mode 100644 index 000000000..859c41a10 --- /dev/null +++ b/src/vllm_router/auth.py @@ -0,0 +1,66 @@ +# Copyright 2024-2025 The vLLM Production Stack Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import FrozenSet + +from fastapi import HTTPException, Request + +from vllm_router.log import init_logger + +logger = init_logger(__name__) + + +def _parse_api_keys(raw: str) -> FrozenSet[str]: + """ + Parse a comma-separated API key string into a frozenset of non-empty keys. + + Leading/trailing whitespace around each key is stripped so that + ``"key1, key2 , key3"`` is treated identically to ``"key1,key2,key3"``. + """ + return frozenset(k.strip() for k in raw.split(",") if k.strip()) + + +def get_allowed_api_keys() -> FrozenSet[str]: + """ + Return the set of valid API keys sourced from the ``VLLM_API_KEY`` + environment variable. Returns an empty frozenset when the variable is + unset or empty, which disables authentication entirely. + """ + raw = os.getenv("VLLM_API_KEY", "") + return _parse_api_keys(raw) + + +async def verify_api_key(request: Request) -> None: + """ + FastAPI dependency that enforces Bearer-token authentication. + + When ``VLLM_API_KEY`` is set the incoming ``Authorization`` header must + carry one of the configured keys. Requests without a valid token receive + a 401 response. When ``VLLM_API_KEY`` is not configured this dependency + is a no-op and all requests are allowed through. + """ + allowed_keys = get_allowed_api_keys() + if not allowed_keys: + return + + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Bearer "): + raise HTTPException( + status_code=401, + detail="Missing or malformed Authorization header. Expected: Bearer ", + ) + + token = auth_header[len("Bearer ") :].strip() + if token not in allowed_keys: + raise HTTPException(status_code=401, detail="Invalid API key") diff --git a/src/vllm_router/routers/metrics_router.py b/src/vllm_router/routers/metrics_router.py index 276023949..48c5fd9ed 100644 --- a/src/vllm_router/routers/metrics_router.py +++ b/src/vllm_router/routers/metrics_router.py @@ -18,6 +18,7 @@ from fastapi import APIRouter, Response from prometheus_client import CONTENT_TYPE_LATEST, Gauge, generate_latest +from vllm_router.log import init_logger from vllm_router.service_discovery import get_service_discovery from vllm_router.services.metrics_service import ( avg_decoding_length, @@ -36,6 +37,8 @@ from vllm_router.stats.engine_stats import get_engine_stats_scraper from vllm_router.stats.request_stats import get_request_stats_monitor +logger = init_logger(__name__) + metrics_router = APIRouter() # Define Gauges for system resource usage @@ -53,6 +56,37 @@ ) +# All label-based gauges that must be cleared on each scrape to prevent +# stale series when endpoints are removed from service discovery. +_LABEL_GAUGES = [ + current_qps, + avg_decoding_length, + num_prefill_requests, + num_decoding_requests, + num_requests_running, + avg_latency, + avg_itl, + num_requests_swapped, + gpu_prefix_cache_hit_rate, + gpu_prefix_cache_hits_total, + gpu_prefix_cache_queries_total, + healthy_pods_total, +] + + +def _clear_label_gauges() -> None: + """ + Clear all label-based gauges to avoid stale series for removed endpoints. + + When an endpoint is removed from service discovery, its ``server=...`` + label is never overwritten, so Prometheus keeps exporting the last known + value indefinitely. Calling ``.clear()`` on each gauge removes all + label combinations so only *currently active* endpoints are re-added. + """ + for gauge in _LABEL_GAUGES: + gauge.clear() + + # --- Prometheus Metrics Endpoint --- @metrics_router.get("/metrics") async def metrics(): @@ -73,6 +107,10 @@ async def metrics(): the appropriate content type. """ + # Clear all label-based gauges to prevent stale series for removed + # endpoints. Unlabeled system gauges (CPU, memory, disk) are unaffected. + _clear_label_gauges() + # Collect CPU utilization (short interval) cpu_percent = psutil.cpu_percent(interval=0.1) router_cpu_usage_percent.set(cpu_percent) @@ -115,9 +153,7 @@ async def metrics(): # Service discovery health status endpoints = get_service_discovery().get_endpoint_info() for ep in endpoints: - healthy_pods_total.labels(server=ep.url).set( - 1 if getattr(ep, "healthy", True) else 0 - ) + healthy_pods_total.labels(server=ep.url).set(1 if ep.healthy else 0) # Return all metrics in Prometheus format return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST) diff --git a/src/vllm_router/service_discovery.py b/src/vllm_router/service_discovery.py index 7a8f5274a..bbc0f3fe8 100644 --- a/src/vllm_router/service_discovery.py +++ b/src/vllm_router/service_discovery.py @@ -114,6 +114,9 @@ class EndpointInfo: # Endpoint's sleep status sleep: bool + # Endpoint health status (from service discovery health checks) + healthy: bool = True + # Pod name pod_name: Optional[str] = None