Skip to content

Commit 8bab84d

Browse files
committed
fix: use --torch-backend instead of --extra-index-url for dependency compilation
The old --extra-index-url approach caused uv to resolve any package from the PyTorch wheel index when available there, not just torch-ecosystem packages. Replace it with --torch-backend which routes only torch/ torchaudio/torchvision/torchsde to the PyTorch index. Pin uv>=0.6.0 as the minimum version supporting --torch-backend.
1 parent ccac968 commit 8bab84d

4 files changed

Lines changed: 112 additions & 17 deletions

File tree

comfy_cli/uv.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,12 @@ class DependencyCompiler:
6767
rocmPytorchUrl = "https://download.pytorch.org/whl/rocm6.1"
6868
nvidiaPytorchUrl = "https://download.pytorch.org/whl/cu126"
6969

70+
rocmTorchBackend = "rocm6.1"
71+
nvidiaTorchBackend = "cu126"
72+
7073
overrideGpu = dedent(
7174
"""
7275
# ensure usage of {gpu} version of pytorch
73-
--extra-index-url {gpuUrl}
7476
torch
7577
torchaudio
7678
torchsde
@@ -118,6 +120,7 @@ def Compile(
118120
out: PathLike | None = None,
119121
override: PathLike | None = None,
120122
resolve_strategy: str | None = None,
123+
torch_backend: str | None = None,
121124
) -> subprocess.CompletedProcess[Any]:
122125
cmd = [
123126
str(executable),
@@ -136,8 +139,11 @@ def Compile(
136139
if emit_index_url:
137140
cmd.append("--emit-index-url")
138141

142+
if torch_backend is not None:
143+
cmd.extend(["--torch-backend", torch_backend])
144+
139145
# ensures that eg tqdm is latest version, even though an old tqdm is on the amd url
140-
# see https://github.com/astral-sh/uv/blob/main/PIP_COMPATIBILITY.md#packages-that-exist-on-multiple-indexes and https://github.com/astral-sh/uv/issues/171
146+
# see https://github.com/astral-sh/uv/blob/main/PIP_COMPATIBILITY.md#packages-that-exist-on-multiple-indexes
141147
if index_strategy is not None:
142148
cmd.extend(["--index-strategy", "unsafe-best-match"])
143149

@@ -383,6 +389,11 @@ def __init__(
383389
DependencyCompiler.rocmPytorchUrl if self.gpu == GPU_OPTION.AMD else
384390
None
385391
) # fmt: skip
392+
self.torchBackend = (
393+
DependencyCompiler.nvidiaTorchBackend if self.gpu == GPU_OPTION.NVIDIA else
394+
DependencyCompiler.rocmTorchBackend if self.gpu == GPU_OPTION.AMD else
395+
None
396+
) # fmt: skip
386397
self.out: Path = self.outDir / outName
387398
self.override = self.outDir / "override.txt"
388399

@@ -401,8 +412,8 @@ def make_override(self):
401412
self.override.unlink(missing_ok=True)
402413

403414
with open(self.override, "w") as f:
404-
if self.gpu is not None and self.gpuUrl is not None:
405-
f.write(DependencyCompiler.overrideGpu.format(gpu=self.gpu, gpuUrl=self.gpuUrl))
415+
if self.gpu is not None:
416+
f.write(DependencyCompiler.overrideGpu.format(gpu=self.gpu))
406417
f.write("\n\n")
407418

408419
completed = DependencyCompiler.Compile(
@@ -412,6 +423,7 @@ def make_override(self):
412423
emit_index_url=False,
413424
executable=self.executable,
414425
override=self.override,
426+
torch_backend=self.torchBackend,
415427
)
416428

417429
with open(self.override, "a") as f:
@@ -442,6 +454,7 @@ def compile_core_plus_ext(self):
442454
override=self.override,
443455
out=self.out,
444456
resolve_strategy="ask",
457+
torch_backend=self.torchBackend,
445458
)
446459

447460
break

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ dependencies = [
4848
"tomlkit",
4949
"typer>=0.12.5",
5050
"typing-extensions>=4.7",
51-
"uv",
51+
"uv>=0.6.0",
5252
"websocket-client",
5353
]
5454

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,27 @@
11
# This file was autogenerated by uv via the following command:
2-
# uv pip compile /home/tel/git/comfy-cli/tests/uv/mock_requirements/core_reqs.txt /home/tel/git/comfy-cli/tests/uv/mock_requirements/x_reqs.txt /home/tel/git/comfy-cli/tests/uv/mock_requirements/y_reqs.txt --emit-index-annotation --emit-index-url --index-strategy unsafe-best-match -o /home/tel/git/comfy-cli/tests/temp/test_uv/requirements.compiled --override /home/tel/git/comfy-cli/tests/temp/test_uv/override.txt
2+
# uv pip compile tests/uv/mock_requirements/core_reqs.txt tests/uv/mock_requirements/x_reqs.txt tests/uv/mock_requirements/y_reqs.txt --emit-index-annotation --emit-index-url --torch-backend rocm6.1 --index-strategy unsafe-best-match -o tests/temp/test_uv/requirements.compiled --override tests/temp/test_uv/override.txt
33
--index-url https://pypi.org/simple
4-
--extra-index-url https://download.pytorch.org/whl/rocm6.1
54

65
mpmath==1.3.0
76
# via
8-
# -r /home/tel/git/comfy-cli/tests/uv/mock_requirements/y_reqs.txt
7+
# -r tests/uv/mock_requirements/y_reqs.txt
98
# sympy
10-
# from https://download.pytorch.org/whl/rocm6.1
9+
# from https://pypi.org/simple
1110
numpy==2.0.2
1211
# via
13-
# --override override.txt
14-
# -r /home/tel/git/comfy-cli/tests/uv/mock_requirements/x_reqs.txt
15-
# -r /home/tel/git/comfy-cli/tests/uv/mock_requirements/y_reqs.txt
12+
# -r tests/uv/mock_requirements/x_reqs.txt
13+
# -r tests/uv/mock_requirements/y_reqs.txt
1614
# from https://pypi.org/simple
1715
sympy==1.13.0
1816
# via
1917
# --override override.txt
20-
# -r /home/tel/git/comfy-cli/tests/uv/mock_requirements/x_reqs.txt
21-
# -r /home/tel/git/comfy-cli/tests/uv/mock_requirements/y_reqs.txt
18+
# -r tests/uv/mock_requirements/x_reqs.txt
19+
# -r tests/uv/mock_requirements/y_reqs.txt
2220
# from https://pypi.org/simple
2321
tqdm==4.66.4
2422
# via
2523
# --override override.txt
26-
# -r /home/tel/git/comfy-cli/tests/uv/mock_requirements/core_reqs.txt
27-
# -r /home/tel/git/comfy-cli/tests/uv/mock_requirements/x_reqs.txt
28-
# -r /home/tel/git/comfy-cli/tests/uv/mock_requirements/y_reqs.txt
24+
# -r tests/uv/mock_requirements/core_reqs.txt
25+
# -r tests/uv/mock_requirements/x_reqs.txt
26+
# -r tests/uv/mock_requirements/y_reqs.txt
2927
# from https://pypi.org/simple

tests/uv/test_uv.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import shutil
22
from pathlib import Path
3+
from unittest.mock import patch
34

45
import pytest
56

@@ -77,3 +78,86 @@ def _filter_optional(lines: list[str]) -> list[str]:
7778
knownLines, testLines = [_filter_optional(lines) for lines in (knownLines, testLines)]
7879

7980
assert knownLines == testLines
81+
82+
83+
def test_torch_backend_nvidia():
84+
depComp = DependencyCompiler(cwd=temp, gpu=GPU_OPTION.NVIDIA, outDir=temp, reqFilesCore=[], reqFilesExt=[])
85+
assert depComp.torchBackend == "cu126"
86+
assert depComp.gpuUrl == DependencyCompiler.nvidiaPytorchUrl
87+
88+
89+
def test_torch_backend_amd():
90+
depComp = DependencyCompiler(cwd=temp, gpu=GPU_OPTION.AMD, outDir=temp, reqFilesCore=[], reqFilesExt=[])
91+
assert depComp.torchBackend == "rocm6.1"
92+
assert depComp.gpuUrl == DependencyCompiler.rocmPytorchUrl
93+
94+
95+
def test_torch_backend_cpu():
96+
depComp = DependencyCompiler(cwd=temp, gpu=GPU_OPTION.CPU, outDir=temp, reqFilesCore=[], reqFilesExt=[])
97+
assert depComp.torchBackend is None
98+
assert depComp.gpuUrl is None
99+
100+
101+
def test_torch_backend_none():
102+
with patch.object(DependencyCompiler, "Resolve_Gpu", return_value=None):
103+
depComp = DependencyCompiler(cwd=temp, gpu=None, outDir=temp, reqFilesCore=[], reqFilesExt=[])
104+
assert depComp.torchBackend is None
105+
assert depComp.gpuUrl is None
106+
107+
108+
def test_compile_passes_torch_backend():
109+
"""Verify that Compile() includes --torch-backend in the command when provided."""
110+
with patch("comfy_cli.uv._run") as mock_run:
111+
mock_run.return_value = type("R", (), {"stdout": "", "stderr": "", "returncode": 0})()
112+
DependencyCompiler.Compile(
113+
cwd=temp,
114+
reqFiles=[mockReqsDir / "core_reqs.txt"],
115+
torch_backend="cu126",
116+
)
117+
cmd = mock_run.call_args[0][0]
118+
idx = cmd.index("--torch-backend")
119+
assert cmd[idx + 1] == "cu126"
120+
121+
122+
def test_compile_omits_torch_backend_when_none():
123+
"""Verify that Compile() does not include --torch-backend when torch_backend is None."""
124+
with patch("comfy_cli.uv._run") as mock_run:
125+
mock_run.return_value = type("R", (), {"stdout": "", "stderr": "", "returncode": 0})()
126+
DependencyCompiler.Compile(
127+
cwd=temp,
128+
reqFiles=[mockReqsDir / "core_reqs.txt"],
129+
torch_backend=None,
130+
)
131+
cmd = mock_run.call_args[0][0]
132+
assert "--torch-backend" not in cmd
133+
134+
135+
def test_compiled_output_has_no_extra_index_url(mock_prompt_select):
136+
"""The compiled output must not contain --extra-index-url (torch-backend handles routing)."""
137+
depComp = DependencyCompiler(
138+
cwd=temp,
139+
gpu=GPU_OPTION.AMD,
140+
outDir=temp,
141+
reqFilesCore=[mockReqsDir / "core_reqs.txt"],
142+
reqFilesExt=[mockReqsDir / "x_reqs.txt", mockReqsDir / "y_reqs.txt"],
143+
)
144+
depComp.make_override()
145+
depComp.compile_core_plus_ext()
146+
147+
content = depComp.out.read_text()
148+
assert "--extra-index-url" not in content
149+
150+
151+
def test_override_file_has_no_extra_index_url():
152+
depComp = DependencyCompiler(
153+
cwd=temp,
154+
gpu=GPU_OPTION.AMD,
155+
outDir=temp,
156+
reqFilesCore=[mockReqsDir / "core_reqs.txt"],
157+
reqFilesExt=[],
158+
)
159+
depComp.make_override()
160+
161+
content = depComp.override.read_text()
162+
assert "--extra-index-url" not in content
163+
assert "torch" in content

0 commit comments

Comments
 (0)