-
-
Notifications
You must be signed in to change notification settings - Fork 222
add hotaisle backend #2935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add hotaisle backend #2935
Changes from 2 commits
f5f2f2e
b9ca0be
9558c29
1bf72e0
16e77e3
d2846ae
c19065b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| # Hotaisle backend for dstack |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| from typing import Any, Dict, Optional | ||
|
|
||
| import requests | ||
|
|
||
| from dstack._internal.utils.logging import get_logger | ||
|
|
||
| API_URL = "https://admin.hotaisle.app/api" | ||
|
|
||
| logger = get_logger(__name__) | ||
|
|
||
|
|
||
| class HotaisleAPIClient: | ||
| def __init__(self, api_key: str, team_handle: str): | ||
| self.api_key = api_key | ||
| self.team_handle = team_handle | ||
|
|
||
| def validate_api_key(self) -> bool: | ||
| try: | ||
| self._validate_user_and_team() | ||
| return True | ||
| except requests.HTTPError as e: | ||
| if e.response.status_code in [401, 403]: | ||
| return False | ||
| raise e | ||
| except ValueError: | ||
| return False | ||
|
|
||
| def _validate_user_and_team(self) -> None: | ||
| url = f"{API_URL}/user/" | ||
| response = self._make_request("GET", url) | ||
|
|
||
| if response.ok: | ||
| user_data = response.json() | ||
| else: | ||
| response.raise_for_status() | ||
|
|
||
| teams = user_data.get("teams", []) | ||
| if not teams: | ||
| raise ValueError("No Hotaisle teams found for this user") | ||
|
|
||
| available_teams = [team["handle"] for team in teams] | ||
| if self.team_handle not in available_teams: | ||
| raise ValueError(f"Hotaisle Team '{self.team_handle}' not found.") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (optional) This validation is already better than in most our backends, but we can further improve it by validating the roles assigned to the key, so that users can see permission-related errors earlier - when configuring the backend rather than when creating instances. It should be possible to validate everything (the key, the user role, and the team roles) by calling only
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you. Will plan to update it in the next iteration. |
||
|
|
||
| def upload_ssh_key(self, public_key: str) -> bool: | ||
| url = f"{API_URL}/user/ssh_keys/" | ||
| payload = {"authorized_key": public_key} | ||
|
|
||
| response = self._make_request("POST", url, json=payload) | ||
|
|
||
| if response.status_code == 409: | ||
| return True # Key already exists - success | ||
| if not response.ok: | ||
| response.raise_for_status() | ||
| return True | ||
|
|
||
| def create_virtual_machine( | ||
| self, vm_payload: Dict[str, Any], instance_name: str | ||
| ) -> Dict[str, Any]: | ||
| url = f"{API_URL}/teams/{self.team_handle}/virtual_machines/" | ||
| response = self._make_request("POST", url, json=vm_payload) | ||
|
|
||
| if not response.ok: | ||
| response.raise_for_status() | ||
|
|
||
| vm_data = response.json() | ||
| return vm_data | ||
|
jvstme marked this conversation as resolved.
Outdated
|
||
|
|
||
| def get_vm_state(self, vm_name: str) -> str: | ||
| url = f"{API_URL}/teams/{self.team_handle}/virtual_machines/{vm_name}/state/" | ||
| response = self._make_request("GET", url) | ||
|
|
||
| if not response.ok: | ||
| response.raise_for_status() | ||
|
|
||
| state_data = response.json() | ||
| return state_data["state"] | ||
|
|
||
| def terminate_virtual_machine(self, vm_name: str) -> bool: | ||
| url = f"{API_URL}/teams/{self.team_handle}/virtual_machines/{vm_name}/" | ||
| response = self._make_request("DELETE", url) | ||
|
|
||
| if response.status_code == 204: | ||
| return True | ||
| else: | ||
| response.raise_for_status() | ||
|
jvstme marked this conversation as resolved.
Outdated
|
||
|
|
||
| def _make_request( | ||
| self, method: str, url: str, json: Optional[Dict[str, Any]] = None, timeout: int = 30 | ||
| ) -> requests.Response: | ||
| headers = { | ||
| "accept": "application/json", | ||
| "Authorization": self.api_key, | ||
|
jvstme marked this conversation as resolved.
Outdated
|
||
| } | ||
| if json is not None: | ||
| headers["Content-Type"] = "application/json" | ||
|
|
||
| return requests.request( | ||
| method=method, | ||
| url=url, | ||
| headers=headers, | ||
| json=json, | ||
| timeout=timeout, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| from dstack._internal.core.backends.base.backend import Backend | ||
| from dstack._internal.core.backends.hotaisle.compute import HotaisleCompute | ||
| from dstack._internal.core.backends.hotaisle.models import HotaisleConfig | ||
| from dstack._internal.core.models.backends.base import BackendType | ||
|
|
||
|
|
||
| class HotaisleBackend(Backend): | ||
| TYPE = BackendType.HOTAISLE | ||
| COMPUTE_CLASS = HotaisleCompute | ||
|
|
||
| def __init__(self, config: HotaisleConfig): | ||
| self.config = config | ||
| self._compute = HotaisleCompute(self.config) | ||
|
|
||
| def compute(self) -> HotaisleCompute: | ||
| return self._compute |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,213 @@ | ||
| import shlex | ||
| import subprocess | ||
| import tempfile | ||
| from threading import Thread | ||
| from typing import List, Optional | ||
|
|
||
| import gpuhunt | ||
| from gpuhunt.providers.hotaisle import HotAisleProvider | ||
|
|
||
| from dstack._internal.core.backends.base.compute import ( | ||
| Compute, | ||
| ComputeWithCreateInstanceSupport, | ||
| generate_unique_instance_name, | ||
| get_shim_commands, | ||
| ) | ||
| from dstack._internal.core.backends.base.offers import get_catalog_offers | ||
| from dstack._internal.core.backends.hotaisle.api_client import HotaisleAPIClient | ||
| from dstack._internal.core.backends.hotaisle.models import HotaisleConfig | ||
| from dstack._internal.core.models.backends.base import BackendType | ||
| from dstack._internal.core.models.instances import ( | ||
| InstanceAvailability, | ||
| InstanceConfiguration, | ||
| InstanceOfferWithAvailability, | ||
| ) | ||
| from dstack._internal.core.models.placement import PlacementGroup | ||
| from dstack._internal.core.models.runs import JobProvisioningData, Requirements | ||
| from dstack._internal.utils.logging import get_logger | ||
|
|
||
| logger = get_logger(__name__) | ||
|
|
||
| MAX_INSTANCE_NAME_LEN = 60 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit) Unused |
||
|
|
||
|
|
||
| class HotaisleCompute( | ||
| ComputeWithCreateInstanceSupport, | ||
| Compute, | ||
| ): | ||
| def __init__(self, config: HotaisleConfig): | ||
| super().__init__() | ||
| self.config = config | ||
| self.api_client = HotaisleAPIClient(config.creds.api_key, config.team_handle) | ||
| self.catalog = gpuhunt.Catalog(balance_resources=False, auto_reload=False) | ||
| self.catalog.add_provider( | ||
| HotAisleProvider(api_key=config.creds.api_key, team_handle=config.team_handle) | ||
| ) | ||
|
|
||
| def get_offers( | ||
| self, requirements: Optional[Requirements] = None | ||
| ) -> List[InstanceOfferWithAvailability]: | ||
| offers = get_catalog_offers( | ||
| backend=BackendType.HOTAISLE, | ||
| locations=self.config.regions or None, | ||
| requirements=requirements, | ||
| catalog=self.catalog, | ||
| ) | ||
| offers = [ | ||
| InstanceOfferWithAvailability( | ||
| **offer.dict(), availability=InstanceAvailability.AVAILABLE | ||
| ) | ||
| for offer in offers | ||
| ] | ||
| return offers | ||
|
|
||
| def get_payload_from_offer(self, instance_type) -> dict: | ||
| # Only two instance types are available in Hotaisle with CPUs: 8-core and 13-core. Other fields are | ||
| # not configurable. | ||
| cpu_cores = instance_type.resources.cpus | ||
| if cpu_cores == 8: | ||
| cpu_model = "Xeon Platinum 8462Y+" | ||
| frequency = 2800000000 | ||
| else: # cpu_cores == 13 | ||
| cpu_model = "Xeon Platinum 8470" | ||
| frequency = 2000000000 | ||
|
jvstme marked this conversation as resolved.
Outdated
|
||
|
|
||
| return { | ||
| "cpu_cores": cpu_cores, | ||
| "cpus": { | ||
| "count": 1, | ||
| "manufacturer": "Intel", | ||
| "model": cpu_model, | ||
| "cores": cpu_cores, | ||
| "frequency": frequency, | ||
| }, | ||
| "disk_capacity": 13194139533312, | ||
| "ram_capacity": 240518168576, | ||
| "gpus": [ | ||
| { | ||
| "count": len(instance_type.resources.gpus), | ||
| "manufacturer": "AMD", | ||
| "model": "MI300X", | ||
|
jvstme marked this conversation as resolved.
Outdated
|
||
| } | ||
| ], | ||
| } | ||
|
|
||
| def create_instance( | ||
| self, | ||
| instance_offer: InstanceOfferWithAvailability, | ||
| instance_config: InstanceConfiguration, | ||
| placement_group: Optional[PlacementGroup], | ||
| ) -> JobProvisioningData: | ||
| instance_name = generate_unique_instance_name( | ||
| instance_config, max_length=MAX_INSTANCE_NAME_LEN | ||
| ) | ||
| project_ssh_key = instance_config.ssh_keys[0] | ||
| self.api_client.upload_ssh_key(project_ssh_key.public) | ||
| vm_payload = self.get_payload_from_offer(instance_offer.instance) | ||
| vm_data = self.api_client.create_virtual_machine(vm_payload, instance_name) | ||
| return JobProvisioningData( | ||
| backend=instance_offer.backend, | ||
| instance_type=instance_offer.instance, | ||
| instance_id=vm_data["name"], | ||
| hostname=None, | ||
| internal_ip=None, | ||
| region=instance_offer.region, | ||
| price=instance_offer.price, | ||
| username="hotaisle", | ||
| ssh_port=22, | ||
| dockerized=True, | ||
| ssh_proxy=None, | ||
| backend_data=vm_data["ip_address"], | ||
|
jvstme marked this conversation as resolved.
Outdated
|
||
| ) | ||
|
|
||
| def update_provisioning_data( | ||
| self, | ||
| provisioning_data: JobProvisioningData, | ||
| project_ssh_public_key: str, | ||
| project_ssh_private_key: str, | ||
| ): | ||
| vm_state = self.api_client.get_vm_state(provisioning_data.instance_id) | ||
| if vm_state == "running": | ||
| if provisioning_data.hostname is None and provisioning_data.backend_data: | ||
| provisioning_data.hostname = provisioning_data.backend_data | ||
| commands = get_shim_commands( | ||
| authorized_keys=[project_ssh_public_key], | ||
| arch=provisioning_data.instance_type.resources.cpu_arch, | ||
| ) | ||
| launch_command = "sudo sh -c " + shlex.quote(" && ".join(commands)) | ||
| thread = Thread( | ||
| target=_start_runner, | ||
| kwargs={ | ||
| "hostname": provisioning_data.hostname, | ||
| "project_ssh_private_key": project_ssh_private_key, | ||
| "launch_command": launch_command, | ||
| }, | ||
| daemon=True, | ||
| ) | ||
| thread.start() | ||
|
|
||
| def terminate_instance( | ||
| self, instance_id: str, region: str, backend_data: Optional[str] = None | ||
| ): | ||
| vm_name = instance_id | ||
| self.api_client.terminate_virtual_machine(vm_name) | ||
|
jvstme marked this conversation as resolved.
|
||
|
|
||
|
|
||
| def _start_runner( | ||
| hostname: str, | ||
| project_ssh_private_key: str, | ||
| launch_command: str, | ||
| ): | ||
| _setup_instance( | ||
| hostname=hostname, | ||
| ssh_private_key=project_ssh_private_key, | ||
| ) | ||
| _launch_runner( | ||
| hostname=hostname, | ||
| ssh_private_key=project_ssh_private_key, | ||
| launch_command=launch_command, | ||
| ) | ||
|
|
||
|
|
||
| def _setup_instance( | ||
| hostname: str, | ||
| ssh_private_key: str, | ||
| ): | ||
| setup_commands = ("sudo apt-get update",) | ||
| _run_ssh_command( | ||
| hostname=hostname, ssh_private_key=ssh_private_key, command=" && ".join(setup_commands) | ||
| ) | ||
|
jvstme marked this conversation as resolved.
Outdated
|
||
|
|
||
|
|
||
| def _launch_runner( | ||
| hostname: str, | ||
| ssh_private_key: str, | ||
| launch_command: str, | ||
| ): | ||
| daemonized_command = f"{launch_command.rstrip('&')} >/tmp/dstack-shim.log 2>&1 & disown" | ||
| _run_ssh_command( | ||
| hostname=hostname, | ||
| ssh_private_key=ssh_private_key, | ||
| command=daemonized_command, | ||
| ) | ||
|
|
||
|
|
||
| def _run_ssh_command(hostname: str, ssh_private_key: str, command: str): | ||
| with tempfile.NamedTemporaryFile("w+", 0o600) as f: | ||
| f.write(ssh_private_key) | ||
| f.flush() | ||
| subprocess.run( | ||
| [ | ||
| "ssh", | ||
| "-F", | ||
| "none", | ||
| "-o", | ||
| "StrictHostKeyChecking=no", | ||
| "-i", | ||
| f.name, | ||
| f"hotaisle@{hostname}", | ||
| command, | ||
| ], | ||
| stdout=subprocess.DEVNULL, | ||
| stderr=subprocess.DEVNULL, | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.