Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/dstack/_internal/core/backends/configurators.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@
except ImportError:
pass

try:
from dstack._internal.core.backends.hotaisle.configurator import (
HotaisleConfigurator,
)

_CONFIGURATOR_CLASSES.append(HotaisleConfigurator)
except ImportError:
pass

try:
from dstack._internal.core.backends.kubernetes.configurator import (
KubernetesConfigurator,
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/hotaisle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Hotaisle backend for dstack
104 changes: 104 additions & 0 deletions src/dstack/_internal/core/backends/hotaisle/api_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from typing import Any, Dict, Optional

import requests

from dstack._internal.utils.logging import get_logger

API_URL = "https://admin.hotaisle.app/api"

logger = get_logger(__name__)


class HotaisleAPIClient:
def __init__(self, api_key: str, team_handle: str):
self.api_key = api_key
self.team_handle = team_handle

def validate_api_key(self) -> bool:
try:
self._validate_user_and_team()
return True
except requests.HTTPError as e:
if e.response.status_code in [401, 403]:
return False
raise e
except ValueError:
return False

def _validate_user_and_team(self) -> None:
url = f"{API_URL}/user/"
response = self._make_request("GET", url)

if response.ok:
user_data = response.json()
else:
response.raise_for_status()
Comment thread
jvstme marked this conversation as resolved.
Outdated

teams = user_data.get("teams", [])
if not teams:
raise ValueError("No Hotaisle teams found for this user")

available_teams = [team["handle"] for team in teams]
if self.team_handle not in available_teams:
raise ValueError(f"Hotaisle Team '{self.team_handle}' not found.")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(optional) This validation is already better than in most our backends, but we can further improve it by validating the roles assigned to the key, so that users can see permission-related errors earlier - when configuring the backend rather than when creating instances.

It should be possible to validate everything (the key, the user role, and the team roles) by calling only GET /user/api_keys/{prefix}/

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. Will plan to update it in the next iteration.


def upload_ssh_key(self, public_key: str) -> bool:
url = f"{API_URL}/user/ssh_keys/"
payload = {"authorized_key": public_key}

response = self._make_request("POST", url, json=payload)

if response.status_code == 409:
return True # Key already exists - success
if not response.ok:
response.raise_for_status()
return True

def create_virtual_machine(
self, vm_payload: Dict[str, Any], instance_name: str
) -> Dict[str, Any]:
url = f"{API_URL}/teams/{self.team_handle}/virtual_machines/"
response = self._make_request("POST", url, json=vm_payload)

if not response.ok:
response.raise_for_status()

vm_data = response.json()
return vm_data
Comment thread
jvstme marked this conversation as resolved.
Outdated

def get_vm_state(self, vm_name: str) -> str:
url = f"{API_URL}/teams/{self.team_handle}/virtual_machines/{vm_name}/state/"
response = self._make_request("GET", url)

if not response.ok:
response.raise_for_status()

state_data = response.json()
return state_data["state"]

def terminate_virtual_machine(self, vm_name: str) -> bool:
url = f"{API_URL}/teams/{self.team_handle}/virtual_machines/{vm_name}/"
response = self._make_request("DELETE", url)

if response.status_code == 204:
return True
else:
response.raise_for_status()
Comment thread
jvstme marked this conversation as resolved.
Outdated

def _make_request(
self, method: str, url: str, json: Optional[Dict[str, Any]] = None, timeout: int = 30
) -> requests.Response:
headers = {
"accept": "application/json",
"Authorization": self.api_key,
Comment thread
jvstme marked this conversation as resolved.
Outdated
}
if json is not None:
headers["Content-Type"] = "application/json"

return requests.request(
method=method,
url=url,
headers=headers,
json=json,
timeout=timeout,
)
16 changes: 16 additions & 0 deletions src/dstack/_internal/core/backends/hotaisle/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from dstack._internal.core.backends.base.backend import Backend
from dstack._internal.core.backends.hotaisle.compute import HotaisleCompute
from dstack._internal.core.backends.hotaisle.models import HotaisleConfig
from dstack._internal.core.models.backends.base import BackendType


class HotaisleBackend(Backend):
TYPE = BackendType.HOTAISLE
COMPUTE_CLASS = HotaisleCompute

def __init__(self, config: HotaisleConfig):
self.config = config
self._compute = HotaisleCompute(self.config)

def compute(self) -> HotaisleCompute:
return self._compute
213 changes: 213 additions & 0 deletions src/dstack/_internal/core/backends/hotaisle/compute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import shlex
import subprocess
import tempfile
from threading import Thread
from typing import List, Optional

import gpuhunt
from gpuhunt.providers.hotaisle import HotAisleProvider

from dstack._internal.core.backends.base.compute import (
Compute,
ComputeWithCreateInstanceSupport,
generate_unique_instance_name,
get_shim_commands,
)
from dstack._internal.core.backends.base.offers import get_catalog_offers
from dstack._internal.core.backends.hotaisle.api_client import HotaisleAPIClient
from dstack._internal.core.backends.hotaisle.models import HotaisleConfig
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.instances import (
InstanceAvailability,
InstanceConfiguration,
InstanceOfferWithAvailability,
)
from dstack._internal.core.models.placement import PlacementGroup
from dstack._internal.core.models.runs import JobProvisioningData, Requirements
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)

MAX_INSTANCE_NAME_LEN = 60
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) Unused



class HotaisleCompute(
ComputeWithCreateInstanceSupport,
Compute,
):
def __init__(self, config: HotaisleConfig):
super().__init__()
self.config = config
self.api_client = HotaisleAPIClient(config.creds.api_key, config.team_handle)
self.catalog = gpuhunt.Catalog(balance_resources=False, auto_reload=False)
self.catalog.add_provider(
HotAisleProvider(api_key=config.creds.api_key, team_handle=config.team_handle)
)

def get_offers(
self, requirements: Optional[Requirements] = None
) -> List[InstanceOfferWithAvailability]:
offers = get_catalog_offers(
backend=BackendType.HOTAISLE,
locations=self.config.regions or None,
requirements=requirements,
catalog=self.catalog,
)
offers = [
InstanceOfferWithAvailability(
**offer.dict(), availability=InstanceAvailability.AVAILABLE
)
for offer in offers
]
return offers

def get_payload_from_offer(self, instance_type) -> dict:
# Only two instance types are available in Hotaisle with CPUs: 8-core and 13-core. Other fields are
# not configurable.
cpu_cores = instance_type.resources.cpus
if cpu_cores == 8:
cpu_model = "Xeon Platinum 8462Y+"
frequency = 2800000000
else: # cpu_cores == 13
cpu_model = "Xeon Platinum 8470"
frequency = 2000000000
Comment thread
jvstme marked this conversation as resolved.
Outdated

return {
"cpu_cores": cpu_cores,
"cpus": {
"count": 1,
"manufacturer": "Intel",
"model": cpu_model,
"cores": cpu_cores,
"frequency": frequency,
},
"disk_capacity": 13194139533312,
"ram_capacity": 240518168576,
"gpus": [
{
"count": len(instance_type.resources.gpus),
"manufacturer": "AMD",
"model": "MI300X",
Comment thread
jvstme marked this conversation as resolved.
Outdated
}
],
}

def create_instance(
self,
instance_offer: InstanceOfferWithAvailability,
instance_config: InstanceConfiguration,
placement_group: Optional[PlacementGroup],
) -> JobProvisioningData:
instance_name = generate_unique_instance_name(
instance_config, max_length=MAX_INSTANCE_NAME_LEN
)
project_ssh_key = instance_config.ssh_keys[0]
self.api_client.upload_ssh_key(project_ssh_key.public)
vm_payload = self.get_payload_from_offer(instance_offer.instance)
vm_data = self.api_client.create_virtual_machine(vm_payload, instance_name)
return JobProvisioningData(
backend=instance_offer.backend,
instance_type=instance_offer.instance,
instance_id=vm_data["name"],
hostname=None,
internal_ip=None,
region=instance_offer.region,
price=instance_offer.price,
username="hotaisle",
ssh_port=22,
dockerized=True,
ssh_proxy=None,
backend_data=vm_data["ip_address"],
Comment thread
jvstme marked this conversation as resolved.
Outdated
)

def update_provisioning_data(
self,
provisioning_data: JobProvisioningData,
project_ssh_public_key: str,
project_ssh_private_key: str,
):
vm_state = self.api_client.get_vm_state(provisioning_data.instance_id)
if vm_state == "running":
if provisioning_data.hostname is None and provisioning_data.backend_data:
provisioning_data.hostname = provisioning_data.backend_data
commands = get_shim_commands(
authorized_keys=[project_ssh_public_key],
arch=provisioning_data.instance_type.resources.cpu_arch,
)
launch_command = "sudo sh -c " + shlex.quote(" && ".join(commands))
thread = Thread(
target=_start_runner,
kwargs={
"hostname": provisioning_data.hostname,
"project_ssh_private_key": project_ssh_private_key,
"launch_command": launch_command,
},
daemon=True,
)
thread.start()

def terminate_instance(
self, instance_id: str, region: str, backend_data: Optional[str] = None
):
vm_name = instance_id
self.api_client.terminate_virtual_machine(vm_name)
Comment thread
jvstme marked this conversation as resolved.


def _start_runner(
hostname: str,
project_ssh_private_key: str,
launch_command: str,
):
_setup_instance(
hostname=hostname,
ssh_private_key=project_ssh_private_key,
)
_launch_runner(
hostname=hostname,
ssh_private_key=project_ssh_private_key,
launch_command=launch_command,
)


def _setup_instance(
hostname: str,
ssh_private_key: str,
):
setup_commands = ("sudo apt-get update",)
_run_ssh_command(
hostname=hostname, ssh_private_key=ssh_private_key, command=" && ".join(setup_commands)
)
Comment thread
jvstme marked this conversation as resolved.
Outdated


def _launch_runner(
hostname: str,
ssh_private_key: str,
launch_command: str,
):
daemonized_command = f"{launch_command.rstrip('&')} >/tmp/dstack-shim.log 2>&1 & disown"
_run_ssh_command(
hostname=hostname,
ssh_private_key=ssh_private_key,
command=daemonized_command,
)


def _run_ssh_command(hostname: str, ssh_private_key: str, command: str):
with tempfile.NamedTemporaryFile("w+", 0o600) as f:
f.write(ssh_private_key)
f.flush()
subprocess.run(
[
"ssh",
"-F",
"none",
"-o",
"StrictHostKeyChecking=no",
"-i",
f.name,
f"hotaisle@{hostname}",
command,
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
Loading
Loading