Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 158 additions & 65 deletions src/dstack/_internal/core/backends/verda/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
get_offers_disk_modifier,
)
from dstack._internal.core.backends.verda.models import VerdaConfig
from dstack._internal.core.errors import NoCapacityError
from dstack._internal.core.errors import BackendError, NoCapacityError
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import CoreModel
from dstack._internal.core.models.instances import (
InstanceAvailability,
InstanceConfiguration,
Expand All @@ -31,7 +32,6 @@
from dstack._internal.core.models.resources import Memory, Range
from dstack._internal.core.models.runs import JobProvisioningData, Requirements
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.ssh import get_public_key_fingerprint

logger = get_logger("verda.compute")

Expand Down Expand Up @@ -101,54 +101,72 @@ def create_instance(
instance_config, max_length=MAX_INSTANCE_NAME_LEN
)
public_keys = instance_config.get_public_keys()
ssh_ids = []
for ssh_public_key in public_keys:
ssh_ids.append(
# verda allows you to use the same name
_get_or_create_ssh_key(
client=self.client,
name=f"dstack-{instance_config.instance_name}.key",
public_key=ssh_public_key,
ssh_ids: List[str] = []
startup_script_id: Optional[str] = None
try:
for idx, ssh_public_key in enumerate(public_keys):
ssh_ids.append(
_create_ssh_key(
client=self.client,
name=f"{instance_name}-{idx}.key",
public_key=ssh_public_key,
)
)
)

commands = get_shim_commands()
startup_script = " ".join([" && ".join(commands)])
script_name = f"dstack-{instance_config.instance_name}.sh"
startup_script_ids = _get_or_create_startup_scrpit(
client=self.client,
name=script_name,
script=startup_script,
)
commands = get_shim_commands()
startup_script = " ".join([" && ".join(commands)])
script_name = f"{instance_name}.sh"
startup_script_id = _create_startup_script(
client=self.client,
name=script_name,
script=startup_script,
)

disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
image_id = _get_vm_image_id(instance_offer)

logger.debug(
"Deploying Verda instance",
{
"instance_type": instance_offer.instance.name,
"ssh_key_ids": ssh_ids,
"startup_script_id": startup_script_ids,
"hostname": instance_name,
"description": instance_name,
"image": image_id,
"disk_size": disk_size,
"location": instance_offer.region,
},
)
instance = _deploy_instance(
client=self.client,
instance_type=instance_offer.instance.name,
ssh_key_ids=ssh_ids,
startup_script_id=startup_script_ids,
hostname=instance_name,
description=instance_name,
image=image_id,
disk_size=disk_size,
is_spot=instance_offer.instance.resources.spot,
location=instance_offer.region,
)
disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
image_id = _get_vm_image_id(instance_offer)

logger.debug(
"Deploying Verda instance",
{
"instance_type": instance_offer.instance.name,
"ssh_key_ids": ssh_ids,
"startup_script_id": startup_script_id,
"hostname": instance_name,
"description": instance_name,
"image": image_id,
"disk_size": disk_size,
"location": instance_offer.region,
},
)
instance = _deploy_instance(
client=self.client,
instance_type=instance_offer.instance.name,
ssh_key_ids=ssh_ids,
startup_script_id=startup_script_id,
hostname=instance_name,
description=instance_name,
image=image_id,
disk_size=disk_size,
is_spot=instance_offer.instance.resources.spot,
location=instance_offer.region,
)
except Exception:
# startup_script_id and ssh_key_ids are per-instance. Ensure no leaks on failures.
try:
_delete_startup_script(self.client, startup_script_id)
except Exception:
logger.exception(
"Failed to cleanup startup script %s after provisioning failure.",
startup_script_id,
)
try:
_delete_ssh_keys(self.client, ssh_ids)
except Exception:
logger.exception(
"Failed to cleanup ssh keys %s after provisioning failure.",
ssh_ids,
)
raise
return JobProvisioningData(
backend=instance_offer.backend,
instance_type=instance_offer.instance,
Expand All @@ -161,12 +179,16 @@ def create_instance(
ssh_port=22,
dockerized=True,
ssh_proxy=None,
backend_data=None,
backend_data=VerdaInstanceBackendData(
startup_script_id=startup_script_id,
ssh_key_ids=ssh_ids,
).json(),
)

def terminate_instance(
self, instance_id: str, region: str, backend_data: Optional[str] = None
):
backend_data_parsed = VerdaInstanceBackendData.load(backend_data)
try:
self.client.instances.action(id_list=[instance_id], action="delete")
except APIException as e:
Expand All @@ -175,8 +197,10 @@ def terminate_instance(
"Can't discontinue a discontinued instance",
]:
logger.debug("Skipping instance %s termination. Instance not found.", instance_id)
return
raise
else:
raise
_delete_startup_script(self.client, backend_data_parsed.startup_script_id)
_delete_ssh_keys(self.client, backend_data_parsed.ssh_key_ids)

def update_provisioning_data(
self,
Expand All @@ -200,26 +224,84 @@ def _get_vm_image_id(instance_offer: InstanceOfferWithAvailability) -> str:
return "77777777-4f48-4249-82b3-f199fb9b701b"


def _get_or_create_ssh_key(client: VerdaClient, name: str, public_key: str) -> str:
fingerprint = get_public_key_fingerprint(public_key)
keys = client.ssh_keys.get()
found_keys = [key for key in keys if fingerprint == get_public_key_fingerprint(key.public_key)]
if found_keys:
key = found_keys[0]
def _create_ssh_key(client: VerdaClient, name: str, public_key: str) -> str:
try:
key = client.ssh_keys.create(name, public_key)
return key.id
key = client.ssh_keys.create(name, public_key)
return key.id
except APIException as e:
raise BackendError(f"Verda API error while creating SSH key: {e.message}")


def _get_or_create_startup_scrpit(client: VerdaClient, name: str, script: str) -> str:
scripts = client.startup_scripts.get()
found_scripts = [startup_script for startup_script in scripts if script == startup_script]
if found_scripts:
startup_script = found_scripts[0]
def _create_startup_script(client: VerdaClient, name: str, script: str) -> str:
try:
startup_script = client.startup_scripts.create(name, script)
return startup_script.id
except APIException as e:
raise BackendError(f"Verda API error while creating startup script: {e.message}")

startup_script = client.startup_scripts.create(name, script)
return startup_script.id

def _delete_startup_script(client: VerdaClient, startup_script_id: Optional[str]) -> None:
if startup_script_id is None:
return
try:
client.startup_scripts.delete_by_id(startup_script_id)
except APIException as e:
if _is_startup_script_not_found_error(e):
logger.debug(
"Skipping startup script %s deletion. Startup script not found.",
startup_script_id,
)
return
raise


def _delete_ssh_keys(client: VerdaClient, ssh_key_ids: Optional[List[str]]) -> None:
if not ssh_key_ids:
return
for ssh_key_id in ssh_key_ids:
_delete_ssh_key(client, ssh_key_id)


def _delete_ssh_key(client: VerdaClient, ssh_key_id: str) -> None:
try:
client.ssh_keys.delete_by_id(ssh_key_id)
except APIException as e:
if _is_ssh_key_not_found_error(e):
logger.debug("Skipping ssh key %s deletion. SSH key not found.", ssh_key_id)
return
raise


def _is_ssh_key_not_found_error(error: APIException) -> bool:
code = (error.code or "").lower()
message = (error.message or "").lower()
if code == "not_found":
return True
if code not in {"", "invalid_request"}:
return False
return (
message == "invalid ssh-key id"
or message == "invalid ssh key id"
or message == "not found"
or ("ssh-key id" in message and "invalid" in message)
or ("ssh key id" in message and "invalid" in message)
)


def _is_startup_script_not_found_error(error: APIException) -> bool:
code = (error.code or "").lower()
message = (error.message or "").lower()
if code == "not_found":
return True
if code not in {"", "invalid_request"}:
return False
return (
message == "invalid startup script id"
or message == "invalid script id"
or message == "not found"
or ("startup script id" in message and "invalid" in message)
or ("script id" in message and "invalid" in message)
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) In my testing, deleting a nonexistent key or script does not actually raise any exceptions. So I assume both _is_ssh_key_not_found_error and _is_startup_script_not_found_error are hallucinations and can be removed.



def _get_instance_by_id(
Expand Down Expand Up @@ -264,3 +346,14 @@ def _deploy_instance(
raise NoCapacityError(f"Verda API error: {e.message}")

return instance


class VerdaInstanceBackendData(CoreModel):
startup_script_id: Optional[str] = None
ssh_key_ids: Optional[List[str]] = None

@classmethod
def load(cls, raw: Optional[str]) -> "VerdaInstanceBackendData":
if raw is None:
return cls()
return cls.__response__.parse_raw(raw)
Loading
Loading