Skip to content

Commit 6090345

Browse files
committed
fix: use --index-url for NVIDIA torch install, pass --cuda-version to fast-deps path
1 parent 88a2021 commit 6090345

4 files changed

Lines changed: 94 additions & 62 deletions

File tree

comfy_cli/command/install.py

Lines changed: 19 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -65,43 +65,19 @@ def pip_install_comfyui_dependencies(
6565

6666
# install torch for NVIDIA
6767
if gpu == GPU_OPTION.NVIDIA:
68-
base_command = [
69-
python,
70-
"-m",
71-
"pip",
72-
"install",
73-
"torch",
74-
"torchvision",
75-
"torchaudio",
76-
]
77-
78-
if plat == constants.OS.WINDOWS and cuda_version == constants.CUDAVersion.v12_9:
79-
base_command += [
80-
"--extra-index-url",
81-
"https://download.pytorch.org/whl/cu129",
82-
]
83-
elif plat == constants.OS.WINDOWS and cuda_version == constants.CUDAVersion.v12_6:
84-
base_command += [
85-
"--extra-index-url",
86-
"https://download.pytorch.org/whl/cu126",
87-
]
88-
elif plat == constants.OS.WINDOWS and cuda_version == constants.CUDAVersion.v12_4:
89-
base_command += [
90-
"--extra-index-url",
91-
"https://download.pytorch.org/whl/cu124",
92-
]
93-
elif plat == constants.OS.WINDOWS and cuda_version == constants.CUDAVersion.v12_1:
94-
base_command += [
95-
"--extra-index-url",
96-
"https://download.pytorch.org/whl/cu121",
97-
]
98-
elif plat == constants.OS.WINDOWS and cuda_version == constants.CUDAVersion.v11_8:
99-
base_command += [
100-
"--extra-index-url",
101-
"https://download.pytorch.org/whl/cu118",
102-
]
68+
cuda_tag = f"cu{cuda_version.value.replace('.', '')}"
69+
pip_url = ["--index-url", f"https://download.pytorch.org/whl/{cuda_tag}"]
10370
result = subprocess.run(
104-
base_command,
71+
[
72+
python,
73+
"-m",
74+
"pip",
75+
"install",
76+
"torch",
77+
"torchvision",
78+
"torchaudio",
79+
]
80+
+ pip_url,
10581
check=False,
10682
)
10783
# Update installation to use upstream torch xpu. ipex is no longer needed for Intel Arc GPUs
@@ -305,7 +281,13 @@ def execute(
305281

306282
if fast_deps:
307283
DependencyCompiler.Install_Build_Deps(executable=python)
308-
depComp = DependencyCompiler(cwd=repo_dir, executable=python, gpu=gpu)
284+
depComp = DependencyCompiler(
285+
cwd=repo_dir,
286+
executable=python,
287+
gpu=gpu,
288+
cuda_version=cuda_version.value,
289+
rocm_version=rocm_version.value,
290+
)
309291
depComp.compile_deps()
310292
depComp.install_deps()
311293

comfy_cli/uv.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,8 @@ def __init__(
365365
reqFilesCore: list[PathLike] | None = None,
366366
reqFilesExt: list[PathLike] | None = None,
367367
extraSpecs: list[str] | None = None,
368+
cuda_version: str | None = None,
369+
rocm_version: str | None = None,
368370
):
369371
"""Compiler/installer of Python dependencies based on uv
370372
@@ -386,18 +388,20 @@ def __init__(
386388
self.reqFiles = [Path(reqFile) for reqFile in reqFilesExt] if reqFilesExt is not None else None
387389
self.extraSpecs = [] if extraSpecs is None else extraSpecs
388390

389-
self.gpuUrl = (
390-
DependencyCompiler.nvidiaPytorchUrl if self.gpu == GPU_OPTION.NVIDIA else
391-
DependencyCompiler.rocmPytorchUrl if self.gpu == GPU_OPTION.AMD else
392-
DependencyCompiler.cpuPytorchUrl if self.gpu == GPU_OPTION.CPU else
393-
None
394-
) # fmt: skip
395-
self.torchBackend = (
396-
DependencyCompiler.nvidiaTorchBackend if self.gpu == GPU_OPTION.NVIDIA else
397-
DependencyCompiler.rocmTorchBackend if self.gpu == GPU_OPTION.AMD else
398-
DependencyCompiler.cpuTorchBackend if self.gpu == GPU_OPTION.CPU else
399-
None
400-
) # fmt: skip
391+
if self.gpu == GPU_OPTION.NVIDIA:
392+
tag = f"cu{cuda_version.replace('.', '')}" if cuda_version else DependencyCompiler.nvidiaTorchBackend
393+
self.gpuUrl = f"https://download.pytorch.org/whl/{tag}"
394+
self.torchBackend = tag
395+
elif self.gpu == GPU_OPTION.AMD:
396+
tag = f"rocm{rocm_version}" if rocm_version else DependencyCompiler.rocmTorchBackend
397+
self.gpuUrl = f"https://download.pytorch.org/whl/{tag}"
398+
self.torchBackend = tag
399+
elif self.gpu == GPU_OPTION.CPU:
400+
self.gpuUrl = DependencyCompiler.cpuPytorchUrl
401+
self.torchBackend = DependencyCompiler.cpuTorchBackend
402+
else:
403+
self.gpuUrl = None
404+
self.torchBackend = None
401405
self.out: Path = self.outDir / outName
402406
self.override = self.outDir / "override.txt"
403407

tests/comfy_cli/test_install_python_resolution.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,17 @@ def _get_torch_install_cmd(calls):
110110

111111

112112
class TestTorchInstallCommands:
113-
def test_amd_linux_uses_index_url(self, tmp_path):
113+
@pytest.mark.parametrize(
114+
"rocm_version,expected_url",
115+
[
116+
(constants.ROCmVersion.v7_1, "https://download.pytorch.org/whl/rocm7.1"),
117+
(constants.ROCmVersion.v7_0, "https://download.pytorch.org/whl/rocm7.0"),
118+
(constants.ROCmVersion.v6_3, "https://download.pytorch.org/whl/rocm6.3"),
119+
(constants.ROCmVersion.v6_2, "https://download.pytorch.org/whl/rocm6.2"),
120+
(constants.ROCmVersion.v6_1, "https://download.pytorch.org/whl/rocm6.1"),
121+
],
122+
)
123+
def test_amd_uses_index_url_with_rocm_version(self, tmp_path, rocm_version, expected_url):
114124
repo_dir = str(tmp_path)
115125
(tmp_path / "requirements.txt").write_text("some-package\n")
116126

@@ -123,39 +133,59 @@ def test_amd_linux_uses_index_url(self, tmp_path):
123133
skip_torch_or_directml=False,
124134
skip_requirement=False,
125135
python="/usr/bin/python",
126-
rocm_version=constants.ROCmVersion.v6_3,
136+
rocm_version=rocm_version,
127137
)
128138

129139
cmd = _get_torch_install_cmd(mock_run.call_args_list)
130140
assert "--index-url" in cmd
131141
assert "--extra-index-url" not in cmd
132-
assert "https://download.pytorch.org/whl/rocm6.3" in cmd
142+
assert expected_url in cmd
133143

134144
@pytest.mark.parametrize(
135-
"rocm_version,expected_url",
145+
"cuda_version,expected_url",
136146
[
137-
(constants.ROCmVersion.v7_1, "https://download.pytorch.org/whl/rocm7.1"),
138-
(constants.ROCmVersion.v7_0, "https://download.pytorch.org/whl/rocm7.0"),
139-
(constants.ROCmVersion.v6_3, "https://download.pytorch.org/whl/rocm6.3"),
140-
(constants.ROCmVersion.v6_2, "https://download.pytorch.org/whl/rocm6.2"),
141-
(constants.ROCmVersion.v6_1, "https://download.pytorch.org/whl/rocm6.1"),
147+
(constants.CUDAVersion.v12_9, "https://download.pytorch.org/whl/cu129"),
148+
(constants.CUDAVersion.v12_6, "https://download.pytorch.org/whl/cu126"),
149+
(constants.CUDAVersion.v12_4, "https://download.pytorch.org/whl/cu124"),
150+
(constants.CUDAVersion.v12_1, "https://download.pytorch.org/whl/cu121"),
151+
(constants.CUDAVersion.v11_8, "https://download.pytorch.org/whl/cu118"),
142152
],
143153
)
144-
def test_amd_linux_rocm_version_controls_url(self, tmp_path, rocm_version, expected_url):
154+
def test_nvidia_uses_index_url_with_cuda_version(self, tmp_path, cuda_version, expected_url):
145155
repo_dir = str(tmp_path)
146156
(tmp_path / "requirements.txt").write_text("some-package\n")
147157

148158
with patch("comfy_cli.command.install.subprocess.run", return_value=MagicMock(returncode=0)) as mock_run:
149159
install.pip_install_comfyui_dependencies(
150160
repo_dir,
151-
gpu=GPU_OPTION.AMD,
161+
gpu=GPU_OPTION.NVIDIA,
162+
plat=constants.OS.WINDOWS,
163+
cuda_version=cuda_version,
164+
skip_torch_or_directml=False,
165+
skip_requirement=False,
166+
python="/usr/bin/python",
167+
)
168+
169+
cmd = _get_torch_install_cmd(mock_run.call_args_list)
170+
assert "--index-url" in cmd
171+
assert "--extra-index-url" not in cmd
172+
assert expected_url in cmd
173+
174+
def test_nvidia_linux_uses_index_url(self, tmp_path):
175+
repo_dir = str(tmp_path)
176+
(tmp_path / "requirements.txt").write_text("some-package\n")
177+
178+
with patch("comfy_cli.command.install.subprocess.run", return_value=MagicMock(returncode=0)) as mock_run:
179+
install.pip_install_comfyui_dependencies(
180+
repo_dir,
181+
gpu=GPU_OPTION.NVIDIA,
152182
plat=constants.OS.LINUX,
153183
cuda_version=constants.CUDAVersion.v12_6,
154184
skip_torch_or_directml=False,
155185
skip_requirement=False,
156186
python="/usr/bin/python",
157-
rocm_version=rocm_version,
158187
)
159188

160189
cmd = _get_torch_install_cmd(mock_run.call_args_list)
161-
assert expected_url in cmd
190+
assert "--index-url" in cmd
191+
assert "https://download.pytorch.org/whl/cu126" in cmd

tests/uv/test_uv.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,19 @@ def test_override_file_has_no_extra_index_url():
161161
content = depComp.override.read_text()
162162
assert "--extra-index-url" not in content
163163
assert "torch" in content
164+
165+
166+
def test_nvidia_custom_cuda_version():
167+
depComp = DependencyCompiler(
168+
cwd=temp, gpu=GPU_OPTION.NVIDIA, outDir=temp, reqFilesCore=[], reqFilesExt=[], cuda_version="11.8"
169+
)
170+
assert depComp.torchBackend == "cu118"
171+
assert depComp.gpuUrl == "https://download.pytorch.org/whl/cu118"
172+
173+
174+
def test_amd_custom_rocm_version():
175+
depComp = DependencyCompiler(
176+
cwd=temp, gpu=GPU_OPTION.AMD, outDir=temp, reqFilesCore=[], reqFilesExt=[], rocm_version="7.1"
177+
)
178+
assert depComp.torchBackend == "rocm7.1"
179+
assert depComp.gpuUrl == "https://download.pytorch.org/whl/rocm7.1"

0 commit comments

Comments
 (0)