Skip to content

Commit 04dcf8e

Browse files
committed
test: add unit tests for AMD/ROCm torch install commands
1 parent 903e811 commit 04dcf8e

1 file changed

Lines changed: 65 additions & 0 deletions

File tree

tests/comfy_cli/test_install_python_resolution.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import sys
22
from unittest.mock import MagicMock, patch
33

4+
import pytest
5+
6+
from comfy_cli import constants
47
from comfy_cli.command import install
8+
from comfy_cli.constants import GPU_OPTION
59

610

711
class TestPipInstallComfyuiDependencies:
@@ -94,3 +98,64 @@ def test_fast_deps_passes_python_to_dependency_compiler(self, tmp_path):
9498
MockCompiler.Install_Build_Deps.assert_called_once_with(executable="/resolved/python")
9599
MockCompiler.assert_called_once()
96100
assert MockCompiler.call_args[1]["executable"] == "/resolved/python"
101+
102+
103+
def _get_torch_install_cmd(calls):
104+
"""Find the subprocess.run call that installs torch packages."""
105+
for c in calls:
106+
cmd = c[0][0]
107+
if "torch" in cmd and "requirements.txt" not in cmd:
108+
return cmd
109+
return None
110+
111+
112+
class TestTorchInstallCommands:
113+
def test_amd_linux_uses_index_url(self, tmp_path):
114+
repo_dir = str(tmp_path)
115+
(tmp_path / "requirements.txt").write_text("some-package\n")
116+
117+
with patch("comfy_cli.command.install.subprocess.run", return_value=MagicMock(returncode=0)) as mock_run:
118+
install.pip_install_comfyui_dependencies(
119+
repo_dir,
120+
gpu=GPU_OPTION.AMD,
121+
plat=constants.OS.LINUX,
122+
cuda_version=constants.CUDAVersion.v12_6,
123+
skip_torch_or_directml=False,
124+
skip_requirement=False,
125+
python="/usr/bin/python",
126+
rocm_version=constants.ROCmVersion.v6_3,
127+
)
128+
129+
cmd = _get_torch_install_cmd(mock_run.call_args_list)
130+
assert "--index-url" in cmd
131+
assert "--extra-index-url" not in cmd
132+
assert "https://download.pytorch.org/whl/rocm6.3" in cmd
133+
134+
@pytest.mark.parametrize(
135+
"rocm_version,expected_url",
136+
[
137+
(constants.ROCmVersion.v7_1, "https://download.pytorch.org/whl/rocm7.1"),
138+
(constants.ROCmVersion.v7_0, "https://download.pytorch.org/whl/rocm7.0"),
139+
(constants.ROCmVersion.v6_3, "https://download.pytorch.org/whl/rocm6.3"),
140+
(constants.ROCmVersion.v6_2, "https://download.pytorch.org/whl/rocm6.2"),
141+
(constants.ROCmVersion.v6_1, "https://download.pytorch.org/whl/rocm6.1"),
142+
],
143+
)
144+
def test_amd_linux_rocm_version_controls_url(self, tmp_path, rocm_version, expected_url):
145+
repo_dir = str(tmp_path)
146+
(tmp_path / "requirements.txt").write_text("some-package\n")
147+
148+
with patch("comfy_cli.command.install.subprocess.run", return_value=MagicMock(returncode=0)) as mock_run:
149+
install.pip_install_comfyui_dependencies(
150+
repo_dir,
151+
gpu=GPU_OPTION.AMD,
152+
plat=constants.OS.LINUX,
153+
cuda_version=constants.CUDAVersion.v12_6,
154+
skip_torch_or_directml=False,
155+
skip_requirement=False,
156+
python="/usr/bin/python",
157+
rocm_version=rocm_version,
158+
)
159+
160+
cmd = _get_torch_install_cmd(mock_run.call_args_list)
161+
assert expected_url in cmd

0 commit comments

Comments
 (0)