Skip to content

Commit 1ab8fde

Browse files
authored
Add optional bearer auth to metrics endpoint (#2460)
1 parent ebe738a commit 1ab8fde

File tree

3 files changed

+45
-1
lines changed

3 files changed

+45
-1
lines changed

src/dstack/_internal/server/routers/prometheus.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from typing import Annotated
23

34
from fastapi import APIRouter, Depends
@@ -6,12 +7,16 @@
67

78
from dstack._internal.server import settings
89
from dstack._internal.server.db import get_session
10+
from dstack._internal.server.security.permissions import OptionalServiceAccount
911
from dstack._internal.server.services import prometheus
1012
from dstack._internal.server.utils.routers import error_not_found
1113

14+
_auth = OptionalServiceAccount(os.getenv("DSTACK_PROMETHEUS_AUTH_TOKEN"))
15+
1216
router = APIRouter(
1317
tags=["prometheus"],
1418
default_response_class=PlainTextResponse,
19+
dependencies=[Depends(_auth)],
1520
)
1621

1722

src/dstack/_internal/server/security/permissions.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Tuple
1+
from typing import Annotated, Optional, Tuple
22

33
from fastapi import Depends, HTTPException, Security
44
from fastapi.security import HTTPBearer
@@ -99,6 +99,24 @@ async def __call__(
9999
return await get_project_member(session, project_name, token.credentials)
100100

101101

102+
class OptionalServiceAccount:
103+
def __init__(self, token: Optional[str]) -> None:
104+
self._token = token
105+
106+
async def __call__(
107+
self,
108+
token: Annotated[
109+
Optional[HTTPAuthorizationCredentials], Security(HTTPBearer(auto_error=False))
110+
],
111+
) -> None:
112+
if self._token is None:
113+
return
114+
if token is None:
115+
raise error_forbidden()
116+
if token.credentials != self._token:
117+
raise error_invalid_token()
118+
119+
102120
async def get_project_member(
103121
session: AsyncSession, project_name: str, token: str
104122
) -> Tuple[UserModel, ProjectModel]:

src/tests/_internal/server/routers/test_prometheus.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
create_repo,
2929
create_run,
3030
create_user,
31+
get_auth_headers,
3132
get_instance_offer_with_availability,
3233
get_job_provisioning_data,
3334
get_job_runtime_data,
@@ -38,6 +39,7 @@
3839
@pytest.fixture
3940
def enable_metrics(monkeypatch: pytest.MonkeyPatch):
4041
monkeypatch.setattr("dstack._internal.server.settings.ENABLE_PROMETHEUS_METRICS", True)
42+
monkeypatch.setattr("dstack._internal.server.routers.prometheus._auth._token", None)
4143

4244

4345
FAKE_NOW = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc)
@@ -289,6 +291,25 @@ async def test_returns_404_if_not_enabled(
289291
response = await client.get("/metrics")
290292
assert response.status_code == 404
291293

294+
@pytest.mark.parametrize("token", [None, "foo"])
295+
async def test_returns_403_if_not_authenticated(
296+
self, monkeypatch: pytest.MonkeyPatch, client: AsyncClient, token: Optional[str]
297+
):
298+
monkeypatch.setattr("dstack._internal.server.routers.prometheus._auth._token", "secret")
299+
if token is not None:
300+
headers = get_auth_headers(token)
301+
else:
302+
headers = None
303+
response = await client.get("/metrics", headers=headers)
304+
assert response.status_code == 403
305+
306+
async def test_returns_200_if_token_is_valid(
307+
self, monkeypatch: pytest.MonkeyPatch, client: AsyncClient
308+
):
309+
monkeypatch.setattr("dstack._internal.server.routers.prometheus._auth._token", "secret")
310+
response = await client.get("/metrics", headers=get_auth_headers("secret"))
311+
assert response.status_code == 200
312+
292313

293314
async def _create_project(session: AsyncSession, name: str, user: UserModel) -> ProjectModel:
294315
project = await create_project(session=session, owner=user, name=name)

0 commit comments

Comments
 (0)