@@ -110,7 +110,17 @@ def _get_torch_install_cmd(calls):
110110
111111
112112class TestTorchInstallCommands :
113- def test_amd_linux_uses_index_url (self , tmp_path ):
113+ @pytest .mark .parametrize (
114+ "rocm_version,expected_url" ,
115+ [
116+ (constants .ROCmVersion .v7_1 , "https://download.pytorch.org/whl/rocm7.1" ),
117+ (constants .ROCmVersion .v7_0 , "https://download.pytorch.org/whl/rocm7.0" ),
118+ (constants .ROCmVersion .v6_3 , "https://download.pytorch.org/whl/rocm6.3" ),
119+ (constants .ROCmVersion .v6_2 , "https://download.pytorch.org/whl/rocm6.2" ),
120+ (constants .ROCmVersion .v6_1 , "https://download.pytorch.org/whl/rocm6.1" ),
121+ ],
122+ )
123+ def test_amd_uses_index_url_with_rocm_version (self , tmp_path , rocm_version , expected_url ):
114124 repo_dir = str (tmp_path )
115125 (tmp_path / "requirements.txt" ).write_text ("some-package\n " )
116126
@@ -123,39 +133,59 @@ def test_amd_linux_uses_index_url(self, tmp_path):
123133 skip_torch_or_directml = False ,
124134 skip_requirement = False ,
125135 python = "/usr/bin/python" ,
126- rocm_version = constants . ROCmVersion . v6_3 ,
136+ rocm_version = rocm_version ,
127137 )
128138
129139 cmd = _get_torch_install_cmd (mock_run .call_args_list )
130140 assert "--index-url" in cmd
131141 assert "--extra-index-url" not in cmd
132- assert "https://download.pytorch.org/whl/rocm6.3" in cmd
142+ assert expected_url in cmd
133143
134144 @pytest .mark .parametrize (
135- "rocm_version ,expected_url" ,
145+ "cuda_version ,expected_url" ,
136146 [
137- (constants .ROCmVersion . v7_1 , "https://download.pytorch.org/whl/rocm7.1 " ),
138- (constants .ROCmVersion . v7_0 , "https://download.pytorch.org/whl/rocm7.0 " ),
139- (constants .ROCmVersion . v6_3 , "https://download.pytorch.org/whl/rocm6.3 " ),
140- (constants .ROCmVersion . v6_2 , "https://download.pytorch.org/whl/rocm6.2 " ),
141- (constants .ROCmVersion . v6_1 , "https://download.pytorch.org/whl/rocm6.1 " ),
147+ (constants .CUDAVersion . v12_9 , "https://download.pytorch.org/whl/cu129 " ),
148+ (constants .CUDAVersion . v12_6 , "https://download.pytorch.org/whl/cu126 " ),
149+ (constants .CUDAVersion . v12_4 , "https://download.pytorch.org/whl/cu124 " ),
150+ (constants .CUDAVersion . v12_1 , "https://download.pytorch.org/whl/cu121 " ),
151+ (constants .CUDAVersion . v11_8 , "https://download.pytorch.org/whl/cu118 " ),
142152 ],
143153 )
144- def test_amd_linux_rocm_version_controls_url (self , tmp_path , rocm_version , expected_url ):
154+ def test_nvidia_uses_index_url_with_cuda_version (self , tmp_path , cuda_version , expected_url ):
145155 repo_dir = str (tmp_path )
146156 (tmp_path / "requirements.txt" ).write_text ("some-package\n " )
147157
148158 with patch ("comfy_cli.command.install.subprocess.run" , return_value = MagicMock (returncode = 0 )) as mock_run :
149159 install .pip_install_comfyui_dependencies (
150160 repo_dir ,
151- gpu = GPU_OPTION .AMD ,
161+ gpu = GPU_OPTION .NVIDIA ,
162+ plat = constants .OS .WINDOWS ,
163+ cuda_version = cuda_version ,
164+ skip_torch_or_directml = False ,
165+ skip_requirement = False ,
166+ python = "/usr/bin/python" ,
167+ )
168+
169+ cmd = _get_torch_install_cmd (mock_run .call_args_list )
170+ assert "--index-url" in cmd
171+ assert "--extra-index-url" not in cmd
172+ assert expected_url in cmd
173+
174+ def test_nvidia_linux_uses_index_url (self , tmp_path ):
175+ repo_dir = str (tmp_path )
176+ (tmp_path / "requirements.txt" ).write_text ("some-package\n " )
177+
178+ with patch ("comfy_cli.command.install.subprocess.run" , return_value = MagicMock (returncode = 0 )) as mock_run :
179+ install .pip_install_comfyui_dependencies (
180+ repo_dir ,
181+ gpu = GPU_OPTION .NVIDIA ,
152182 plat = constants .OS .LINUX ,
153183 cuda_version = constants .CUDAVersion .v12_6 ,
154184 skip_torch_or_directml = False ,
155185 skip_requirement = False ,
156186 python = "/usr/bin/python" ,
157- rocm_version = rocm_version ,
158187 )
159188
160189 cmd = _get_torch_install_cmd (mock_run .call_args_list )
161- assert expected_url in cmd
190+ assert "--index-url" in cmd
191+ assert "https://download.pytorch.org/whl/cu126" in cmd
0 commit comments