Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion helm/templates/deployment-router.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ spec:
name: {{ .Values.routerSpec.vllmApiKey.secretName }}
key: {{ .Values.routerSpec.vllmApiKey.secretKey }}
{{- else if and .Values.servingEngineSpec.enableEngine }}
{{- if kindIs "string" .Values.servingEngineSpec.vllmApiKey }}
{{- if or (kindIs "string" .Values.servingEngineSpec.vllmApiKey) (kindIs "slice" .Values.servingEngineSpec.vllmApiKey) }}
- name: VLLM_API_KEY
valueFrom:
secretKeyRef:
Expand Down
2 changes: 1 addition & 1 deletion helm/templates/deployment-vllm-multi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ spec:
{{- end }}
{{- $vllmApiKey := $.Values.servingEngineSpec.vllmApiKey }}
{{- if $vllmApiKey }}
{{- if kindIs "string" $vllmApiKey }}
{{- if or (kindIs "string" $vllmApiKey) (kindIs "slice" $vllmApiKey) }}
- name: VLLM_API_KEY
valueFrom:
secretKeyRef:
Expand Down
6 changes: 5 additions & 1 deletion helm/templates/secrets.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@ metadata:
type: Opaque
data:
{{- $vllmApiKey := $.Values.servingEngineSpec.vllmApiKey }}
{{- if and $vllmApiKey (kindIs "string" $vllmApiKey) }}
{{- if $vllmApiKey }}
{{- if kindIs "slice" $vllmApiKey }}
vllmApiKey: {{ join "," $vllmApiKey | b64enc | quote }}
{{- else if kindIs "string" $vllmApiKey }}
vllmApiKey: {{ $vllmApiKey | b64enc | quote }}
{{- end }}
{{- end }}

{{- range $modelSpec := .Values.servingEngineSpec.modelSpec }}
{{- with $ -}}
Expand Down
13 changes: 11 additions & 2 deletions helm/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,17 @@ servingEngineSpec:
# -- Extra service ports for models
extraPorts: []

# -- API key for securing the vLLM models. Can be either a string that will be stored in a generated secret or an object referencing an existing secret.
vllmApiKey: # @schema type:[string, object]
# -- API key for securing the vLLM models. Accepts a single string, a
# -- comma-separated string, a list of strings (all stored in a generated
# -- secret joined as comma-separated), or an object referencing an existing
# -- Kubernetes secret. Multiple keys enable multi-tenant access so different
# -- teams can authenticate with independent tokens.
# -- Examples:
# -- vllmApiKey: "single-key"
# -- vllmApiKey: "key1,key2,key3"
# -- vllmApiKey: ["key1", "key2", "key3"]
# -- vllmApiKey: {secretName: my-secret, secretKey: api-key}
vllmApiKey: # @schema type:[string, array, object]
# -- Name of the existing Kubernetes secret that contains the vLLM API key
secretName: ""
# -- Key within the secret that contains the vLLM API key
Expand Down
130 changes: 130 additions & 0 deletions src/tests/test_multi_api_key_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from unittest.mock import MagicMock

import pytest
from fastapi import HTTPException

from vllm_router.auth import _parse_api_keys, get_allowed_api_keys, verify_api_key

# ---------------------------------------------------------------------------
# _parse_api_keys
# ---------------------------------------------------------------------------


def test_parse_single_key():
assert _parse_api_keys("key1") == {"key1"}


def test_parse_comma_separated_keys():
assert _parse_api_keys("key1,key2,key3") == {"key1", "key2", "key3"}


def test_parse_strips_whitespace():
assert _parse_api_keys("key1, key2 , key3") == {"key1", "key2", "key3"}


def test_parse_ignores_empty_segments():
assert _parse_api_keys("key1,,key2") == {"key1", "key2"}
assert _parse_api_keys(",key1,") == {"key1"}


def test_parse_empty_string_returns_empty_set():
assert _parse_api_keys("") == frozenset()


def test_parse_whitespace_only_returns_empty_set():
assert _parse_api_keys(" , ") == frozenset()


# ---------------------------------------------------------------------------
# get_allowed_api_keys — reads from environment
# ---------------------------------------------------------------------------


def test_get_allowed_api_keys_no_env_var(monkeypatch):
monkeypatch.delenv("VLLM_API_KEY", raising=False)
assert get_allowed_api_keys() == frozenset()


def test_get_allowed_api_keys_single(monkeypatch):
monkeypatch.setenv("VLLM_API_KEY", "secret")
assert get_allowed_api_keys() == {"secret"}


def test_get_allowed_api_keys_multiple(monkeypatch):
monkeypatch.setenv("VLLM_API_KEY", "key1,key2,key3")
assert get_allowed_api_keys() == {"key1", "key2", "key3"}


def test_get_allowed_api_keys_trims_spaces(monkeypatch):
monkeypatch.setenv("VLLM_API_KEY", " key1 , key2 ")
assert get_allowed_api_keys() == {"key1", "key2"}


# ---------------------------------------------------------------------------
# verify_api_key dependency
# ---------------------------------------------------------------------------


def _make_request(auth_header: str | None = None) -> MagicMock:
request = MagicMock()
headers = {}
if auth_header is not None:
headers["Authorization"] = auth_header
request.headers = headers
return request


@pytest.mark.anyio
async def test_verify_no_keys_configured_allows_all(monkeypatch):
monkeypatch.delenv("VLLM_API_KEY", raising=False)
request = _make_request()
await verify_api_key(request) # must not raise


@pytest.mark.anyio
async def test_verify_valid_single_key(monkeypatch):
monkeypatch.setenv("VLLM_API_KEY", "secret")
request = _make_request("Bearer secret")
await verify_api_key(request) # must not raise


@pytest.mark.anyio
async def test_verify_valid_key_among_multiple(monkeypatch):
monkeypatch.setenv("VLLM_API_KEY", "key1,key2,key3")
for key in ("key1", "key2", "key3"):
request = _make_request(f"Bearer {key}")
await verify_api_key(request) # must not raise


@pytest.mark.anyio
async def test_verify_invalid_key_raises_401(monkeypatch):
monkeypatch.setenv("VLLM_API_KEY", "key1,key2")
request = _make_request("Bearer wrong-key")
with pytest.raises(HTTPException) as exc_info:
await verify_api_key(request)
assert exc_info.value.status_code == 401


@pytest.mark.anyio
async def test_verify_missing_auth_header_raises_401(monkeypatch):
monkeypatch.setenv("VLLM_API_KEY", "secret")
request = _make_request() # no Authorization header
with pytest.raises(HTTPException) as exc_info:
await verify_api_key(request)
assert exc_info.value.status_code == 401


@pytest.mark.anyio
async def test_verify_non_bearer_scheme_raises_401(monkeypatch):
monkeypatch.setenv("VLLM_API_KEY", "secret")
request = _make_request("Basic secret")
with pytest.raises(HTTPException) as exc_info:
await verify_api_key(request)
assert exc_info.value.status_code == 401


@pytest.mark.anyio
async def test_verify_extra_whitespace_in_env_key(monkeypatch):
monkeypatch.setenv("VLLM_API_KEY", " key1 , key2 ")
request = _make_request("Bearer key1")
await verify_api_key(request) # must not raise
126 changes: 126 additions & 0 deletions src/tests/test_stale_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import pytest
from prometheus_client import REGISTRY, generate_latest

from vllm_router.routers.metrics_router import _LABEL_GAUGES, _clear_label_gauges
from vllm_router.service_discovery import EndpointInfo
from vllm_router.services.metrics_service import (
current_qps,
gpu_prefix_cache_hit_rate,
gpu_prefix_cache_hits_total,
gpu_prefix_cache_queries_total,
healthy_pods_total,
num_requests_running,
)

# ---------------------------------------------------------------------------
# EndpointInfo.healthy field
# ---------------------------------------------------------------------------


def test_endpoint_info_healthy_defaults_to_true():
ep = EndpointInfo(
url="http://ep1:8000",
model_names=["llama"],
Id="id1",
added_timestamp=0,
model_label="default",
sleep=False,
)
assert ep.healthy is True


def test_endpoint_info_healthy_can_be_set_false():
ep = EndpointInfo(
url="http://ep1:8000",
model_names=["llama"],
Id="id1",
added_timestamp=0,
model_label="default",
sleep=False,
healthy=False,
)
assert ep.healthy is False


# ---------------------------------------------------------------------------
# _clear_label_gauges behaviour
# ---------------------------------------------------------------------------


@pytest.fixture(autouse=True)
def _reset_gauges():
"""Clear every label-based gauge before and after each test."""
_clear_label_gauges()
yield
_clear_label_gauges()


def test_cleared_gauge_removes_stale_labels():
"""Stale server labels must disappear after _clear_label_gauges()."""
# Simulate two active endpoints
healthy_pods_total.labels(server="http://ep1:8000").set(1)
healthy_pods_total.labels(server="http://ep2:8000").set(1)

output = generate_latest(REGISTRY).decode()
assert "ep1" in output
assert "ep2" in output

# Clear all labels (simulates start of /metrics handler)
_clear_label_gauges()

# Re-populate only ep1 (ep2 was removed from service discovery)
healthy_pods_total.labels(server="http://ep1:8000").set(1)

output = generate_latest(REGISTRY).decode()
assert "ep1" in output
assert "ep2" not in output


def test_clear_removes_all_labels_when_all_endpoints_gone():
"""When every endpoint is removed, no server labels remain."""
current_qps.labels(server="http://ep1:8000").set(42)
num_requests_running.labels(server="http://ep1:8000").set(3)

_clear_label_gauges()

output = generate_latest(REGISTRY).decode()
assert "ep1" not in output


def test_clear_does_not_affect_unlabeled_gauges():
"""System gauges (CPU, memory, disk) have no labels and are unaffected."""
from vllm_router.routers.metrics_router import router_cpu_usage_percent

router_cpu_usage_percent.set(42.0)
healthy_pods_total.labels(server="http://ep1:8000").set(1)

_clear_label_gauges()

output = generate_latest(REGISTRY).decode()
assert "router_cpu_usage_percent" in output
assert "42.0" in output


def test_label_gauges_list_contains_all_expected_gauges():
"""Ensure every gauge we export with a server label is in _LABEL_GAUGES."""
expected = {
current_qps,
gpu_prefix_cache_hit_rate,
gpu_prefix_cache_hits_total,
gpu_prefix_cache_queries_total,
healthy_pods_total,
num_requests_running,
}
assert expected.issubset(set(_LABEL_GAUGES))


def test_repopulate_after_clear_shows_correct_values():
"""Values set after clear must reflect in Prometheus output."""
_clear_label_gauges()

healthy_pods_total.labels(server="http://new-ep:8000").set(1)
current_qps.labels(server="http://new-ep:8000").set(99.5)

output = generate_latest(REGISTRY).decode()
assert "new-ep" in output
assert "99.5" in output
13 changes: 5 additions & 8 deletions src/vllm_router/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

import sentry_sdk
import uvicorn
from fastapi import FastAPI
from fastapi import Depends, FastAPI

from vllm_router.aiohttp_client import AiohttpClientWrapper
from vllm_router.auth import verify_api_key
from vllm_router.dynamic_config import (
DynamicRouterConfig,
get_dynamic_config_watcher,
Expand All @@ -45,9 +46,7 @@
from vllm_router.services.batch_service import initialize_batch_processor
from vllm_router.services.callbacks_service.callbacks import configure_custom_callbacks
from vllm_router.services.files_service import initialize_storage
from vllm_router.services.request_service.rewriter import (
get_request_rewriter,
)
from vllm_router.services.request_service.rewriter import get_request_rewriter
from vllm_router.stats.engine_stats import (
get_engine_stats_scraper,
initialize_engine_stats_scraper,
Expand Down Expand Up @@ -80,9 +79,7 @@
initialize_semantic_cache,
is_semantic_cache_enabled,
)
from vllm_router.experimental.semantic_cache_integration import (
semantic_cache_size,
)
from vllm_router.experimental.semantic_cache_integration import semantic_cache_size

semantic_cache_available = True
except ImportError:
Expand Down Expand Up @@ -365,7 +362,7 @@ def initialize_all(app: FastAPI, args):


app = FastAPI(lifespan=lifespan)
app.include_router(main_router)
app.include_router(main_router, dependencies=[Depends(verify_api_key)])
app.include_router(files_router)
app.include_router(batches_router)
app.include_router(metrics_router)
Expand Down
Loading
Loading