|
1 | 1 | import pytest |
2 | 2 |
|
3 | 3 | from gpuhunt import CatalogItem, QueryFilter |
4 | | -from gpuhunt._internal.constraints import correct_gpu_memory_gib, matches |
| 4 | +from gpuhunt._internal.constraints import ( |
| 5 | + correct_gpu_memory_gib, |
| 6 | + find_accelerators, |
| 7 | + get_gpu_vendor, |
| 8 | + matches, |
| 9 | +) |
5 | 10 | from gpuhunt._internal.models import AcceleratorVendor |
6 | 11 |
|
7 | 12 |
|
@@ -222,3 +227,26 @@ def test_matches_cpu_instances_with_zero_gpu_count_and_gpu_memory(self): |
222 | 227 | ) |
223 | 228 | def test_correct_gpu_memory(gpu_name: str, memory_mib: float, expected_memory_gib: int) -> None: |
224 | 229 | assert correct_gpu_memory_gib(gpu_name, memory_mib) == expected_memory_gib |
| 230 | + |
| 231 | + |
| 232 | +@pytest.mark.parametrize( |
| 233 | + ("gpu_name", "expected_memories_gib"), |
| 234 | + [ |
| 235 | + ("n150", {12}), |
| 236 | + ("n300", {24}), |
| 237 | + ("tt-galaxy-wh", {12}), |
| 238 | + ("p100a", {28}), |
| 239 | + ("p150", {32}), |
| 240 | + ("p300", {32, 64}), |
| 241 | + ("tt-galaxy-bh", {32}), |
| 242 | + ], |
| 243 | +) |
| 244 | +def test_tenstorrent_accelerators(gpu_name: str, expected_memories_gib: set[int]) -> None: |
| 245 | + accelerators = find_accelerators( |
| 246 | + names=[gpu_name.upper()], |
| 247 | + vendors=[AcceleratorVendor.TENSTORRENT], |
| 248 | + ) |
| 249 | + |
| 250 | + assert {accelerator.name for accelerator in accelerators} == {gpu_name} |
| 251 | + assert {accelerator.memory for accelerator in accelerators} == expected_memories_gib |
| 252 | + assert get_gpu_vendor(gpu_name.upper()) == AcceleratorVendor.TENSTORRENT |
0 commit comments