Skip to content

Commit 623de5d

Browse files
peterschmidt85Andrey Cheptsov
andauthored
Add Tenstorrent Blackhole accelerators (#232)
Co-authored-by: Andrey Cheptsov <andrey.cheptsov@github.com>
1 parent 2a0d7f5 commit 623de5d

2 files changed

Lines changed: 35 additions & 1 deletion

File tree

src/gpuhunt/_internal/constraints.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,12 @@ def is_nvidia_superchip(gpu_name: str) -> bool:
306306
KNOWN_TENSTORRENT_ACCELERATORS: list[TenstorrentAcceleratorInfo] = [
307307
TenstorrentAcceleratorInfo(name="n150", memory=12),
308308
TenstorrentAcceleratorInfo(name="n300", memory=24),
309+
TenstorrentAcceleratorInfo(name="tt-galaxy-wh", memory=12),
310+
TenstorrentAcceleratorInfo(name="p100a", memory=28),
311+
TenstorrentAcceleratorInfo(name="p150", memory=32),
312+
TenstorrentAcceleratorInfo(name="p300", memory=32),
313+
TenstorrentAcceleratorInfo(name="p300", memory=64),
314+
TenstorrentAcceleratorInfo(name="tt-galaxy-bh", memory=32),
309315
]
310316

311317
KNOWN_ACCELERATORS: list[

src/tests/_internal/test_constraints.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import pytest
22

33
from gpuhunt import CatalogItem, QueryFilter
4-
from gpuhunt._internal.constraints import correct_gpu_memory_gib, matches
4+
from gpuhunt._internal.constraints import (
5+
correct_gpu_memory_gib,
6+
find_accelerators,
7+
get_gpu_vendor,
8+
matches,
9+
)
510
from gpuhunt._internal.models import AcceleratorVendor
611

712

@@ -222,3 +227,26 @@ def test_matches_cpu_instances_with_zero_gpu_count_and_gpu_memory(self):
222227
)
223228
def test_correct_gpu_memory(gpu_name: str, memory_mib: float, expected_memory_gib: int) -> None:
224229
assert correct_gpu_memory_gib(gpu_name, memory_mib) == expected_memory_gib
230+
231+
232+
@pytest.mark.parametrize(
233+
("gpu_name", "expected_memories_gib"),
234+
[
235+
("n150", {12}),
236+
("n300", {24}),
237+
("tt-galaxy-wh", {12}),
238+
("p100a", {28}),
239+
("p150", {32}),
240+
("p300", {32, 64}),
241+
("tt-galaxy-bh", {32}),
242+
],
243+
)
244+
def test_tenstorrent_accelerators(gpu_name: str, expected_memories_gib: set[int]) -> None:
245+
accelerators = find_accelerators(
246+
names=[gpu_name.upper()],
247+
vendors=[AcceleratorVendor.TENSTORRENT],
248+
)
249+
250+
assert {accelerator.name for accelerator in accelerators} == {gpu_name}
251+
assert {accelerator.memory for accelerator in accelerators} == expected_memories_gib
252+
assert get_gpu_vendor(gpu_name.upper()) == AcceleratorVendor.TENSTORRENT

0 commit comments

Comments
 (0)