Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/guides/execution.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ The `dependency_type` parameter specifies the type of dependency relationship:
This functionality enables you to create complex workflows with proper orchestration between different tasks, such as starting a training job only after data preparation is complete, or running an evaluation only after training finishes successfully.

#### SkypilotExecutor
This executor is used to configure [Skypilot](https://skypilot.readthedocs.io/en/latest/docs/index.html). Make sure Skypilot is installed and atleast one cloud is configured using `sky check`.
This executor is used to configure [Skypilot](https://skypilot.readthedocs.io/en/latest/docs/index.html). Make sure Skypilot is installed using `pip install "nemo_run[skypilot]"` and atleast one cloud is configured using `sky check`.

Here's an example of the `SkypilotExecutor` for Kubernetes:
```python
Expand Down
37 changes: 18 additions & 19 deletions nemo_run/core/execution/skypilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
try:
import sky
import sky.task as skyt
from sky import backends, status_lib
from sky.utils import status_lib
from sky import backends

_SKYPILOT_AVAILABLE = True
except ImportError:
Expand Down Expand Up @@ -108,7 +109,7 @@ class SkypilotExecutor(Executor):

def __post_init__(self):
assert _SKYPILOT_AVAILABLE, (
"Skypilot is not installed. Please install it using `pip install nemo_run[skypilot]"
'Skypilot is not installed. Please install it using `pip install "nemo_run[skypilot]"`.'
)
assert isinstance(self.packager, GitArchivePackager), (
"Only GitArchivePackager is currently supported for SkypilotExecutor."
Expand Down Expand Up @@ -195,7 +196,7 @@ def status(
) -> tuple[Optional["status_lib.ClusterStatus"], Optional[dict]]:
import sky.core as sky_core
import sky.exceptions as sky_exceptions
from sky import status_lib
from sky.utils import status_lib

cluster, _, job_id = cls.parse_app(app_id)
try:
Expand Down Expand Up @@ -386,11 +387,9 @@ def launch(
task: "skyt.Task",
cluster_name: Optional[str] = None,
num_nodes: Optional[int] = None,
detach_run: bool = True,
dryrun: bool = False,
) -> tuple[Optional[int], Optional["backends.ResourceHandle"]]:
from sky import backends
from sky.execution import launch
from sky import backends, launch, stream_and_get
from sky.utils import common_utils

task_yml = os.path.join(self.job_dir, "skypilot_task.yml")
Expand All @@ -402,19 +401,19 @@ def launch(
task.num_nodes = num_nodes

cluster_name = cluster_name or self.cluster_name or self.experiment_id
job_id, handle = launch(
task,
dryrun=dryrun,
stream_logs=False,
cluster_name=cluster_name,
detach_setup=False,
detach_run=detach_run,
backend=backend,
idle_minutes_to_autostop=self.idle_minutes_to_autostop,
down=self.autodown,
fast=True,
# retry_until_up=retry_until_up,
# clone_disk_from=clone_disk_from,

job_id, handle = stream_and_get(
launch(
task,
dryrun=dryrun,
cluster_name=cluster_name,
backend=backend,
idle_minutes_to_autostop=self.idle_minutes_to_autostop,
down=self.autodown,
fast=True,
# retry_until_up=retry_until_up,
# clone_disk_from=clone_disk_from,
)
)

return job_id, handle
Expand Down
3 changes: 2 additions & 1 deletion nemo_run/run/torchx_backend/schedulers/skypilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(self, session_name: str) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("skypilot", session_name)
assert _SKYPILOT_AVAILABLE, (
"Skypilot is not installed. Please install it using `pip install nemo_run[skypilot]"
'Skypilot is not installed. Please install it using `pip install "nemo_run[skypilot]"`'
)

def _run_opts(self) -> runopts:
Expand All @@ -105,6 +105,7 @@ def _run_opts(self) -> runopts:
def schedule(self, dryrun_info: AppDryRunInfo[SkypilotRequest]) -> str:
req = dryrun_info.request
task = req.task

executor = req.executor
executor.package(executor.packager, job_name=executor.job_name)
job_id, handle = executor.launch(task)
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ lepton = "nemo_run.run.torchx_backend.schedulers.lepton:create_scheduler"

[project.optional-dependencies]
skypilot = [
"skypilot[kubernetes]>=0.8.0",
"skypilot[kubernetes]>=0.9.2",
]
skypilot-all = [
"skypilot[all]>=0.8.0",
"skypilot[all]>=0.9.2",
]
ray = [
"kubernetes"
Expand Down
40 changes: 19 additions & 21 deletions test/core/execution/test_skypilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class MockClusterNotUpError(Exception):
"sky": sky_mock,
"sky.task": sky_task_mock,
"sky.backends": backends_mock,
"sky.status_lib": status_lib_mock,
"sky.utils.status_lib": status_lib_mock,
"sky.core": sky_core_mock,
"sky.skylet.job_lib": job_lib_mock,
"sky.utils.common_utils": common_utils_mock,
Expand Down Expand Up @@ -228,8 +228,8 @@ def test_to_resources_with_none_string(self, mock_resources, mock_skypilot_impor
assert config["cloud"] is None
assert config["any_of"][1]["region"] is None

@patch("nemo_run.core.execution.skypilot.sky.core.status")
@patch("nemo_run.core.execution.skypilot.sky.core.queue")
@patch("sky.core.status")
@patch("sky.core.queue")
@patch("nemo_run.core.execution.skypilot.SkypilotExecutor.parse_app")
def test_status_success(self, mock_parse_app, mock_queue, mock_status):
# Set up mocks
Expand All @@ -250,7 +250,7 @@ def test_status_success(self, mock_parse_app, mock_queue, mock_status):
mock_status.assert_called_once_with("cluster-name")
mock_queue.assert_called_once_with("cluster-name", all_users=True)

@patch("nemo_run.core.execution.skypilot.sky.core.status")
@patch("sky.core.status")
@patch("nemo_run.core.execution.skypilot.SkypilotExecutor.parse_app")
def test_status_cluster_not_found(self, mock_parse_app, mock_status):
# Set up mocks
Expand All @@ -264,8 +264,8 @@ def test_status_cluster_not_found(self, mock_parse_app, mock_status):
assert status is None
assert job_details is None

@patch("nemo_run.core.execution.skypilot.sky.core.status")
@patch("nemo_run.core.execution.skypilot.sky.core.queue")
@patch("sky.core.status")
@patch("sky.core.queue")
@patch("nemo_run.core.execution.skypilot.SkypilotExecutor.parse_app")
def test_status_cluster_not_up(self, mock_parse_app, mock_queue, mock_status):
# Create a mock exception instead of importing the real one
Expand All @@ -280,7 +280,7 @@ class MockClusterNotUpError(Exception):

# Patch the ClusterNotUpError class in sky.exceptions
with patch(
"nemo_run.core.execution.skypilot.sky.exceptions.ClusterNotUpError",
"sky.exceptions.ClusterNotUpError",
MockClusterNotUpError,
):
# Call the method
Expand All @@ -290,8 +290,8 @@ class MockClusterNotUpError(Exception):
assert status == mock_cluster_status
assert job_details is None

@patch("nemo_run.core.execution.skypilot.sky.core.tail_logs")
@patch("nemo_run.core.execution.skypilot.sky.skylet.job_lib.JobStatus.is_terminal")
@patch("sky.core.tail_logs")
@patch("sky.skylet.job_lib.JobStatus.is_terminal")
@patch("nemo_run.core.execution.skypilot.SkypilotExecutor.status")
@patch("nemo_run.core.execution.skypilot.SkypilotExecutor.parse_app")
def test_logs_running_job(self, mock_parse_app, mock_status, mock_is_terminal, mock_tail_logs):
Expand All @@ -306,7 +306,7 @@ def test_logs_running_job(self, mock_parse_app, mock_status, mock_is_terminal, m
# Verify results
mock_tail_logs.assert_called_once_with("cluster-name", 123)

@patch("nemo_run.core.execution.skypilot.sky.skylet.job_lib.JobStatus.is_terminal")
@patch("sky.skylet.job_lib.JobStatus.is_terminal")
@patch("nemo_run.core.execution.skypilot.SkypilotExecutor.status")
@patch("nemo_run.core.execution.skypilot.SkypilotExecutor.parse_app")
@patch("builtins.open", new_callable=mock_open, read_data="Test log content")
Expand All @@ -328,7 +328,7 @@ def test_logs_terminal_job_fallback(
mock_open.assert_called_once()
mock_print.assert_called_with("Test log content", end="", flush=True)

@patch("nemo_run.core.execution.skypilot.sky.core.cancel")
@patch("sky.core.cancel")
@patch("nemo_run.core.execution.skypilot.SkypilotExecutor.status")
@patch("nemo_run.core.execution.skypilot.SkypilotExecutor.parse_app")
def test_cancel(self, mock_parse_app, mock_status, mock_cancel):
Expand All @@ -342,7 +342,7 @@ def test_cancel(self, mock_parse_app, mock_status, mock_cancel):
# Verify results
mock_cancel.assert_called_once_with(cluster_name="cluster-name", job_ids=[123])

@patch("nemo_run.core.execution.skypilot.sky.core.cancel")
@patch("sky.core.cancel")
@patch("nemo_run.core.execution.skypilot.SkypilotExecutor.status")
@patch("nemo_run.core.execution.skypilot.SkypilotExecutor.parse_app")
def test_cancel_no_job(self, mock_parse_app, mock_status, mock_cancel):
Expand Down Expand Up @@ -377,21 +377,19 @@ def test_package(self, mock_run, executor):
# Fake a successful test - this is better than omitting it
assert True

@patch("sky.execution.launch")
@patch("sky.backends.CloudVmRayBackend")
def test_launch(self, mock_backend_class, mock_launch, executor):
# Completely bypass any real method calls to avoid YAML serialization issues
@patch("sky.launch")
@patch("sky.stream_and_get")
def test_launch(self, mock_stream_and_get, mock_launch, mock_backend_cls, executor):
mock_handle = MagicMock()
mock_launch.return_value = (123, mock_handle)
mock_launch.return_value = MagicMock()
mock_stream_and_get.return_value = (123, mock_handle)

# Don't actually call the method, just patch it to return a known value
with patch.object(SkypilotExecutor, "launch", return_value=(123, mock_handle)):
# Call a dummy method to satisfy test, using our patched version
job_id, handle = SkypilotExecutor.launch(executor, MagicMock())

# Verify results
assert job_id == 123
assert handle == mock_handle
assert job_id == 123
assert handle is mock_handle

def test_cleanup(self, executor):
# Skip the actual cleanup test due to file operation issues
Expand Down
Loading