33from unittest .mock import patch
44
55import pytest
6+ from freezegun import freeze_time
67from pydantic import parse_obj_as
78from sqlalchemy .ext .asyncio import AsyncSession
89
3031 get_job_provisioning_data ,
3132 get_run_spec ,
3233)
34+ from dstack ._internal .utils import common
3335
3436pytestmark = pytest .mark .usefixtures ("image_config_mock" )
3537
@@ -80,10 +82,28 @@ async def make_run(
8082class TestProcessRuns :
8183 @pytest .mark .asyncio
8284 @pytest .mark .parametrize ("test_db" , ["sqlite" , "postgres" ], indirect = True )
85+ @freeze_time (datetime .datetime (2023 , 1 , 2 , 3 , 5 , 20 , tzinfo = datetime .timezone .utc ))
8386 async def test_submitted_to_provisioning (self , test_db , session : AsyncSession ):
8487 run = await make_run (session , status = RunStatus .SUBMITTED )
8588 await create_job (session = session , run = run , status = JobStatus .PROVISIONING )
86- await process_runs .process_runs ()
89+ current_time = common .get_current_datetime ()
90+
91+ expected_duration = (
92+ current_time - run .submitted_at .replace (tzinfo = datetime .timezone .utc )
93+ ).total_seconds ()
94+
95+ with patch (
96+ "dstack._internal.server.background.tasks.process_runs.run_metrics"
97+ ) as mock_run_metrics :
98+ await process_runs .process_runs ()
99+
100+ mock_run_metrics .log_submit_to_provision_duration .assert_called_once ()
101+ args = mock_run_metrics .log_submit_to_provision_duration .call_args [0 ]
102+ assert args [1 ] == run .project .name
103+ assert args [2 ] == "service"
104+ # Assert the duration is close to our expected duration (within 0.05 second tolerance)
105+ assert args [0 ] == expected_duration
106+
87107 await session .refresh (run )
88108 assert run .status == RunStatus .PROVISIONING
89109
@@ -103,7 +123,14 @@ async def test_keep_provisioning(self, test_db, session: AsyncSession):
103123 run = await make_run (session , status = RunStatus .PROVISIONING )
104124 await create_job (session = session , run = run , status = JobStatus .PULLING )
105125
106- await process_runs .process_runs ()
126+ with patch (
127+ "dstack._internal.server.background.tasks.process_runs.run_metrics"
128+ ) as mock_run_metrics :
129+ await process_runs .process_runs ()
130+
131+ mock_run_metrics .log_submit_to_provision_duration .assert_not_called ()
132+ mock_run_metrics .increment_pending_runs .assert_not_called ()
133+
107134 await session .refresh (run )
108135 assert run .status == RunStatus .PROVISIONING
109136
@@ -161,9 +188,19 @@ async def test_retry_running_to_pending(self, test_db, session: AsyncSession):
161188 instance = instance ,
162189 job_provisioning_data = get_job_provisioning_data (),
163190 )
164- with patch ("dstack._internal.utils.common.get_current_datetime" ) as datetime_mock :
191+ with (
192+ patch ("dstack._internal.utils.common.get_current_datetime" ) as datetime_mock ,
193+ patch (
194+ "dstack._internal.server.background.tasks.process_runs.run_metrics"
195+ ) as mock_run_metrics ,
196+ ):
165197 datetime_mock .return_value = run .submitted_at + datetime .timedelta (minutes = 3 )
166198 await process_runs .process_runs ()
199+
200+ mock_run_metrics .increment_pending_runs .assert_called_once_with (
201+ run .project .name , "service"
202+ )
203+
167204 await session .refresh (run )
168205 assert run .status == RunStatus .PENDING
169206
@@ -205,12 +242,29 @@ async def test_pending_to_submitted(self, test_db, session: AsyncSession):
205242class TestProcessRunsReplicas :
206243 @pytest .mark .asyncio
207244 @pytest .mark .parametrize ("test_db" , ["sqlite" , "postgres" ], indirect = True )
245+ @freeze_time (datetime .datetime (2023 , 1 , 2 , 3 , 5 , 20 , tzinfo = datetime .timezone .utc ))
208246 async def test_submitted_to_provisioning_if_any (self , test_db , session : AsyncSession ):
209247 run = await make_run (session , status = RunStatus .SUBMITTED , replicas = 2 )
210248 await create_job (session = session , run = run , status = JobStatus .SUBMITTED , replica_num = 0 )
211249 await create_job (session = session , run = run , status = JobStatus .PROVISIONING , replica_num = 1 )
250+ current_time = common .get_current_datetime ()
251+
252+ expected_duration = (
253+ current_time - run .submitted_at .replace (tzinfo = datetime .timezone .utc )
254+ ).total_seconds ()
255+
256+ with patch (
257+ "dstack._internal.server.background.tasks.process_runs.run_metrics"
258+ ) as mock_run_metrics :
259+ await process_runs .process_runs ()
260+
261+ mock_run_metrics .log_submit_to_provision_duration .assert_called_once ()
262+ args = mock_run_metrics .log_submit_to_provision_duration .call_args [0 ]
263+ assert args [1 ] == run .project .name
264+ assert args [2 ] == "service"
265+ assert isinstance (args [0 ], float )
266+ assert args [0 ] == expected_duration
212267
213- await process_runs .process_runs ()
214268 await session .refresh (run )
215269 assert run .status == RunStatus .PROVISIONING
216270
@@ -251,9 +305,19 @@ async def test_all_no_capacity_to_pending(self, test_db, session: AsyncSession):
251305 instance = await create_instance (session , project = run .project , spot = True ),
252306 job_provisioning_data = get_job_provisioning_data (),
253307 )
254- with patch ("dstack._internal.utils.common.get_current_datetime" ) as datetime_mock :
308+ with (
309+ patch ("dstack._internal.utils.common.get_current_datetime" ) as datetime_mock ,
310+ patch (
311+ "dstack._internal.server.background.tasks.process_runs.run_metrics"
312+ ) as mock_run_metrics ,
313+ ):
255314 datetime_mock .return_value = run .submitted_at + datetime .timedelta (minutes = 3 )
256315 await process_runs .process_runs ()
316+
317+ mock_run_metrics .increment_pending_runs .assert_called_once_with (
318+ run .project .name , "service"
319+ )
320+
257321 await session .refresh (run )
258322 assert run .status == RunStatus .PENDING
259323
0 commit comments