Skip to content

Commit fc9afa9

Browse files
peterschmidt85Andrey Cheptsovjvstme
authored
Verda: make startup script and SSH key lifecycle per-instance with reliable cleanup (#3718)
* Make Verda startup scripts and SSH keys lifecycle symmetric * Fix Verda test imports for Python 3.9 collection * Update src/dstack/_internal/core/backends/verda/compute.py Co-authored-by: jvstme <36324149+jvstme@users.noreply.github.com> * Update src/dstack/_internal/core/backends/verda/compute.py Co-authored-by: jvstme <36324149+jvstme@users.noreply.github.com> * Update src/dstack/_internal/core/backends/verda/compute.py Co-authored-by: jvstme <36324149+jvstme@users.noreply.github.com> * Update src/dstack/_internal/core/backends/verda/compute.py Co-authored-by: jvstme <36324149+jvstme@users.noreply.github.com> * Fix Verda terminate tests for merge-base API args --------- Co-authored-by: Andrey Cheptsov <andrey.cheptsov@github.com> Co-authored-by: jvstme <36324149+jvstme@users.noreply.github.com>
1 parent a325f56 commit fc9afa9

File tree

2 files changed

+636
-65
lines changed

2 files changed

+636
-65
lines changed

src/dstack/_internal/core/backends/verda/compute.py

Lines changed: 158 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
get_offers_disk_modifier,
2020
)
2121
from dstack._internal.core.backends.verda.models import VerdaConfig
22-
from dstack._internal.core.errors import NoCapacityError
22+
from dstack._internal.core.errors import BackendError, NoCapacityError
2323
from dstack._internal.core.models.backends.base import BackendType
24+
from dstack._internal.core.models.common import CoreModel
2425
from dstack._internal.core.models.instances import (
2526
InstanceAvailability,
2627
InstanceConfiguration,
@@ -31,7 +32,6 @@
3132
from dstack._internal.core.models.resources import Memory, Range
3233
from dstack._internal.core.models.runs import JobProvisioningData, Requirements
3334
from dstack._internal.utils.logging import get_logger
34-
from dstack._internal.utils.ssh import get_public_key_fingerprint
3535

3636
logger = get_logger("verda.compute")
3737

@@ -101,54 +101,72 @@ def create_instance(
101101
instance_config, max_length=MAX_INSTANCE_NAME_LEN
102102
)
103103
public_keys = instance_config.get_public_keys()
104-
ssh_ids = []
105-
for ssh_public_key in public_keys:
106-
ssh_ids.append(
107-
# verda allows you to use the same name
108-
_get_or_create_ssh_key(
109-
client=self.client,
110-
name=f"dstack-{instance_config.instance_name}.key",
111-
public_key=ssh_public_key,
104+
ssh_ids: List[str] = []
105+
startup_script_id: Optional[str] = None
106+
try:
107+
for idx, ssh_public_key in enumerate(public_keys):
108+
ssh_ids.append(
109+
_create_ssh_key(
110+
client=self.client,
111+
name=f"{instance_name}-{idx}.key",
112+
public_key=ssh_public_key,
113+
)
112114
)
113-
)
114115

115-
commands = get_shim_commands()
116-
startup_script = " ".join([" && ".join(commands)])
117-
script_name = f"dstack-{instance_config.instance_name}.sh"
118-
startup_script_ids = _get_or_create_startup_scrpit(
119-
client=self.client,
120-
name=script_name,
121-
script=startup_script,
122-
)
116+
commands = get_shim_commands()
117+
startup_script = " ".join([" && ".join(commands)])
118+
script_name = f"{instance_name}.sh"
119+
startup_script_id = _create_startup_script(
120+
client=self.client,
121+
name=script_name,
122+
script=startup_script,
123+
)
123124

124-
disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
125-
image_id = _get_vm_image_id(instance_offer)
126-
127-
logger.debug(
128-
"Deploying Verda instance",
129-
{
130-
"instance_type": instance_offer.instance.name,
131-
"ssh_key_ids": ssh_ids,
132-
"startup_script_id": startup_script_ids,
133-
"hostname": instance_name,
134-
"description": instance_name,
135-
"image": image_id,
136-
"disk_size": disk_size,
137-
"location": instance_offer.region,
138-
},
139-
)
140-
instance = _deploy_instance(
141-
client=self.client,
142-
instance_type=instance_offer.instance.name,
143-
ssh_key_ids=ssh_ids,
144-
startup_script_id=startup_script_ids,
145-
hostname=instance_name,
146-
description=instance_name,
147-
image=image_id,
148-
disk_size=disk_size,
149-
is_spot=instance_offer.instance.resources.spot,
150-
location=instance_offer.region,
151-
)
125+
disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
126+
image_id = _get_vm_image_id(instance_offer)
127+
128+
logger.debug(
129+
"Deploying Verda instance",
130+
{
131+
"instance_type": instance_offer.instance.name,
132+
"ssh_key_ids": ssh_ids,
133+
"startup_script_id": startup_script_id,
134+
"hostname": instance_name,
135+
"description": instance_name,
136+
"image": image_id,
137+
"disk_size": disk_size,
138+
"location": instance_offer.region,
139+
},
140+
)
141+
instance = _deploy_instance(
142+
client=self.client,
143+
instance_type=instance_offer.instance.name,
144+
ssh_key_ids=ssh_ids,
145+
startup_script_id=startup_script_id,
146+
hostname=instance_name,
147+
description=instance_name,
148+
image=image_id,
149+
disk_size=disk_size,
150+
is_spot=instance_offer.instance.resources.spot,
151+
location=instance_offer.region,
152+
)
153+
except Exception:
154+
# startup_script_id and ssh_key_ids are per-instance. Ensure no leaks on failures.
155+
try:
156+
_delete_startup_script(self.client, startup_script_id)
157+
except Exception:
158+
logger.exception(
159+
"Failed to cleanup startup script %s after provisioning failure.",
160+
startup_script_id,
161+
)
162+
try:
163+
_delete_ssh_keys(self.client, ssh_ids)
164+
except Exception:
165+
logger.exception(
166+
"Failed to cleanup ssh keys %s after provisioning failure.",
167+
ssh_ids,
168+
)
169+
raise
152170
return JobProvisioningData(
153171
backend=instance_offer.backend,
154172
instance_type=instance_offer.instance,
@@ -161,12 +179,16 @@ def create_instance(
161179
ssh_port=22,
162180
dockerized=True,
163181
ssh_proxy=None,
164-
backend_data=None,
182+
backend_data=VerdaInstanceBackendData(
183+
startup_script_id=startup_script_id,
184+
ssh_key_ids=ssh_ids,
185+
).json(),
165186
)
166187

167188
def terminate_instance(
168189
self, instance_id: str, region: str, backend_data: Optional[str] = None
169190
):
191+
backend_data_parsed = VerdaInstanceBackendData.load(backend_data)
170192
try:
171193
self.client.instances.action(
172194
id_list=[instance_id],
@@ -179,8 +201,10 @@ def terminate_instance(
179201
"Can't discontinue a discontinued instance",
180202
]:
181203
logger.debug("Skipping instance %s termination. Instance not found.", instance_id)
182-
return
183-
raise
204+
else:
205+
raise
206+
_delete_startup_script(self.client, backend_data_parsed.startup_script_id)
207+
_delete_ssh_keys(self.client, backend_data_parsed.ssh_key_ids)
184208

185209
def update_provisioning_data(
186210
self,
@@ -204,26 +228,84 @@ def _get_vm_image_id(instance_offer: InstanceOfferWithAvailability) -> str:
204228
return "77777777-4f48-4249-82b3-f199fb9b701b"
205229

206230

207-
def _get_or_create_ssh_key(client: VerdaClient, name: str, public_key: str) -> str:
208-
fingerprint = get_public_key_fingerprint(public_key)
209-
keys = client.ssh_keys.get()
210-
found_keys = [key for key in keys if fingerprint == get_public_key_fingerprint(key.public_key)]
211-
if found_keys:
212-
key = found_keys[0]
231+
def _create_ssh_key(client: VerdaClient, name: str, public_key: str) -> str:
232+
try:
233+
key = client.ssh_keys.create(name, public_key)
213234
return key.id
214-
key = client.ssh_keys.create(name, public_key)
215-
return key.id
235+
except APIException as e:
236+
raise BackendError(f"Verda API error while creating SSH key: {e.message}")
216237

217238

218-
def _get_or_create_startup_scrpit(client: VerdaClient, name: str, script: str) -> str:
219-
scripts = client.startup_scripts.get()
220-
found_scripts = [startup_script for startup_script in scripts if script == startup_script]
221-
if found_scripts:
222-
startup_script = found_scripts[0]
239+
def _create_startup_script(client: VerdaClient, name: str, script: str) -> str:
240+
try:
241+
startup_script = client.startup_scripts.create(name, script)
223242
return startup_script.id
243+
except APIException as e:
244+
raise BackendError(f"Verda API error while creating startup script: {e.message}")
224245

225-
startup_script = client.startup_scripts.create(name, script)
226-
return startup_script.id
246+
247+
def _delete_startup_script(client: VerdaClient, startup_script_id: Optional[str]) -> None:
248+
if startup_script_id is None:
249+
return
250+
try:
251+
client.startup_scripts.delete_by_id(startup_script_id)
252+
except APIException as e:
253+
if _is_startup_script_not_found_error(e):
254+
logger.debug(
255+
"Skipping startup script %s deletion. Startup script not found.",
256+
startup_script_id,
257+
)
258+
return
259+
raise
260+
261+
262+
def _delete_ssh_keys(client: VerdaClient, ssh_key_ids: Optional[List[str]]) -> None:
263+
if not ssh_key_ids:
264+
return
265+
for ssh_key_id in ssh_key_ids:
266+
_delete_ssh_key(client, ssh_key_id)
267+
268+
269+
def _delete_ssh_key(client: VerdaClient, ssh_key_id: str) -> None:
270+
try:
271+
client.ssh_keys.delete_by_id(ssh_key_id)
272+
except APIException as e:
273+
if _is_ssh_key_not_found_error(e):
274+
logger.debug("Skipping ssh key %s deletion. SSH key not found.", ssh_key_id)
275+
return
276+
raise
277+
278+
279+
def _is_ssh_key_not_found_error(error: APIException) -> bool:
280+
code = (error.code or "").lower()
281+
message = (error.message or "").lower()
282+
if code == "not_found":
283+
return True
284+
if code not in {"", "invalid_request"}:
285+
return False
286+
return (
287+
message == "invalid ssh-key id"
288+
or message == "invalid ssh key id"
289+
or message == "not found"
290+
or ("ssh-key id" in message and "invalid" in message)
291+
or ("ssh key id" in message and "invalid" in message)
292+
)
293+
294+
295+
def _is_startup_script_not_found_error(error: APIException) -> bool:
296+
code = (error.code or "").lower()
297+
message = (error.message or "").lower()
298+
if code == "not_found":
299+
return True
300+
if code not in {"", "invalid_request"}:
301+
return False
302+
return (
303+
message == "invalid startup script id"
304+
or message == "invalid script id"
305+
or message == "not found"
306+
or ("startup script id" in message and "invalid" in message)
307+
or ("script id" in message and "invalid" in message)
308+
)
227309

228310

229311
def _get_instance_by_id(
@@ -269,3 +351,14 @@ def _deploy_instance(
269351
raise NoCapacityError(f"Verda API error: {e.message}")
270352

271353
return instance
354+
355+
356+
class VerdaInstanceBackendData(CoreModel):
357+
startup_script_id: Optional[str] = None
358+
ssh_key_ids: Optional[List[str]] = None
359+
360+
@classmethod
361+
def load(cls, raw: Optional[str]) -> "VerdaInstanceBackendData":
362+
if raw is None:
363+
return cls()
364+
return cls.__response__.parse_raw(raw)

0 commit comments

Comments
 (0)