|
39 | 39 | from dstack._internal.server import settings as server_settings |
40 | 40 | from dstack._internal.server.background.scheduled_tasks.running_jobs import ( |
41 | 41 | _patch_base_image_for_aws_efa, |
| 42 | + _RunnerAvailability, |
42 | 43 | process_running_jobs, |
43 | 44 | ) |
44 | 45 | from dstack._internal.server.models import JobModel |
@@ -579,6 +580,91 @@ async def test_pulling_shim_runner_not_ready( |
579 | 580 | assert job is not None |
580 | 581 | assert job.status == JobStatus.PULLING |
581 | 582 |
|
| 583 | + @pytest.mark.asyncio |
| 584 | + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) |
| 585 | + async def test_pulling_shim_uses_runtime_port_mapping_for_runner_calls( |
| 586 | + self, |
| 587 | + test_db, |
| 588 | + session: AsyncSession, |
| 589 | + ssh_tunnel_mock: Mock, |
| 590 | + shim_client_mock: Mock, |
| 591 | + ): |
| 592 | + project = await create_project(session=session) |
| 593 | + user = await create_user(session=session) |
| 594 | + repo = await create_repo(session=session, project_id=project.id) |
| 595 | + run = await create_run( |
| 596 | + session=session, |
| 597 | + project=project, |
| 598 | + repo=repo, |
| 599 | + user=user, |
| 600 | + ) |
| 601 | + instance = await create_instance( |
| 602 | + session=session, |
| 603 | + project=project, |
| 604 | + status=InstanceStatus.BUSY, |
| 605 | + ) |
| 606 | + job = await create_job( |
| 607 | + session=session, |
| 608 | + run=run, |
| 609 | + status=JobStatus.PULLING, |
| 610 | + job_provisioning_data=get_job_provisioning_data(dockerized=True), |
| 611 | + job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None), |
| 612 | + instance=instance, |
| 613 | + instance_assigned=True, |
| 614 | + ) |
| 615 | + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING |
| 616 | + shim_client_mock.get_task.return_value.ports = [ |
| 617 | + PortMapping(container=10022, host=32771), |
| 618 | + PortMapping(container=10999, host=32772), |
| 619 | + ] |
| 620 | + |
| 621 | + expected_ports = { |
| 622 | + 10022: 32771, |
| 623 | + 10999: 32772, |
| 624 | + } |
| 625 | + |
| 626 | + def assert_runner_availability(_, __, job_runtime_data): |
| 627 | + assert job_runtime_data is not None |
| 628 | + assert job_runtime_data.ports == expected_ports |
| 629 | + return _RunnerAvailability.AVAILABLE |
| 630 | + |
| 631 | + def assert_submit_job_to_runner(_, __, job_runtime_data, **kwargs): |
| 632 | + assert job_runtime_data is not None |
| 633 | + assert job_runtime_data.ports == expected_ports |
| 634 | + return True |
| 635 | + |
| 636 | + with ( |
| 637 | + patch( |
| 638 | + "dstack._internal.server.background.scheduled_tasks.running_jobs._get_runner_availability", |
| 639 | + side_effect=assert_runner_availability, |
| 640 | + ) as get_runner_availability_mock, |
| 641 | + patch( |
| 642 | + "dstack._internal.server.background.scheduled_tasks.running_jobs._submit_job_to_runner", |
| 643 | + side_effect=assert_submit_job_to_runner, |
| 644 | + ) as submit_job_to_runner_mock, |
| 645 | + patch( |
| 646 | + "dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_file_archives", |
| 647 | + new_callable=AsyncMock, |
| 648 | + return_value=[], |
| 649 | + ), |
| 650 | + patch( |
| 651 | + "dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_code", |
| 652 | + new_callable=AsyncMock, |
| 653 | + return_value=b"", |
| 654 | + ), |
| 655 | + ): |
| 656 | + await process_running_jobs() |
| 657 | + |
| 658 | + ssh_tunnel_mock.assert_called_once() |
| 659 | + get_runner_availability_mock.assert_called_once() |
| 660 | + submit_job_to_runner_mock.assert_called_once() |
| 661 | + |
| 662 | + await session.refresh(job) |
| 663 | + assert job is not None |
| 664 | + assert job.status == JobStatus.PULLING |
| 665 | + jrd = JobRuntimeData.__response__.parse_raw(job.job_runtime_data) |
| 666 | + assert jrd.ports == expected_ports |
| 667 | + |
582 | 668 | @pytest.mark.asyncio |
583 | 669 | @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) |
584 | 670 | async def test_pulling_shim_failed(self, test_db, session: AsyncSession): |
|
0 commit comments