Skip to content

Commit 40ee802

Browse files
authored
Health metrics (Part 2) (#2796)
1 parent e47c927 commit 40ee802

File tree

9 files changed

+192
-12
lines changed

9 files changed

+192
-12
lines changed

src/dstack/_internal/server/background/tasks/process_runs.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
group_jobs_by_replica_latest,
2828
)
2929
from dstack._internal.server.services.locking import get_locker
30+
from dstack._internal.server.services.prometheus.client_metrics import run_metrics
3031
from dstack._internal.server.services.runs import (
3132
fmt,
3233
process_terminating_run,
@@ -329,6 +330,24 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
329330
run_model.status.name,
330331
new_status.name,
331332
)
333+
if run_model.status == RunStatus.SUBMITTED and new_status == RunStatus.PROVISIONING:
334+
current_time = common.get_current_datetime()
335+
submit_to_provision_duration = (
336+
current_time - run_model.submitted_at.replace(tzinfo=datetime.timezone.utc)
337+
).total_seconds()
338+
logger.info(
339+
"%s: run took %.2f seconds from submision to provisioning.",
340+
fmt(run_model),
341+
submit_to_provision_duration,
342+
)
343+
project_name = run_model.project.name
344+
run_metrics.log_submit_to_provision_duration(
345+
submit_to_provision_duration, project_name, run_spec.configuration.type
346+
)
347+
348+
if new_status == RunStatus.PENDING:
349+
run_metrics.increment_pending_runs(run_model.project.name, run_spec.configuration.type)
350+
332351
run_model.status = new_status
333352
run_model.termination_reason = termination_reason
334353
# While a run goes to pending without provisioning, resubmission_attempt increases.

src/dstack/_internal/server/routers/prometheus.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import os
22
from typing import Annotated
33

4+
import prometheus_client
45
from fastapi import APIRouter, Depends
56
from fastapi.responses import PlainTextResponse
6-
from prometheus_client import generate_latest
77
from sqlalchemy.ext.asyncio import AsyncSession
88

99
from dstack._internal.server import settings
1010
from dstack._internal.server.db import get_session
1111
from dstack._internal.server.security.permissions import OptionalServiceAccount
12-
from dstack._internal.server.services import prometheus
12+
from dstack._internal.server.services.prometheus import custom_metrics
1313
from dstack._internal.server.utils.routers import error_not_found
1414

1515
_auth = OptionalServiceAccount(os.getenv("DSTACK_PROMETHEUS_AUTH_TOKEN"))
@@ -27,6 +27,6 @@ async def get_prometheus_metrics(
2727
) -> str:
2828
if not settings.ENABLE_PROMETHEUS_METRICS:
2929
raise error_not_found()
30-
custom_metrics = await prometheus.get_metrics(session=session)
31-
prometheus_metrics = generate_latest()
32-
return custom_metrics + prometheus_metrics.decode()
30+
custom_metrics_ = await custom_metrics.get_metrics(session=session)
31+
client_metrics = prometheus_client.generate_latest().decode()
32+
return custom_metrics_ + client_metrics

src/dstack/_internal/server/services/prometheus/__init__.py

Whitespace-only changes.
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from prometheus_client import Counter, Histogram
2+
3+
4+
class RunMetrics:
5+
"""Wrapper class for run-related Prometheus metrics."""
6+
7+
def __init__(self):
8+
self._submit_to_provision_duration = Histogram(
9+
"dstack_submit_to_provision_duration_seconds",
10+
"Time from when a run has been submitted and first job provisioning",
11+
# Buckets optimized for percentile calculation
12+
buckets=[
13+
15,
14+
30,
15+
45,
16+
60,
17+
90,
18+
120,
19+
180,
20+
240,
21+
300,
22+
360,
23+
420,
24+
480,
25+
540,
26+
600,
27+
900,
28+
1200,
29+
1800,
30+
float("inf"),
31+
],
32+
labelnames=["project_name", "run_type"],
33+
)
34+
35+
self._pending_runs_total = Counter(
36+
"dstack_pending_runs_total",
37+
"Number of pending runs",
38+
labelnames=["project_name", "run_type"],
39+
)
40+
41+
def log_submit_to_provision_duration(
42+
self, duration_seconds: float, project_name: str, run_type: str
43+
):
44+
self._submit_to_provision_duration.labels(
45+
project_name=project_name, run_type=run_type
46+
).observe(duration_seconds)
47+
48+
def increment_pending_runs(self, project_name: str, run_type: str):
49+
self._pending_runs_total.labels(project_name=project_name, run_type=run_type).inc()
50+
51+
52+
run_metrics = RunMetrics()

src/dstack/_internal/server/services/prometheus.py renamed to src/dstack/_internal/server/services/prometheus/custom_metrics.py

File renamed without changes.

src/tests/_internal/server/background/tasks/test_process_runs.py

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from unittest.mock import patch
44

55
import pytest
6+
from freezegun import freeze_time
67
from pydantic import parse_obj_as
78
from sqlalchemy.ext.asyncio import AsyncSession
89

@@ -30,6 +31,7 @@
3031
get_job_provisioning_data,
3132
get_run_spec,
3233
)
34+
from dstack._internal.utils import common
3335

3436
pytestmark = pytest.mark.usefixtures("image_config_mock")
3537

@@ -80,10 +82,28 @@ async def make_run(
8082
class TestProcessRuns:
8183
@pytest.mark.asyncio
8284
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
85+
@freeze_time(datetime.datetime(2023, 1, 2, 3, 5, 20, tzinfo=datetime.timezone.utc))
8386
async def test_submitted_to_provisioning(self, test_db, session: AsyncSession):
8487
run = await make_run(session, status=RunStatus.SUBMITTED)
8588
await create_job(session=session, run=run, status=JobStatus.PROVISIONING)
86-
await process_runs.process_runs()
89+
current_time = common.get_current_datetime()
90+
91+
expected_duration = (
92+
current_time - run.submitted_at.replace(tzinfo=datetime.timezone.utc)
93+
).total_seconds()
94+
95+
with patch(
96+
"dstack._internal.server.background.tasks.process_runs.run_metrics"
97+
) as mock_run_metrics:
98+
await process_runs.process_runs()
99+
100+
mock_run_metrics.log_submit_to_provision_duration.assert_called_once()
101+
args = mock_run_metrics.log_submit_to_provision_duration.call_args[0]
102+
assert args[1] == run.project.name
103+
assert args[2] == "service"
104+
# Assert the duration is close to our expected duration (within 0.05 second tolerance)
105+
assert args[0] == expected_duration
106+
87107
await session.refresh(run)
88108
assert run.status == RunStatus.PROVISIONING
89109

@@ -103,7 +123,14 @@ async def test_keep_provisioning(self, test_db, session: AsyncSession):
103123
run = await make_run(session, status=RunStatus.PROVISIONING)
104124
await create_job(session=session, run=run, status=JobStatus.PULLING)
105125

106-
await process_runs.process_runs()
126+
with patch(
127+
"dstack._internal.server.background.tasks.process_runs.run_metrics"
128+
) as mock_run_metrics:
129+
await process_runs.process_runs()
130+
131+
mock_run_metrics.log_submit_to_provision_duration.assert_not_called()
132+
mock_run_metrics.increment_pending_runs.assert_not_called()
133+
107134
await session.refresh(run)
108135
assert run.status == RunStatus.PROVISIONING
109136

@@ -161,9 +188,19 @@ async def test_retry_running_to_pending(self, test_db, session: AsyncSession):
161188
instance=instance,
162189
job_provisioning_data=get_job_provisioning_data(),
163190
)
164-
with patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock:
191+
with (
192+
patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock,
193+
patch(
194+
"dstack._internal.server.background.tasks.process_runs.run_metrics"
195+
) as mock_run_metrics,
196+
):
165197
datetime_mock.return_value = run.submitted_at + datetime.timedelta(minutes=3)
166198
await process_runs.process_runs()
199+
200+
mock_run_metrics.increment_pending_runs.assert_called_once_with(
201+
run.project.name, "service"
202+
)
203+
167204
await session.refresh(run)
168205
assert run.status == RunStatus.PENDING
169206

@@ -205,12 +242,29 @@ async def test_pending_to_submitted(self, test_db, session: AsyncSession):
205242
class TestProcessRunsReplicas:
206243
@pytest.mark.asyncio
207244
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
245+
@freeze_time(datetime.datetime(2023, 1, 2, 3, 5, 20, tzinfo=datetime.timezone.utc))
208246
async def test_submitted_to_provisioning_if_any(self, test_db, session: AsyncSession):
209247
run = await make_run(session, status=RunStatus.SUBMITTED, replicas=2)
210248
await create_job(session=session, run=run, status=JobStatus.SUBMITTED, replica_num=0)
211249
await create_job(session=session, run=run, status=JobStatus.PROVISIONING, replica_num=1)
250+
current_time = common.get_current_datetime()
251+
252+
expected_duration = (
253+
current_time - run.submitted_at.replace(tzinfo=datetime.timezone.utc)
254+
).total_seconds()
255+
256+
with patch(
257+
"dstack._internal.server.background.tasks.process_runs.run_metrics"
258+
) as mock_run_metrics:
259+
await process_runs.process_runs()
260+
261+
mock_run_metrics.log_submit_to_provision_duration.assert_called_once()
262+
args = mock_run_metrics.log_submit_to_provision_duration.call_args[0]
263+
assert args[1] == run.project.name
264+
assert args[2] == "service"
265+
assert isinstance(args[0], float)
266+
assert args[0] == expected_duration
212267

213-
await process_runs.process_runs()
214268
await session.refresh(run)
215269
assert run.status == RunStatus.PROVISIONING
216270

@@ -251,9 +305,19 @@ async def test_all_no_capacity_to_pending(self, test_db, session: AsyncSession):
251305
instance=await create_instance(session, project=run.project, spot=True),
252306
job_provisioning_data=get_job_provisioning_data(),
253307
)
254-
with patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock:
308+
with (
309+
patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock,
310+
patch(
311+
"dstack._internal.server.background.tasks.process_runs.run_metrics"
312+
) as mock_run_metrics,
313+
):
255314
datetime_mock.return_value = run.submitted_at + datetime.timedelta(minutes=3)
256315
await process_runs.process_runs()
316+
317+
mock_run_metrics.increment_pending_runs.assert_called_once_with(
318+
run.project.name, "service"
319+
)
320+
257321
await session.refresh(run)
258322
assert run.status == RunStatus.PENDING
259323

src/tests/_internal/server/routers/test_prometheus.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def enable_metrics(monkeypatch: pytest.MonkeyPatch):
100100
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
101101
@pytest.mark.usefixtures("image_config_mock", "test_db", "enable_metrics")
102102
class TestGetPrometheusMetrics:
103-
@patch("dstack._internal.server.routers.prometheus.generate_latest", lambda: BASE_HTTP_METRICS)
103+
@patch("prometheus_client.generate_latest", lambda: BASE_HTTP_METRICS)
104104
async def test_returns_metrics(self, session: AsyncSession, client: AsyncClient):
105105
user = await create_user(session=session, name="test-user", global_role=GlobalRole.USER)
106106
offer = get_instance_offer_with_availability(
@@ -335,7 +335,7 @@ async def test_returns_metrics(self, session: AsyncSession, client: AsyncClient)
335335
)
336336
assert response.text.strip() == expected
337337

338-
@patch("dstack._internal.server.routers.prometheus.generate_latest", lambda: BASE_HTTP_METRICS)
338+
@patch("prometheus_client.generate_latest", lambda: BASE_HTTP_METRICS)
339339
async def test_returns_empty_response_if_no_runs(self, client: AsyncClient):
340340
response = await client.get("/metrics")
341341
assert response.status_code == 200

src/tests/_internal/server/services/prometheus/__init__.py

Whitespace-only changes.
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from unittest.mock import MagicMock
2+
3+
from dstack._internal.server.services.prometheus.client_metrics import run_metrics
4+
5+
6+
class TestRunMetrics:
7+
def test_log_submit_to_provision_duration(self, monkeypatch):
8+
mock_histogram = MagicMock()
9+
mock_labels = MagicMock()
10+
mock_histogram.labels.return_value = mock_labels
11+
monkeypatch.setattr(run_metrics, "_submit_to_provision_duration", mock_histogram)
12+
13+
duration = 120.5
14+
project_name = "test-project"
15+
run_type = "dev"
16+
17+
run_metrics.log_submit_to_provision_duration(duration, project_name, run_type)
18+
19+
mock_histogram.labels.assert_called_once_with(project_name=project_name, run_type=run_type)
20+
mock_labels.observe.assert_called_once_with(duration)
21+
22+
def test_increment_pending_runs(self, monkeypatch):
23+
mock_counter = MagicMock()
24+
mock_labels = MagicMock()
25+
mock_counter.labels.return_value = mock_labels
26+
27+
monkeypatch.setattr(run_metrics, "_pending_runs_total", mock_counter)
28+
29+
project_name = "test-project"
30+
run_type = "train"
31+
32+
run_metrics.increment_pending_runs(project_name, run_type)
33+
mock_counter.labels.assert_called_once_with(project_name=project_name, run_type=run_type)
34+
mock_labels.inc.assert_called_once()
35+
36+
def test_multiple_calls_to_log_submit_to_provision_duration(self):
37+
run_metrics.log_submit_to_provision_duration(60.0, "project1", "dev")
38+
run_metrics.log_submit_to_provision_duration(120.0, "project1", "prod")
39+
run_metrics.log_submit_to_provision_duration(30.0, "project2", "dev")
40+
41+
def test_multiple_calls_to_increment_pending_runs(self):
42+
run_metrics.increment_pending_runs("project1", "dev")
43+
run_metrics.increment_pending_runs("project1", "prod")
44+
run_metrics.increment_pending_runs("project2", "dev")
45+
run_metrics.increment_pending_runs("project1", "dev")

0 commit comments

Comments
 (0)