Skip to content

Commit 9b0a11e

Browse files
committed
Fix missing jrd update
1 parent f975c2d commit 9b0a11e

2 files changed

Lines changed: 91 additions & 4 deletions

File tree

src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ async def _process_running_job_pulling_state(
470470
context.job_submission.age,
471471
)
472472
shim_state = await common_utils.run_async(
473-
_get_shim_pulling_state,
473+
_sync_shim_pulling_state,
474474
server_ssh_private_keys,
475475
job_provisioning_data,
476476
None,
@@ -481,11 +481,12 @@ async def _process_running_job_pulling_state(
481481
return
482482

483483
if shim_state == _ShimPullingState.READY:
484+
job_runtime_data = get_job_runtime_data(context.job_model)
484485
runner_availability = await common_utils.run_async(
485486
_get_runner_availability,
486487
server_ssh_private_keys,
487488
job_provisioning_data,
488-
None,
489+
job_runtime_data,
489490
)
490491
if runner_availability == _RunnerAvailability.UNAVAILABLE:
491492
_reset_disconnected_at(session, context.job_model)
@@ -507,7 +508,7 @@ async def _process_running_job_pulling_state(
507508
_submit_job_to_runner,
508509
server_ssh_private_keys,
509510
job_provisioning_data,
510-
None,
511+
job_runtime_data,
511512
session=session,
512513
run=context.run,
513514
job_model=context.job_model,
@@ -856,7 +857,7 @@ def _get_runner_availability(ports: Dict[int, int]) -> _RunnerAvailability:
856857

857858

858859
@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT])
859-
def _get_shim_pulling_state(
860+
def _sync_shim_pulling_state(
860861
ports: Dict[int, int],
861862
job_model: JobModel,
862863
) -> Union[Literal[False], _ShimPullingState]:

src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from dstack._internal.server import settings as server_settings
4040
from dstack._internal.server.background.scheduled_tasks.running_jobs import (
4141
_patch_base_image_for_aws_efa,
42+
_RunnerAvailability,
4243
process_running_jobs,
4344
)
4445
from dstack._internal.server.models import JobModel
@@ -579,6 +580,91 @@ async def test_pulling_shim_runner_not_ready(
579580
assert job is not None
580581
assert job.status == JobStatus.PULLING
581582

583+
@pytest.mark.asyncio
584+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
585+
async def test_pulling_shim_uses_runtime_port_mapping_for_runner_calls(
586+
self,
587+
test_db,
588+
session: AsyncSession,
589+
ssh_tunnel_mock: Mock,
590+
shim_client_mock: Mock,
591+
):
592+
project = await create_project(session=session)
593+
user = await create_user(session=session)
594+
repo = await create_repo(session=session, project_id=project.id)
595+
run = await create_run(
596+
session=session,
597+
project=project,
598+
repo=repo,
599+
user=user,
600+
)
601+
instance = await create_instance(
602+
session=session,
603+
project=project,
604+
status=InstanceStatus.BUSY,
605+
)
606+
job = await create_job(
607+
session=session,
608+
run=run,
609+
status=JobStatus.PULLING,
610+
job_provisioning_data=get_job_provisioning_data(dockerized=True),
611+
job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None),
612+
instance=instance,
613+
instance_assigned=True,
614+
)
615+
shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING
616+
shim_client_mock.get_task.return_value.ports = [
617+
PortMapping(container=10022, host=32771),
618+
PortMapping(container=10999, host=32772),
619+
]
620+
621+
expected_ports = {
622+
10022: 32771,
623+
10999: 32772,
624+
}
625+
626+
def assert_runner_availability(_, __, job_runtime_data):
627+
assert job_runtime_data is not None
628+
assert job_runtime_data.ports == expected_ports
629+
return _RunnerAvailability.AVAILABLE
630+
631+
def assert_submit_job_to_runner(_, __, job_runtime_data, **kwargs):
632+
assert job_runtime_data is not None
633+
assert job_runtime_data.ports == expected_ports
634+
return True
635+
636+
with (
637+
patch(
638+
"dstack._internal.server.background.scheduled_tasks.running_jobs._get_runner_availability",
639+
side_effect=assert_runner_availability,
640+
) as get_runner_availability_mock,
641+
patch(
642+
"dstack._internal.server.background.scheduled_tasks.running_jobs._submit_job_to_runner",
643+
side_effect=assert_submit_job_to_runner,
644+
) as submit_job_to_runner_mock,
645+
patch(
646+
"dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_file_archives",
647+
new_callable=AsyncMock,
648+
return_value=[],
649+
),
650+
patch(
651+
"dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_code",
652+
new_callable=AsyncMock,
653+
return_value=b"",
654+
),
655+
):
656+
await process_running_jobs()
657+
658+
ssh_tunnel_mock.assert_called_once()
659+
get_runner_availability_mock.assert_called_once()
660+
submit_job_to_runner_mock.assert_called_once()
661+
662+
await session.refresh(job)
663+
assert job is not None
664+
assert job.status == JobStatus.PULLING
665+
jrd = JobRuntimeData.__response__.parse_raw(job.job_runtime_data)
666+
assert jrd.ports == expected_ports
667+
582668
@pytest.mark.asyncio
583669
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
584670
async def test_pulling_shim_failed(self, test_db, session: AsyncSession):

0 commit comments

Comments
 (0)