1919 get_offers_disk_modifier ,
2020)
2121from dstack ._internal .core .backends .verda .models import VerdaConfig
22- from dstack ._internal .core .errors import NoCapacityError
22+ from dstack ._internal .core .errors import BackendError , NoCapacityError
2323from dstack ._internal .core .models .backends .base import BackendType
24+ from dstack ._internal .core .models .common import CoreModel
2425from dstack ._internal .core .models .instances import (
2526 InstanceAvailability ,
2627 InstanceConfiguration ,
3132from dstack ._internal .core .models .resources import Memory , Range
3233from dstack ._internal .core .models .runs import JobProvisioningData , Requirements
3334from dstack ._internal .utils .logging import get_logger
34- from dstack ._internal .utils .ssh import get_public_key_fingerprint
3535
3636logger = get_logger ("verda.compute" )
3737
@@ -101,54 +101,74 @@ def create_instance(
101101 instance_config , max_length = MAX_INSTANCE_NAME_LEN
102102 )
103103 public_keys = instance_config .get_public_keys ()
104- ssh_ids = []
105- for ssh_public_key in public_keys :
106- ssh_ids .append (
107- # verda allows you to use the same name
108- _get_or_create_ssh_key (
109- client = self .client ,
110- name = f"dstack-{ instance_config .instance_name } .key" ,
111- public_key = ssh_public_key ,
104+ ssh_ids : List [str ] = []
105+ startup_script_id : Optional [str ] = None
106+ try :
107+ for idx , ssh_public_key in enumerate (public_keys ):
108+ ssh_ids .append (
109+ _create_ssh_key (
110+ client = self .client ,
111+ name = f"dstack-{ instance_name } -{ idx } .key" ,
112+ public_key = ssh_public_key ,
113+ )
112114 )
113- )
114115
115- commands = get_shim_commands ()
116- startup_script = " " .join ([" && " .join (commands )])
117- script_name = f"dstack-{ instance_config . instance_name } .sh"
118- startup_script_ids = _get_or_create_startup_scrpit (
119- client = self .client ,
120- name = script_name ,
121- script = startup_script ,
122- )
116+ commands = get_shim_commands ()
117+ startup_script = " " .join ([" && " .join (commands )])
118+ script_name = f"dstack-{ instance_name } .sh"
119+ startup_script_id = _create_startup_script (
120+ client = self .client ,
121+ name = script_name ,
122+ script = startup_script ,
123+ )
123124
124- disk_size = round (instance_offer .instance .resources .disk .size_mib / 1024 )
125- image_id = _get_vm_image_id (instance_offer )
126-
127- logger .debug (
128- "Deploying Verda instance" ,
129- {
130- "instance_type" : instance_offer .instance .name ,
131- "ssh_key_ids" : ssh_ids ,
132- "startup_script_id" : startup_script_ids ,
133- "hostname" : instance_name ,
134- "description" : instance_name ,
135- "image" : image_id ,
136- "disk_size" : disk_size ,
137- "location" : instance_offer .region ,
138- },
139- )
140- instance = _deploy_instance (
141- client = self .client ,
142- instance_type = instance_offer .instance .name ,
143- ssh_key_ids = ssh_ids ,
144- startup_script_id = startup_script_ids ,
145- hostname = instance_name ,
146- description = instance_name ,
147- image = image_id ,
148- disk_size = disk_size ,
149- is_spot = instance_offer .instance .resources .spot ,
150- location = instance_offer .region ,
151- )
125+ disk_size = round (instance_offer .instance .resources .disk .size_mib / 1024 )
126+ image_id = _get_vm_image_id (instance_offer )
127+
128+ logger .debug (
129+ "Deploying Verda instance" ,
130+ {
131+ "instance_type" : instance_offer .instance .name ,
132+ "ssh_key_ids" : ssh_ids ,
133+ "startup_script_id" : startup_script_id ,
134+ "hostname" : instance_name ,
135+ "description" : instance_name ,
136+ "image" : image_id ,
137+ "disk_size" : disk_size ,
138+ "location" : instance_offer .region ,
139+ },
140+ )
141+ instance = _deploy_instance (
142+ client = self .client ,
143+ instance_type = instance_offer .instance .name ,
144+ ssh_key_ids = ssh_ids ,
145+ startup_script_id = startup_script_id ,
146+ hostname = instance_name ,
147+ description = instance_name ,
148+ image = image_id ,
149+ disk_size = disk_size ,
150+ is_spot = instance_offer .instance .resources .spot ,
151+ location = instance_offer .region ,
152+ )
153+ except Exception :
154+ # startup_script_id and ssh_key_ids are per-instance. Ensure no leaks on failures.
155+ try :
156+ _delete_startup_script (self .client , startup_script_id )
157+ except Exception :
158+ logger .warning (
159+ "Failed to cleanup startup script %s after provisioning failure." ,
160+ startup_script_id ,
161+ exc_info = True ,
162+ )
163+ try :
164+ _delete_ssh_keys (self .client , ssh_ids )
165+ except Exception :
166+ logger .warning (
167+ "Failed to cleanup ssh keys %s after provisioning failure." ,
168+ ssh_ids ,
169+ exc_info = True ,
170+ )
171+ raise
152172 return JobProvisioningData (
153173 backend = instance_offer .backend ,
154174 instance_type = instance_offer .instance ,
@@ -161,12 +181,16 @@ def create_instance(
161181 ssh_port = 22 ,
162182 dockerized = True ,
163183 ssh_proxy = None ,
164- backend_data = None ,
184+ backend_data = VerdaInstanceBackendData (
185+ startup_script_id = startup_script_id ,
186+ ssh_key_ids = ssh_ids ,
187+ ).json (),
165188 )
166189
167190 def terminate_instance (
168191 self , instance_id : str , region : str , backend_data : Optional [str ] = None
169192 ):
193+ backend_data_parsed = VerdaInstanceBackendData .load (backend_data )
170194 try :
171195 self .client .instances .action (id_list = [instance_id ], action = "delete" )
172196 except APIException as e :
@@ -175,8 +199,10 @@ def terminate_instance(
175199 "Can't discontinue a discontinued instance" ,
176200 ]:
177201 logger .debug ("Skipping instance %s termination. Instance not found." , instance_id )
178- return
179- raise
202+ else :
203+ raise
204+ _delete_startup_script (self .client , backend_data_parsed .startup_script_id )
205+ _delete_ssh_keys (self .client , backend_data_parsed .ssh_key_ids )
180206
181207 def update_provisioning_data (
182208 self ,
@@ -200,26 +226,86 @@ def _get_vm_image_id(instance_offer: InstanceOfferWithAvailability) -> str:
200226 return "77777777-4f48-4249-82b3-f199fb9b701b"
201227
202228
203- def _get_or_create_ssh_key (client : VerdaClient , name : str , public_key : str ) -> str :
204- fingerprint = get_public_key_fingerprint (public_key )
205- keys = client .ssh_keys .get ()
206- found_keys = [key for key in keys if fingerprint == get_public_key_fingerprint (key .public_key )]
207- if found_keys :
208- key = found_keys [0 ]
229+ def _create_ssh_key (client : VerdaClient , name : str , public_key : str ) -> str :
230+ try :
231+ key = client .ssh_keys .create (name , public_key )
209232 return key .id
210- key = client . ssh_keys . create ( name , public_key )
211- return key . id
233+ except APIException as e :
234+ raise BackendError ( f"Verda API error while creating SSH key: { e . message } " )
212235
213236
214- def _get_or_create_startup_scrpit (client : VerdaClient , name : str , script : str ) -> str :
215- scripts = client .startup_scripts .get ()
216- found_scripts = [startup_script for startup_script in scripts if script == startup_script ]
217- if found_scripts :
218- startup_script = found_scripts [0 ]
237+ def _create_startup_script (client : VerdaClient , name : str , script : str ) -> str :
238+ try :
239+ startup_script = client .startup_scripts .create (name , script )
219240 return startup_script .id
241+ except APIException as e :
242+ raise BackendError (f"Verda API error while creating startup script: { e .message } " )
220243
221- startup_script = client .startup_scripts .create (name , script )
222- return startup_script .id
244+
245+ def _delete_startup_script (client : VerdaClient , startup_script_id : Optional [str ]) -> None :
246+ if startup_script_id is None :
247+ return
248+ try :
249+ client .startup_scripts .delete_by_id (startup_script_id )
250+ except APIException as e :
251+ if _is_startup_script_not_found_error (e ):
252+ logger .debug (
253+ "Skipping startup script %s deletion. Startup script not found." ,
254+ startup_script_id ,
255+ )
256+ return
257+ raise
258+
259+
260+ def _delete_ssh_keys (client : VerdaClient , ssh_key_ids : Optional [List [str ]]) -> None :
261+ if not ssh_key_ids :
262+ return
263+ for ssh_key_id in ssh_key_ids :
264+ _delete_ssh_key (client , ssh_key_id )
265+
266+
267+ def _delete_ssh_key (client : VerdaClient , ssh_key_id : Optional [str ]) -> None :
268+ if ssh_key_id is None :
269+ return
270+ try :
271+ client .ssh_keys .delete_by_id (ssh_key_id )
272+ except APIException as e :
273+ if _is_ssh_key_not_found_error (e ):
274+ logger .debug ("Skipping ssh key %s deletion. SSH key not found." , ssh_key_id )
275+ return
276+ raise
277+
278+
279+ def _is_ssh_key_not_found_error (error : APIException ) -> bool :
280+ code = (error .code or "" ).lower ()
281+ message = (error .message or "" ).lower ()
282+ if code == "not_found" :
283+ return True
284+ if code not in {"" , "invalid_request" }:
285+ return False
286+ return (
287+ message == "invalid ssh-key id"
288+ or message == "invalid ssh key id"
289+ or message == "not found"
290+ or ("ssh-key id" in message and "invalid" in message )
291+ or ("ssh key id" in message and "invalid" in message )
292+ )
293+
294+
295+ def _is_startup_script_not_found_error (error : APIException ) -> bool :
296+ code = (error .code or "" ).lower ()
297+ message = (error .message or "" ).lower ()
298+ if code == "not_found" :
299+ return True
300+ if code not in {"" , "invalid_request" }:
301+ return False
302+ return (
303+ message == "invalid startup script id"
304+ or message == "invalid script id"
305+ or message == "not found"
306+ or ("startup script id" in message and "invalid" in message )
307+ or ("script id" in message and "invalid" in message )
308+ )
223309
224310
225311def _get_instance_by_id (
@@ -264,3 +350,14 @@ def _deploy_instance(
264350 raise NoCapacityError (f"Verda API error: { e .message } " )
265351
266352 return instance
353+
354+
355+ class VerdaInstanceBackendData (CoreModel ):
356+ startup_script_id : Optional [str ] = None
357+ ssh_key_ids : Optional [List [str ]] = None
358+
359+ @classmethod
360+ def load (cls , raw : Optional [str ]) -> "VerdaInstanceBackendData" :
361+ if raw is None :
362+ return cls ()
363+ return cls .__response__ .parse_raw (raw )
0 commit comments