Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 19 additions & 37 deletions comfy_cli/command/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,43 +65,19 @@ def pip_install_comfyui_dependencies(

# install torch for NVIDIA
if gpu == GPU_OPTION.NVIDIA:
base_command = [
python,
"-m",
"pip",
"install",
"torch",
"torchvision",
"torchaudio",
]

if plat == constants.OS.WINDOWS and cuda_version == constants.CUDAVersion.v12_9:
base_command += [
"--extra-index-url",
"https://download.pytorch.org/whl/cu129",
]
elif plat == constants.OS.WINDOWS and cuda_version == constants.CUDAVersion.v12_6:
base_command += [
"--extra-index-url",
"https://download.pytorch.org/whl/cu126",
]
elif plat == constants.OS.WINDOWS and cuda_version == constants.CUDAVersion.v12_4:
base_command += [
"--extra-index-url",
"https://download.pytorch.org/whl/cu124",
]
elif plat == constants.OS.WINDOWS and cuda_version == constants.CUDAVersion.v12_1:
base_command += [
"--extra-index-url",
"https://download.pytorch.org/whl/cu121",
]
elif plat == constants.OS.WINDOWS and cuda_version == constants.CUDAVersion.v11_8:
base_command += [
"--extra-index-url",
"https://download.pytorch.org/whl/cu118",
]
cuda_tag = f"cu{cuda_version.value.replace('.', '')}"
pip_url = ["--index-url", f"https://download.pytorch.org/whl/{cuda_tag}"]
result = subprocess.run(
base_command,
[
python,
"-m",
"pip",
"install",
"torch",
"torchvision",
"torchaudio",
]
+ pip_url,
check=False,
)
# Update installation to use upstream torch xpu. ipex is no longer needed for Intel Arc GPUs
Expand Down Expand Up @@ -305,7 +281,13 @@ def execute(

if fast_deps:
DependencyCompiler.Install_Build_Deps(executable=python)
depComp = DependencyCompiler(cwd=repo_dir, executable=python, gpu=gpu)
depComp = DependencyCompiler(
cwd=repo_dir,
executable=python,
gpu=gpu,
cuda_version=cuda_version.value,
rocm_version=rocm_version.value,
)
depComp.compile_deps()
depComp.install_deps()

Expand Down
28 changes: 16 additions & 12 deletions comfy_cli/uv.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,8 @@ def __init__(
reqFilesCore: list[PathLike] | None = None,
reqFilesExt: list[PathLike] | None = None,
extraSpecs: list[str] | None = None,
cuda_version: str | None = None,
rocm_version: str | None = None,
):
"""Compiler/installer of Python dependencies based on uv

Expand All @@ -386,18 +388,20 @@ def __init__(
self.reqFiles = [Path(reqFile) for reqFile in reqFilesExt] if reqFilesExt is not None else None
self.extraSpecs = [] if extraSpecs is None else extraSpecs

self.gpuUrl = (
DependencyCompiler.nvidiaPytorchUrl if self.gpu == GPU_OPTION.NVIDIA else
DependencyCompiler.rocmPytorchUrl if self.gpu == GPU_OPTION.AMD else
DependencyCompiler.cpuPytorchUrl if self.gpu == GPU_OPTION.CPU else
None
) # fmt: skip
self.torchBackend = (
DependencyCompiler.nvidiaTorchBackend if self.gpu == GPU_OPTION.NVIDIA else
DependencyCompiler.rocmTorchBackend if self.gpu == GPU_OPTION.AMD else
DependencyCompiler.cpuTorchBackend if self.gpu == GPU_OPTION.CPU else
None
) # fmt: skip
if self.gpu == GPU_OPTION.NVIDIA:
tag = f"cu{cuda_version.replace('.', '')}" if cuda_version else DependencyCompiler.nvidiaTorchBackend
self.gpuUrl = f"https://download.pytorch.org/whl/{tag}"
self.torchBackend = tag
elif self.gpu == GPU_OPTION.AMD:
tag = f"rocm{rocm_version}" if rocm_version else DependencyCompiler.rocmTorchBackend
self.gpuUrl = f"https://download.pytorch.org/whl/{tag}"
self.torchBackend = tag
elif self.gpu == GPU_OPTION.CPU:
self.gpuUrl = DependencyCompiler.cpuPytorchUrl
self.torchBackend = DependencyCompiler.cpuTorchBackend
else:
self.gpuUrl = None
self.torchBackend = None
self.out: Path = self.outDir / outName
self.override = self.outDir / "override.txt"

Expand Down
56 changes: 43 additions & 13 deletions tests/comfy_cli/test_install_python_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,17 @@ def _get_torch_install_cmd(calls):


class TestTorchInstallCommands:
def test_amd_linux_uses_index_url(self, tmp_path):
@pytest.mark.parametrize(
"rocm_version,expected_url",
[
(constants.ROCmVersion.v7_1, "https://download.pytorch.org/whl/rocm7.1"),
(constants.ROCmVersion.v7_0, "https://download.pytorch.org/whl/rocm7.0"),
(constants.ROCmVersion.v6_3, "https://download.pytorch.org/whl/rocm6.3"),
(constants.ROCmVersion.v6_2, "https://download.pytorch.org/whl/rocm6.2"),
(constants.ROCmVersion.v6_1, "https://download.pytorch.org/whl/rocm6.1"),
],
)
def test_amd_uses_index_url_with_rocm_version(self, tmp_path, rocm_version, expected_url):
repo_dir = str(tmp_path)
(tmp_path / "requirements.txt").write_text("some-package\n")

Expand All @@ -123,39 +133,59 @@ def test_amd_linux_uses_index_url(self, tmp_path):
skip_torch_or_directml=False,
skip_requirement=False,
python="/usr/bin/python",
rocm_version=constants.ROCmVersion.v6_3,
rocm_version=rocm_version,
)

cmd = _get_torch_install_cmd(mock_run.call_args_list)
assert "--index-url" in cmd
assert "--extra-index-url" not in cmd
assert "https://download.pytorch.org/whl/rocm6.3" in cmd
assert expected_url in cmd

@pytest.mark.parametrize(
"rocm_version,expected_url",
"cuda_version,expected_url",
[
(constants.ROCmVersion.v7_1, "https://download.pytorch.org/whl/rocm7.1"),
(constants.ROCmVersion.v7_0, "https://download.pytorch.org/whl/rocm7.0"),
(constants.ROCmVersion.v6_3, "https://download.pytorch.org/whl/rocm6.3"),
(constants.ROCmVersion.v6_2, "https://download.pytorch.org/whl/rocm6.2"),
(constants.ROCmVersion.v6_1, "https://download.pytorch.org/whl/rocm6.1"),
(constants.CUDAVersion.v12_9, "https://download.pytorch.org/whl/cu129"),
(constants.CUDAVersion.v12_6, "https://download.pytorch.org/whl/cu126"),
(constants.CUDAVersion.v12_4, "https://download.pytorch.org/whl/cu124"),
(constants.CUDAVersion.v12_1, "https://download.pytorch.org/whl/cu121"),
(constants.CUDAVersion.v11_8, "https://download.pytorch.org/whl/cu118"),
],
)
def test_amd_linux_rocm_version_controls_url(self, tmp_path, rocm_version, expected_url):
def test_nvidia_uses_index_url_with_cuda_version(self, tmp_path, cuda_version, expected_url):
repo_dir = str(tmp_path)
(tmp_path / "requirements.txt").write_text("some-package\n")

with patch("comfy_cli.command.install.subprocess.run", return_value=MagicMock(returncode=0)) as mock_run:
install.pip_install_comfyui_dependencies(
repo_dir,
gpu=GPU_OPTION.AMD,
gpu=GPU_OPTION.NVIDIA,
plat=constants.OS.WINDOWS,
cuda_version=cuda_version,
skip_torch_or_directml=False,
skip_requirement=False,
python="/usr/bin/python",
)

cmd = _get_torch_install_cmd(mock_run.call_args_list)
assert "--index-url" in cmd
assert "--extra-index-url" not in cmd
assert expected_url in cmd

def test_nvidia_linux_uses_index_url(self, tmp_path):
repo_dir = str(tmp_path)
(tmp_path / "requirements.txt").write_text("some-package\n")

with patch("comfy_cli.command.install.subprocess.run", return_value=MagicMock(returncode=0)) as mock_run:
install.pip_install_comfyui_dependencies(
repo_dir,
gpu=GPU_OPTION.NVIDIA,
plat=constants.OS.LINUX,
cuda_version=constants.CUDAVersion.v12_6,
skip_torch_or_directml=False,
skip_requirement=False,
python="/usr/bin/python",
rocm_version=rocm_version,
)

cmd = _get_torch_install_cmd(mock_run.call_args_list)
assert expected_url in cmd
assert "--index-url" in cmd
assert "https://download.pytorch.org/whl/cu126" in cmd
16 changes: 16 additions & 0 deletions tests/uv/test_uv.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,19 @@ def test_override_file_has_no_extra_index_url():
content = depComp.override.read_text()
assert "--extra-index-url" not in content
assert "torch" in content


def test_nvidia_custom_cuda_version():
depComp = DependencyCompiler(
cwd=temp, gpu=GPU_OPTION.NVIDIA, outDir=temp, reqFilesCore=[], reqFilesExt=[], cuda_version="11.8"
)
assert depComp.torchBackend == "cu118"
assert depComp.gpuUrl == "https://download.pytorch.org/whl/cu118"


def test_amd_custom_rocm_version():
depComp = DependencyCompiler(
cwd=temp, gpu=GPU_OPTION.AMD, outDir=temp, reqFilesCore=[], reqFilesExt=[], rocm_version="7.1"
)
assert depComp.torchBackend == "rocm7.1"
assert depComp.gpuUrl == "https://download.pytorch.org/whl/rocm7.1"
Loading