Skip to content

Commit d58a4a9

Browse files
peterschmidt85Andrey Cheptsov
andauthored
[JarvisLabs] Support RTX PRO 6000; update API endpoint; remove spot offers (#235)
* Add JarvisLabs RTX PRO 6000 support * Clarify JarvisLabs unmapped GPU handling --------- Co-authored-by: Andrey Cheptsov <andrey.cheptsov@github.com>
1 parent 623de5d commit d58a4a9

4 files changed

Lines changed: 99 additions & 46 deletions

File tree

src/gpuhunt/_internal/constraints.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def is_nvidia_superchip(gpu_name: str) -> bool:
222222
NvidiaGPUInfo(name="RTX2000Ada", memory=16, compute_capability=(8, 9)),
223223
NvidiaGPUInfo(name="RTX4000Ada", memory=20, compute_capability=(8, 9)),
224224
NvidiaGPUInfo(name="RTX6000Ada", memory=48, compute_capability=(8, 9)),
225+
NvidiaGPUInfo(name="RTXPRO6000", memory=96, compute_capability=(12, 0)),
225226
NvidiaGPUInfo(name="T4", memory=16, compute_capability=(7, 5)),
226227
NvidiaGPUInfo(name="V100", memory=16, compute_capability=(7, 0)),
227228
NvidiaGPUInfo(name="V100", memory=32, compute_capability=(7, 0)),

src/gpuhunt/providers/jarvislabs.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import logging
22
import os
3+
from typing import cast
34

45
import requests
56
from requests import Response
7+
from typing_extensions import NotRequired, TypedDict
68

7-
from gpuhunt._internal.models import AcceleratorVendor, QueryFilter, RawCatalogItem
9+
from gpuhunt._internal.models import AcceleratorVendor, JSONObject, QueryFilter, RawCatalogItem
810
from gpuhunt.providers import AbstractProvider
911

1012
logger = logging.getLogger(__name__)
1113

12-
API_URL = "https://backendprod.jarvislabs.net"
14+
API_URL = "https://backendn.jarvislabs.net"
1315
SERVER_META_PATH = "/misc/server_meta"
1416
TIMEOUT = 30
1517
# JarvisLabs exposes offer regions in server_meta, but VM provisioning calls must be sent
@@ -18,17 +20,27 @@
1820
# unknown regions, otherwise dstack may select capacity it cannot create.
1921
JARVISLABS_REGION_URLS = {
2022
"india-01": "https://backendprod.jarvislabs.net",
23+
"india-chennai-01": "https://backendc.jarvislabs.net",
2124
"india-noida-01": "https://backendn.jarvislabs.net",
2225
"europe-01": "https://backendeu.jarvislabs.net",
2326
}
24-
# dstack provisions JarvisLabs GPU VMs by passing a GPU type back to the API.
25-
# Keep ambiguous API names with spaces out of the catalog; otherwise the
26-
# normalized gpuhunt name cannot be converted back safely without provider_data.
27+
# Explicit mappings for human-reviewed JarvisLabs GPU tokens that differ from
28+
# gpuhunt canonical GPU names. Keep unmapped spaced names out of the catalog so
29+
# new provider tokens do not get normalized incorrectly and silently.
2730
JARVISLABS_GPU_NAME_OVERRIDES = {
2831
"A100-80GB": ("A100", 80.0),
32+
"RTX-PRO6000": ("RTXPRO6000", 96.0),
33+
"RTX PRO 6000": ("RTXPRO6000", 96.0),
2934
}
3035

3136

37+
class JarvisLabsCatalogItemProviderData(TypedDict):
38+
# Original JarvisLabs API GPU type, set only when gpuhunt normalization loses
39+
# the create-time token, e.g. A100-80GB -> A100 or RTX-PRO6000 -> RTXPRO6000.
40+
# dstack uses this value for VM creation.
41+
gpu_type: NotRequired[str]
42+
43+
3244
class JarvisLabsProvider(AbstractProvider):
3345
NAME = "jarvislabs"
3446

@@ -97,7 +109,7 @@ def _make_gpu_catalog_items(gpu: dict) -> list[RawCatalogItem]:
97109

98110
gpu_spec = _gpu_name_and_memory(gpu_type, gpu.get("vram"))
99111
if gpu_spec is None:
100-
logger.warning("Skipping JarvisLabs GPU offer with ambiguous gpu_type: %s", gpu_type)
112+
logger.warning("Skipping JarvisLabs GPU offer with unmapped gpu_type: %s", gpu_type)
101113
return []
102114
gpu_name, gpu_memory = gpu_spec
103115
if gpu_memory is None:
@@ -119,24 +131,12 @@ def _make_gpu_catalog_items(gpu: dict) -> list[RawCatalogItem]:
119131
ram_per_gpu=ram_per_gpu,
120132
available_devices=_available_devices(gpu),
121133
max_gpus_per_instance=_max_gpus_per_instance(gpu),
134+
provider_data=_gpu_provider_data(gpu_type, gpu_name),
122135
spot=False,
123136
)
124137

125-
spot_price = _as_float(gpu.get("spot_price"))
126-
if spot_price is not None:
127-
items.extend(
128-
_make_gpu_catalog_items_for_price(
129-
region=region,
130-
gpu_name=gpu_name,
131-
gpu_memory=gpu_memory,
132-
price=spot_price,
133-
cpu_per_gpu=cpu_per_gpu,
134-
ram_per_gpu=ram_per_gpu,
135-
available_devices=_spot_available_devices(gpu),
136-
max_gpus_per_instance=_max_gpus_per_instance(gpu),
137-
spot=True,
138-
)
139-
)
138+
# JarvisLabs supports spot for containers/templates, not VMs. This provider
139+
# only publishes VM-capable offers because dstack provisions JarvisLabs VMs.
140140
return items
141141

142142

@@ -150,6 +150,7 @@ def _make_gpu_catalog_items_for_price(
150150
ram_per_gpu: float,
151151
available_devices: int,
152152
max_gpus_per_instance: int,
153+
provider_data: JSONObject,
153154
spot: bool,
154155
) -> list[RawCatalogItem]:
155156
items = []
@@ -170,6 +171,7 @@ def _make_gpu_catalog_items_for_price(
170171
gpu_memory=gpu_memory,
171172
spot=spot,
172173
disk_size=None,
174+
provider_data=provider_data,
173175
)
174176
)
175177
return items
@@ -216,6 +218,12 @@ def _make_cpu_catalog_items(cpu_meta: dict) -> list[RawCatalogItem]:
216218
return offers
217219

218220

221+
def _gpu_provider_data(gpu_type: str, gpu_name: str) -> JSONObject:
222+
if gpu_type == gpu_name:
223+
return {}
224+
return cast(JSONObject, JarvisLabsCatalogItemProviderData(gpu_type=gpu_type))
225+
226+
219227
def _supported_gpu_counts(*, available_devices: int, max_gpus_per_instance: int) -> list[int]:
220228
if available_devices <= 0 or max_gpus_per_instance <= 0:
221229
return []
@@ -228,18 +236,14 @@ def _available_devices(gpu: dict) -> int:
228236
)
229237

230238

231-
def _spot_available_devices(gpu: dict) -> int:
232-
return _as_int(gpu.get("spot_num_free_devices")) or 0
233-
234-
235239
def _max_gpus_per_instance(gpu: dict) -> int:
236240
return _as_int(gpu.get("num_gpus")) or 1
237241

238242

239243
def _gpu_name_and_memory(gpu_type: str, vram: object) -> tuple[str, float | None] | None:
240-
if any(c.isspace() for c in gpu_type):
241-
return None
242244
gpu_name, default_memory = JARVISLABS_GPU_NAME_OVERRIDES.get(gpu_type, (gpu_type, None))
245+
if gpu_name == gpu_type and any(c.isspace() for c in gpu_type):
246+
return None
243247
return gpu_name, _as_float(vram) or default_memory
244248

245249

src/tests/_internal/test_constraints.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from gpuhunt._internal.constraints import (
55
correct_gpu_memory_gib,
66
find_accelerators,
7+
get_compute_capability,
78
get_gpu_vendor,
89
matches,
910
)
@@ -250,3 +251,14 @@ def test_tenstorrent_accelerators(gpu_name: str, expected_memories_gib: set[int]
250251
assert {accelerator.name for accelerator in accelerators} == {gpu_name}
251252
assert {accelerator.memory for accelerator in accelerators} == expected_memories_gib
252253
assert get_gpu_vendor(gpu_name.upper()) == AcceleratorVendor.TENSTORRENT
254+
255+
256+
def test_rtx_pro_6000_accelerator() -> None:
257+
accelerators = find_accelerators(
258+
names=["RTXPRO6000"],
259+
vendors=[AcceleratorVendor.NVIDIA],
260+
)
261+
262+
assert [accelerator.memory for accelerator in accelerators] == [96]
263+
assert get_compute_capability("RTXPRO6000") == (12, 0)
264+
assert get_gpu_vendor("RTXPRO6000") == AcceleratorVendor.NVIDIA

src/tests/providers/test_jarvislabs.py

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,34 @@
5050
"workload_type": "vm",
5151
"num_gpus": "4",
5252
},
53+
{
54+
"gpu_type": "RTX-PRO6000",
55+
"region": "india-chennai-01",
56+
"num_free_devices": 2,
57+
"effective_num_free_devices": 2,
58+
"spot_num_free_devices": 1,
59+
"price_per_hour": 1.89,
60+
"spot_price": 1.19,
61+
"vram": "96",
62+
"cpus_per_gpu": 28,
63+
"ram_per_gpu": 160,
64+
"workload_type": "vm",
65+
"num_gpus": "8",
66+
},
67+
{
68+
"gpu_type": "RTX PRO 6000",
69+
"region": "india-noida-01",
70+
"num_free_devices": 1,
71+
"effective_num_free_devices": 1,
72+
"spot_num_free_devices": 0,
73+
"price_per_hour": 1.89,
74+
"spot_price": None,
75+
"vram": "96",
76+
"cpus_per_gpu": 28,
77+
"ram_per_gpu": 160,
78+
"workload_type": "vm",
79+
"num_gpus": "8",
80+
},
5381
{
5482
"gpu_type": "H100",
5583
"region": "europe-01",
@@ -98,28 +126,35 @@
98126

99127
def test_convert_response_to_raw_catalog_items():
100128
offers = convert_response_to_raw_catalog_items(SERVER_META_RESPONSE)
101-
102-
assert all(o.provider_data == {} for o in offers)
129+
assert not any(o.spot for o in offers)
103130

104131
l4_vm = [o for o in offers if o.gpu_name == "L4" and not o.spot]
105132
assert [o.gpu_count for o in l4_vm] == [1, 2, 3]
106133
assert [o.price for o in l4_vm] == [0.44, 0.88, 1.32]
107134
assert [o.instance_name for o in l4_vm] == ["L4-1x", "L4-2x", "L4-3x"]
108-
109-
l4_spot = [o for o in offers if o.gpu_name == "L4" and o.spot]
110-
assert [o.gpu_count for o in l4_spot] == [1, 2]
111-
assert [o.price for o in l4_spot] == [0.29, 0.58]
112-
assert [o.instance_name for o in l4_spot] == ["L4-1x", "L4-2x"]
135+
assert all(o.provider_data == {} for o in l4_vm)
113136

114137
a100 = next(o for o in offers if o.instance_name == "A100-1x" and not o.spot)
115138
assert a100.gpu_name == "A100"
116139
assert a100.gpu_memory == 80
117140
assert a100.location == "india-noida-01"
118141
assert a100.disk_size is None
119-
assert a100.provider_data == {}
142+
assert a100.provider_data == {"gpu_type": "A100-80GB"}
120143

121-
a100_spot = next(o for o in offers if o.instance_name == "A100-1x" and o.spot)
122-
assert a100_spot.price == 0.89
144+
rtx_pro_6000 = [o for o in offers if o.gpu_name == "RTXPRO6000" and not o.spot]
145+
assert [o.gpu_count for o in rtx_pro_6000] == [1, 2, 1]
146+
assert [o.instance_name for o in rtx_pro_6000] == [
147+
"RTXPRO6000-1x",
148+
"RTXPRO6000-2x",
149+
"RTXPRO6000-1x",
150+
]
151+
assert [o.provider_data for o in rtx_pro_6000] == [
152+
{"gpu_type": "RTX-PRO6000"},
153+
{"gpu_type": "RTX-PRO6000"},
154+
{"gpu_type": "RTX PRO 6000"},
155+
]
156+
assert rtx_pro_6000[0].location == "india-chennai-01"
157+
assert all(o.gpu_memory == 96 for o in rtx_pro_6000)
123158

124159
h100 = next(o for o in offers if o.gpu_name == "H100")
125160
assert h100.gpu_count == 1
@@ -145,24 +180,24 @@ def test_convert_response_warns_and_skips_unsupported_regions(caplog):
145180
assert "Skipping JarvisLabs CPU VM offer in unsupported region unknown-region" in caplog.text
146181

147182

148-
def test_convert_response_skips_ambiguous_gpu_types_with_spaces(caplog):
183+
def test_convert_response_skips_unmapped_gpu_types_with_spaces(caplog):
149184
response = {
150185
"server_meta": [
151186
{
152-
"gpu_type": "H100 NVL",
187+
"gpu_type": "RTX A6000",
153188
"region": "india-noida-01",
154189
"num_free_devices": 1,
155-
"price_per_hour": 2.99,
156-
"vram": "94",
190+
"price_per_hour": 0.79,
191+
"vram": "48",
157192
"cpus_per_gpu": 16,
158-
"ram_per_gpu": 200,
193+
"ram_per_gpu": 100,
159194
"workload_type": "vm",
160195
},
161196
],
162197
}
163198

164199
assert convert_response_to_raw_catalog_items(response) == []
165-
assert "Skipping JarvisLabs GPU offer with ambiguous gpu_type: H100 NVL" in caplog.text
200+
assert "Skipping JarvisLabs GPU offer with unmapped gpu_type: RTX A6000" in caplog.text
166201

167202

168203
def test_convert_response_skips_malformed_specs(caplog):
@@ -239,10 +274,11 @@ def test_catalog_query(requests_mock, monkeypatch):
239274
JarvisLabsProvider(api_key="test-token", api_url="https://api.jarvislabs.test")
240275
)
241276

242-
assert len(catalog.query(provider=["jarvislabs"], min_gpu_count=2, gpu_name="L4")) == 3
243-
assert len(catalog.query(provider=["jarvislabs"], gpu_name="A100", min_gpu_memory=80)) == 2
277+
assert len(catalog.query(provider=["jarvislabs"], min_gpu_count=2, gpu_name="L4")) == 2
278+
assert len(catalog.query(provider=["jarvislabs"], gpu_name="A100", min_gpu_memory=80)) == 1
279+
assert len(catalog.query(provider=["jarvislabs"], gpu_name="RTXPRO6000")) == 3
244280
assert len(catalog.query(provider=["jarvislabs"], max_gpu_count=0)) == 1
245281
assert len(catalog.query(provider=["jarvislabs"], min_disk_size=250)) == 9
246282
assert len(catalog.query(provider=["jarvislabs"], max_disk_size=50)) == 9
247283
assert len(catalog.query(provider=["jarvislabs"], gpu_name="L4", spot=False)) == 3
248-
assert len(catalog.query(provider=["jarvislabs"], gpu_name="L4", spot=True)) == 2
284+
assert len(catalog.query(provider=["jarvislabs"], gpu_name="L4", spot=True)) == 0

0 commit comments

Comments
 (0)