Skip to content

Commit 2e54d24

Browse files
author
Andrey Cheptsov
committed
Make Verda startup scripts and SSH keys lifecycle symmetric
1 parent 3277143 commit 2e54d24

File tree

2 files changed

+646
-65
lines changed

2 files changed

+646
-65
lines changed

src/dstack/_internal/core/backends/verda/compute.py

Lines changed: 162 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
get_offers_disk_modifier,
2020
)
2121
from dstack._internal.core.backends.verda.models import VerdaConfig
22-
from dstack._internal.core.errors import NoCapacityError
22+
from dstack._internal.core.errors import BackendError, NoCapacityError
2323
from dstack._internal.core.models.backends.base import BackendType
24+
from dstack._internal.core.models.common import CoreModel
2425
from dstack._internal.core.models.instances import (
2526
InstanceAvailability,
2627
InstanceConfiguration,
@@ -31,7 +32,6 @@
3132
from dstack._internal.core.models.resources import Memory, Range
3233
from dstack._internal.core.models.runs import JobProvisioningData, Requirements
3334
from dstack._internal.utils.logging import get_logger
34-
from dstack._internal.utils.ssh import get_public_key_fingerprint
3535

3636
logger = get_logger("verda.compute")
3737

@@ -101,54 +101,74 @@ def create_instance(
101101
instance_config, max_length=MAX_INSTANCE_NAME_LEN
102102
)
103103
public_keys = instance_config.get_public_keys()
104-
ssh_ids = []
105-
for ssh_public_key in public_keys:
106-
ssh_ids.append(
107-
# verda allows you to use the same name
108-
_get_or_create_ssh_key(
109-
client=self.client,
110-
name=f"dstack-{instance_config.instance_name}.key",
111-
public_key=ssh_public_key,
104+
ssh_ids: List[str] = []
105+
startup_script_id: Optional[str] = None
106+
try:
107+
for idx, ssh_public_key in enumerate(public_keys):
108+
ssh_ids.append(
109+
_create_ssh_key(
110+
client=self.client,
111+
name=f"dstack-{instance_name}-{idx}.key",
112+
public_key=ssh_public_key,
113+
)
112114
)
113-
)
114115

115-
commands = get_shim_commands()
116-
startup_script = " ".join([" && ".join(commands)])
117-
script_name = f"dstack-{instance_config.instance_name}.sh"
118-
startup_script_ids = _get_or_create_startup_scrpit(
119-
client=self.client,
120-
name=script_name,
121-
script=startup_script,
122-
)
116+
commands = get_shim_commands()
117+
startup_script = " ".join([" && ".join(commands)])
118+
script_name = f"dstack-{instance_name}.sh"
119+
startup_script_id = _create_startup_script(
120+
client=self.client,
121+
name=script_name,
122+
script=startup_script,
123+
)
123124

124-
disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
125-
image_id = _get_vm_image_id(instance_offer)
126-
127-
logger.debug(
128-
"Deploying Verda instance",
129-
{
130-
"instance_type": instance_offer.instance.name,
131-
"ssh_key_ids": ssh_ids,
132-
"startup_script_id": startup_script_ids,
133-
"hostname": instance_name,
134-
"description": instance_name,
135-
"image": image_id,
136-
"disk_size": disk_size,
137-
"location": instance_offer.region,
138-
},
139-
)
140-
instance = _deploy_instance(
141-
client=self.client,
142-
instance_type=instance_offer.instance.name,
143-
ssh_key_ids=ssh_ids,
144-
startup_script_id=startup_script_ids,
145-
hostname=instance_name,
146-
description=instance_name,
147-
image=image_id,
148-
disk_size=disk_size,
149-
is_spot=instance_offer.instance.resources.spot,
150-
location=instance_offer.region,
151-
)
125+
disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
126+
image_id = _get_vm_image_id(instance_offer)
127+
128+
logger.debug(
129+
"Deploying Verda instance",
130+
{
131+
"instance_type": instance_offer.instance.name,
132+
"ssh_key_ids": ssh_ids,
133+
"startup_script_id": startup_script_id,
134+
"hostname": instance_name,
135+
"description": instance_name,
136+
"image": image_id,
137+
"disk_size": disk_size,
138+
"location": instance_offer.region,
139+
},
140+
)
141+
instance = _deploy_instance(
142+
client=self.client,
143+
instance_type=instance_offer.instance.name,
144+
ssh_key_ids=ssh_ids,
145+
startup_script_id=startup_script_id,
146+
hostname=instance_name,
147+
description=instance_name,
148+
image=image_id,
149+
disk_size=disk_size,
150+
is_spot=instance_offer.instance.resources.spot,
151+
location=instance_offer.region,
152+
)
153+
except Exception:
154+
# startup_script_id and ssh_key_ids are per-instance. Ensure no leaks on failures.
155+
try:
156+
_delete_startup_script(self.client, startup_script_id)
157+
except Exception:
158+
logger.warning(
159+
"Failed to cleanup startup script %s after provisioning failure.",
160+
startup_script_id,
161+
exc_info=True,
162+
)
163+
try:
164+
_delete_ssh_keys(self.client, ssh_ids)
165+
except Exception:
166+
logger.warning(
167+
"Failed to cleanup ssh keys %s after provisioning failure.",
168+
ssh_ids,
169+
exc_info=True,
170+
)
171+
raise
152172
return JobProvisioningData(
153173
backend=instance_offer.backend,
154174
instance_type=instance_offer.instance,
@@ -161,12 +181,16 @@ def create_instance(
161181
ssh_port=22,
162182
dockerized=True,
163183
ssh_proxy=None,
164-
backend_data=None,
184+
backend_data=VerdaInstanceBackendData(
185+
startup_script_id=startup_script_id,
186+
ssh_key_ids=ssh_ids,
187+
).json(),
165188
)
166189

167190
def terminate_instance(
168191
self, instance_id: str, region: str, backend_data: Optional[str] = None
169192
):
193+
backend_data_parsed = VerdaInstanceBackendData.load(backend_data)
170194
try:
171195
self.client.instances.action(id_list=[instance_id], action="delete")
172196
except APIException as e:
@@ -175,8 +199,10 @@ def terminate_instance(
175199
"Can't discontinue a discontinued instance",
176200
]:
177201
logger.debug("Skipping instance %s termination. Instance not found.", instance_id)
178-
return
179-
raise
202+
else:
203+
raise
204+
_delete_startup_script(self.client, backend_data_parsed.startup_script_id)
205+
_delete_ssh_keys(self.client, backend_data_parsed.ssh_key_ids)
180206

181207
def update_provisioning_data(
182208
self,
@@ -200,26 +226,86 @@ def _get_vm_image_id(instance_offer: InstanceOfferWithAvailability) -> str:
200226
return "77777777-4f48-4249-82b3-f199fb9b701b"
201227

202228

203-
def _get_or_create_ssh_key(client: VerdaClient, name: str, public_key: str) -> str:
204-
fingerprint = get_public_key_fingerprint(public_key)
205-
keys = client.ssh_keys.get()
206-
found_keys = [key for key in keys if fingerprint == get_public_key_fingerprint(key.public_key)]
207-
if found_keys:
208-
key = found_keys[0]
229+
def _create_ssh_key(client: VerdaClient, name: str, public_key: str) -> str:
230+
try:
231+
key = client.ssh_keys.create(name, public_key)
209232
return key.id
210-
key = client.ssh_keys.create(name, public_key)
211-
return key.id
233+
except APIException as e:
234+
raise BackendError(f"Verda API error while creating SSH key: {e.message}")
212235

213236

214-
def _get_or_create_startup_scrpit(client: VerdaClient, name: str, script: str) -> str:
215-
scripts = client.startup_scripts.get()
216-
found_scripts = [startup_script for startup_script in scripts if script == startup_script]
217-
if found_scripts:
218-
startup_script = found_scripts[0]
237+
def _create_startup_script(client: VerdaClient, name: str, script: str) -> str:
238+
try:
239+
startup_script = client.startup_scripts.create(name, script)
219240
return startup_script.id
241+
except APIException as e:
242+
raise BackendError(f"Verda API error while creating startup script: {e.message}")
220243

221-
startup_script = client.startup_scripts.create(name, script)
222-
return startup_script.id
244+
245+
def _delete_startup_script(client: VerdaClient, startup_script_id: Optional[str]) -> None:
246+
if startup_script_id is None:
247+
return
248+
try:
249+
client.startup_scripts.delete_by_id(startup_script_id)
250+
except APIException as e:
251+
if _is_startup_script_not_found_error(e):
252+
logger.debug(
253+
"Skipping startup script %s deletion. Startup script not found.",
254+
startup_script_id,
255+
)
256+
return
257+
raise
258+
259+
260+
def _delete_ssh_keys(client: VerdaClient, ssh_key_ids: Optional[List[str]]) -> None:
261+
if not ssh_key_ids:
262+
return
263+
for ssh_key_id in ssh_key_ids:
264+
_delete_ssh_key(client, ssh_key_id)
265+
266+
267+
def _delete_ssh_key(client: VerdaClient, ssh_key_id: Optional[str]) -> None:
268+
if ssh_key_id is None:
269+
return
270+
try:
271+
client.ssh_keys.delete_by_id(ssh_key_id)
272+
except APIException as e:
273+
if _is_ssh_key_not_found_error(e):
274+
logger.debug("Skipping ssh key %s deletion. SSH key not found.", ssh_key_id)
275+
return
276+
raise
277+
278+
279+
def _is_ssh_key_not_found_error(error: APIException) -> bool:
280+
code = (error.code or "").lower()
281+
message = (error.message or "").lower()
282+
if code == "not_found":
283+
return True
284+
if code not in {"", "invalid_request"}:
285+
return False
286+
return (
287+
message == "invalid ssh-key id"
288+
or message == "invalid ssh key id"
289+
or message == "not found"
290+
or ("ssh-key id" in message and "invalid" in message)
291+
or ("ssh key id" in message and "invalid" in message)
292+
)
293+
294+
295+
def _is_startup_script_not_found_error(error: APIException) -> bool:
296+
code = (error.code or "").lower()
297+
message = (error.message or "").lower()
298+
if code == "not_found":
299+
return True
300+
if code not in {"", "invalid_request"}:
301+
return False
302+
return (
303+
message == "invalid startup script id"
304+
or message == "invalid script id"
305+
or message == "not found"
306+
or ("startup script id" in message and "invalid" in message)
307+
or ("script id" in message and "invalid" in message)
308+
)
223309

224310

225311
def _get_instance_by_id(
@@ -264,3 +350,14 @@ def _deploy_instance(
264350
raise NoCapacityError(f"Verda API error: {e.message}")
265351

266352
return instance
353+
354+
355+
class VerdaInstanceBackendData(CoreModel):
356+
startup_script_id: Optional[str] = None
357+
ssh_key_ids: Optional[List[str]] = None
358+
359+
@classmethod
360+
def load(cls, raw: Optional[str]) -> "VerdaInstanceBackendData":
361+
if raw is None:
362+
return cls()
363+
return cls.__response__.parse_raw(raw)

0 commit comments

Comments
 (0)