Skip to content

Commit 252d8cd

Browse files
committed
fix: use --index-url for NVIDIA torch install, pass --cuda-version to fast-deps path
1 parent 88a2021 commit 252d8cd

4 files changed

Lines changed: 118 additions & 49 deletions

File tree

comfy_cli/command/install.py

Lines changed: 19 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -65,43 +65,19 @@ def pip_install_comfyui_dependencies(
6565

6666
# install torch for NVIDIA
6767
if gpu == GPU_OPTION.NVIDIA:
68-
base_command = [
69-
python,
70-
"-m",
71-
"pip",
72-
"install",
73-
"torch",
74-
"torchvision",
75-
"torchaudio",
76-
]
77-
78-
if plat == constants.OS.WINDOWS and cuda_version == constants.CUDAVersion.v12_9:
79-
base_command += [
80-
"--extra-index-url",
81-
"https://download.pytorch.org/whl/cu129",
82-
]
83-
elif plat == constants.OS.WINDOWS and cuda_version == constants.CUDAVersion.v12_6:
84-
base_command += [
85-
"--extra-index-url",
86-
"https://download.pytorch.org/whl/cu126",
87-
]
88-
elif plat == constants.OS.WINDOWS and cuda_version == constants.CUDAVersion.v12_4:
89-
base_command += [
90-
"--extra-index-url",
91-
"https://download.pytorch.org/whl/cu124",
92-
]
93-
elif plat == constants.OS.WINDOWS and cuda_version == constants.CUDAVersion.v12_1:
94-
base_command += [
95-
"--extra-index-url",
96-
"https://download.pytorch.org/whl/cu121",
97-
]
98-
elif plat == constants.OS.WINDOWS and cuda_version == constants.CUDAVersion.v11_8:
99-
base_command += [
100-
"--extra-index-url",
101-
"https://download.pytorch.org/whl/cu118",
102-
]
68+
cuda_tag = f"cu{cuda_version.value.replace('.', '')}"
69+
pip_url = ["--index-url", f"https://download.pytorch.org/whl/{cuda_tag}"]
10370
result = subprocess.run(
104-
base_command,
71+
[
72+
python,
73+
"-m",
74+
"pip",
75+
"install",
76+
"torch",
77+
"torchvision",
78+
"torchaudio",
79+
]
80+
+ pip_url,
10581
check=False,
10682
)
10783
# Update installation to use upstream torch xpu. ipex is no longer needed for Intel Arc GPUs
@@ -305,7 +281,13 @@ def execute(
305281

306282
if fast_deps:
307283
DependencyCompiler.Install_Build_Deps(executable=python)
308-
depComp = DependencyCompiler(cwd=repo_dir, executable=python, gpu=gpu)
284+
depComp = DependencyCompiler(
285+
cwd=repo_dir,
286+
executable=python,
287+
gpu=gpu,
288+
cuda_version=cuda_version.value,
289+
rocm_version=rocm_version.value,
290+
)
309291
depComp.compile_deps()
310292
depComp.install_deps()
311293

comfy_cli/uv.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,8 @@ def __init__(
365365
reqFilesCore: list[PathLike] | None = None,
366366
reqFilesExt: list[PathLike] | None = None,
367367
extraSpecs: list[str] | None = None,
368+
cuda_version: str | None = None,
369+
rocm_version: str | None = None,
368370
):
369371
"""Compiler/installer of Python dependencies based on uv
370372
@@ -386,18 +388,20 @@ def __init__(
386388
self.reqFiles = [Path(reqFile) for reqFile in reqFilesExt] if reqFilesExt is not None else None
387389
self.extraSpecs = [] if extraSpecs is None else extraSpecs
388390

389-
self.gpuUrl = (
390-
DependencyCompiler.nvidiaPytorchUrl if self.gpu == GPU_OPTION.NVIDIA else
391-
DependencyCompiler.rocmPytorchUrl if self.gpu == GPU_OPTION.AMD else
392-
DependencyCompiler.cpuPytorchUrl if self.gpu == GPU_OPTION.CPU else
393-
None
394-
) # fmt: skip
395-
self.torchBackend = (
396-
DependencyCompiler.nvidiaTorchBackend if self.gpu == GPU_OPTION.NVIDIA else
397-
DependencyCompiler.rocmTorchBackend if self.gpu == GPU_OPTION.AMD else
398-
DependencyCompiler.cpuTorchBackend if self.gpu == GPU_OPTION.CPU else
399-
None
400-
) # fmt: skip
391+
if self.gpu == GPU_OPTION.NVIDIA:
392+
tag = f"cu{cuda_version.replace('.', '')}" if cuda_version else DependencyCompiler.nvidiaTorchBackend
393+
self.gpuUrl = f"https://download.pytorch.org/whl/{tag}"
394+
self.torchBackend = tag
395+
elif self.gpu == GPU_OPTION.AMD:
396+
tag = f"rocm{rocm_version}" if rocm_version else DependencyCompiler.rocmTorchBackend
397+
self.gpuUrl = f"https://download.pytorch.org/whl/{tag}"
398+
self.torchBackend = tag
399+
elif self.gpu == GPU_OPTION.CPU:
400+
self.gpuUrl = DependencyCompiler.cpuPytorchUrl
401+
self.torchBackend = DependencyCompiler.cpuTorchBackend
402+
else:
403+
self.gpuUrl = None
404+
self.torchBackend = None
401405
self.out: Path = self.outDir / outName
402406
self.override = self.outDir / "override.txt"
403407

tests/comfy_cli/test_install_python_resolution.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,70 @@ def test_amd_linux_rocm_version_controls_url(self, tmp_path, rocm_version, expec
159159

160160
cmd = _get_torch_install_cmd(mock_run.call_args_list)
161161
assert expected_url in cmd
162+
163+
def test_nvidia_uses_index_url(self, tmp_path):
164+
repo_dir = str(tmp_path)
165+
(tmp_path / "requirements.txt").write_text("some-package\n")
166+
167+
with patch("comfy_cli.command.install.subprocess.run", return_value=MagicMock(returncode=0)) as mock_run:
168+
install.pip_install_comfyui_dependencies(
169+
repo_dir,
170+
gpu=GPU_OPTION.NVIDIA,
171+
plat=constants.OS.WINDOWS,
172+
cuda_version=constants.CUDAVersion.v12_6,
173+
skip_torch_or_directml=False,
174+
skip_requirement=False,
175+
python="/usr/bin/python",
176+
)
177+
178+
cmd = _get_torch_install_cmd(mock_run.call_args_list)
179+
assert "--index-url" in cmd
180+
assert "--extra-index-url" not in cmd
181+
assert "https://download.pytorch.org/whl/cu126" in cmd
182+
183+
@pytest.mark.parametrize(
184+
"cuda_version,expected_url",
185+
[
186+
(constants.CUDAVersion.v12_9, "https://download.pytorch.org/whl/cu129"),
187+
(constants.CUDAVersion.v12_6, "https://download.pytorch.org/whl/cu126"),
188+
(constants.CUDAVersion.v12_4, "https://download.pytorch.org/whl/cu124"),
189+
(constants.CUDAVersion.v12_1, "https://download.pytorch.org/whl/cu121"),
190+
(constants.CUDAVersion.v11_8, "https://download.pytorch.org/whl/cu118"),
191+
],
192+
)
193+
def test_nvidia_cuda_version_controls_url(self, tmp_path, cuda_version, expected_url):
194+
repo_dir = str(tmp_path)
195+
(tmp_path / "requirements.txt").write_text("some-package\n")
196+
197+
with patch("comfy_cli.command.install.subprocess.run", return_value=MagicMock(returncode=0)) as mock_run:
198+
install.pip_install_comfyui_dependencies(
199+
repo_dir,
200+
gpu=GPU_OPTION.NVIDIA,
201+
plat=constants.OS.WINDOWS,
202+
cuda_version=cuda_version,
203+
skip_torch_or_directml=False,
204+
skip_requirement=False,
205+
python="/usr/bin/python",
206+
)
207+
208+
cmd = _get_torch_install_cmd(mock_run.call_args_list)
209+
assert expected_url in cmd
210+
211+
def test_nvidia_linux_uses_index_url(self, tmp_path):
212+
repo_dir = str(tmp_path)
213+
(tmp_path / "requirements.txt").write_text("some-package\n")
214+
215+
with patch("comfy_cli.command.install.subprocess.run", return_value=MagicMock(returncode=0)) as mock_run:
216+
install.pip_install_comfyui_dependencies(
217+
repo_dir,
218+
gpu=GPU_OPTION.NVIDIA,
219+
plat=constants.OS.LINUX,
220+
cuda_version=constants.CUDAVersion.v12_6,
221+
skip_torch_or_directml=False,
222+
skip_requirement=False,
223+
python="/usr/bin/python",
224+
)
225+
226+
cmd = _get_torch_install_cmd(mock_run.call_args_list)
227+
assert "--index-url" in cmd
228+
assert "https://download.pytorch.org/whl/cu126" in cmd

tests/uv/test_uv.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,19 @@ def test_override_file_has_no_extra_index_url():
161161
content = depComp.override.read_text()
162162
assert "--extra-index-url" not in content
163163
assert "torch" in content
164+
165+
166+
def test_nvidia_custom_cuda_version():
167+
depComp = DependencyCompiler(
168+
cwd=temp, gpu=GPU_OPTION.NVIDIA, outDir=temp, reqFilesCore=[], reqFilesExt=[], cuda_version="11.8"
169+
)
170+
assert depComp.torchBackend == "cu118"
171+
assert depComp.gpuUrl == "https://download.pytorch.org/whl/cu118"
172+
173+
174+
def test_amd_custom_rocm_version():
175+
depComp = DependencyCompiler(
176+
cwd=temp, gpu=GPU_OPTION.AMD, outDir=temp, reqFilesCore=[], reqFilesExt=[], rocm_version="7.1"
177+
)
178+
assert depComp.torchBackend == "rocm7.1"
179+
assert depComp.gpuUrl == "https://download.pytorch.org/whl/rocm7.1"

0 commit comments

Comments
 (0)