From 8861fe29aa0926790a681129803158a9eafdbe2b Mon Sep 17 00:00:00 2001 From: Andrey Cheptsov Date: Tue, 12 May 2026 22:18:02 +0200 Subject: [PATCH 1/5] Add JarvisLabs backend --- mkdocs/docs/concepts/backends.md | 20 + mkdocs/docs/reference/server/config.yml.md | 17 + pyproject.toml | 5 +- .../_internal/core/backends/configurators.py | 9 + .../core/backends/jarvislabs/__init__.py | 0 .../core/backends/jarvislabs/api_client.py | 298 ++++++++++++++ .../core/backends/jarvislabs/backend.py | 16 + .../core/backends/jarvislabs/compute.py | 364 ++++++++++++++++++ .../core/backends/jarvislabs/configurator.py | 85 ++++ .../core/backends/jarvislabs/models.py | 47 +++ src/dstack/_internal/core/backends/models.py | 8 + .../_internal/core/models/backends/base.py | 2 + .../core/backends/jarvislabs/__init__.py | 0 .../backends/jarvislabs/test_api_client.py | 186 +++++++++ .../core/backends/jarvislabs/test_compute.py | 310 +++++++++++++++ .../backends/jarvislabs/test_configurator.py | 54 +++ .../_internal/server/routers/test_backends.py | 1 + 17 files changed, 1421 insertions(+), 1 deletion(-) create mode 100644 src/dstack/_internal/core/backends/jarvislabs/__init__.py create mode 100644 src/dstack/_internal/core/backends/jarvislabs/api_client.py create mode 100644 src/dstack/_internal/core/backends/jarvislabs/backend.py create mode 100644 src/dstack/_internal/core/backends/jarvislabs/compute.py create mode 100644 src/dstack/_internal/core/backends/jarvislabs/configurator.py create mode 100644 src/dstack/_internal/core/backends/jarvislabs/models.py create mode 100644 src/tests/_internal/core/backends/jarvislabs/__init__.py create mode 100644 src/tests/_internal/core/backends/jarvislabs/test_api_client.py create mode 100644 src/tests/_internal/core/backends/jarvislabs/test_compute.py create mode 100644 src/tests/_internal/core/backends/jarvislabs/test_configurator.py diff --git a/mkdocs/docs/concepts/backends.md b/mkdocs/docs/concepts/backends.md index 8f7b3325d8..e6483fd3ab 100644 --- a/mkdocs/docs/concepts/backends.md +++ b/mkdocs/docs/concepts/backends.md @@ -918,6 +918,26 @@ projects: +### JarvisLabs + +Log into your [JarvisLabs](https://cloud.jarvislabs.ai/) account and create an API key. + +Then, go ahead and configure the backend: + +
+ +```yaml +projects: +- name: main + backends: + - type: jarvislabs + creds: + type: api_key + api_key: ... +``` + +
+ ### CloudRift Log into your [CloudRift](https://console.cloudrift.ai/) console, click `API Keys` in the sidebar and click the button to create a new API key. diff --git a/mkdocs/docs/reference/server/config.yml.md b/mkdocs/docs/reference/server/config.yml.md index 80e48b028e..4e515aae5a 100644 --- a/mkdocs/docs/reference/server/config.yml.md +++ b/mkdocs/docs/reference/server/config.yml.md @@ -369,6 +369,23 @@ to configure [backends](../../concepts/backends.md) and other [server-level sett type: required: true +##### `projects[n].backends[type=jarvislabs]` { #jarvislabs data-toc-label="jarvislabs" } + +#SCHEMA# dstack._internal.core.backends.jarvislabs.models.JarvisLabsBackendFileConfigWithCreds + overrides: + show_root_heading: false + type: + required: true + item_id_prefix: jarvislabs- + +###### `projects[n].backends[type=jarvislabs].creds` { #jarvislabs-creds data-toc-label="creds" } + +#SCHEMA# dstack._internal.core.backends.jarvislabs.models.JarvisLabsAPIKeyCreds + overrides: + show_root_heading: false + type: + required: true + ##### `projects[n].backends[type=cloudrift]` { #cloudrift data-toc-label="cloudrift" } #SCHEMA# dstack._internal.core.backends.cloudrift.models.CloudRiftBackendConfigWithCreds diff --git a/pyproject.toml b/pyproject.toml index 4d19b23d75..617f750bbd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "python-multipart>=0.0.16", "filelock", "psutil", - "gpuhunt==0.1.21", + "gpuhunt @ git+https://github.com/dstackai/gpuhunt.git@jarvislabs", "argcomplete>=3.5.0", "ignore-python>=0.2.0", "orjson", @@ -67,6 +67,9 @@ artifacts = [ "src/dstack/_internal/server/statics/**", ] +[tool.hatch.metadata] +allow-direct-references = true + [tool.hatch.metadata.hooks.fancy-pypi-readme] content-type = "text/markdown" diff --git a/src/dstack/_internal/core/backends/configurators.py b/src/dstack/_internal/core/backends/configurators.py index 75a4a86abb..cdeac7f608 100644 --- a/src/dstack/_internal/core/backends/configurators.py +++ b/src/dstack/_internal/core/backends/configurators.py @@ -87,6 +87,15 @@ except ImportError: pass +try: + from dstack._internal.core.backends.jarvislabs.configurator import ( + JarvisLabsConfigurator, + ) + + _CONFIGURATOR_CLASSES.append(JarvisLabsConfigurator) +except ImportError: + pass + try: from dstack._internal.core.backends.kubernetes.configurator import ( KubernetesConfigurator, diff --git a/src/dstack/_internal/core/backends/jarvislabs/__init__.py b/src/dstack/_internal/core/backends/jarvislabs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/dstack/_internal/core/backends/jarvislabs/api_client.py b/src/dstack/_internal/core/backends/jarvislabs/api_client.py new file mode 100644 index 0000000000..d55319e5bf --- /dev/null +++ b/src/dstack/_internal/core/backends/jarvislabs/api_client.py @@ -0,0 +1,298 @@ +import hashlib +from typing import Any, Dict, List, Optional + +import requests +from gpuhunt.providers.jarvislabs import API_URL, JARVISLABS_REGION_URLS + +from dstack._internal.core.errors import ( + BackendError, + BackendInvalidCredentialsError, + NoCapacityError, +) + +TIMEOUT = 120 + + +class JarvisLabsNotFoundError(BackendError): + pass + + +class JarvisLabsAPIClient: + def __init__(self, api_key: str): + self.api_key = api_key + + def validate_api_key(self) -> bool: + try: + self.get_user_info() + except BackendInvalidCredentialsError: + return False + return True + + def get_user_info(self) -> Dict[str, Any]: + resp = self._make_request("GET", "users/user_info") + if not isinstance(resp, dict): + raise BackendError("Unexpected JarvisLabs user_info response") + return resp + + def list_ssh_keys(self) -> List[Dict[str, Any]]: + resp = self._make_request("GET", "ssh/") + if isinstance(resp, list): + return resp + raise BackendError("Unexpected JarvisLabs SSH key list response") + + def add_ssh_key(self, public_key: str, key_name: str) -> None: + resp = self._make_request( + "POST", + "ssh/", + json={ + "ssh_key": public_key, + "key_name": key_name, + }, + ) + _raise_if_unsuccessful(resp, "Failed to add JarvisLabs SSH key") + + def add_ssh_key_if_needed(self, public_key: str) -> None: + normalized_key = _normalize_public_key(public_key) + for ssh_key in self.list_ssh_keys(): + if _normalize_public_key(str(ssh_key.get("ssh_key", ""))) == normalized_key: + return + key_name = _get_ssh_key_name(normalized_key) + self.add_ssh_key(public_key=public_key, key_name=key_name) + + def create_gpu_vm( + self, + *, + gpu_type: str, + num_gpus: int, + is_spot: bool, + storage: int, + region: str, + name: str, + ) -> str: + resp = self._make_request( + "POST", + "templates/vm/create", + region=region, + json={ + "gpu_type": gpu_type, + "num_gpus": num_gpus, + "hdd": storage, + "region": region, + "name": name, + "is_spot": is_spot, + "duration": "hour", + "disk_type": "ssd", + "http_ports": "", + "script_id": None, + "script_args": "", + "fs_id": None, + "arguments": "", + }, + ) + return _get_created_machine_id(resp, "GPU VM creation") + + def create_cpu_vm( + self, + *, + vcpus: int, + ram_gb: int, + storage: int, + region: str, + name: str, + ) -> str: + resp = self._make_request( + "POST", + "templates/vm/cpu/create", + region=region, + json={ + "num_cpus": 1, + "vcpus": vcpus, + "ram_gb": ram_gb, + "hdd": storage, + "region": region, + "name": name, + "duration": "hour", + "disk_type": "ssd", + }, + ) + return _get_created_machine_id(resp, "CPU VM creation") + + def get_instance(self, machine_id: str) -> Optional[Dict[str, Any]]: + try: + resp = self._make_request("GET", f"users/fetch/{machine_id}") + except JarvisLabsNotFoundError: + return None + if not _is_successful(resp): + return None + if isinstance(resp, dict): + instance = resp.get("instance") + if isinstance(instance, dict): + return instance + return None + + def get_instance_status(self, *, machine_id: str, region: str) -> Optional[Dict[str, Any]]: + try: + resp = self._make_request( + "GET", + "misc/status", + region=region, + params={"machine_id": machine_id}, + ) + except JarvisLabsNotFoundError: + return None + if isinstance(resp, dict): + return resp + return None + + def destroy_instance(self, *, machine_id: str, region: str) -> None: + instance = self.get_instance(machine_id) + if instance is None: + return + endpoint = "templates/vm/destroy" + if is_cpu_vm(instance): + endpoint = "templates/vm/cpu/destroy" + elif _instance_template(instance) != "vm": + endpoint = "misc/destroy" + + try: + resp = self._make_request( + "POST", + endpoint, + region=instance.get("region") or region, + params={"machine_id": machine_id}, + ) + except JarvisLabsNotFoundError: + return + _raise_if_unsuccessful(resp, "Failed to destroy JarvisLabs instance") + + def _make_request( + self, + method: str, + path: str, + *, + json: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None, + region: Optional[str] = None, + ) -> Any: + try: + response = requests.request( + method=method, + url=self._url(path=path, region=region), + headers={"Authorization": f"Bearer {self.api_key}"}, + json=json, + params=params, + timeout=TIMEOUT, + ) + except requests.RequestException as e: + raise BackendError(f"JarvisLabs request failed: {e}") from e + if response.ok: + if not response.content: + return {} + try: + return response.json() + except ValueError as e: + raise BackendError("Unexpected non-JSON JarvisLabs response") from e + message = _get_response_error(response) + if response.status_code in [401, 403]: + raise BackendInvalidCredentialsError(fields=[["creds", "api_key"]]) + if response.status_code == 404: + raise JarvisLabsNotFoundError(message) + if response.status_code in [400, 409] and _looks_like_no_capacity(message): + raise NoCapacityError(message) + raise BackendError(message) + + def _url(self, *, path: str, region: Optional[str] = None) -> str: + if region is None: + base_url = API_URL + else: + # gpuhunt owns this allowlist because it filters JarvisLabs offers. Do not + # fall back for unknown regions: regional VM APIs use separate hosts and + # JarvisLabs does not expose endpoint discovery in server_meta. + base_url = JARVISLABS_REGION_URLS.get(region) + if base_url is None: + raise BackendError( + f"Unsupported JarvisLabs region {region!r}. " + "JarvisLabs does not expose provisioning endpoint discovery." + ) + return base_url.rstrip("/") + "/" + path.lstrip("/") + + +def is_cpu_vm(instance: Dict[str, Any]) -> bool: + return _instance_template(instance) == "vm" and str(instance.get("gpu_type")).upper() == "CPU" + + +def _instance_template(instance: Dict[str, Any]) -> str: + return str(instance.get("template") or instance.get("framework") or "").lower() + + +def _get_created_machine_id(resp: Any, operation: str) -> str: + _raise_if_unsuccessful(resp, f"JarvisLabs {operation} failed") + if isinstance(resp, dict): + machine_id = resp.get("machine_id") + if machine_id is not None: + return str(machine_id) + raise BackendError(f"JarvisLabs {operation} failed: missing machine_id") + + +def _raise_if_unsuccessful(resp: Any, message: str) -> None: + if _is_successful(resp): + return + backend_message = _backend_message(resp) + if _looks_like_no_capacity(backend_message): + raise NoCapacityError(backend_message) + raise BackendError(f"{message}: {backend_message}") + + +def _is_successful(resp: Any) -> bool: + if not isinstance(resp, dict): + return True + if "success" in resp: + return _coerce_bool(resp["success"]) + if "sucess" in resp: + return _coerce_bool(resp["sucess"]) + return True + + +def _coerce_bool(value: Any) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.strip().lower() in {"1", "true", "yes", "success"} + return bool(value) + + +def _get_response_error(response: requests.Response) -> str: + try: + data = response.json() + except ValueError: + return response.text or f"HTTP {response.status_code}" + message = _backend_message(data) + return message or f"HTTP {response.status_code}" + + +def _backend_message(resp: Any) -> str: + if isinstance(resp, dict): + detail = resp.get("detail") + if isinstance(detail, list): + return "; ".join(str(item.get("msg", item)) for item in detail) + return str( + resp.get("message") + or resp.get("error") + or resp.get("detail") + or resp.get("msg") + or resp + ) + return str(resp) + + +def _looks_like_no_capacity(message: str) -> bool: + message = message.lower() + return "capacity" in message or "available" in message or "stock" in message + + +def _normalize_public_key(public_key: str) -> str: + return " ".join(public_key.strip().split()[:2]) + + +def _get_ssh_key_name(public_key: str) -> str: + return "dstack-" + hashlib.sha1(public_key.encode()).hexdigest()[:16] diff --git a/src/dstack/_internal/core/backends/jarvislabs/backend.py b/src/dstack/_internal/core/backends/jarvislabs/backend.py new file mode 100644 index 0000000000..ac47171bd6 --- /dev/null +++ b/src/dstack/_internal/core/backends/jarvislabs/backend.py @@ -0,0 +1,16 @@ +from dstack._internal.core.backends.base.backend import Backend +from dstack._internal.core.backends.jarvislabs.compute import JarvisLabsCompute +from dstack._internal.core.backends.jarvislabs.models import JarvisLabsConfig +from dstack._internal.core.models.backends.base import BackendType + + +class JarvisLabsBackend(Backend): + TYPE = BackendType.JARVISLABS + COMPUTE_CLASS = JarvisLabsCompute + + def __init__(self, config: JarvisLabsConfig): + self.config = config + self._compute = JarvisLabsCompute(self.config) + + def compute(self) -> JarvisLabsCompute: + return self._compute diff --git a/src/dstack/_internal/core/backends/jarvislabs/compute.py b/src/dstack/_internal/core/backends/jarvislabs/compute.py new file mode 100644 index 0000000000..4f6151dde0 --- /dev/null +++ b/src/dstack/_internal/core/backends/jarvislabs/compute.py @@ -0,0 +1,364 @@ +import shlex +import subprocess +import tempfile +import time +from typing import List, Optional + +import gpuhunt + +from dstack._internal.core.backends.base.backend import Compute +from dstack._internal.core.backends.base.compute import ( + ComputeWithCreateInstanceSupport, + ComputeWithFilteredOffersCached, + ComputeWithInstanceVolumesSupport, + ComputeWithPrivilegedSupport, + generate_unique_instance_name, + get_shim_commands, +) +from dstack._internal.core.backends.base.offers import get_catalog_offers +from dstack._internal.core.backends.jarvislabs.api_client import JarvisLabsAPIClient +from dstack._internal.core.backends.jarvislabs.models import JarvisLabsConfig +from dstack._internal.core.errors import BackendError, NoCapacityError, ProvisioningError +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import ( + InstanceAvailability, + InstanceConfiguration, + InstanceOfferWithAvailability, +) +from dstack._internal.core.models.placement import PlacementGroup +from dstack._internal.core.models.runs import JobProvisioningData, Requirements +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + +MAX_INSTANCE_NAME_LEN = 40 +DEFAULT_DISK_SIZE_GB = 100 +DEFAULT_USERNAME = "ubuntu" +SSH_CONNECT_TIMEOUT_SECONDS = 10 +SSH_SETUP_TIMEOUT_SECONDS = 240 +SSH_LAUNCH_TIMEOUT_SECONDS = 60 +CREATE_FAILURE_POLL_INTERVAL_SECONDS = 5 +CREATE_FAILURE_POLL_TIMEOUT_SECONDS = 30 + + +class JarvisLabsCompute( + ComputeWithFilteredOffersCached, + ComputeWithCreateInstanceSupport, + ComputeWithPrivilegedSupport, + ComputeWithInstanceVolumesSupport, + Compute, +): + def __init__(self, config: JarvisLabsConfig): + super().__init__() + self.config = config + self.api_client = JarvisLabsAPIClient(config.creds.api_key) + self._catalog: Optional[gpuhunt.Catalog] = None + + def get_offers_by_requirements( + self, requirements: Requirements + ) -> List[InstanceOfferWithAvailability]: + offers = get_catalog_offers( + backend=BackendType.JARVISLABS, + locations=self.config.regions or None, + requirements=requirements, + catalog=self._get_catalog(), + ) + return [ + offer.with_availability(availability=InstanceAvailability.AVAILABLE) + for offer in offers + ] + + def create_instance( + self, + instance_offer: InstanceOfferWithAvailability, + instance_config: InstanceConfiguration, + placement_group: Optional[PlacementGroup], + ) -> JobProvisioningData: + instance_name = generate_unique_instance_name( + instance_config, max_length=MAX_INSTANCE_NAME_LEN + ) + self.api_client.add_ssh_key_if_needed(instance_config.ssh_keys[0].public) + instance_id = None + try: + if instance_offer.instance.resources.gpus: + instance_id = self.api_client.create_gpu_vm( + gpu_type=_get_jarvislabs_gpu_type(instance_offer), + num_gpus=len(instance_offer.instance.resources.gpus), + is_spot=instance_offer.instance.resources.spot, + storage=_get_disk_size_gb(instance_offer), + region=instance_offer.region, + name=instance_name, + ) + else: + instance_id = self.api_client.create_cpu_vm( + vcpus=instance_offer.instance.resources.cpus, + ram_gb=round(instance_offer.instance.resources.memory_mib / 1024), + storage=_get_disk_size_gb(instance_offer), + region=instance_offer.region, + name=instance_name, + ) + + _raise_if_create_failed( + api_client=self.api_client, + machine_id=instance_id, + region=instance_offer.region, + ) + except BaseException: + if instance_id is not None: + try: + self.api_client.destroy_instance( + machine_id=instance_id, + region=instance_offer.region, + ) + except Exception: + logger.exception( + "Could not destroy failed JarvisLabs instance %s", instance_id + ) + raise + return JobProvisioningData( + backend=instance_offer.backend, + instance_type=instance_offer.instance, + instance_id=instance_id, + hostname=None, + internal_ip=None, + region=instance_offer.region, + price=instance_offer.price, + username=DEFAULT_USERNAME, + ssh_port=22, + dockerized=True, + ssh_proxy=None, + backend_data=None, + ) + + def update_provisioning_data( + self, + provisioning_data: JobProvisioningData, + project_ssh_public_key: str, + project_ssh_private_key: str, + ): + instance = self.api_client.get_instance(provisioning_data.instance_id) + if instance is None: + status = self.api_client.get_instance_status( + machine_id=provisioning_data.instance_id, + region=provisioning_data.region, + ) + if status is not None and str(status.get("status")).lower() == "failed": + raise ProvisioningError(_format_failed_status(status), status) + return + + status = str(instance.get("status")).lower() + if status == "failed": + raise ProvisioningError("JarvisLabs instance entered Failed state", instance) + if status != "running": + return + + hostname = instance.get("public_ip") + if not hostname: + return + username = _get_ssh_username(instance) + if not _start_runner( + hostname=hostname, + username=username, + project_ssh_private_key=project_ssh_private_key, + arch=provisioning_data.instance_type.resources.cpu_arch, + ): + return + provisioning_data.hostname = hostname + provisioning_data.username = username + + def terminate_instance( + self, instance_id: str, region: str, backend_data: Optional[str] = None + ): + self.api_client.destroy_instance(machine_id=instance_id, region=region) + + def _get_catalog(self) -> gpuhunt.Catalog: + if self._catalog is None: + try: + from gpuhunt.providers.jarvislabs import JarvisLabsProvider + except ImportError as e: + raise BackendError( + "JarvisLabs backend requires gpuhunt with JarvisLabs provider support" + ) from e + catalog = gpuhunt.Catalog(balance_resources=False, auto_reload=False) + catalog.add_provider(JarvisLabsProvider(api_key=self.config.creds.api_key)) + self._catalog = catalog + return self._catalog + + +def _get_jarvislabs_gpu_type(instance_offer: InstanceOfferWithAvailability) -> str: + gpu = instance_offer.instance.resources.gpus[0] + memory_gb = round(gpu.memory_mib / 1024) + if gpu.name == "A100" and memory_gb == 80: + return "A100-80GB" + return gpu.name + + +def _get_disk_size_gb(instance_offer: InstanceOfferWithAvailability) -> int: + disk_size_gb = round(instance_offer.instance.resources.disk.size_mib / 1024) + return max(DEFAULT_DISK_SIZE_GB, disk_size_gb) + + +def _format_failed_status(status: dict) -> str: + message = status.get("error") or "unknown error" + code = status.get("code") + if code is not None: + return f"JarvisLabs instance creation failed: {message} (code={code})" + return f"JarvisLabs instance creation failed: {message}" + + +def _raise_if_create_failed( + *, + api_client: JarvisLabsAPIClient, + machine_id: str, + region: str, +): + deadline = time.monotonic() + CREATE_FAILURE_POLL_TIMEOUT_SECONDS + while time.monotonic() < deadline: + status = api_client.get_instance_status(machine_id=machine_id, region=region) + if status is None: + return + status_value = str(status.get("status")).lower() + if status_value == "failed": + message = _format_failed_status(status) + if _looks_like_no_capacity(message): + raise NoCapacityError(message) + raise ProvisioningError(message) + if status_value == "running": + return + time.sleep(CREATE_FAILURE_POLL_INTERVAL_SECONDS) + + +def _looks_like_no_capacity(message: str) -> bool: + message = message.lower() + return "capacity" in message or "available" in message or "stock" in message + + +def _get_ssh_username(instance: dict) -> str: + ssh_command = instance.get("ssh_str") or instance.get("ssh_command") + if not isinstance(ssh_command, str): + return DEFAULT_USERNAME + try: + parts = shlex.split(ssh_command) + except ValueError: + return DEFAULT_USERNAME + for part in parts[1:]: + if part.startswith("-") or "@" not in part: + continue + return part.rsplit("@", 1)[0] + return DEFAULT_USERNAME + + +def _start_runner( + hostname: str, + username: str, + project_ssh_private_key: str, + arch: Optional[str], +) -> bool: + commands = get_shim_commands(arch=arch) + launch_command = "sudo sh -c " + shlex.quote(" && ".join(commands)) + try: + if not _setup_instance( + hostname=hostname, + username=username, + ssh_private_key=project_ssh_private_key, + ): + return False + return _launch_runner( + hostname=hostname, + username=username, + ssh_private_key=project_ssh_private_key, + launch_command=launch_command, + ) + except Exception: + logger.exception("Failed to start dstack shim on JarvisLabs instance %s", hostname) + return False + + +def _setup_instance( + hostname: str, + username: str, + ssh_private_key: str, +) -> bool: + setup_commands = [ + "mkdir -p ~/.dstack", + "if ! command -v curl >/dev/null 2>&1 || ! command -v docker >/dev/null 2>&1 || ! command -v jq >/dev/null 2>&1; then sudo apt-get update; fi", + "if ! command -v curl >/dev/null 2>&1; then sudo DEBIAN_FRONTEND=noninteractive apt-get install -y curl; fi", + "if ! command -v docker >/dev/null 2>&1; then sudo apt-get update && sudo DEBIAN_FRONTEND=noninteractive apt-get install -y docker.io; fi", + "if ! command -v jq >/dev/null 2>&1; then sudo apt-get update && sudo DEBIAN_FRONTEND=noninteractive apt-get install -y jq; fi", + "sudo systemctl enable --now docker || sudo service docker start || true", + ] + return _run_ssh_command( + hostname=hostname, + username=username, + ssh_private_key=ssh_private_key, + command=" && ".join(setup_commands), + timeout=SSH_SETUP_TIMEOUT_SECONDS, + ) + + +def _launch_runner( + hostname: str, + username: str, + ssh_private_key: str, + launch_command: str, +) -> bool: + daemonized_command = f"{launch_command.rstrip('&')} >/tmp/dstack-shim.log 2>&1 & disown" + return _run_ssh_command( + hostname=hostname, + username=username, + ssh_private_key=ssh_private_key, + command=daemonized_command, + timeout=SSH_LAUNCH_TIMEOUT_SECONDS, + ) + + +def _run_ssh_command( + hostname: str, + username: str, + ssh_private_key: str, + command: str, + timeout: int, +) -> bool: + with tempfile.NamedTemporaryFile("w+") as f: + f.write(ssh_private_key) + f.flush() + try: + proc = subprocess.run( + [ + "ssh", + "-F", + "none", + "-o", + "BatchMode=yes", + "-o", + f"ConnectTimeout={SSH_CONNECT_TIMEOUT_SECONDS}", + "-o", + "ConnectionAttempts=1", + "-o", + "StrictHostKeyChecking=no", + "-o", + "UserKnownHostsFile=/dev/null", + "-o", + "LogLevel=ERROR", + "-i", + f.name, + f"{username}@{hostname}", + command, + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=timeout, + ) + except subprocess.TimeoutExpired: + logger.debug("Timed out running SSH command on JarvisLabs instance %s", hostname) + return False + if proc.returncode != 0: + logger.debug( + "SSH command failed on JarvisLabs instance %s: exit_code=%s stderr=%r", + hostname, + proc.returncode, + proc.stderr[-1000:], + ) + return False + return True diff --git a/src/dstack/_internal/core/backends/jarvislabs/configurator.py b/src/dstack/_internal/core/backends/jarvislabs/configurator.py new file mode 100644 index 0000000000..ceaebeedf9 --- /dev/null +++ b/src/dstack/_internal/core/backends/jarvislabs/configurator.py @@ -0,0 +1,85 @@ +import json + +from gpuhunt.providers.jarvislabs import JARVISLABS_REGION_URLS + +from dstack._internal.core.backends.base.configurator import ( + BackendRecord, + Configurator, + raise_invalid_credentials_error, +) +from dstack._internal.core.backends.jarvislabs import api_client +from dstack._internal.core.backends.jarvislabs.backend import JarvisLabsBackend +from dstack._internal.core.backends.jarvislabs.models import ( + JarvisLabsBackendConfig, + JarvisLabsBackendConfigWithCreds, + JarvisLabsConfig, + JarvisLabsCreds, + JarvisLabsStoredConfig, +) +from dstack._internal.core.errors import ServerClientError +from dstack._internal.core.models.backends.base import BackendType + + +class JarvisLabsConfigurator( + Configurator[ + JarvisLabsBackendConfig, + JarvisLabsBackendConfigWithCreds, + ] +): + TYPE = BackendType.JARVISLABS + BACKEND_CLASS = JarvisLabsBackend + + def validate_config( + self, config: JarvisLabsBackendConfigWithCreds, default_creds_enabled: bool + ): + self._validate_api_key(config.creds.api_key) + self._validate_regions(config.regions) + + def create_backend( + self, project_name: str, config: JarvisLabsBackendConfigWithCreds + ) -> BackendRecord: + return BackendRecord( + config=JarvisLabsStoredConfig( + **JarvisLabsBackendConfig.__response__.parse_obj(config).dict() + ).json(), + auth=JarvisLabsCreds.parse_obj(config.creds).json(), + ) + + def get_backend_config_with_creds( + self, record: BackendRecord + ) -> JarvisLabsBackendConfigWithCreds: + config = self._get_config(record) + return JarvisLabsBackendConfigWithCreds.__response__.parse_obj(config) + + def get_backend_config_without_creds(self, record: BackendRecord) -> JarvisLabsBackendConfig: + config = self._get_config(record) + return JarvisLabsBackendConfig.__response__.parse_obj(config) + + def get_backend(self, record: BackendRecord) -> JarvisLabsBackend: + config = self._get_config(record) + return JarvisLabsBackend(config=config) + + def _get_config(self, record: BackendRecord) -> JarvisLabsConfig: + return JarvisLabsConfig.__response__( + **json.loads(record.config), + creds=JarvisLabsCreds.parse_raw(record.auth), + ) + + def _validate_api_key(self, api_key: str): + client = api_client.JarvisLabsAPIClient(api_key=api_key) + if not client.validate_api_key(): + raise_invalid_credentials_error(fields=[["creds", "api_key"]]) + + def _validate_regions(self, regions: list[str] | None): + if not regions: + return + invalid_regions = sorted(set(regions) - set(JARVISLABS_REGION_URLS)) + if invalid_regions: + raise ServerClientError( + msg=( + f"Unsupported JarvisLabs regions: {invalid_regions}. " + f"Supported regions: {sorted(JARVISLABS_REGION_URLS)}. " + "JarvisLabs does not expose provisioning endpoint discovery." + ), + fields=[["regions"]], + ) diff --git a/src/dstack/_internal/core/backends/jarvislabs/models.py b/src/dstack/_internal/core/backends/jarvislabs/models.py new file mode 100644 index 0000000000..dae710089e --- /dev/null +++ b/src/dstack/_internal/core/backends/jarvislabs/models.py @@ -0,0 +1,47 @@ +from typing import Annotated, List, Literal, Optional, Union + +from pydantic import Field + +from dstack._internal.core.models.common import CoreModel + + +class JarvisLabsAPIKeyCreds(CoreModel): + type: Annotated[Literal["api_key"], Field(description="The type of credentials")] = "api_key" + api_key: Annotated[str, Field(description="The JarvisLabs API key")] + + +AnyJarvisLabsCreds = JarvisLabsAPIKeyCreds +JarvisLabsCreds = AnyJarvisLabsCreds + + +class JarvisLabsBackendConfig(CoreModel): + type: Annotated[ + Literal["jarvislabs"], + Field(description="The type of backend"), + ] = "jarvislabs" + regions: Annotated[ + Optional[List[str]], + Field(description="The list of JarvisLabs regions. Omit to use all regions"), + ] = None + + +class JarvisLabsBackendConfigWithCreds(JarvisLabsBackendConfig): + creds: Annotated[AnyJarvisLabsCreds, Field(description="The credentials")] + + +AnyJarvisLabsBackendConfig = Union[ + JarvisLabsBackendConfig, + JarvisLabsBackendConfigWithCreds, +] + + +class JarvisLabsBackendFileConfigWithCreds(JarvisLabsBackendConfig): + creds: Annotated[AnyJarvisLabsCreds, Field(description="The credentials")] + + +class JarvisLabsStoredConfig(JarvisLabsBackendConfig): + pass + + +class JarvisLabsConfig(JarvisLabsStoredConfig): + creds: AnyJarvisLabsCreds diff --git a/src/dstack/_internal/core/backends/models.py b/src/dstack/_internal/core/backends/models.py index 36a7856e38..c21141378e 100644 --- a/src/dstack/_internal/core/backends/models.py +++ b/src/dstack/_internal/core/backends/models.py @@ -39,6 +39,11 @@ HotAisleBackendConfigWithCreds, HotAisleBackendFileConfigWithCreds, ) +from dstack._internal.core.backends.jarvislabs.models import ( + JarvisLabsBackendConfig, + JarvisLabsBackendConfigWithCreds, + JarvisLabsBackendFileConfigWithCreds, +) from dstack._internal.core.backends.kubernetes.models import ( KubernetesBackendConfig, KubernetesBackendConfigWithCreds, @@ -89,6 +94,7 @@ BaseDigitalOceanBackendConfig, GCPBackendConfig, HotAisleBackendConfig, + JarvisLabsBackendConfig, KubernetesBackendConfig, LambdaBackendConfig, NebiusBackendConfig, @@ -115,6 +121,7 @@ BaseDigitalOceanBackendConfigWithCreds, GCPBackendConfigWithCreds, HotAisleBackendConfigWithCreds, + JarvisLabsBackendConfigWithCreds, KubernetesBackendConfigWithCreds, LambdaBackendConfigWithCreds, OCIBackendConfigWithCreds, @@ -139,6 +146,7 @@ BaseDigitalOceanBackendConfigWithCreds, GCPBackendFileConfigWithCreds, HotAisleBackendFileConfigWithCreds, + JarvisLabsBackendFileConfigWithCreds, KubernetesBackendFileConfigWithCreds, LambdaBackendConfigWithCreds, OCIBackendConfigWithCreds, diff --git a/src/dstack/_internal/core/models/backends/base.py b/src/dstack/_internal/core/models/backends/base.py index 2e8eb898ee..dd11ae67a3 100644 --- a/src/dstack/_internal/core/models/backends/base.py +++ b/src/dstack/_internal/core/models/backends/base.py @@ -15,6 +15,7 @@ class BackendType(str, enum.Enum): DSTACK (BackendType): dstack Sky GCP (BackendType): Google Cloud Platform HOTAISLE (BackendType): Hot Aisle + JARVISLABS (BackendType): JarvisLabs KUBERNETES (BackendType): Kubernetes LAMBDA (BackendType): Lambda Cloud NEBIUS (BackendType): Nebius AI Cloud @@ -38,6 +39,7 @@ class BackendType(str, enum.Enum): DSTACK = "dstack" GCP = "gcp" HOTAISLE = "hotaisle" + JARVISLABS = "jarvislabs" KUBERNETES = "kubernetes" LAMBDA = "lambda" LOCAL = "local" diff --git a/src/tests/_internal/core/backends/jarvislabs/__init__.py b/src/tests/_internal/core/backends/jarvislabs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/tests/_internal/core/backends/jarvislabs/test_api_client.py b/src/tests/_internal/core/backends/jarvislabs/test_api_client.py new file mode 100644 index 0000000000..09980d6ed1 --- /dev/null +++ b/src/tests/_internal/core/backends/jarvislabs/test_api_client.py @@ -0,0 +1,186 @@ +import pytest +import requests + +from dstack._internal.core.backends.jarvislabs.api_client import ( + JarvisLabsAPIClient, + is_cpu_vm, +) +from dstack._internal.core.errors import BackendError, BackendInvalidCredentialsError + + +def test_validate_api_key_returns_false_on_unauthorized(requests_mock): + requests_mock.get("https://backendprod.jarvislabs.net/users/user_info", status_code=401) + + assert JarvisLabsAPIClient("bad").validate_api_key() is False + + +def test_get_user_info_raises_invalid_credentials_on_forbidden(requests_mock): + requests_mock.get("https://backendprod.jarvislabs.net/users/user_info", status_code=403) + + with pytest.raises(BackendInvalidCredentialsError): + JarvisLabsAPIClient("bad").get_user_info() + + +def test_make_request_wraps_request_errors(requests_mock): + requests_mock.get( + "https://backendprod.jarvislabs.net/users/user_info", + exc=requests.ConnectTimeout("timed out"), + ) + + with pytest.raises(BackendError, match="JarvisLabs request failed"): + JarvisLabsAPIClient("token").get_user_info() + + +def test_get_user_info_rejects_non_json_success_response(requests_mock): + requests_mock.get("https://backendprod.jarvislabs.net/users/user_info", text="ok") + + with pytest.raises(BackendError, match="Unexpected non-JSON JarvisLabs response"): + JarvisLabsAPIClient("token").get_user_info() + + +def test_add_ssh_key_if_needed_reuses_existing_key(requests_mock): + public_key = "ssh-rsa AAAA test-comment" + requests_mock.get( + "https://backendprod.jarvislabs.net/ssh/", + json=[{"ssh_key": "ssh-rsa AAAA another-comment", "key_name": "existing"}], + ) + + JarvisLabsAPIClient("token").add_ssh_key_if_needed(public_key) + + assert requests_mock.call_count == 1 + + +def test_add_ssh_key_if_needed_adds_missing_key(requests_mock): + public_key = "ssh-rsa AAAA test-comment" + requests_mock.get("https://backendprod.jarvislabs.net/ssh/", json=[]) + requests_mock.post("https://backendprod.jarvislabs.net/ssh/", json={"success": True}) + + JarvisLabsAPIClient("token").add_ssh_key_if_needed(public_key) + + assert requests_mock.last_request.json() == { + "ssh_key": public_key, + "key_name": "dstack-36deb09319b2204c", + } + + +def test_create_gpu_vm_posts_to_regional_vm_endpoint(requests_mock): + requests_mock.post( + "https://backendn.jarvislabs.net/templates/vm/create", + json={"machine_id": 123}, + ) + + machine_id = JarvisLabsAPIClient("token").create_gpu_vm( + gpu_type="A100-80GB", + num_gpus=1, + is_spot=False, + storage=250, + region="india-noida-01", + name="dstack-test", + ) + + assert machine_id == "123" + assert requests_mock.last_request.headers["Authorization"] == "Bearer token" + assert requests_mock.last_request.json() == { + "gpu_type": "A100-80GB", + "num_gpus": 1, + "hdd": 250, + "region": "india-noida-01", + "name": "dstack-test", + "is_spot": False, + "duration": "hour", + "disk_type": "ssd", + "http_ports": "", + "script_id": None, + "script_args": "", + "fs_id": None, + "arguments": "", + } + + +def test_create_gpu_vm_rejects_unsupported_region(requests_mock): + with pytest.raises(BackendError, match="Unsupported JarvisLabs region"): + JarvisLabsAPIClient("token").create_gpu_vm( + gpu_type="H100", + num_gpus=1, + is_spot=False, + storage=100, + region="unknown-region", + name="dstack-test", + ) + + assert requests_mock.call_count == 0 + + +def test_create_gpu_vm_sets_spot_flag(requests_mock): + requests_mock.post( + "https://backendn.jarvislabs.net/templates/vm/create", + json={"machine_id": 123}, + ) + + JarvisLabsAPIClient("token").create_gpu_vm( + gpu_type="L4", + num_gpus=1, + is_spot=True, + storage=100, + region="india-noida-01", + name="dstack-spot", + ) + + assert requests_mock.last_request.json()["is_spot"] is True + + +def test_create_cpu_vm_posts_to_regional_cpu_vm_endpoint(requests_mock): + requests_mock.post( + "https://backendn.jarvislabs.net/templates/vm/cpu/create", + json={"machine_id": 456}, + ) + + machine_id = JarvisLabsAPIClient("token").create_cpu_vm( + vcpus=4, + ram_gb=16, + storage=100, + region="india-noida-01", + name="dstack-cpu", + ) + + assert machine_id == "456" + assert requests_mock.last_request.json() == { + "num_cpus": 1, + "vcpus": 4, + "ram_gb": 16, + "hdd": 100, + "region": "india-noida-01", + "name": "dstack-cpu", + "duration": "hour", + "disk_type": "ssd", + } + + +def test_destroy_instance_uses_cpu_vm_endpoint_for_cpu_vm(requests_mock): + requests_mock.get( + "https://backendprod.jarvislabs.net/users/fetch/456", + json={ + "success": True, + "instance": { + "machine_id": 456, + "template": "vm", + "gpu_type": "CPU", + "region": "india-noida-01", + }, + }, + ) + requests_mock.post( + "https://backendn.jarvislabs.net/templates/vm/cpu/destroy", + json={"success": True}, + ) + + JarvisLabsAPIClient("token").destroy_instance(machine_id="456", region="india-noida-01") + + assert requests_mock.last_request.qs == {"machine_id": ["456"]} + + +def test_is_cpu_vm_requires_vm_template_and_cpu_gpu_type(): + assert is_cpu_vm({"template": "vm", "gpu_type": "CPU"}) + assert is_cpu_vm({"framework": "VM", "gpu_type": "CPU"}) + assert not is_cpu_vm({"template": "pytorch", "gpu_type": "CPU"}) + assert not is_cpu_vm({"template": "vm", "gpu_type": "H100"}) diff --git a/src/tests/_internal/core/backends/jarvislabs/test_compute.py b/src/tests/_internal/core/backends/jarvislabs/test_compute.py new file mode 100644 index 0000000000..987fc3d1e0 --- /dev/null +++ b/src/tests/_internal/core/backends/jarvislabs/test_compute.py @@ -0,0 +1,310 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from dstack._internal.core.backends.jarvislabs.compute import ( + JarvisLabsCompute, + _get_disk_size_gb, + _get_jarvislabs_gpu_type, + _get_ssh_username, + _raise_if_create_failed, +) +from dstack._internal.core.backends.jarvislabs.models import JarvisLabsConfig, JarvisLabsCreds +from dstack._internal.core.errors import NoCapacityError, ProvisioningError +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import ( + Disk, + Gpu, + InstanceAvailability, + InstanceConfiguration, + InstanceOfferWithAvailability, + InstanceType, + Resources, + SSHKey, +) +from dstack._internal.core.models.runs import JobProvisioningData + + +def _compute() -> JarvisLabsCompute: + compute = JarvisLabsCompute( + JarvisLabsConfig(creds=JarvisLabsCreds(api_key="test"), regions=["india-noida-01"]) + ) + compute.api_client = MagicMock() + compute.api_client.get_instance_status.return_value = {"status": "Running"} + return compute + + +def _instance_config() -> InstanceConfiguration: + return InstanceConfiguration( + project_name="test-project", + instance_name="jarvislabs-test", + user="test-user", + ssh_keys=[SSHKey(public="ssh-rsa AAAA test")], + ) + + +def _gpu_offer( + *, + gpu_name: str = "A100", + gpu_memory_mib: int = 80 * 1024, + disk_size_mib: int = 250 * 1024, + spot: bool = False, +) -> InstanceOfferWithAvailability: + return InstanceOfferWithAvailability( + backend=BackendType.JARVISLABS, + instance=InstanceType( + name=f"{gpu_name}-1x", + resources=Resources( + cpus=28, + memory_mib=112 * 1024, + gpus=[Gpu(name=gpu_name, memory_mib=gpu_memory_mib)], + spot=spot, + disk=Disk(size_mib=disk_size_mib), + ), + ), + region="india-noida-01", + price=1.49, + availability=InstanceAvailability.AVAILABLE, + ) + + +def _cpu_offer(*, disk_size_mib: int = 10 * 1024) -> InstanceOfferWithAvailability: + return InstanceOfferWithAvailability( + backend=BackendType.JARVISLABS, + instance=InstanceType( + name="cpu-4x16", + resources=Resources( + cpus=4, + memory_mib=16 * 1024, + gpus=[], + spot=False, + disk=Disk(size_mib=disk_size_mib), + ), + ), + region="india-noida-01", + price=0.0992, + availability=InstanceAvailability.AVAILABLE, + ) + + +def test_get_jarvislabs_gpu_type_reconstructs_a100_80gb(): + assert _get_jarvislabs_gpu_type(_gpu_offer()) == "A100-80GB" + assert _get_jarvislabs_gpu_type(_gpu_offer(gpu_memory_mib=40 * 1024)) == "A100" + assert _get_jarvislabs_gpu_type(_gpu_offer(gpu_name="H100")) == "H100" + + +def test_get_disk_size_gb_clamps_to_jarvislabs_vm_minimum(): + assert _get_disk_size_gb(_cpu_offer(disk_size_mib=10 * 1024)) == 100 + assert _get_disk_size_gb(_gpu_offer(disk_size_mib=250 * 1024)) == 250 + + +def test_create_gpu_instance_registers_ssh_key_and_creates_gpu_vm(): + compute = _compute() + compute.api_client.create_gpu_vm.return_value = "123" + + with patch( + "dstack._internal.core.backends.jarvislabs.compute.generate_unique_instance_name", + return_value="dstack-test", + ): + provisioning_data = compute.create_instance(_gpu_offer(), _instance_config(), None) + + compute.api_client.add_ssh_key_if_needed.assert_called_once_with("ssh-rsa AAAA test") + compute.api_client.create_gpu_vm.assert_called_once_with( + gpu_type="A100-80GB", + num_gpus=1, + is_spot=False, + storage=250, + region="india-noida-01", + name="dstack-test", + ) + assert provisioning_data.instance_id == "123" + assert provisioning_data.username == "ubuntu" + assert provisioning_data.dockerized is True + assert provisioning_data.backend_data is None + + +def test_create_gpu_instance_passes_spot_flag(): + compute = _compute() + compute.api_client.create_gpu_vm.return_value = "123" + + with patch( + "dstack._internal.core.backends.jarvislabs.compute.generate_unique_instance_name", + return_value="dstack-test", + ): + compute.create_instance(_gpu_offer(spot=True), _instance_config(), None) + + compute.api_client.create_gpu_vm.assert_called_once_with( + gpu_type="A100-80GB", + num_gpus=1, + is_spot=True, + storage=250, + region="india-noida-01", + name="dstack-test", + ) + + +def test_create_cpu_instance_registers_ssh_key_and_creates_cpu_vm(): + compute = _compute() + compute.api_client.create_cpu_vm.return_value = "456" + + with patch( + "dstack._internal.core.backends.jarvislabs.compute.generate_unique_instance_name", + return_value="dstack-cpu", + ): + provisioning_data = compute.create_instance(_cpu_offer(), _instance_config(), None) + + compute.api_client.add_ssh_key_if_needed.assert_called_once_with("ssh-rsa AAAA test") + compute.api_client.create_cpu_vm.assert_called_once_with( + vcpus=4, + ram_gb=16, + storage=100, + region="india-noida-01", + name="dstack-cpu", + ) + assert provisioning_data.instance_id == "456" + assert provisioning_data.backend_data is None + + +def test_update_provisioning_data_sets_hostname_and_starts_runner(): + compute = _compute() + compute.api_client.get_instance.return_value = { + "machine_id": 123, + "status": "Running", + "public_ip": "203.0.113.10", + "ssh_str": "ssh -o StrictHostKeyChecking=no ubuntu@203.0.113.10", + } + provisioning_data = JobProvisioningData( + backend=BackendType.JARVISLABS, + instance_type=_gpu_offer().instance, + instance_id="123", + region="india-noida-01", + price=1.49, + username="ubuntu", + ssh_port=22, + dockerized=True, + ) + + with patch( + "dstack._internal.core.backends.jarvislabs.compute._start_runner", return_value=True + ) as m: + compute.update_provisioning_data( + provisioning_data, + project_ssh_public_key="ssh-rsa AAAA test", + project_ssh_private_key="private-key", + ) + + assert provisioning_data.hostname == "203.0.113.10" + assert provisioning_data.username == "ubuntu" + m.assert_called_once_with( + hostname="203.0.113.10", + username="ubuntu", + project_ssh_private_key="private-key", + arch=None, + ) + + +def test_update_provisioning_data_does_not_set_hostname_until_runner_starts(): + compute = _compute() + compute.api_client.get_instance.return_value = { + "machine_id": 123, + "status": "Running", + "public_ip": "203.0.113.10", + "ssh_str": "ssh -o StrictHostKeyChecking=no ubuntu@203.0.113.10", + } + provisioning_data = JobProvisioningData( + backend=BackendType.JARVISLABS, + instance_type=_gpu_offer().instance, + instance_id="123", + region="india-noida-01", + price=1.49, + username="ubuntu", + ssh_port=22, + dockerized=True, + ) + + with patch( + "dstack._internal.core.backends.jarvislabs.compute._start_runner", return_value=False + ): + compute.update_provisioning_data( + provisioning_data, + project_ssh_public_key="ssh-rsa AAAA test", + project_ssh_private_key="private-key", + ) + + assert provisioning_data.hostname is None + + +def test_get_ssh_username_parses_jarvislabs_ssh_command(): + assert ( + _get_ssh_username({"ssh_str": "ssh -o StrictHostKeyChecking=no ubuntu@203.0.113.10"}) + == "ubuntu" + ) + assert _get_ssh_username({"ssh_str": "ssh -p 22 root@203.0.113.10"}) == "root" + assert _get_ssh_username({}) == "ubuntu" + + +def test_terminate_instance_delegates_to_api_client(): + compute = _compute() + + compute.terminate_instance("123", "india-noida-01") + + compute.api_client.destroy_instance.assert_called_once_with( + machine_id="123", + region="india-noida-01", + ) + + +def test_create_instance_cleans_up_post_create_failure(): + compute = _compute() + compute.api_client.create_gpu_vm.return_value = "123" + compute.api_client.get_instance_status.return_value = { + "status": "Failed", + "error": "L4 not available at this moment, please try again later", + "code": 404, + } + + with patch( + "dstack._internal.core.backends.jarvislabs.compute.generate_unique_instance_name", + return_value="dstack-test", + ): + with pytest.raises(NoCapacityError): + compute.create_instance(_gpu_offer(spot=True), _instance_config(), None) + + compute.api_client.destroy_instance.assert_called_once_with( + machine_id="123", + region="india-noida-01", + ) + + +def test_raise_if_create_failed_due_to_no_capacity(): + api_client = MagicMock() + api_client.get_instance_status.return_value = { + "status": "Failed", + "error": "L4 not available at this moment, please try again later", + "code": 404, + } + + with patch("dstack._internal.core.backends.jarvislabs.compute.time.sleep"): + with pytest.raises(NoCapacityError): + _raise_if_create_failed( + api_client=api_client, + machine_id="123", + region="india-noida-01", + ) + + +def test_raise_if_create_failed_raises_provisioning_error(): + api_client = MagicMock() + api_client.get_instance_status.return_value = { + "status": "Failed", + "error": "image setup failed", + "code": 500, + } + + with patch("dstack._internal.core.backends.jarvislabs.compute.time.sleep"): + with pytest.raises(ProvisioningError): + _raise_if_create_failed( + api_client=api_client, + machine_id="123", + region="india-noida-01", + ) diff --git a/src/tests/_internal/core/backends/jarvislabs/test_configurator.py b/src/tests/_internal/core/backends/jarvislabs/test_configurator.py new file mode 100644 index 0000000000..1da92e2600 --- /dev/null +++ b/src/tests/_internal/core/backends/jarvislabs/test_configurator.py @@ -0,0 +1,54 @@ +from unittest.mock import patch + +import pytest + +from dstack._internal.core.backends.jarvislabs.configurator import JarvisLabsConfigurator +from dstack._internal.core.backends.jarvislabs.models import ( + JarvisLabsBackendConfigWithCreds, + JarvisLabsCreds, +) +from dstack._internal.core.errors import BackendInvalidCredentialsError, ServerClientError + + +class TestJarvisLabsConfigurator: + def test_validate_config_valid(self): + config = JarvisLabsBackendConfigWithCreds( + creds=JarvisLabsCreds(api_key="valid"), + regions=["india-noida-01"], + ) + with patch( + "dstack._internal.core.backends.jarvislabs.api_client.JarvisLabsAPIClient.validate_api_key" + ) as validate_mock: + validate_mock.return_value = True + JarvisLabsConfigurator().validate_config(config, default_creds_enabled=True) + + def test_validate_config_invalid_creds(self): + config = JarvisLabsBackendConfigWithCreds( + creds=JarvisLabsCreds(api_key="invalid"), + regions=["india-noida-01"], + ) + with ( + patch( + "dstack._internal.core.backends.jarvislabs.api_client.JarvisLabsAPIClient.validate_api_key" + ) as validate_mock, + pytest.raises(BackendInvalidCredentialsError) as exc_info, + ): + validate_mock.return_value = False + JarvisLabsConfigurator().validate_config(config, default_creds_enabled=True) + assert exc_info.value.fields == [["creds", "api_key"]] + + def test_validate_config_unsupported_region(self): + config = JarvisLabsBackendConfigWithCreds( + creds=JarvisLabsCreds(api_key="valid"), + regions=["unknown-region"], + ) + with ( + patch( + "dstack._internal.core.backends.jarvislabs.api_client.JarvisLabsAPIClient.validate_api_key" + ) as validate_mock, + pytest.raises(ServerClientError) as exc_info, + ): + validate_mock.return_value = True + JarvisLabsConfigurator().validate_config(config, default_creds_enabled=True) + assert exc_info.value.fields == [["regions"]] + assert "Unsupported JarvisLabs regions" in exc_info.value.msg diff --git a/src/tests/_internal/server/routers/test_backends.py b/src/tests/_internal/server/routers/test_backends.py index 47295748d3..79fb13667e 100644 --- a/src/tests/_internal/server/routers/test_backends.py +++ b/src/tests/_internal/server/routers/test_backends.py @@ -92,6 +92,7 @@ async def test_returns_backend_types(self, client: AsyncClient): "digitalocean", "gcp", "hotaisle", + "jarvislabs", "kubernetes", "lambda", "nebius", From 760dbd76cae2722f602065c61722a509ee5539db Mon Sep 17 00:00:00 2001 From: Andrey Cheptsov Date: Wed, 13 May 2026 09:20:00 +0200 Subject: [PATCH 2/5] Document JarvisLabs VM startup script behavior --- src/dstack/_internal/core/backends/jarvislabs/api_client.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/dstack/_internal/core/backends/jarvislabs/api_client.py b/src/dstack/_internal/core/backends/jarvislabs/api_client.py index d55319e5bf..e1a2a022ec 100644 --- a/src/dstack/_internal/core/backends/jarvislabs/api_client.py +++ b/src/dstack/_internal/core/backends/jarvislabs/api_client.py @@ -83,6 +83,8 @@ def create_gpu_vm( "duration": "hour", "disk_type": "ssd", "http_ports": "", + # JarvisLabs accepts script_id for VM creates, but live CPU/GPU VM tests + # showed it is not injected into cloud-init user-data/runcmd. "script_id": None, "script_args": "", "fs_id": None, @@ -113,6 +115,7 @@ def create_cpu_vm( "name": name, "duration": "hour", "disk_type": "ssd", + # Do not pass script_id here either; CPU VM create accepts it but ignores it. }, ) return _get_created_machine_id(resp, "CPU VM creation") From 537300853ea47250e83789f88a099a589275b8f5 Mon Sep 17 00:00:00 2001 From: Andrey Cheptsov Date: Wed, 20 May 2026 17:12:28 +0200 Subject: [PATCH 3/5] Address JarvisLabs backend review --- .../core/backends/jarvislabs/compute.py | 87 ++++-------- .../core/backends/jarvislabs/test_compute.py | 128 +++++++++++++----- 2 files changed, 122 insertions(+), 93 deletions(-) diff --git a/src/dstack/_internal/core/backends/jarvislabs/compute.py b/src/dstack/_internal/core/backends/jarvislabs/compute.py index 4f6151dde0..2e79902b45 100644 --- a/src/dstack/_internal/core/backends/jarvislabs/compute.py +++ b/src/dstack/_internal/core/backends/jarvislabs/compute.py @@ -1,24 +1,29 @@ import shlex import subprocess import tempfile -import time +from collections.abc import Iterable from typing import List, Optional import gpuhunt +from gpuhunt.providers.jarvislabs import JarvisLabsProvider from dstack._internal.core.backends.base.backend import Compute from dstack._internal.core.backends.base.compute import ( + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, - ComputeWithFilteredOffersCached, ComputeWithInstanceVolumesSupport, ComputeWithPrivilegedSupport, generate_unique_instance_name, get_shim_commands, ) -from dstack._internal.core.backends.base.offers import get_catalog_offers +from dstack._internal.core.backends.base.offers import ( + OfferModifier, + get_catalog_offers, + get_offers_disk_modifier, +) from dstack._internal.core.backends.jarvislabs.api_client import JarvisLabsAPIClient from dstack._internal.core.backends.jarvislabs.models import JarvisLabsConfig -from dstack._internal.core.errors import BackendError, NoCapacityError, ProvisioningError +from dstack._internal.core.errors import ProvisioningError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import ( InstanceAvailability, @@ -26,23 +31,24 @@ InstanceOfferWithAvailability, ) from dstack._internal.core.models.placement import PlacementGroup +from dstack._internal.core.models.resources import Memory, Range from dstack._internal.core.models.runs import JobProvisioningData, Requirements from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) MAX_INSTANCE_NAME_LEN = 40 -DEFAULT_DISK_SIZE_GB = 100 +# JarvisLabs VM storage is configurable through the `hdd` create parameter. +MIN_DISK_SIZE = Memory.parse("100GB") +CONFIGURABLE_DISK_SIZE = Range[Memory](min=MIN_DISK_SIZE, max=None) DEFAULT_USERNAME = "ubuntu" SSH_CONNECT_TIMEOUT_SECONDS = 10 SSH_SETUP_TIMEOUT_SECONDS = 240 SSH_LAUNCH_TIMEOUT_SECONDS = 60 -CREATE_FAILURE_POLL_INTERVAL_SECONDS = 5 -CREATE_FAILURE_POLL_TIMEOUT_SECONDS = 30 class JarvisLabsCompute( - ComputeWithFilteredOffersCached, + ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, ComputeWithPrivilegedSupport, ComputeWithInstanceVolumesSupport, @@ -52,22 +58,24 @@ def __init__(self, config: JarvisLabsConfig): super().__init__() self.config = config self.api_client = JarvisLabsAPIClient(config.creds.api_key) - self._catalog: Optional[gpuhunt.Catalog] = None + self._catalog = gpuhunt.Catalog(balance_resources=False, auto_reload=False) + self._catalog.add_provider(JarvisLabsProvider(api_key=self.config.creds.api_key)) - def get_offers_by_requirements( - self, requirements: Requirements - ) -> List[InstanceOfferWithAvailability]: + def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]: offers = get_catalog_offers( backend=BackendType.JARVISLABS, locations=self.config.regions or None, - requirements=requirements, - catalog=self._get_catalog(), + catalog=self._catalog, + configurable_disk_size=CONFIGURABLE_DISK_SIZE, ) return [ offer.with_availability(availability=InstanceAvailability.AVAILABLE) for offer in offers ] + def get_offers_modifiers(self, requirements: Requirements) -> Iterable[OfferModifier]: + return [get_offers_disk_modifier(CONFIGURABLE_DISK_SIZE, requirements)] + def create_instance( self, instance_offer: InstanceOfferWithAvailability, @@ -98,11 +106,6 @@ def create_instance( name=instance_name, ) - _raise_if_create_failed( - api_client=self.api_client, - machine_id=instance_id, - region=instance_offer.region, - ) except BaseException: if instance_id is not None: try: @@ -143,12 +146,12 @@ def update_provisioning_data( region=provisioning_data.region, ) if status is not None and str(status.get("status")).lower() == "failed": - raise ProvisioningError(_format_failed_status(status), status) + _raise_failed_status(status) return status = str(instance.get("status")).lower() if status == "failed": - raise ProvisioningError("JarvisLabs instance entered Failed state", instance) + _raise_failed_status(instance) if status != "running": return @@ -171,19 +174,6 @@ def terminate_instance( ): self.api_client.destroy_instance(machine_id=instance_id, region=region) - def _get_catalog(self) -> gpuhunt.Catalog: - if self._catalog is None: - try: - from gpuhunt.providers.jarvislabs import JarvisLabsProvider - except ImportError as e: - raise BackendError( - "JarvisLabs backend requires gpuhunt with JarvisLabs provider support" - ) from e - catalog = gpuhunt.Catalog(balance_resources=False, auto_reload=False) - catalog.add_provider(JarvisLabsProvider(api_key=self.config.creds.api_key)) - self._catalog = catalog - return self._catalog - def _get_jarvislabs_gpu_type(instance_offer: InstanceOfferWithAvailability) -> str: gpu = instance_offer.instance.resources.gpus[0] @@ -195,7 +185,7 @@ def _get_jarvislabs_gpu_type(instance_offer: InstanceOfferWithAvailability) -> s def _get_disk_size_gb(instance_offer: InstanceOfferWithAvailability) -> int: disk_size_gb = round(instance_offer.instance.resources.disk.size_mib / 1024) - return max(DEFAULT_DISK_SIZE_GB, disk_size_gb) + return max(round(MIN_DISK_SIZE), disk_size_gb) def _format_failed_status(status: dict) -> str: @@ -206,31 +196,8 @@ def _format_failed_status(status: dict) -> str: return f"JarvisLabs instance creation failed: {message}" -def _raise_if_create_failed( - *, - api_client: JarvisLabsAPIClient, - machine_id: str, - region: str, -): - deadline = time.monotonic() + CREATE_FAILURE_POLL_TIMEOUT_SECONDS - while time.monotonic() < deadline: - status = api_client.get_instance_status(machine_id=machine_id, region=region) - if status is None: - return - status_value = str(status.get("status")).lower() - if status_value == "failed": - message = _format_failed_status(status) - if _looks_like_no_capacity(message): - raise NoCapacityError(message) - raise ProvisioningError(message) - if status_value == "running": - return - time.sleep(CREATE_FAILURE_POLL_INTERVAL_SECONDS) - - -def _looks_like_no_capacity(message: str) -> bool: - message = message.lower() - return "capacity" in message or "available" in message or "stock" in message +def _raise_failed_status(status: dict) -> None: + raise ProvisioningError(_format_failed_status(status), status) def _get_ssh_username(instance: dict) -> str: diff --git a/src/tests/_internal/core/backends/jarvislabs/test_compute.py b/src/tests/_internal/core/backends/jarvislabs/test_compute.py index 987fc3d1e0..27be084e1f 100644 --- a/src/tests/_internal/core/backends/jarvislabs/test_compute.py +++ b/src/tests/_internal/core/backends/jarvislabs/test_compute.py @@ -3,11 +3,11 @@ import pytest from dstack._internal.core.backends.jarvislabs.compute import ( + CONFIGURABLE_DISK_SIZE, JarvisLabsCompute, _get_disk_size_gb, _get_jarvislabs_gpu_type, _get_ssh_username, - _raise_if_create_failed, ) from dstack._internal.core.backends.jarvislabs.models import JarvisLabsConfig, JarvisLabsCreds from dstack._internal.core.errors import NoCapacityError, ProvisioningError @@ -17,12 +17,14 @@ Gpu, InstanceAvailability, InstanceConfiguration, + InstanceOffer, InstanceOfferWithAvailability, InstanceType, Resources, SSHKey, ) -from dstack._internal.core.models.runs import JobProvisioningData +from dstack._internal.core.models.resources import ResourcesSpec +from dstack._internal.core.models.runs import JobProvisioningData, Requirements def _compute() -> JarvisLabsCompute: @@ -87,6 +89,16 @@ def _cpu_offer(*, disk_size_mib: int = 10 * 1024) -> InstanceOfferWithAvailabili ) +def _cpu_catalog_offer(*, disk_size_mib: int = 10 * 1024) -> InstanceOffer: + offer = _cpu_offer(disk_size_mib=disk_size_mib) + return InstanceOffer( + backend=offer.backend, + instance=offer.instance, + region=offer.region, + price=offer.price, + ) + + def test_get_jarvislabs_gpu_type_reconstructs_a100_80gb(): assert _get_jarvislabs_gpu_type(_gpu_offer()) == "A100-80GB" assert _get_jarvislabs_gpu_type(_gpu_offer(gpu_memory_mib=40 * 1024)) == "A100" @@ -98,6 +110,41 @@ def test_get_disk_size_gb_clamps_to_jarvislabs_vm_minimum(): assert _get_disk_size_gb(_gpu_offer(disk_size_mib=250 * 1024)) == 250 +def test_get_all_offers_uses_configurable_disk_size(): + compute = _compute() + + with patch( + "dstack._internal.core.backends.jarvislabs.compute.get_catalog_offers", + return_value=[_cpu_catalog_offer()], + ) as m: + offers = compute.get_all_offers_with_availability() + + assert len(offers) == 1 + assert offers[0].availability == InstanceAvailability.AVAILABLE + m.assert_called_once_with( + backend=BackendType.JARVISLABS, + locations=["india-noida-01"], + catalog=compute._catalog, + configurable_disk_size=CONFIGURABLE_DISK_SIZE, + ) + + +def test_get_offers_reuses_all_offers_cache_and_modifies_disk_size(): + compute = _compute() + compute.get_all_offers_with_availability = MagicMock( + return_value=[_cpu_offer(disk_size_mib=100 * 1024)] + ) + + offers_250gb = list(compute.get_offers(Requirements(resources=ResourcesSpec(disk="250GB")))) + offers_300gb = list(compute.get_offers(Requirements(resources=ResourcesSpec(disk="300GB")))) + + assert len(offers_250gb) == 1 + assert offers_250gb[0].instance.resources.disk.size_mib == 250 * 1024 + assert len(offers_300gb) == 1 + assert offers_300gb[0].instance.resources.disk.size_mib == 300 * 1024 + compute.get_all_offers_with_availability.assert_called_once() + + def test_create_gpu_instance_registers_ssh_key_and_creates_gpu_vm(): compute = _compute() compute.api_client.create_gpu_vm.return_value = "123" @@ -121,6 +168,7 @@ def test_create_gpu_instance_registers_ssh_key_and_creates_gpu_vm(): assert provisioning_data.username == "ubuntu" assert provisioning_data.dockerized is True assert provisioning_data.backend_data is None + compute.api_client.get_instance_status.assert_not_called() def test_create_gpu_instance_passes_spot_flag(): @@ -254,14 +302,11 @@ def test_terminate_instance_delegates_to_api_client(): ) -def test_create_instance_cleans_up_post_create_failure(): +def test_create_instance_propagates_create_failure_without_cleanup(): compute = _compute() - compute.api_client.create_gpu_vm.return_value = "123" - compute.api_client.get_instance_status.return_value = { - "status": "Failed", - "error": "L4 not available at this moment, please try again later", - "code": 404, - } + compute.api_client.create_gpu_vm.side_effect = NoCapacityError( + "L4 not available at this moment, please try again later" + ) with patch( "dstack._internal.core.backends.jarvislabs.compute.generate_unique_instance_name", @@ -270,41 +315,58 @@ def test_create_instance_cleans_up_post_create_failure(): with pytest.raises(NoCapacityError): compute.create_instance(_gpu_offer(spot=True), _instance_config(), None) - compute.api_client.destroy_instance.assert_called_once_with( - machine_id="123", - region="india-noida-01", - ) + compute.api_client.destroy_instance.assert_not_called() -def test_raise_if_create_failed_due_to_no_capacity(): - api_client = MagicMock() - api_client.get_instance_status.return_value = { +def test_update_provisioning_data_raises_provisioning_error_from_failed_capacity_status(): + compute = _compute() + compute.api_client.get_instance.return_value = None + compute.api_client.get_instance_status.return_value = { "status": "Failed", "error": "L4 not available at this moment, please try again later", "code": 404, } + provisioning_data = JobProvisioningData( + backend=BackendType.JARVISLABS, + instance_type=_gpu_offer().instance, + instance_id="123", + region="india-noida-01", + price=1.49, + username="ubuntu", + ssh_port=22, + dockerized=True, + ) - with patch("dstack._internal.core.backends.jarvislabs.compute.time.sleep"): - with pytest.raises(NoCapacityError): - _raise_if_create_failed( - api_client=api_client, - machine_id="123", - region="india-noida-01", - ) + with pytest.raises(ProvisioningError): + compute.update_provisioning_data( + provisioning_data, + project_ssh_public_key="ssh-rsa AAAA test", + project_ssh_private_key="private-key", + ) -def test_raise_if_create_failed_raises_provisioning_error(): - api_client = MagicMock() - api_client.get_instance_status.return_value = { +def test_update_provisioning_data_raises_provisioning_error_from_failed_status(): + compute = _compute() + compute.api_client.get_instance.return_value = None + compute.api_client.get_instance_status.return_value = { "status": "Failed", "error": "image setup failed", "code": 500, } + provisioning_data = JobProvisioningData( + backend=BackendType.JARVISLABS, + instance_type=_gpu_offer().instance, + instance_id="123", + region="india-noida-01", + price=1.49, + username="ubuntu", + ssh_port=22, + dockerized=True, + ) - with patch("dstack._internal.core.backends.jarvislabs.compute.time.sleep"): - with pytest.raises(ProvisioningError): - _raise_if_create_failed( - api_client=api_client, - machine_id="123", - region="india-noida-01", - ) + with pytest.raises(ProvisioningError): + compute.update_provisioning_data( + provisioning_data, + project_ssh_public_key="ssh-rsa AAAA test", + project_ssh_private_key="private-key", + ) From 82b9069503be500f2c7d7bea8d7d72b67722f23a Mon Sep 17 00:00:00 2001 From: Andrey Cheptsov Date: Thu, 21 May 2026 09:34:56 +0200 Subject: [PATCH 4/5] Use gpuhunt 0.1.22 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 617f750bbd..7d36f320c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "python-multipart>=0.0.16", "filelock", "psutil", - "gpuhunt @ git+https://github.com/dstackai/gpuhunt.git@jarvislabs", + "gpuhunt==0.1.22", "argcomplete>=3.5.0", "ignore-python>=0.2.0", "orjson", From a2a38f2be36d7aff020d3453cfaa32ab24775a1a Mon Sep 17 00:00:00 2001 From: Andrey Cheptsov Date: Thu, 21 May 2026 09:38:36 +0200 Subject: [PATCH 5/5] Remove stale Hatch direct references setting --- pyproject.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7d36f320c3..55b31f3572 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,9 +67,6 @@ artifacts = [ "src/dstack/_internal/server/statics/**", ] -[tool.hatch.metadata] -allow-direct-references = true - [tool.hatch.metadata.hooks.fancy-pypi-readme] content-type = "text/markdown"