diff --git a/mkdocs/docs/concepts/backends.md b/mkdocs/docs/concepts/backends.md
index 8f7b3325d8..e6483fd3ab 100644
--- a/mkdocs/docs/concepts/backends.md
+++ b/mkdocs/docs/concepts/backends.md
@@ -918,6 +918,26 @@ projects:
+### JarvisLabs
+
+Log into your [JarvisLabs](https://cloud.jarvislabs.ai/) account and create an API key.
+
+Then, go ahead and configure the backend:
+
+
+
+```yaml
+projects:
+- name: main
+ backends:
+ - type: jarvislabs
+ creds:
+ type: api_key
+ api_key: ...
+```
+
+
+
### CloudRift
Log into your [CloudRift](https://console.cloudrift.ai/) console, click `API Keys` in the sidebar and click the button to create a new API key.
diff --git a/mkdocs/docs/reference/server/config.yml.md b/mkdocs/docs/reference/server/config.yml.md
index 80e48b028e..4e515aae5a 100644
--- a/mkdocs/docs/reference/server/config.yml.md
+++ b/mkdocs/docs/reference/server/config.yml.md
@@ -369,6 +369,23 @@ to configure [backends](../../concepts/backends.md) and other [server-level sett
type:
required: true
+##### `projects[n].backends[type=jarvislabs]` { #jarvislabs data-toc-label="jarvislabs" }
+
+#SCHEMA# dstack._internal.core.backends.jarvislabs.models.JarvisLabsBackendFileConfigWithCreds
+ overrides:
+ show_root_heading: false
+ type:
+ required: true
+ item_id_prefix: jarvislabs-
+
+###### `projects[n].backends[type=jarvislabs].creds` { #jarvislabs-creds data-toc-label="creds" }
+
+#SCHEMA# dstack._internal.core.backends.jarvislabs.models.JarvisLabsAPIKeyCreds
+ overrides:
+ show_root_heading: false
+ type:
+ required: true
+
##### `projects[n].backends[type=cloudrift]` { #cloudrift data-toc-label="cloudrift" }
#SCHEMA# dstack._internal.core.backends.cloudrift.models.CloudRiftBackendConfigWithCreds
diff --git a/pyproject.toml b/pyproject.toml
index 4d19b23d75..55b31f3572 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -32,7 +32,7 @@ dependencies = [
"python-multipart>=0.0.16",
"filelock",
"psutil",
- "gpuhunt==0.1.21",
+ "gpuhunt==0.1.22",
"argcomplete>=3.5.0",
"ignore-python>=0.2.0",
"orjson",
diff --git a/src/dstack/_internal/core/backends/configurators.py b/src/dstack/_internal/core/backends/configurators.py
index 75a4a86abb..cdeac7f608 100644
--- a/src/dstack/_internal/core/backends/configurators.py
+++ b/src/dstack/_internal/core/backends/configurators.py
@@ -87,6 +87,15 @@
except ImportError:
pass
+try:
+ from dstack._internal.core.backends.jarvislabs.configurator import (
+ JarvisLabsConfigurator,
+ )
+
+ _CONFIGURATOR_CLASSES.append(JarvisLabsConfigurator)
+except ImportError:
+ pass
+
try:
from dstack._internal.core.backends.kubernetes.configurator import (
KubernetesConfigurator,
diff --git a/src/dstack/_internal/core/backends/jarvislabs/__init__.py b/src/dstack/_internal/core/backends/jarvislabs/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/dstack/_internal/core/backends/jarvislabs/api_client.py b/src/dstack/_internal/core/backends/jarvislabs/api_client.py
new file mode 100644
index 0000000000..e1a2a022ec
--- /dev/null
+++ b/src/dstack/_internal/core/backends/jarvislabs/api_client.py
@@ -0,0 +1,301 @@
+import hashlib
+from typing import Any, Dict, List, Optional
+
+import requests
+from gpuhunt.providers.jarvislabs import API_URL, JARVISLABS_REGION_URLS
+
+from dstack._internal.core.errors import (
+ BackendError,
+ BackendInvalidCredentialsError,
+ NoCapacityError,
+)
+
+TIMEOUT = 120
+
+
+class JarvisLabsNotFoundError(BackendError):
+ pass
+
+
+class JarvisLabsAPIClient:
+ def __init__(self, api_key: str):
+ self.api_key = api_key
+
+ def validate_api_key(self) -> bool:
+ try:
+ self.get_user_info()
+ except BackendInvalidCredentialsError:
+ return False
+ return True
+
+ def get_user_info(self) -> Dict[str, Any]:
+ resp = self._make_request("GET", "users/user_info")
+ if not isinstance(resp, dict):
+ raise BackendError("Unexpected JarvisLabs user_info response")
+ return resp
+
+ def list_ssh_keys(self) -> List[Dict[str, Any]]:
+ resp = self._make_request("GET", "ssh/")
+ if isinstance(resp, list):
+ return resp
+ raise BackendError("Unexpected JarvisLabs SSH key list response")
+
+ def add_ssh_key(self, public_key: str, key_name: str) -> None:
+ resp = self._make_request(
+ "POST",
+ "ssh/",
+ json={
+ "ssh_key": public_key,
+ "key_name": key_name,
+ },
+ )
+ _raise_if_unsuccessful(resp, "Failed to add JarvisLabs SSH key")
+
+ def add_ssh_key_if_needed(self, public_key: str) -> None:
+ normalized_key = _normalize_public_key(public_key)
+ for ssh_key in self.list_ssh_keys():
+ if _normalize_public_key(str(ssh_key.get("ssh_key", ""))) == normalized_key:
+ return
+ key_name = _get_ssh_key_name(normalized_key)
+ self.add_ssh_key(public_key=public_key, key_name=key_name)
+
+ def create_gpu_vm(
+ self,
+ *,
+ gpu_type: str,
+ num_gpus: int,
+ is_spot: bool,
+ storage: int,
+ region: str,
+ name: str,
+ ) -> str:
+ resp = self._make_request(
+ "POST",
+ "templates/vm/create",
+ region=region,
+ json={
+ "gpu_type": gpu_type,
+ "num_gpus": num_gpus,
+ "hdd": storage,
+ "region": region,
+ "name": name,
+ "is_spot": is_spot,
+ "duration": "hour",
+ "disk_type": "ssd",
+ "http_ports": "",
+ # JarvisLabs accepts script_id for VM creates, but live CPU/GPU VM tests
+ # showed it is not injected into cloud-init user-data/runcmd.
+ "script_id": None,
+ "script_args": "",
+ "fs_id": None,
+ "arguments": "",
+ },
+ )
+ return _get_created_machine_id(resp, "GPU VM creation")
+
+ def create_cpu_vm(
+ self,
+ *,
+ vcpus: int,
+ ram_gb: int,
+ storage: int,
+ region: str,
+ name: str,
+ ) -> str:
+ resp = self._make_request(
+ "POST",
+ "templates/vm/cpu/create",
+ region=region,
+ json={
+ "num_cpus": 1,
+ "vcpus": vcpus,
+ "ram_gb": ram_gb,
+ "hdd": storage,
+ "region": region,
+ "name": name,
+ "duration": "hour",
+ "disk_type": "ssd",
+ # Do not pass script_id here either; CPU VM create accepts it but ignores it.
+ },
+ )
+ return _get_created_machine_id(resp, "CPU VM creation")
+
+ def get_instance(self, machine_id: str) -> Optional[Dict[str, Any]]:
+ try:
+ resp = self._make_request("GET", f"users/fetch/{machine_id}")
+ except JarvisLabsNotFoundError:
+ return None
+ if not _is_successful(resp):
+ return None
+ if isinstance(resp, dict):
+ instance = resp.get("instance")
+ if isinstance(instance, dict):
+ return instance
+ return None
+
+ def get_instance_status(self, *, machine_id: str, region: str) -> Optional[Dict[str, Any]]:
+ try:
+ resp = self._make_request(
+ "GET",
+ "misc/status",
+ region=region,
+ params={"machine_id": machine_id},
+ )
+ except JarvisLabsNotFoundError:
+ return None
+ if isinstance(resp, dict):
+ return resp
+ return None
+
+ def destroy_instance(self, *, machine_id: str, region: str) -> None:
+ instance = self.get_instance(machine_id)
+ if instance is None:
+ return
+ endpoint = "templates/vm/destroy"
+ if is_cpu_vm(instance):
+ endpoint = "templates/vm/cpu/destroy"
+ elif _instance_template(instance) != "vm":
+ endpoint = "misc/destroy"
+
+ try:
+ resp = self._make_request(
+ "POST",
+ endpoint,
+ region=instance.get("region") or region,
+ params={"machine_id": machine_id},
+ )
+ except JarvisLabsNotFoundError:
+ return
+ _raise_if_unsuccessful(resp, "Failed to destroy JarvisLabs instance")
+
+ def _make_request(
+ self,
+ method: str,
+ path: str,
+ *,
+ json: Optional[Dict[str, Any]] = None,
+ params: Optional[Dict[str, Any]] = None,
+ region: Optional[str] = None,
+ ) -> Any:
+ try:
+ response = requests.request(
+ method=method,
+ url=self._url(path=path, region=region),
+ headers={"Authorization": f"Bearer {self.api_key}"},
+ json=json,
+ params=params,
+ timeout=TIMEOUT,
+ )
+ except requests.RequestException as e:
+ raise BackendError(f"JarvisLabs request failed: {e}") from e
+ if response.ok:
+ if not response.content:
+ return {}
+ try:
+ return response.json()
+ except ValueError as e:
+ raise BackendError("Unexpected non-JSON JarvisLabs response") from e
+ message = _get_response_error(response)
+ if response.status_code in [401, 403]:
+ raise BackendInvalidCredentialsError(fields=[["creds", "api_key"]])
+ if response.status_code == 404:
+ raise JarvisLabsNotFoundError(message)
+ if response.status_code in [400, 409] and _looks_like_no_capacity(message):
+ raise NoCapacityError(message)
+ raise BackendError(message)
+
+ def _url(self, *, path: str, region: Optional[str] = None) -> str:
+ if region is None:
+ base_url = API_URL
+ else:
+ # gpuhunt owns this allowlist because it filters JarvisLabs offers. Do not
+ # fall back for unknown regions: regional VM APIs use separate hosts and
+ # JarvisLabs does not expose endpoint discovery in server_meta.
+ base_url = JARVISLABS_REGION_URLS.get(region)
+ if base_url is None:
+ raise BackendError(
+ f"Unsupported JarvisLabs region {region!r}. "
+ "JarvisLabs does not expose provisioning endpoint discovery."
+ )
+ return base_url.rstrip("/") + "/" + path.lstrip("/")
+
+
+def is_cpu_vm(instance: Dict[str, Any]) -> bool:
+ return _instance_template(instance) == "vm" and str(instance.get("gpu_type")).upper() == "CPU"
+
+
+def _instance_template(instance: Dict[str, Any]) -> str:
+ return str(instance.get("template") or instance.get("framework") or "").lower()
+
+
+def _get_created_machine_id(resp: Any, operation: str) -> str:
+ _raise_if_unsuccessful(resp, f"JarvisLabs {operation} failed")
+ if isinstance(resp, dict):
+ machine_id = resp.get("machine_id")
+ if machine_id is not None:
+ return str(machine_id)
+ raise BackendError(f"JarvisLabs {operation} failed: missing machine_id")
+
+
+def _raise_if_unsuccessful(resp: Any, message: str) -> None:
+ if _is_successful(resp):
+ return
+ backend_message = _backend_message(resp)
+ if _looks_like_no_capacity(backend_message):
+ raise NoCapacityError(backend_message)
+ raise BackendError(f"{message}: {backend_message}")
+
+
+def _is_successful(resp: Any) -> bool:
+ if not isinstance(resp, dict):
+ return True
+ if "success" in resp:
+ return _coerce_bool(resp["success"])
+ if "sucess" in resp:
+ return _coerce_bool(resp["sucess"])
+ return True
+
+
+def _coerce_bool(value: Any) -> bool:
+ if isinstance(value, bool):
+ return value
+ if isinstance(value, str):
+ return value.strip().lower() in {"1", "true", "yes", "success"}
+ return bool(value)
+
+
+def _get_response_error(response: requests.Response) -> str:
+ try:
+ data = response.json()
+ except ValueError:
+ return response.text or f"HTTP {response.status_code}"
+ message = _backend_message(data)
+ return message or f"HTTP {response.status_code}"
+
+
+def _backend_message(resp: Any) -> str:
+ if isinstance(resp, dict):
+ detail = resp.get("detail")
+ if isinstance(detail, list):
+ return "; ".join(str(item.get("msg", item)) for item in detail)
+ return str(
+ resp.get("message")
+ or resp.get("error")
+ or resp.get("detail")
+ or resp.get("msg")
+ or resp
+ )
+ return str(resp)
+
+
+def _looks_like_no_capacity(message: str) -> bool:
+ message = message.lower()
+ return "capacity" in message or "available" in message or "stock" in message
+
+
+def _normalize_public_key(public_key: str) -> str:
+ return " ".join(public_key.strip().split()[:2])
+
+
+def _get_ssh_key_name(public_key: str) -> str:
+ return "dstack-" + hashlib.sha1(public_key.encode()).hexdigest()[:16]
diff --git a/src/dstack/_internal/core/backends/jarvislabs/backend.py b/src/dstack/_internal/core/backends/jarvislabs/backend.py
new file mode 100644
index 0000000000..ac47171bd6
--- /dev/null
+++ b/src/dstack/_internal/core/backends/jarvislabs/backend.py
@@ -0,0 +1,16 @@
+from dstack._internal.core.backends.base.backend import Backend
+from dstack._internal.core.backends.jarvislabs.compute import JarvisLabsCompute
+from dstack._internal.core.backends.jarvislabs.models import JarvisLabsConfig
+from dstack._internal.core.models.backends.base import BackendType
+
+
+class JarvisLabsBackend(Backend):
+ TYPE = BackendType.JARVISLABS
+ COMPUTE_CLASS = JarvisLabsCompute
+
+ def __init__(self, config: JarvisLabsConfig):
+ self.config = config
+ self._compute = JarvisLabsCompute(self.config)
+
+ def compute(self) -> JarvisLabsCompute:
+ return self._compute
diff --git a/src/dstack/_internal/core/backends/jarvislabs/compute.py b/src/dstack/_internal/core/backends/jarvislabs/compute.py
new file mode 100644
index 0000000000..2e79902b45
--- /dev/null
+++ b/src/dstack/_internal/core/backends/jarvislabs/compute.py
@@ -0,0 +1,331 @@
+import shlex
+import subprocess
+import tempfile
+from collections.abc import Iterable
+from typing import List, Optional
+
+import gpuhunt
+from gpuhunt.providers.jarvislabs import JarvisLabsProvider
+
+from dstack._internal.core.backends.base.backend import Compute
+from dstack._internal.core.backends.base.compute import (
+ ComputeWithAllOffersCached,
+ ComputeWithCreateInstanceSupport,
+ ComputeWithInstanceVolumesSupport,
+ ComputeWithPrivilegedSupport,
+ generate_unique_instance_name,
+ get_shim_commands,
+)
+from dstack._internal.core.backends.base.offers import (
+ OfferModifier,
+ get_catalog_offers,
+ get_offers_disk_modifier,
+)
+from dstack._internal.core.backends.jarvislabs.api_client import JarvisLabsAPIClient
+from dstack._internal.core.backends.jarvislabs.models import JarvisLabsConfig
+from dstack._internal.core.errors import ProvisioningError
+from dstack._internal.core.models.backends.base import BackendType
+from dstack._internal.core.models.instances import (
+ InstanceAvailability,
+ InstanceConfiguration,
+ InstanceOfferWithAvailability,
+)
+from dstack._internal.core.models.placement import PlacementGroup
+from dstack._internal.core.models.resources import Memory, Range
+from dstack._internal.core.models.runs import JobProvisioningData, Requirements
+from dstack._internal.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+MAX_INSTANCE_NAME_LEN = 40
+# JarvisLabs VM storage is configurable through the `hdd` create parameter.
+MIN_DISK_SIZE = Memory.parse("100GB")
+CONFIGURABLE_DISK_SIZE = Range[Memory](min=MIN_DISK_SIZE, max=None)
+DEFAULT_USERNAME = "ubuntu"
+SSH_CONNECT_TIMEOUT_SECONDS = 10
+SSH_SETUP_TIMEOUT_SECONDS = 240
+SSH_LAUNCH_TIMEOUT_SECONDS = 60
+
+
+class JarvisLabsCompute(
+ ComputeWithAllOffersCached,
+ ComputeWithCreateInstanceSupport,
+ ComputeWithPrivilegedSupport,
+ ComputeWithInstanceVolumesSupport,
+ Compute,
+):
+ def __init__(self, config: JarvisLabsConfig):
+ super().__init__()
+ self.config = config
+ self.api_client = JarvisLabsAPIClient(config.creds.api_key)
+ self._catalog = gpuhunt.Catalog(balance_resources=False, auto_reload=False)
+ self._catalog.add_provider(JarvisLabsProvider(api_key=self.config.creds.api_key))
+
+ def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]:
+ offers = get_catalog_offers(
+ backend=BackendType.JARVISLABS,
+ locations=self.config.regions or None,
+ catalog=self._catalog,
+ configurable_disk_size=CONFIGURABLE_DISK_SIZE,
+ )
+ return [
+ offer.with_availability(availability=InstanceAvailability.AVAILABLE)
+ for offer in offers
+ ]
+
+ def get_offers_modifiers(self, requirements: Requirements) -> Iterable[OfferModifier]:
+ return [get_offers_disk_modifier(CONFIGURABLE_DISK_SIZE, requirements)]
+
+ def create_instance(
+ self,
+ instance_offer: InstanceOfferWithAvailability,
+ instance_config: InstanceConfiguration,
+ placement_group: Optional[PlacementGroup],
+ ) -> JobProvisioningData:
+ instance_name = generate_unique_instance_name(
+ instance_config, max_length=MAX_INSTANCE_NAME_LEN
+ )
+ self.api_client.add_ssh_key_if_needed(instance_config.ssh_keys[0].public)
+ instance_id = None
+ try:
+ if instance_offer.instance.resources.gpus:
+ instance_id = self.api_client.create_gpu_vm(
+ gpu_type=_get_jarvislabs_gpu_type(instance_offer),
+ num_gpus=len(instance_offer.instance.resources.gpus),
+ is_spot=instance_offer.instance.resources.spot,
+ storage=_get_disk_size_gb(instance_offer),
+ region=instance_offer.region,
+ name=instance_name,
+ )
+ else:
+ instance_id = self.api_client.create_cpu_vm(
+ vcpus=instance_offer.instance.resources.cpus,
+ ram_gb=round(instance_offer.instance.resources.memory_mib / 1024),
+ storage=_get_disk_size_gb(instance_offer),
+ region=instance_offer.region,
+ name=instance_name,
+ )
+
+ except BaseException:
+ if instance_id is not None:
+ try:
+ self.api_client.destroy_instance(
+ machine_id=instance_id,
+ region=instance_offer.region,
+ )
+ except Exception:
+ logger.exception(
+ "Could not destroy failed JarvisLabs instance %s", instance_id
+ )
+ raise
+ return JobProvisioningData(
+ backend=instance_offer.backend,
+ instance_type=instance_offer.instance,
+ instance_id=instance_id,
+ hostname=None,
+ internal_ip=None,
+ region=instance_offer.region,
+ price=instance_offer.price,
+ username=DEFAULT_USERNAME,
+ ssh_port=22,
+ dockerized=True,
+ ssh_proxy=None,
+ backend_data=None,
+ )
+
+ def update_provisioning_data(
+ self,
+ provisioning_data: JobProvisioningData,
+ project_ssh_public_key: str,
+ project_ssh_private_key: str,
+ ):
+ instance = self.api_client.get_instance(provisioning_data.instance_id)
+ if instance is None:
+ status = self.api_client.get_instance_status(
+ machine_id=provisioning_data.instance_id,
+ region=provisioning_data.region,
+ )
+ if status is not None and str(status.get("status")).lower() == "failed":
+ _raise_failed_status(status)
+ return
+
+ status = str(instance.get("status")).lower()
+ if status == "failed":
+ _raise_failed_status(instance)
+ if status != "running":
+ return
+
+ hostname = instance.get("public_ip")
+ if not hostname:
+ return
+ username = _get_ssh_username(instance)
+ if not _start_runner(
+ hostname=hostname,
+ username=username,
+ project_ssh_private_key=project_ssh_private_key,
+ arch=provisioning_data.instance_type.resources.cpu_arch,
+ ):
+ return
+ provisioning_data.hostname = hostname
+ provisioning_data.username = username
+
+ def terminate_instance(
+ self, instance_id: str, region: str, backend_data: Optional[str] = None
+ ):
+ self.api_client.destroy_instance(machine_id=instance_id, region=region)
+
+
+def _get_jarvislabs_gpu_type(instance_offer: InstanceOfferWithAvailability) -> str:
+ gpu = instance_offer.instance.resources.gpus[0]
+ memory_gb = round(gpu.memory_mib / 1024)
+ if gpu.name == "A100" and memory_gb == 80:
+ return "A100-80GB"
+ return gpu.name
+
+
+def _get_disk_size_gb(instance_offer: InstanceOfferWithAvailability) -> int:
+ disk_size_gb = round(instance_offer.instance.resources.disk.size_mib / 1024)
+ return max(round(MIN_DISK_SIZE), disk_size_gb)
+
+
+def _format_failed_status(status: dict) -> str:
+ message = status.get("error") or "unknown error"
+ code = status.get("code")
+ if code is not None:
+ return f"JarvisLabs instance creation failed: {message} (code={code})"
+ return f"JarvisLabs instance creation failed: {message}"
+
+
+def _raise_failed_status(status: dict) -> None:
+ raise ProvisioningError(_format_failed_status(status), status)
+
+
+def _get_ssh_username(instance: dict) -> str:
+ ssh_command = instance.get("ssh_str") or instance.get("ssh_command")
+ if not isinstance(ssh_command, str):
+ return DEFAULT_USERNAME
+ try:
+ parts = shlex.split(ssh_command)
+ except ValueError:
+ return DEFAULT_USERNAME
+ for part in parts[1:]:
+ if part.startswith("-") or "@" not in part:
+ continue
+ return part.rsplit("@", 1)[0]
+ return DEFAULT_USERNAME
+
+
+def _start_runner(
+ hostname: str,
+ username: str,
+ project_ssh_private_key: str,
+ arch: Optional[str],
+) -> bool:
+ commands = get_shim_commands(arch=arch)
+ launch_command = "sudo sh -c " + shlex.quote(" && ".join(commands))
+ try:
+ if not _setup_instance(
+ hostname=hostname,
+ username=username,
+ ssh_private_key=project_ssh_private_key,
+ ):
+ return False
+ return _launch_runner(
+ hostname=hostname,
+ username=username,
+ ssh_private_key=project_ssh_private_key,
+ launch_command=launch_command,
+ )
+ except Exception:
+ logger.exception("Failed to start dstack shim on JarvisLabs instance %s", hostname)
+ return False
+
+
+def _setup_instance(
+ hostname: str,
+ username: str,
+ ssh_private_key: str,
+) -> bool:
+ setup_commands = [
+ "mkdir -p ~/.dstack",
+ "if ! command -v curl >/dev/null 2>&1 || ! command -v docker >/dev/null 2>&1 || ! command -v jq >/dev/null 2>&1; then sudo apt-get update; fi",
+ "if ! command -v curl >/dev/null 2>&1; then sudo DEBIAN_FRONTEND=noninteractive apt-get install -y curl; fi",
+ "if ! command -v docker >/dev/null 2>&1; then sudo apt-get update && sudo DEBIAN_FRONTEND=noninteractive apt-get install -y docker.io; fi",
+ "if ! command -v jq >/dev/null 2>&1; then sudo apt-get update && sudo DEBIAN_FRONTEND=noninteractive apt-get install -y jq; fi",
+ "sudo systemctl enable --now docker || sudo service docker start || true",
+ ]
+ return _run_ssh_command(
+ hostname=hostname,
+ username=username,
+ ssh_private_key=ssh_private_key,
+ command=" && ".join(setup_commands),
+ timeout=SSH_SETUP_TIMEOUT_SECONDS,
+ )
+
+
+def _launch_runner(
+ hostname: str,
+ username: str,
+ ssh_private_key: str,
+ launch_command: str,
+) -> bool:
+ daemonized_command = f"{launch_command.rstrip('&')} >/tmp/dstack-shim.log 2>&1 & disown"
+ return _run_ssh_command(
+ hostname=hostname,
+ username=username,
+ ssh_private_key=ssh_private_key,
+ command=daemonized_command,
+ timeout=SSH_LAUNCH_TIMEOUT_SECONDS,
+ )
+
+
+def _run_ssh_command(
+ hostname: str,
+ username: str,
+ ssh_private_key: str,
+ command: str,
+ timeout: int,
+) -> bool:
+ with tempfile.NamedTemporaryFile("w+") as f:
+ f.write(ssh_private_key)
+ f.flush()
+ try:
+ proc = subprocess.run(
+ [
+ "ssh",
+ "-F",
+ "none",
+ "-o",
+ "BatchMode=yes",
+ "-o",
+ f"ConnectTimeout={SSH_CONNECT_TIMEOUT_SECONDS}",
+ "-o",
+ "ConnectionAttempts=1",
+ "-o",
+ "StrictHostKeyChecking=no",
+ "-o",
+ "UserKnownHostsFile=/dev/null",
+ "-o",
+ "LogLevel=ERROR",
+ "-i",
+ f.name,
+ f"{username}@{hostname}",
+ command,
+ ],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True,
+ timeout=timeout,
+ )
+ except subprocess.TimeoutExpired:
+ logger.debug("Timed out running SSH command on JarvisLabs instance %s", hostname)
+ return False
+ if proc.returncode != 0:
+ logger.debug(
+ "SSH command failed on JarvisLabs instance %s: exit_code=%s stderr=%r",
+ hostname,
+ proc.returncode,
+ proc.stderr[-1000:],
+ )
+ return False
+ return True
diff --git a/src/dstack/_internal/core/backends/jarvislabs/configurator.py b/src/dstack/_internal/core/backends/jarvislabs/configurator.py
new file mode 100644
index 0000000000..ceaebeedf9
--- /dev/null
+++ b/src/dstack/_internal/core/backends/jarvislabs/configurator.py
@@ -0,0 +1,85 @@
+import json
+
+from gpuhunt.providers.jarvislabs import JARVISLABS_REGION_URLS
+
+from dstack._internal.core.backends.base.configurator import (
+ BackendRecord,
+ Configurator,
+ raise_invalid_credentials_error,
+)
+from dstack._internal.core.backends.jarvislabs import api_client
+from dstack._internal.core.backends.jarvislabs.backend import JarvisLabsBackend
+from dstack._internal.core.backends.jarvislabs.models import (
+ JarvisLabsBackendConfig,
+ JarvisLabsBackendConfigWithCreds,
+ JarvisLabsConfig,
+ JarvisLabsCreds,
+ JarvisLabsStoredConfig,
+)
+from dstack._internal.core.errors import ServerClientError
+from dstack._internal.core.models.backends.base import BackendType
+
+
+class JarvisLabsConfigurator(
+ Configurator[
+ JarvisLabsBackendConfig,
+ JarvisLabsBackendConfigWithCreds,
+ ]
+):
+ TYPE = BackendType.JARVISLABS
+ BACKEND_CLASS = JarvisLabsBackend
+
+ def validate_config(
+ self, config: JarvisLabsBackendConfigWithCreds, default_creds_enabled: bool
+ ):
+ self._validate_api_key(config.creds.api_key)
+ self._validate_regions(config.regions)
+
+ def create_backend(
+ self, project_name: str, config: JarvisLabsBackendConfigWithCreds
+ ) -> BackendRecord:
+ return BackendRecord(
+ config=JarvisLabsStoredConfig(
+ **JarvisLabsBackendConfig.__response__.parse_obj(config).dict()
+ ).json(),
+ auth=JarvisLabsCreds.parse_obj(config.creds).json(),
+ )
+
+ def get_backend_config_with_creds(
+ self, record: BackendRecord
+ ) -> JarvisLabsBackendConfigWithCreds:
+ config = self._get_config(record)
+ return JarvisLabsBackendConfigWithCreds.__response__.parse_obj(config)
+
+ def get_backend_config_without_creds(self, record: BackendRecord) -> JarvisLabsBackendConfig:
+ config = self._get_config(record)
+ return JarvisLabsBackendConfig.__response__.parse_obj(config)
+
+ def get_backend(self, record: BackendRecord) -> JarvisLabsBackend:
+ config = self._get_config(record)
+ return JarvisLabsBackend(config=config)
+
+ def _get_config(self, record: BackendRecord) -> JarvisLabsConfig:
+ return JarvisLabsConfig.__response__(
+ **json.loads(record.config),
+ creds=JarvisLabsCreds.parse_raw(record.auth),
+ )
+
+ def _validate_api_key(self, api_key: str):
+ client = api_client.JarvisLabsAPIClient(api_key=api_key)
+ if not client.validate_api_key():
+ raise_invalid_credentials_error(fields=[["creds", "api_key"]])
+
+ def _validate_regions(self, regions: list[str] | None):
+ if not regions:
+ return
+ invalid_regions = sorted(set(regions) - set(JARVISLABS_REGION_URLS))
+ if invalid_regions:
+ raise ServerClientError(
+ msg=(
+ f"Unsupported JarvisLabs regions: {invalid_regions}. "
+ f"Supported regions: {sorted(JARVISLABS_REGION_URLS)}. "
+ "JarvisLabs does not expose provisioning endpoint discovery."
+ ),
+ fields=[["regions"]],
+ )
diff --git a/src/dstack/_internal/core/backends/jarvislabs/models.py b/src/dstack/_internal/core/backends/jarvislabs/models.py
new file mode 100644
index 0000000000..dae710089e
--- /dev/null
+++ b/src/dstack/_internal/core/backends/jarvislabs/models.py
@@ -0,0 +1,47 @@
+from typing import Annotated, List, Literal, Optional, Union
+
+from pydantic import Field
+
+from dstack._internal.core.models.common import CoreModel
+
+
+class JarvisLabsAPIKeyCreds(CoreModel):
+ type: Annotated[Literal["api_key"], Field(description="The type of credentials")] = "api_key"
+ api_key: Annotated[str, Field(description="The JarvisLabs API key")]
+
+
+AnyJarvisLabsCreds = JarvisLabsAPIKeyCreds
+JarvisLabsCreds = AnyJarvisLabsCreds
+
+
+class JarvisLabsBackendConfig(CoreModel):
+ type: Annotated[
+ Literal["jarvislabs"],
+ Field(description="The type of backend"),
+ ] = "jarvislabs"
+ regions: Annotated[
+ Optional[List[str]],
+ Field(description="The list of JarvisLabs regions. Omit to use all regions"),
+ ] = None
+
+
+class JarvisLabsBackendConfigWithCreds(JarvisLabsBackendConfig):
+ creds: Annotated[AnyJarvisLabsCreds, Field(description="The credentials")]
+
+
+AnyJarvisLabsBackendConfig = Union[
+ JarvisLabsBackendConfig,
+ JarvisLabsBackendConfigWithCreds,
+]
+
+
+class JarvisLabsBackendFileConfigWithCreds(JarvisLabsBackendConfig):
+ creds: Annotated[AnyJarvisLabsCreds, Field(description="The credentials")]
+
+
+class JarvisLabsStoredConfig(JarvisLabsBackendConfig):
+ pass
+
+
+class JarvisLabsConfig(JarvisLabsStoredConfig):
+ creds: AnyJarvisLabsCreds
diff --git a/src/dstack/_internal/core/backends/models.py b/src/dstack/_internal/core/backends/models.py
index 36a7856e38..c21141378e 100644
--- a/src/dstack/_internal/core/backends/models.py
+++ b/src/dstack/_internal/core/backends/models.py
@@ -39,6 +39,11 @@
HotAisleBackendConfigWithCreds,
HotAisleBackendFileConfigWithCreds,
)
+from dstack._internal.core.backends.jarvislabs.models import (
+ JarvisLabsBackendConfig,
+ JarvisLabsBackendConfigWithCreds,
+ JarvisLabsBackendFileConfigWithCreds,
+)
from dstack._internal.core.backends.kubernetes.models import (
KubernetesBackendConfig,
KubernetesBackendConfigWithCreds,
@@ -89,6 +94,7 @@
BaseDigitalOceanBackendConfig,
GCPBackendConfig,
HotAisleBackendConfig,
+ JarvisLabsBackendConfig,
KubernetesBackendConfig,
LambdaBackendConfig,
NebiusBackendConfig,
@@ -115,6 +121,7 @@
BaseDigitalOceanBackendConfigWithCreds,
GCPBackendConfigWithCreds,
HotAisleBackendConfigWithCreds,
+ JarvisLabsBackendConfigWithCreds,
KubernetesBackendConfigWithCreds,
LambdaBackendConfigWithCreds,
OCIBackendConfigWithCreds,
@@ -139,6 +146,7 @@
BaseDigitalOceanBackendConfigWithCreds,
GCPBackendFileConfigWithCreds,
HotAisleBackendFileConfigWithCreds,
+ JarvisLabsBackendFileConfigWithCreds,
KubernetesBackendFileConfigWithCreds,
LambdaBackendConfigWithCreds,
OCIBackendConfigWithCreds,
diff --git a/src/dstack/_internal/core/models/backends/base.py b/src/dstack/_internal/core/models/backends/base.py
index 2e8eb898ee..dd11ae67a3 100644
--- a/src/dstack/_internal/core/models/backends/base.py
+++ b/src/dstack/_internal/core/models/backends/base.py
@@ -15,6 +15,7 @@ class BackendType(str, enum.Enum):
DSTACK (BackendType): dstack Sky
GCP (BackendType): Google Cloud Platform
HOTAISLE (BackendType): Hot Aisle
+ JARVISLABS (BackendType): JarvisLabs
KUBERNETES (BackendType): Kubernetes
LAMBDA (BackendType): Lambda Cloud
NEBIUS (BackendType): Nebius AI Cloud
@@ -38,6 +39,7 @@ class BackendType(str, enum.Enum):
DSTACK = "dstack"
GCP = "gcp"
HOTAISLE = "hotaisle"
+ JARVISLABS = "jarvislabs"
KUBERNETES = "kubernetes"
LAMBDA = "lambda"
LOCAL = "local"
diff --git a/src/tests/_internal/core/backends/jarvislabs/__init__.py b/src/tests/_internal/core/backends/jarvislabs/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/tests/_internal/core/backends/jarvislabs/test_api_client.py b/src/tests/_internal/core/backends/jarvislabs/test_api_client.py
new file mode 100644
index 0000000000..09980d6ed1
--- /dev/null
+++ b/src/tests/_internal/core/backends/jarvislabs/test_api_client.py
@@ -0,0 +1,186 @@
+import pytest
+import requests
+
+from dstack._internal.core.backends.jarvislabs.api_client import (
+ JarvisLabsAPIClient,
+ is_cpu_vm,
+)
+from dstack._internal.core.errors import BackendError, BackendInvalidCredentialsError
+
+
+def test_validate_api_key_returns_false_on_unauthorized(requests_mock):
+ requests_mock.get("https://backendprod.jarvislabs.net/users/user_info", status_code=401)
+
+ assert JarvisLabsAPIClient("bad").validate_api_key() is False
+
+
+def test_get_user_info_raises_invalid_credentials_on_forbidden(requests_mock):
+ requests_mock.get("https://backendprod.jarvislabs.net/users/user_info", status_code=403)
+
+ with pytest.raises(BackendInvalidCredentialsError):
+ JarvisLabsAPIClient("bad").get_user_info()
+
+
+def test_make_request_wraps_request_errors(requests_mock):
+ requests_mock.get(
+ "https://backendprod.jarvislabs.net/users/user_info",
+ exc=requests.ConnectTimeout("timed out"),
+ )
+
+ with pytest.raises(BackendError, match="JarvisLabs request failed"):
+ JarvisLabsAPIClient("token").get_user_info()
+
+
+def test_get_user_info_rejects_non_json_success_response(requests_mock):
+ requests_mock.get("https://backendprod.jarvislabs.net/users/user_info", text="ok")
+
+ with pytest.raises(BackendError, match="Unexpected non-JSON JarvisLabs response"):
+ JarvisLabsAPIClient("token").get_user_info()
+
+
+def test_add_ssh_key_if_needed_reuses_existing_key(requests_mock):
+ public_key = "ssh-rsa AAAA test-comment"
+ requests_mock.get(
+ "https://backendprod.jarvislabs.net/ssh/",
+ json=[{"ssh_key": "ssh-rsa AAAA another-comment", "key_name": "existing"}],
+ )
+
+ JarvisLabsAPIClient("token").add_ssh_key_if_needed(public_key)
+
+ assert requests_mock.call_count == 1
+
+
+def test_add_ssh_key_if_needed_adds_missing_key(requests_mock):
+ public_key = "ssh-rsa AAAA test-comment"
+ requests_mock.get("https://backendprod.jarvislabs.net/ssh/", json=[])
+ requests_mock.post("https://backendprod.jarvislabs.net/ssh/", json={"success": True})
+
+ JarvisLabsAPIClient("token").add_ssh_key_if_needed(public_key)
+
+ assert requests_mock.last_request.json() == {
+ "ssh_key": public_key,
+ "key_name": "dstack-36deb09319b2204c",
+ }
+
+
+def test_create_gpu_vm_posts_to_regional_vm_endpoint(requests_mock):
+ requests_mock.post(
+ "https://backendn.jarvislabs.net/templates/vm/create",
+ json={"machine_id": 123},
+ )
+
+ machine_id = JarvisLabsAPIClient("token").create_gpu_vm(
+ gpu_type="A100-80GB",
+ num_gpus=1,
+ is_spot=False,
+ storage=250,
+ region="india-noida-01",
+ name="dstack-test",
+ )
+
+ assert machine_id == "123"
+ assert requests_mock.last_request.headers["Authorization"] == "Bearer token"
+ assert requests_mock.last_request.json() == {
+ "gpu_type": "A100-80GB",
+ "num_gpus": 1,
+ "hdd": 250,
+ "region": "india-noida-01",
+ "name": "dstack-test",
+ "is_spot": False,
+ "duration": "hour",
+ "disk_type": "ssd",
+ "http_ports": "",
+ "script_id": None,
+ "script_args": "",
+ "fs_id": None,
+ "arguments": "",
+ }
+
+
+def test_create_gpu_vm_rejects_unsupported_region(requests_mock):
+ with pytest.raises(BackendError, match="Unsupported JarvisLabs region"):
+ JarvisLabsAPIClient("token").create_gpu_vm(
+ gpu_type="H100",
+ num_gpus=1,
+ is_spot=False,
+ storage=100,
+ region="unknown-region",
+ name="dstack-test",
+ )
+
+ assert requests_mock.call_count == 0
+
+
+def test_create_gpu_vm_sets_spot_flag(requests_mock):
+ requests_mock.post(
+ "https://backendn.jarvislabs.net/templates/vm/create",
+ json={"machine_id": 123},
+ )
+
+ JarvisLabsAPIClient("token").create_gpu_vm(
+ gpu_type="L4",
+ num_gpus=1,
+ is_spot=True,
+ storage=100,
+ region="india-noida-01",
+ name="dstack-spot",
+ )
+
+ assert requests_mock.last_request.json()["is_spot"] is True
+
+
+def test_create_cpu_vm_posts_to_regional_cpu_vm_endpoint(requests_mock):
+ requests_mock.post(
+ "https://backendn.jarvislabs.net/templates/vm/cpu/create",
+ json={"machine_id": 456},
+ )
+
+ machine_id = JarvisLabsAPIClient("token").create_cpu_vm(
+ vcpus=4,
+ ram_gb=16,
+ storage=100,
+ region="india-noida-01",
+ name="dstack-cpu",
+ )
+
+ assert machine_id == "456"
+ assert requests_mock.last_request.json() == {
+ "num_cpus": 1,
+ "vcpus": 4,
+ "ram_gb": 16,
+ "hdd": 100,
+ "region": "india-noida-01",
+ "name": "dstack-cpu",
+ "duration": "hour",
+ "disk_type": "ssd",
+ }
+
+
+def test_destroy_instance_uses_cpu_vm_endpoint_for_cpu_vm(requests_mock):
+ requests_mock.get(
+ "https://backendprod.jarvislabs.net/users/fetch/456",
+ json={
+ "success": True,
+ "instance": {
+ "machine_id": 456,
+ "template": "vm",
+ "gpu_type": "CPU",
+ "region": "india-noida-01",
+ },
+ },
+ )
+ requests_mock.post(
+ "https://backendn.jarvislabs.net/templates/vm/cpu/destroy",
+ json={"success": True},
+ )
+
+ JarvisLabsAPIClient("token").destroy_instance(machine_id="456", region="india-noida-01")
+
+ assert requests_mock.last_request.qs == {"machine_id": ["456"]}
+
+
+def test_is_cpu_vm_requires_vm_template_and_cpu_gpu_type():
+ assert is_cpu_vm({"template": "vm", "gpu_type": "CPU"})
+ assert is_cpu_vm({"framework": "VM", "gpu_type": "CPU"})
+ assert not is_cpu_vm({"template": "pytorch", "gpu_type": "CPU"})
+ assert not is_cpu_vm({"template": "vm", "gpu_type": "H100"})
diff --git a/src/tests/_internal/core/backends/jarvislabs/test_compute.py b/src/tests/_internal/core/backends/jarvislabs/test_compute.py
new file mode 100644
index 0000000000..27be084e1f
--- /dev/null
+++ b/src/tests/_internal/core/backends/jarvislabs/test_compute.py
@@ -0,0 +1,372 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from dstack._internal.core.backends.jarvislabs.compute import (
+ CONFIGURABLE_DISK_SIZE,
+ JarvisLabsCompute,
+ _get_disk_size_gb,
+ _get_jarvislabs_gpu_type,
+ _get_ssh_username,
+)
+from dstack._internal.core.backends.jarvislabs.models import JarvisLabsConfig, JarvisLabsCreds
+from dstack._internal.core.errors import NoCapacityError, ProvisioningError
+from dstack._internal.core.models.backends.base import BackendType
+from dstack._internal.core.models.instances import (
+ Disk,
+ Gpu,
+ InstanceAvailability,
+ InstanceConfiguration,
+ InstanceOffer,
+ InstanceOfferWithAvailability,
+ InstanceType,
+ Resources,
+ SSHKey,
+)
+from dstack._internal.core.models.resources import ResourcesSpec
+from dstack._internal.core.models.runs import JobProvisioningData, Requirements
+
+
+def _compute() -> JarvisLabsCompute:
+ compute = JarvisLabsCompute(
+ JarvisLabsConfig(creds=JarvisLabsCreds(api_key="test"), regions=["india-noida-01"])
+ )
+ compute.api_client = MagicMock()
+ compute.api_client.get_instance_status.return_value = {"status": "Running"}
+ return compute
+
+
+def _instance_config() -> InstanceConfiguration:
+ return InstanceConfiguration(
+ project_name="test-project",
+ instance_name="jarvislabs-test",
+ user="test-user",
+ ssh_keys=[SSHKey(public="ssh-rsa AAAA test")],
+ )
+
+
+def _gpu_offer(
+ *,
+ gpu_name: str = "A100",
+ gpu_memory_mib: int = 80 * 1024,
+ disk_size_mib: int = 250 * 1024,
+ spot: bool = False,
+) -> InstanceOfferWithAvailability:
+ return InstanceOfferWithAvailability(
+ backend=BackendType.JARVISLABS,
+ instance=InstanceType(
+ name=f"{gpu_name}-1x",
+ resources=Resources(
+ cpus=28,
+ memory_mib=112 * 1024,
+ gpus=[Gpu(name=gpu_name, memory_mib=gpu_memory_mib)],
+ spot=spot,
+ disk=Disk(size_mib=disk_size_mib),
+ ),
+ ),
+ region="india-noida-01",
+ price=1.49,
+ availability=InstanceAvailability.AVAILABLE,
+ )
+
+
+def _cpu_offer(*, disk_size_mib: int = 10 * 1024) -> InstanceOfferWithAvailability:
+ return InstanceOfferWithAvailability(
+ backend=BackendType.JARVISLABS,
+ instance=InstanceType(
+ name="cpu-4x16",
+ resources=Resources(
+ cpus=4,
+ memory_mib=16 * 1024,
+ gpus=[],
+ spot=False,
+ disk=Disk(size_mib=disk_size_mib),
+ ),
+ ),
+ region="india-noida-01",
+ price=0.0992,
+ availability=InstanceAvailability.AVAILABLE,
+ )
+
+
+def _cpu_catalog_offer(*, disk_size_mib: int = 10 * 1024) -> InstanceOffer:
+ offer = _cpu_offer(disk_size_mib=disk_size_mib)
+ return InstanceOffer(
+ backend=offer.backend,
+ instance=offer.instance,
+ region=offer.region,
+ price=offer.price,
+ )
+
+
+def test_get_jarvislabs_gpu_type_reconstructs_a100_80gb():
+ assert _get_jarvislabs_gpu_type(_gpu_offer()) == "A100-80GB"
+ assert _get_jarvislabs_gpu_type(_gpu_offer(gpu_memory_mib=40 * 1024)) == "A100"
+ assert _get_jarvislabs_gpu_type(_gpu_offer(gpu_name="H100")) == "H100"
+
+
+def test_get_disk_size_gb_clamps_to_jarvislabs_vm_minimum():
+ assert _get_disk_size_gb(_cpu_offer(disk_size_mib=10 * 1024)) == 100
+ assert _get_disk_size_gb(_gpu_offer(disk_size_mib=250 * 1024)) == 250
+
+
+def test_get_all_offers_uses_configurable_disk_size():
+ compute = _compute()
+
+ with patch(
+ "dstack._internal.core.backends.jarvislabs.compute.get_catalog_offers",
+ return_value=[_cpu_catalog_offer()],
+ ) as m:
+ offers = compute.get_all_offers_with_availability()
+
+ assert len(offers) == 1
+ assert offers[0].availability == InstanceAvailability.AVAILABLE
+ m.assert_called_once_with(
+ backend=BackendType.JARVISLABS,
+ locations=["india-noida-01"],
+ catalog=compute._catalog,
+ configurable_disk_size=CONFIGURABLE_DISK_SIZE,
+ )
+
+
+def test_get_offers_reuses_all_offers_cache_and_modifies_disk_size():
+ compute = _compute()
+ compute.get_all_offers_with_availability = MagicMock(
+ return_value=[_cpu_offer(disk_size_mib=100 * 1024)]
+ )
+
+ offers_250gb = list(compute.get_offers(Requirements(resources=ResourcesSpec(disk="250GB"))))
+ offers_300gb = list(compute.get_offers(Requirements(resources=ResourcesSpec(disk="300GB"))))
+
+ assert len(offers_250gb) == 1
+ assert offers_250gb[0].instance.resources.disk.size_mib == 250 * 1024
+ assert len(offers_300gb) == 1
+ assert offers_300gb[0].instance.resources.disk.size_mib == 300 * 1024
+ compute.get_all_offers_with_availability.assert_called_once()
+
+
+def test_create_gpu_instance_registers_ssh_key_and_creates_gpu_vm():
+ compute = _compute()
+ compute.api_client.create_gpu_vm.return_value = "123"
+
+ with patch(
+ "dstack._internal.core.backends.jarvislabs.compute.generate_unique_instance_name",
+ return_value="dstack-test",
+ ):
+ provisioning_data = compute.create_instance(_gpu_offer(), _instance_config(), None)
+
+ compute.api_client.add_ssh_key_if_needed.assert_called_once_with("ssh-rsa AAAA test")
+ compute.api_client.create_gpu_vm.assert_called_once_with(
+ gpu_type="A100-80GB",
+ num_gpus=1,
+ is_spot=False,
+ storage=250,
+ region="india-noida-01",
+ name="dstack-test",
+ )
+ assert provisioning_data.instance_id == "123"
+ assert provisioning_data.username == "ubuntu"
+ assert provisioning_data.dockerized is True
+ assert provisioning_data.backend_data is None
+ compute.api_client.get_instance_status.assert_not_called()
+
+
+def test_create_gpu_instance_passes_spot_flag():
+ compute = _compute()
+ compute.api_client.create_gpu_vm.return_value = "123"
+
+ with patch(
+ "dstack._internal.core.backends.jarvislabs.compute.generate_unique_instance_name",
+ return_value="dstack-test",
+ ):
+ compute.create_instance(_gpu_offer(spot=True), _instance_config(), None)
+
+ compute.api_client.create_gpu_vm.assert_called_once_with(
+ gpu_type="A100-80GB",
+ num_gpus=1,
+ is_spot=True,
+ storage=250,
+ region="india-noida-01",
+ name="dstack-test",
+ )
+
+
+def test_create_cpu_instance_registers_ssh_key_and_creates_cpu_vm():
+ compute = _compute()
+ compute.api_client.create_cpu_vm.return_value = "456"
+
+ with patch(
+ "dstack._internal.core.backends.jarvislabs.compute.generate_unique_instance_name",
+ return_value="dstack-cpu",
+ ):
+ provisioning_data = compute.create_instance(_cpu_offer(), _instance_config(), None)
+
+ compute.api_client.add_ssh_key_if_needed.assert_called_once_with("ssh-rsa AAAA test")
+ compute.api_client.create_cpu_vm.assert_called_once_with(
+ vcpus=4,
+ ram_gb=16,
+ storage=100,
+ region="india-noida-01",
+ name="dstack-cpu",
+ )
+ assert provisioning_data.instance_id == "456"
+ assert provisioning_data.backend_data is None
+
+
+def test_update_provisioning_data_sets_hostname_and_starts_runner():
+ compute = _compute()
+ compute.api_client.get_instance.return_value = {
+ "machine_id": 123,
+ "status": "Running",
+ "public_ip": "203.0.113.10",
+ "ssh_str": "ssh -o StrictHostKeyChecking=no ubuntu@203.0.113.10",
+ }
+ provisioning_data = JobProvisioningData(
+ backend=BackendType.JARVISLABS,
+ instance_type=_gpu_offer().instance,
+ instance_id="123",
+ region="india-noida-01",
+ price=1.49,
+ username="ubuntu",
+ ssh_port=22,
+ dockerized=True,
+ )
+
+ with patch(
+ "dstack._internal.core.backends.jarvislabs.compute._start_runner", return_value=True
+ ) as m:
+ compute.update_provisioning_data(
+ provisioning_data,
+ project_ssh_public_key="ssh-rsa AAAA test",
+ project_ssh_private_key="private-key",
+ )
+
+ assert provisioning_data.hostname == "203.0.113.10"
+ assert provisioning_data.username == "ubuntu"
+ m.assert_called_once_with(
+ hostname="203.0.113.10",
+ username="ubuntu",
+ project_ssh_private_key="private-key",
+ arch=None,
+ )
+
+
+def test_update_provisioning_data_does_not_set_hostname_until_runner_starts():
+ compute = _compute()
+ compute.api_client.get_instance.return_value = {
+ "machine_id": 123,
+ "status": "Running",
+ "public_ip": "203.0.113.10",
+ "ssh_str": "ssh -o StrictHostKeyChecking=no ubuntu@203.0.113.10",
+ }
+ provisioning_data = JobProvisioningData(
+ backend=BackendType.JARVISLABS,
+ instance_type=_gpu_offer().instance,
+ instance_id="123",
+ region="india-noida-01",
+ price=1.49,
+ username="ubuntu",
+ ssh_port=22,
+ dockerized=True,
+ )
+
+ with patch(
+ "dstack._internal.core.backends.jarvislabs.compute._start_runner", return_value=False
+ ):
+ compute.update_provisioning_data(
+ provisioning_data,
+ project_ssh_public_key="ssh-rsa AAAA test",
+ project_ssh_private_key="private-key",
+ )
+
+ assert provisioning_data.hostname is None
+
+
+def test_get_ssh_username_parses_jarvislabs_ssh_command():
+ assert (
+ _get_ssh_username({"ssh_str": "ssh -o StrictHostKeyChecking=no ubuntu@203.0.113.10"})
+ == "ubuntu"
+ )
+ assert _get_ssh_username({"ssh_str": "ssh -p 22 root@203.0.113.10"}) == "root"
+ assert _get_ssh_username({}) == "ubuntu"
+
+
+def test_terminate_instance_delegates_to_api_client():
+ compute = _compute()
+
+ compute.terminate_instance("123", "india-noida-01")
+
+ compute.api_client.destroy_instance.assert_called_once_with(
+ machine_id="123",
+ region="india-noida-01",
+ )
+
+
+def test_create_instance_propagates_create_failure_without_cleanup():
+ compute = _compute()
+ compute.api_client.create_gpu_vm.side_effect = NoCapacityError(
+ "L4 not available at this moment, please try again later"
+ )
+
+ with patch(
+ "dstack._internal.core.backends.jarvislabs.compute.generate_unique_instance_name",
+ return_value="dstack-test",
+ ):
+ with pytest.raises(NoCapacityError):
+ compute.create_instance(_gpu_offer(spot=True), _instance_config(), None)
+
+ compute.api_client.destroy_instance.assert_not_called()
+
+
+def test_update_provisioning_data_raises_provisioning_error_from_failed_capacity_status():
+ compute = _compute()
+ compute.api_client.get_instance.return_value = None
+ compute.api_client.get_instance_status.return_value = {
+ "status": "Failed",
+ "error": "L4 not available at this moment, please try again later",
+ "code": 404,
+ }
+ provisioning_data = JobProvisioningData(
+ backend=BackendType.JARVISLABS,
+ instance_type=_gpu_offer().instance,
+ instance_id="123",
+ region="india-noida-01",
+ price=1.49,
+ username="ubuntu",
+ ssh_port=22,
+ dockerized=True,
+ )
+
+ with pytest.raises(ProvisioningError):
+ compute.update_provisioning_data(
+ provisioning_data,
+ project_ssh_public_key="ssh-rsa AAAA test",
+ project_ssh_private_key="private-key",
+ )
+
+
+def test_update_provisioning_data_raises_provisioning_error_from_failed_status():
+ compute = _compute()
+ compute.api_client.get_instance.return_value = None
+ compute.api_client.get_instance_status.return_value = {
+ "status": "Failed",
+ "error": "image setup failed",
+ "code": 500,
+ }
+ provisioning_data = JobProvisioningData(
+ backend=BackendType.JARVISLABS,
+ instance_type=_gpu_offer().instance,
+ instance_id="123",
+ region="india-noida-01",
+ price=1.49,
+ username="ubuntu",
+ ssh_port=22,
+ dockerized=True,
+ )
+
+ with pytest.raises(ProvisioningError):
+ compute.update_provisioning_data(
+ provisioning_data,
+ project_ssh_public_key="ssh-rsa AAAA test",
+ project_ssh_private_key="private-key",
+ )
diff --git a/src/tests/_internal/core/backends/jarvislabs/test_configurator.py b/src/tests/_internal/core/backends/jarvislabs/test_configurator.py
new file mode 100644
index 0000000000..1da92e2600
--- /dev/null
+++ b/src/tests/_internal/core/backends/jarvislabs/test_configurator.py
@@ -0,0 +1,54 @@
+from unittest.mock import patch
+
+import pytest
+
+from dstack._internal.core.backends.jarvislabs.configurator import JarvisLabsConfigurator
+from dstack._internal.core.backends.jarvislabs.models import (
+ JarvisLabsBackendConfigWithCreds,
+ JarvisLabsCreds,
+)
+from dstack._internal.core.errors import BackendInvalidCredentialsError, ServerClientError
+
+
+class TestJarvisLabsConfigurator:
+ def test_validate_config_valid(self):
+ config = JarvisLabsBackendConfigWithCreds(
+ creds=JarvisLabsCreds(api_key="valid"),
+ regions=["india-noida-01"],
+ )
+ with patch(
+ "dstack._internal.core.backends.jarvislabs.api_client.JarvisLabsAPIClient.validate_api_key"
+ ) as validate_mock:
+ validate_mock.return_value = True
+ JarvisLabsConfigurator().validate_config(config, default_creds_enabled=True)
+
+ def test_validate_config_invalid_creds(self):
+ config = JarvisLabsBackendConfigWithCreds(
+ creds=JarvisLabsCreds(api_key="invalid"),
+ regions=["india-noida-01"],
+ )
+ with (
+ patch(
+ "dstack._internal.core.backends.jarvislabs.api_client.JarvisLabsAPIClient.validate_api_key"
+ ) as validate_mock,
+ pytest.raises(BackendInvalidCredentialsError) as exc_info,
+ ):
+ validate_mock.return_value = False
+ JarvisLabsConfigurator().validate_config(config, default_creds_enabled=True)
+ assert exc_info.value.fields == [["creds", "api_key"]]
+
+ def test_validate_config_unsupported_region(self):
+ config = JarvisLabsBackendConfigWithCreds(
+ creds=JarvisLabsCreds(api_key="valid"),
+ regions=["unknown-region"],
+ )
+ with (
+ patch(
+ "dstack._internal.core.backends.jarvislabs.api_client.JarvisLabsAPIClient.validate_api_key"
+ ) as validate_mock,
+ pytest.raises(ServerClientError) as exc_info,
+ ):
+ validate_mock.return_value = True
+ JarvisLabsConfigurator().validate_config(config, default_creds_enabled=True)
+ assert exc_info.value.fields == [["regions"]]
+ assert "Unsupported JarvisLabs regions" in exc_info.value.msg
diff --git a/src/tests/_internal/server/routers/test_backends.py b/src/tests/_internal/server/routers/test_backends.py
index 47295748d3..79fb13667e 100644
--- a/src/tests/_internal/server/routers/test_backends.py
+++ b/src/tests/_internal/server/routers/test_backends.py
@@ -92,6 +92,7 @@ async def test_returns_backend_types(self, client: AsyncClient):
"digitalocean",
"gcp",
"hotaisle",
+ "jarvislabs",
"kubernetes",
"lambda",
"nebius",