-
-
Notifications
You must be signed in to change notification settings - Fork 234
Expand file tree
/
Copy pathtest_process_metrics.py
More file actions
188 lines (180 loc) · 6.77 KB
/
Copy pathtest_process_metrics.py
File metadata and controls
188 lines (180 loc) · 6.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from datetime import datetime, timezone
from unittest.mock import patch
import pytest
from freezegun import freeze_time
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from dstack._internal.core.models.instances import InstanceStatus
from dstack._internal.core.models.runs import JobStatus
from dstack._internal.core.models.users import GlobalRole, ProjectRole
from dstack._internal.server import settings
from dstack._internal.server.background.tasks.process_metrics import (
collect_metrics,
delete_metrics,
)
from dstack._internal.server.models import JobMetricsPoint
from dstack._internal.server.schemas.runner import GPUMetrics, MetricsResponse
from dstack._internal.server.services.projects import add_project_member
from dstack._internal.server.testing.common import (
create_instance,
create_job,
create_job_metrics_point,
create_project,
create_repo,
create_run,
create_user,
get_job_provisioning_data,
)
pytestmark = pytest.mark.usefixtures("image_config_mock")
class TestCollectMetrics:
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_collects_metrics(self, test_db, session: AsyncSession):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
repo = await create_repo(
session=session,
project_id=project.id,
)
instance = await create_instance(
session=session,
project=project,
status=InstanceStatus.BUSY,
)
run = await create_run(
session=session,
project=project,
repo=repo,
user=user,
)
job = await create_job(
session=session,
run=run,
status=JobStatus.RUNNING,
job_provisioning_data=get_job_provisioning_data(),
instance_assigned=True,
instance=instance,
)
with (
patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock,
patch(
"dstack._internal.server.services.runner.client.RunnerClient"
) as RunnerClientMock,
):
runner_client_mock = RunnerClientMock.return_value
runner_client_mock.get_metrics.return_value = MetricsResponse(
timestamp_micro=1,
cpu_usage_micro=2,
memory_usage_bytes=3,
memory_working_set_bytes=4,
gpus=[
GPUMetrics(
gpu_memory_usage_bytes=0,
gpu_util_percent=0,
)
],
)
await collect_metrics()
SSHTunnelMock.assert_called_once()
runner_client_mock.get_metrics.assert_called_once()
res = await session.execute(select(JobMetricsPoint))
metrics_point = res.scalar_one()
assert metrics_point.job_id == job.id
class TestDeleteMetrics:
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@freeze_time(datetime(2023, 1, 2, 3, 5, 20, tzinfo=timezone.utc))
async def test_deletes_old_metrics_running_job(self, test_db, session: AsyncSession):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
repo = await create_repo(
session=session,
project_id=project.id,
)
run = await create_run(
session=session,
project=project,
repo=repo,
user=user,
)
job = await create_job(
session=session,
run=run,
status=JobStatus.RUNNING,
)
await create_job_metrics_point(
session=session,
job_model=job,
timestamp=datetime(2023, 1, 2, 3, 4, 10, tzinfo=timezone.utc),
)
await create_job_metrics_point(
session=session,
job_model=job,
timestamp=datetime(2023, 1, 2, 3, 4, 20, tzinfo=timezone.utc),
)
last_metric = await create_job_metrics_point(
session=session,
job_model=job,
timestamp=datetime(2023, 1, 2, 3, 5, 10, tzinfo=timezone.utc),
)
with patch.multiple(
settings, SERVER_METRICS_RUNNING_TTL_SECONDS=15, SERVER_METRICS_FINISHED_TTL_SECONDS=0
):
await delete_metrics()
res = await session.execute(select(JobMetricsPoint))
points = res.scalars().all()
assert len(points) == 1
assert points[0].id == last_metric.id
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@freeze_time(datetime(2023, 1, 2, 3, 5, 20, tzinfo=timezone.utc))
async def test_deletes_old_metrics_finished_job(self, test_db, session: AsyncSession):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
repo = await create_repo(
session=session,
project_id=project.id,
)
run = await create_run(
session=session,
project=project,
repo=repo,
user=user,
)
job = await create_job(
session=session,
run=run,
status=JobStatus.FAILED,
)
await create_job_metrics_point(
session=session,
job_model=job,
timestamp=datetime(2023, 1, 2, 3, 4, 10, tzinfo=timezone.utc),
)
await create_job_metrics_point(
session=session,
job_model=job,
timestamp=datetime(2023, 1, 2, 3, 4, 20, tzinfo=timezone.utc),
)
last_metric = await create_job_metrics_point(
session=session,
job_model=job,
timestamp=datetime(2023, 1, 2, 3, 5, 10, tzinfo=timezone.utc),
)
with patch.multiple(
settings, SERVER_METRICS_RUNNING_TTL_SECONDS=0, SERVER_METRICS_FINISHED_TTL_SECONDS=15
):
await delete_metrics()
res = await session.execute(select(JobMetricsPoint))
points = res.scalars().all()
assert len(points) == 1
assert points[0].id == last_metric.id