Skip to content

Commit fb4a4da

Browse files
authored
Optimize create instance on AWS (#3556)
* Implement update_provisioning_data() for AWS * Fix catching and retrying ec2_client.cancel_spot_instance_requests() * Handle ec2_client.cancel_spot_instance_requests() error * Type check backends/aws * Fix log level for Requesting instance offers
1 parent cfea44e commit fb4a4da

File tree

4 files changed

+97
-49
lines changed

4 files changed

+97
-49
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ include = [
101101
"src/dstack/plugins",
102102
"src/dstack/_internal/server",
103103
"src/dstack/_internal/core/services",
104+
"src/dstack/_internal/core/backends/aws",
104105
"src/dstack/_internal/core/backends/kubernetes",
105106
"src/dstack/_internal/core/backends/runpod",
106107
"src/dstack/_internal/cli/services/configurators",

src/dstack/_internal/core/backends/aws/compute.py

Lines changed: 93 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
NoCapacityError,
4949
PlacementGroupInUseError,
5050
PlacementGroupNotSupportedError,
51+
ProvisioningError,
5152
)
5253
from dstack._internal.core.models.backends.base import BackendType
5354
from dstack._internal.core.models.common import CoreModel
@@ -291,35 +292,35 @@ def create_instance(
291292
}
292293
if reservation.get("ReservationType") == "capacity-block":
293294
is_capacity_block = True
294-
295295
except botocore.exceptions.ClientError as e:
296296
logger.warning("Got botocore.exceptions.ClientError: %s", e)
297297
raise NoCapacityError()
298+
298299
tried_zones = set()
299300
for subnet_id, az in subnet_id_to_az_map.items():
300301
if az in tried_zones:
301302
continue
302303
tried_zones.add(az)
304+
logger.debug("Trying provisioning %s in %s", instance_offer.instance.name, az)
305+
image_id, username = self._get_image_id_and_username(
306+
ec2_client=ec2_client,
307+
region=instance_offer.region,
308+
gpu_name=(
309+
instance_offer.instance.resources.gpus[0].name
310+
if len(instance_offer.instance.resources.gpus) > 0
311+
else None
312+
),
313+
instance_type=instance_offer.instance.name,
314+
image_config=self.config.os_images,
315+
)
316+
security_group_id = self._create_security_group(
317+
ec2_client=ec2_client,
318+
region=instance_offer.region,
319+
project_id=project_name,
320+
vpc_id=vpc_id,
321+
)
303322
try:
304-
logger.debug("Trying provisioning %s in %s", instance_offer.instance.name, az)
305-
image_id, username = self._get_image_id_and_username(
306-
ec2_client=ec2_client,
307-
region=instance_offer.region,
308-
gpu_name=(
309-
instance_offer.instance.resources.gpus[0].name
310-
if len(instance_offer.instance.resources.gpus) > 0
311-
else None
312-
),
313-
instance_type=instance_offer.instance.name,
314-
image_config=self.config.os_images,
315-
)
316-
security_group_id = self._create_security_group(
317-
ec2_client=ec2_client,
318-
region=instance_offer.region,
319-
project_id=project_name,
320-
vpc_id=vpc_id,
321-
)
322-
response = ec2_resource.create_instances(
323+
response = ec2_resource.create_instances( # pyright: ignore[reportAttributeAccessIssue]
323324
**aws_resources.create_instances_struct(
324325
disk_size=disk_size,
325326
image_id=image_id,
@@ -343,39 +344,85 @@ def create_instance(
343344
is_capacity_block=is_capacity_block,
344345
)
345346
)
346-
instance = response[0]
347-
instance.wait_until_running()
348-
instance.reload() # populate instance.public_ip_address
349-
if instance_offer.instance.resources.spot: # it will not terminate the instance
350-
ec2_client.cancel_spot_instance_requests(
351-
SpotInstanceRequestIds=[instance.spot_instance_request_id]
352-
)
353-
hostname = _get_instance_ip(instance, allocate_public_ip)
354-
return JobProvisioningData(
355-
backend=instance_offer.backend,
356-
instance_type=instance_offer.instance,
357-
instance_id=instance.instance_id,
358-
public_ip_enabled=allocate_public_ip,
359-
hostname=hostname,
360-
internal_ip=instance.private_ip_address,
361-
region=instance_offer.region,
362-
availability_zone=az,
363-
reservation=instance.capacity_reservation_id,
364-
price=instance_offer.price,
365-
username=username,
366-
ssh_port=22,
367-
dockerized=True, # because `dstack-shim` is used
368-
ssh_proxy=None,
369-
backend_data=None,
370-
)
371347
except botocore.exceptions.ClientError as e:
372348
logger.warning("Got botocore.exceptions.ClientError: %s", e)
373349
if e.response["Error"]["Code"] == "InvalidParameterValue":
374350
msg = e.response["Error"].get("Message", "")
375351
raise ComputeError(f"Invalid AWS request: {msg}")
376352
continue
353+
instance = response[0]
354+
if instance_offer.instance.resources.spot:
355+
# it will not terminate the instance
356+
try:
357+
ec2_client.cancel_spot_instance_requests(
358+
SpotInstanceRequestIds=[instance.spot_instance_request_id]
359+
)
360+
except Exception:
361+
logger.exception(
362+
"Failed to cancel spot instance request. The instance will be terminated."
363+
)
364+
self.terminate_instance(
365+
instance_id=instance.instance_id, region=instance_offer.region
366+
)
367+
raise NoCapacityError()
368+
return JobProvisioningData(
369+
backend=instance_offer.backend,
370+
instance_type=instance_offer.instance,
371+
instance_id=instance.instance_id,
372+
public_ip_enabled=allocate_public_ip,
373+
hostname=None,
374+
internal_ip=None,
375+
region=instance_offer.region,
376+
availability_zone=az,
377+
reservation=instance.capacity_reservation_id,
378+
price=instance_offer.price,
379+
username=username,
380+
ssh_port=None,
381+
dockerized=True, # because `dstack-shim` is used
382+
ssh_proxy=None,
383+
backend_data=None,
384+
)
377385
raise NoCapacityError()
378386

387+
def update_provisioning_data(
388+
self,
389+
provisioning_data: JobProvisioningData,
390+
project_ssh_public_key: str,
391+
project_ssh_private_key: str,
392+
):
393+
ec2_resource = self.session.resource("ec2", region_name=provisioning_data.region)
394+
instance = ec2_resource.Instance(provisioning_data.instance_id) # pyright: ignore[reportAttributeAccessIssue]
395+
try:
396+
instance.load()
397+
except botocore.exceptions.ClientError as e:
398+
if e.response["Error"]["Code"] == "InvalidInstanceID.NotFound":
399+
logger.debug(
400+
"Instance %s not found. Waiting for the instance to appear"
401+
" or to timeout if the instance is manually deleted.",
402+
provisioning_data.instance_id,
403+
)
404+
# Instance may be created but not yet visible to due AWS eventual consistency,
405+
# so we wait instead of failing immediately.
406+
return
407+
raise e
408+
409+
state = instance.state.get("Name")
410+
if state == "pending":
411+
return
412+
if state in [None, "shutting-down", "terminated", "stopping", "stopped"]:
413+
raise ProvisioningError(
414+
f"Failed to get instance IP address. Instance state is {state}."
415+
)
416+
if state != "running":
417+
raise ProvisioningError(
418+
f"Failed to get instance IP address. Unknown instance state {state}."
419+
)
420+
421+
hostname = _get_instance_ip(instance, self.config.allocate_public_ips)
422+
provisioning_data.hostname = hostname
423+
provisioning_data.internal_ip = instance.private_ip_address
424+
provisioning_data.ssh_port = 22
425+
379426
def create_placement_group(
380427
self,
381428
placement_group: PlacementGroup,
@@ -478,7 +525,7 @@ def create_gateway(
478525
allocate_public_ip=configuration.public_ip,
479526
)
480527
try:
481-
response = ec2_resource.create_instances(**instance_struct)
528+
response = ec2_resource.create_instances(**instance_struct) # pyright: ignore[reportAttributeAccessIssue]
482529
except botocore.exceptions.ClientError as e:
483530
msg = f"AWS Error: {e.response['Error']['Code']}"
484531
if e.response["Error"].get("Message"):

src/dstack/_internal/core/backends/base/configurator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, ClassVar, Generic, List, Optional, TypeVar
2+
from typing import Any, ClassVar, Generic, List, NoReturn, Optional, TypeVar
33
from uuid import UUID
44

55
from dstack._internal.core.backends.base.backend import Backend
@@ -110,7 +110,7 @@ def get_backend(self, record: StoredBackendRecord) -> Backend:
110110

111111
def raise_invalid_credentials_error(
112112
fields: Optional[List[List[str]]] = None, details: Optional[Any] = None
113-
):
113+
) -> NoReturn:
114114
msg = BackendInvalidCredentialsError.msg
115115
if details:
116116
msg += f": {details}"

src/dstack/_internal/server/services/backends/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def get_filtered_offers_with_backends(
361361
if not exclude_not_available or offer.availability.is_available():
362362
yield (backend, offer)
363363

364-
logger.info("Requesting instance offers from backends: %s", [b.TYPE.value for b in backends])
364+
logger.debug("Requesting instance offers from backends: %s", [b.TYPE.value for b in backends])
365365
tasks = [run_async(get_offers_tracked, backend, requirements) for backend in backends]
366366
offers_by_backend = []
367367
for backend, result in zip(backends, await asyncio.gather(*tasks, return_exceptions=True)):

0 commit comments

Comments
 (0)