Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion comfy_cli/cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from comfy_cli.command.launch import launch as launch_command
from comfy_cli.command.models import models as models_command
from comfy_cli.config_manager import ConfigManager
from comfy_cli.constants import GPU_OPTION, CUDAVersion
from comfy_cli.constants import GPU_OPTION, CUDAVersion, ROCmVersion
from comfy_cli.env_checker import EnvChecker
from comfy_cli.resolve_python import resolve_workspace_python
from comfy_cli.standalone import StandalonePython
Expand Down Expand Up @@ -194,6 +194,7 @@ def install(
),
] = None,
cuda_version: Annotated[CUDAVersion, typer.Option(show_default=True)] = CUDAVersion.v12_6,
rocm_version: Annotated[ROCmVersion, typer.Option(show_default=True)] = ROCmVersion.v6_3,
amd: Annotated[
bool | None,
typer.Option(
Expand Down Expand Up @@ -287,6 +288,7 @@ def install(
version=version,
gpu=None,
cuda_version=cuda_version,
rocm_version=rocm_version,
plat=platform,
skip_torch_or_directml=skip_torch_or_directml,
skip_requirement=skip_requirement,
Expand Down Expand Up @@ -342,6 +344,7 @@ def install(
gpu=gpu,
version=version,
cuda_version=cuda_version,
rocm_version=rocm_version,
plat=platform,
skip_torch_or_directml=skip_torch_or_directml,
skip_requirement=skip_requirement,
Expand Down
13 changes: 11 additions & 2 deletions comfy_cli/command/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ def pip_install_comfyui_dependencies(
skip_torch_or_directml: bool,
skip_requirement: bool,
python: str = sys.executable,
rocm_version: constants.ROCmVersion = constants.ROCmVersion.v6_3,
):
os.chdir(repo_dir)

result = None
if not skip_torch_or_directml:
# install torch for AMD Linux
if gpu == GPU_OPTION.AMD and plat == constants.OS.LINUX:
pip_url = ["--extra-index-url", "https://download.pytorch.org/whl/rocm6.0"]
pip_url = ["--index-url", f"https://download.pytorch.org/whl/rocm{rocm_version.value}"]
result = subprocess.run(
[
python,
Expand Down Expand Up @@ -198,6 +199,7 @@ def execute(
manager_commit: str | None = None,
gpu: constants.GPU_OPTION = None,
cuda_version: constants.CUDAVersion = constants.CUDAVersion.v12_6,
rocm_version: constants.ROCmVersion = constants.ROCmVersion.v6_3,
plat: constants.OS = None,
skip_torch_or_directml: bool = False,
skip_requirement: bool = False,
Expand Down Expand Up @@ -255,7 +257,14 @@ def execute(

if not fast_deps:
pip_install_comfyui_dependencies(
repo_dir, gpu, plat, cuda_version, skip_torch_or_directml, skip_requirement, python=python
repo_dir,
gpu,
plat,
cuda_version,
skip_torch_or_directml,
skip_requirement,
python=python,
rocm_version=rocm_version,
)

WorkspaceManager().set_recent_workspace(repo_dir)
Expand Down
8 changes: 8 additions & 0 deletions comfy_cli/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ class CUDAVersion(str, Enum):
v11_8 = "11.8"


class ROCmVersion(str, Enum):
v7_1 = "7.1"
v7_0 = "7.0"
v6_3 = "6.3"
v6_2 = "6.2"
v6_1 = "6.1"


class GPU_OPTION(str, Enum):
CPU = None
NVIDIA = "nvidia"
Expand Down
4 changes: 2 additions & 2 deletions comfy_cli/uv.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ def parse_req_file(rf: PathLike, skips: list[str] | None = None):

class DependencyCompiler:
cpuPytorchUrl = "https://download.pytorch.org/whl/cpu"
rocmPytorchUrl = "https://download.pytorch.org/whl/rocm6.1"
rocmPytorchUrl = "https://download.pytorch.org/whl/rocm6.3"
nvidiaPytorchUrl = "https://download.pytorch.org/whl/cu126"

cpuTorchBackend = "cpu"
rocmTorchBackend = "rocm6.1"
rocmTorchBackend = "rocm6.3"
nvidiaTorchBackend = "cu126"

overrideGpu = dedent(
Expand Down
65 changes: 65 additions & 0 deletions tests/comfy_cli/test_install_python_resolution.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import sys
from unittest.mock import MagicMock, patch

import pytest

from comfy_cli import constants
from comfy_cli.command import install
from comfy_cli.constants import GPU_OPTION


class TestPipInstallComfyuiDependencies:
Expand Down Expand Up @@ -94,3 +98,64 @@ def test_fast_deps_passes_python_to_dependency_compiler(self, tmp_path):
MockCompiler.Install_Build_Deps.assert_called_once_with(executable="/resolved/python")
MockCompiler.assert_called_once()
assert MockCompiler.call_args[1]["executable"] == "/resolved/python"


def _get_torch_install_cmd(calls):
"""Find the subprocess.run call that installs torch packages."""
for c in calls:
cmd = c[0][0]
if "torch" in cmd and "requirements.txt" not in cmd:
return cmd
return None


class TestTorchInstallCommands:
def test_amd_linux_uses_index_url(self, tmp_path):
repo_dir = str(tmp_path)
(tmp_path / "requirements.txt").write_text("some-package\n")

with patch("comfy_cli.command.install.subprocess.run", return_value=MagicMock(returncode=0)) as mock_run:
install.pip_install_comfyui_dependencies(
repo_dir,
gpu=GPU_OPTION.AMD,
plat=constants.OS.LINUX,
cuda_version=constants.CUDAVersion.v12_6,
skip_torch_or_directml=False,
skip_requirement=False,
python="/usr/bin/python",
rocm_version=constants.ROCmVersion.v6_3,
)

cmd = _get_torch_install_cmd(mock_run.call_args_list)
assert "--index-url" in cmd
assert "--extra-index-url" not in cmd
assert "https://download.pytorch.org/whl/rocm6.3" in cmd

@pytest.mark.parametrize(
"rocm_version,expected_url",
[
(constants.ROCmVersion.v7_1, "https://download.pytorch.org/whl/rocm7.1"),
(constants.ROCmVersion.v7_0, "https://download.pytorch.org/whl/rocm7.0"),
(constants.ROCmVersion.v6_3, "https://download.pytorch.org/whl/rocm6.3"),
(constants.ROCmVersion.v6_2, "https://download.pytorch.org/whl/rocm6.2"),
(constants.ROCmVersion.v6_1, "https://download.pytorch.org/whl/rocm6.1"),
],
)
def test_amd_linux_rocm_version_controls_url(self, tmp_path, rocm_version, expected_url):
repo_dir = str(tmp_path)
(tmp_path / "requirements.txt").write_text("some-package\n")

with patch("comfy_cli.command.install.subprocess.run", return_value=MagicMock(returncode=0)) as mock_run:
install.pip_install_comfyui_dependencies(
repo_dir,
gpu=GPU_OPTION.AMD,
plat=constants.OS.LINUX,
cuda_version=constants.CUDAVersion.v12_6,
skip_torch_or_directml=False,
skip_requirement=False,
python="/usr/bin/python",
rocm_version=rocm_version,
)

cmd = _get_torch_install_cmd(mock_run.call_args_list)
assert expected_url in cmd
2 changes: 1 addition & 1 deletion tests/uv/test_uv.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_torch_backend_nvidia():

def test_torch_backend_amd():
depComp = DependencyCompiler(cwd=temp, gpu=GPU_OPTION.AMD, outDir=temp, reqFilesCore=[], reqFilesExt=[])
assert depComp.torchBackend == "rocm6.1"
assert depComp.torchBackend == "rocm6.3"
assert depComp.gpuUrl == DependencyCompiler.rocmPytorchUrl


Expand Down
Loading