|
1 | 1 | import sys |
2 | 2 | from unittest.mock import MagicMock, patch |
3 | 3 |
|
| 4 | +import pytest |
| 5 | + |
| 6 | +from comfy_cli import constants |
4 | 7 | from comfy_cli.command import install |
| 8 | +from comfy_cli.constants import GPU_OPTION |
5 | 9 |
|
6 | 10 |
|
7 | 11 | class TestPipInstallComfyuiDependencies: |
@@ -94,3 +98,64 @@ def test_fast_deps_passes_python_to_dependency_compiler(self, tmp_path): |
94 | 98 | MockCompiler.Install_Build_Deps.assert_called_once_with(executable="/resolved/python") |
95 | 99 | MockCompiler.assert_called_once() |
96 | 100 | assert MockCompiler.call_args[1]["executable"] == "/resolved/python" |
| 101 | + |
| 102 | + |
| 103 | +def _get_torch_install_cmd(calls): |
| 104 | + """Find the subprocess.run call that installs torch packages.""" |
| 105 | + for c in calls: |
| 106 | + cmd = c[0][0] |
| 107 | + if "torch" in cmd and "requirements.txt" not in cmd: |
| 108 | + return cmd |
| 109 | + return None |
| 110 | + |
| 111 | + |
| 112 | +class TestTorchInstallCommands: |
| 113 | + def test_amd_linux_uses_index_url(self, tmp_path): |
| 114 | + repo_dir = str(tmp_path) |
| 115 | + (tmp_path / "requirements.txt").write_text("some-package\n") |
| 116 | + |
| 117 | + with patch("comfy_cli.command.install.subprocess.run", return_value=MagicMock(returncode=0)) as mock_run: |
| 118 | + install.pip_install_comfyui_dependencies( |
| 119 | + repo_dir, |
| 120 | + gpu=GPU_OPTION.AMD, |
| 121 | + plat=constants.OS.LINUX, |
| 122 | + cuda_version=constants.CUDAVersion.v12_6, |
| 123 | + skip_torch_or_directml=False, |
| 124 | + skip_requirement=False, |
| 125 | + python="/usr/bin/python", |
| 126 | + rocm_version=constants.ROCmVersion.v6_3, |
| 127 | + ) |
| 128 | + |
| 129 | + cmd = _get_torch_install_cmd(mock_run.call_args_list) |
| 130 | + assert "--index-url" in cmd |
| 131 | + assert "--extra-index-url" not in cmd |
| 132 | + assert "https://download.pytorch.org/whl/rocm6.3" in cmd |
| 133 | + |
| 134 | + @pytest.mark.parametrize( |
| 135 | + "rocm_version,expected_url", |
| 136 | + [ |
| 137 | + (constants.ROCmVersion.v7_1, "https://download.pytorch.org/whl/rocm7.1"), |
| 138 | + (constants.ROCmVersion.v7_0, "https://download.pytorch.org/whl/rocm7.0"), |
| 139 | + (constants.ROCmVersion.v6_3, "https://download.pytorch.org/whl/rocm6.3"), |
| 140 | + (constants.ROCmVersion.v6_2, "https://download.pytorch.org/whl/rocm6.2"), |
| 141 | + (constants.ROCmVersion.v6_1, "https://download.pytorch.org/whl/rocm6.1"), |
| 142 | + ], |
| 143 | + ) |
| 144 | + def test_amd_linux_rocm_version_controls_url(self, tmp_path, rocm_version, expected_url): |
| 145 | + repo_dir = str(tmp_path) |
| 146 | + (tmp_path / "requirements.txt").write_text("some-package\n") |
| 147 | + |
| 148 | + with patch("comfy_cli.command.install.subprocess.run", return_value=MagicMock(returncode=0)) as mock_run: |
| 149 | + install.pip_install_comfyui_dependencies( |
| 150 | + repo_dir, |
| 151 | + gpu=GPU_OPTION.AMD, |
| 152 | + plat=constants.OS.LINUX, |
| 153 | + cuda_version=constants.CUDAVersion.v12_6, |
| 154 | + skip_torch_or_directml=False, |
| 155 | + skip_requirement=False, |
| 156 | + python="/usr/bin/python", |
| 157 | + rocm_version=rocm_version, |
| 158 | + ) |
| 159 | + |
| 160 | + cmd = _get_torch_install_cmd(mock_run.call_args_list) |
| 161 | + assert expected_url in cmd |
0 commit comments