Skip to content

Commit 8861fe2

Browse files
author
Andrey Cheptsov
committed
Add JarvisLabs backend
1 parent 927e7f8 commit 8861fe2

17 files changed

Lines changed: 1421 additions & 1 deletion

File tree

mkdocs/docs/concepts/backends.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,26 @@ projects:
918918

919919
</div>
920920

921+
### JarvisLabs
922+
923+
Log into your [JarvisLabs](https://cloud.jarvislabs.ai/) account and create an API key.
924+
925+
Then, go ahead and configure the backend:
926+
927+
<div editor-title="~/.dstack/server/config.yml">
928+
929+
```yaml
930+
projects:
931+
- name: main
932+
backends:
933+
- type: jarvislabs
934+
creds:
935+
type: api_key
936+
api_key: ...
937+
```
938+
939+
</div>
940+
921941
### CloudRift
922942

923943
Log into your [CloudRift](https://console.cloudrift.ai/) console, click `API Keys` in the sidebar and click the button to create a new API key.

mkdocs/docs/reference/server/config.yml.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,23 @@ to configure [backends](../../concepts/backends.md) and other [server-level sett
369369
type:
370370
required: true
371371

372+
##### `projects[n].backends[type=jarvislabs]` { #jarvislabs data-toc-label="jarvislabs" }
373+
374+
#SCHEMA# dstack._internal.core.backends.jarvislabs.models.JarvisLabsBackendFileConfigWithCreds
375+
overrides:
376+
show_root_heading: false
377+
type:
378+
required: true
379+
item_id_prefix: jarvislabs-
380+
381+
###### `projects[n].backends[type=jarvislabs].creds` { #jarvislabs-creds data-toc-label="creds" }
382+
383+
#SCHEMA# dstack._internal.core.backends.jarvislabs.models.JarvisLabsAPIKeyCreds
384+
overrides:
385+
show_root_heading: false
386+
type:
387+
required: true
388+
372389
##### `projects[n].backends[type=cloudrift]` { #cloudrift data-toc-label="cloudrift" }
373390

374391
#SCHEMA# dstack._internal.core.backends.cloudrift.models.CloudRiftBackendConfigWithCreds

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ dependencies = [
3232
"python-multipart>=0.0.16",
3333
"filelock",
3434
"psutil",
35-
"gpuhunt==0.1.21",
35+
"gpuhunt @ git+https://github.com/dstackai/gpuhunt.git@jarvislabs",
3636
"argcomplete>=3.5.0",
3737
"ignore-python>=0.2.0",
3838
"orjson",
@@ -67,6 +67,9 @@ artifacts = [
6767
"src/dstack/_internal/server/statics/**",
6868
]
6969

70+
[tool.hatch.metadata]
71+
allow-direct-references = true
72+
7073
[tool.hatch.metadata.hooks.fancy-pypi-readme]
7174
content-type = "text/markdown"
7275

src/dstack/_internal/core/backends/configurators.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,15 @@
8787
except ImportError:
8888
pass
8989

90+
try:
91+
from dstack._internal.core.backends.jarvislabs.configurator import (
92+
JarvisLabsConfigurator,
93+
)
94+
95+
_CONFIGURATOR_CLASSES.append(JarvisLabsConfigurator)
96+
except ImportError:
97+
pass
98+
9099
try:
91100
from dstack._internal.core.backends.kubernetes.configurator import (
92101
KubernetesConfigurator,

src/dstack/_internal/core/backends/jarvislabs/__init__.py

Whitespace-only changes.
Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
import hashlib
2+
from typing import Any, Dict, List, Optional
3+
4+
import requests
5+
from gpuhunt.providers.jarvislabs import API_URL, JARVISLABS_REGION_URLS
6+
7+
from dstack._internal.core.errors import (
8+
BackendError,
9+
BackendInvalidCredentialsError,
10+
NoCapacityError,
11+
)
12+
13+
TIMEOUT = 120
14+
15+
16+
class JarvisLabsNotFoundError(BackendError):
17+
pass
18+
19+
20+
class JarvisLabsAPIClient:
21+
def __init__(self, api_key: str):
22+
self.api_key = api_key
23+
24+
def validate_api_key(self) -> bool:
25+
try:
26+
self.get_user_info()
27+
except BackendInvalidCredentialsError:
28+
return False
29+
return True
30+
31+
def get_user_info(self) -> Dict[str, Any]:
32+
resp = self._make_request("GET", "users/user_info")
33+
if not isinstance(resp, dict):
34+
raise BackendError("Unexpected JarvisLabs user_info response")
35+
return resp
36+
37+
def list_ssh_keys(self) -> List[Dict[str, Any]]:
38+
resp = self._make_request("GET", "ssh/")
39+
if isinstance(resp, list):
40+
return resp
41+
raise BackendError("Unexpected JarvisLabs SSH key list response")
42+
43+
def add_ssh_key(self, public_key: str, key_name: str) -> None:
44+
resp = self._make_request(
45+
"POST",
46+
"ssh/",
47+
json={
48+
"ssh_key": public_key,
49+
"key_name": key_name,
50+
},
51+
)
52+
_raise_if_unsuccessful(resp, "Failed to add JarvisLabs SSH key")
53+
54+
def add_ssh_key_if_needed(self, public_key: str) -> None:
55+
normalized_key = _normalize_public_key(public_key)
56+
for ssh_key in self.list_ssh_keys():
57+
if _normalize_public_key(str(ssh_key.get("ssh_key", ""))) == normalized_key:
58+
return
59+
key_name = _get_ssh_key_name(normalized_key)
60+
self.add_ssh_key(public_key=public_key, key_name=key_name)
61+
62+
def create_gpu_vm(
63+
self,
64+
*,
65+
gpu_type: str,
66+
num_gpus: int,
67+
is_spot: bool,
68+
storage: int,
69+
region: str,
70+
name: str,
71+
) -> str:
72+
resp = self._make_request(
73+
"POST",
74+
"templates/vm/create",
75+
region=region,
76+
json={
77+
"gpu_type": gpu_type,
78+
"num_gpus": num_gpus,
79+
"hdd": storage,
80+
"region": region,
81+
"name": name,
82+
"is_spot": is_spot,
83+
"duration": "hour",
84+
"disk_type": "ssd",
85+
"http_ports": "",
86+
"script_id": None,
87+
"script_args": "",
88+
"fs_id": None,
89+
"arguments": "",
90+
},
91+
)
92+
return _get_created_machine_id(resp, "GPU VM creation")
93+
94+
def create_cpu_vm(
95+
self,
96+
*,
97+
vcpus: int,
98+
ram_gb: int,
99+
storage: int,
100+
region: str,
101+
name: str,
102+
) -> str:
103+
resp = self._make_request(
104+
"POST",
105+
"templates/vm/cpu/create",
106+
region=region,
107+
json={
108+
"num_cpus": 1,
109+
"vcpus": vcpus,
110+
"ram_gb": ram_gb,
111+
"hdd": storage,
112+
"region": region,
113+
"name": name,
114+
"duration": "hour",
115+
"disk_type": "ssd",
116+
},
117+
)
118+
return _get_created_machine_id(resp, "CPU VM creation")
119+
120+
def get_instance(self, machine_id: str) -> Optional[Dict[str, Any]]:
121+
try:
122+
resp = self._make_request("GET", f"users/fetch/{machine_id}")
123+
except JarvisLabsNotFoundError:
124+
return None
125+
if not _is_successful(resp):
126+
return None
127+
if isinstance(resp, dict):
128+
instance = resp.get("instance")
129+
if isinstance(instance, dict):
130+
return instance
131+
return None
132+
133+
def get_instance_status(self, *, machine_id: str, region: str) -> Optional[Dict[str, Any]]:
134+
try:
135+
resp = self._make_request(
136+
"GET",
137+
"misc/status",
138+
region=region,
139+
params={"machine_id": machine_id},
140+
)
141+
except JarvisLabsNotFoundError:
142+
return None
143+
if isinstance(resp, dict):
144+
return resp
145+
return None
146+
147+
def destroy_instance(self, *, machine_id: str, region: str) -> None:
148+
instance = self.get_instance(machine_id)
149+
if instance is None:
150+
return
151+
endpoint = "templates/vm/destroy"
152+
if is_cpu_vm(instance):
153+
endpoint = "templates/vm/cpu/destroy"
154+
elif _instance_template(instance) != "vm":
155+
endpoint = "misc/destroy"
156+
157+
try:
158+
resp = self._make_request(
159+
"POST",
160+
endpoint,
161+
region=instance.get("region") or region,
162+
params={"machine_id": machine_id},
163+
)
164+
except JarvisLabsNotFoundError:
165+
return
166+
_raise_if_unsuccessful(resp, "Failed to destroy JarvisLabs instance")
167+
168+
def _make_request(
169+
self,
170+
method: str,
171+
path: str,
172+
*,
173+
json: Optional[Dict[str, Any]] = None,
174+
params: Optional[Dict[str, Any]] = None,
175+
region: Optional[str] = None,
176+
) -> Any:
177+
try:
178+
response = requests.request(
179+
method=method,
180+
url=self._url(path=path, region=region),
181+
headers={"Authorization": f"Bearer {self.api_key}"},
182+
json=json,
183+
params=params,
184+
timeout=TIMEOUT,
185+
)
186+
except requests.RequestException as e:
187+
raise BackendError(f"JarvisLabs request failed: {e}") from e
188+
if response.ok:
189+
if not response.content:
190+
return {}
191+
try:
192+
return response.json()
193+
except ValueError as e:
194+
raise BackendError("Unexpected non-JSON JarvisLabs response") from e
195+
message = _get_response_error(response)
196+
if response.status_code in [401, 403]:
197+
raise BackendInvalidCredentialsError(fields=[["creds", "api_key"]])
198+
if response.status_code == 404:
199+
raise JarvisLabsNotFoundError(message)
200+
if response.status_code in [400, 409] and _looks_like_no_capacity(message):
201+
raise NoCapacityError(message)
202+
raise BackendError(message)
203+
204+
def _url(self, *, path: str, region: Optional[str] = None) -> str:
205+
if region is None:
206+
base_url = API_URL
207+
else:
208+
# gpuhunt owns this allowlist because it filters JarvisLabs offers. Do not
209+
# fall back for unknown regions: regional VM APIs use separate hosts and
210+
# JarvisLabs does not expose endpoint discovery in server_meta.
211+
base_url = JARVISLABS_REGION_URLS.get(region)
212+
if base_url is None:
213+
raise BackendError(
214+
f"Unsupported JarvisLabs region {region!r}. "
215+
"JarvisLabs does not expose provisioning endpoint discovery."
216+
)
217+
return base_url.rstrip("/") + "/" + path.lstrip("/")
218+
219+
220+
def is_cpu_vm(instance: Dict[str, Any]) -> bool:
221+
return _instance_template(instance) == "vm" and str(instance.get("gpu_type")).upper() == "CPU"
222+
223+
224+
def _instance_template(instance: Dict[str, Any]) -> str:
225+
return str(instance.get("template") or instance.get("framework") or "").lower()
226+
227+
228+
def _get_created_machine_id(resp: Any, operation: str) -> str:
229+
_raise_if_unsuccessful(resp, f"JarvisLabs {operation} failed")
230+
if isinstance(resp, dict):
231+
machine_id = resp.get("machine_id")
232+
if machine_id is not None:
233+
return str(machine_id)
234+
raise BackendError(f"JarvisLabs {operation} failed: missing machine_id")
235+
236+
237+
def _raise_if_unsuccessful(resp: Any, message: str) -> None:
238+
if _is_successful(resp):
239+
return
240+
backend_message = _backend_message(resp)
241+
if _looks_like_no_capacity(backend_message):
242+
raise NoCapacityError(backend_message)
243+
raise BackendError(f"{message}: {backend_message}")
244+
245+
246+
def _is_successful(resp: Any) -> bool:
247+
if not isinstance(resp, dict):
248+
return True
249+
if "success" in resp:
250+
return _coerce_bool(resp["success"])
251+
if "sucess" in resp:
252+
return _coerce_bool(resp["sucess"])
253+
return True
254+
255+
256+
def _coerce_bool(value: Any) -> bool:
257+
if isinstance(value, bool):
258+
return value
259+
if isinstance(value, str):
260+
return value.strip().lower() in {"1", "true", "yes", "success"}
261+
return bool(value)
262+
263+
264+
def _get_response_error(response: requests.Response) -> str:
265+
try:
266+
data = response.json()
267+
except ValueError:
268+
return response.text or f"HTTP {response.status_code}"
269+
message = _backend_message(data)
270+
return message or f"HTTP {response.status_code}"
271+
272+
273+
def _backend_message(resp: Any) -> str:
274+
if isinstance(resp, dict):
275+
detail = resp.get("detail")
276+
if isinstance(detail, list):
277+
return "; ".join(str(item.get("msg", item)) for item in detail)
278+
return str(
279+
resp.get("message")
280+
or resp.get("error")
281+
or resp.get("detail")
282+
or resp.get("msg")
283+
or resp
284+
)
285+
return str(resp)
286+
287+
288+
def _looks_like_no_capacity(message: str) -> bool:
289+
message = message.lower()
290+
return "capacity" in message or "available" in message or "stock" in message
291+
292+
293+
def _normalize_public_key(public_key: str) -> str:
294+
return " ".join(public_key.strip().split()[:2])
295+
296+
297+
def _get_ssh_key_name(public_key: str) -> str:
298+
return "dstack-" + hashlib.sha1(public_key.encode()).hexdigest()[:16]

0 commit comments

Comments
 (0)