Skip to content

Commit a61cc26

Browse files
peterschmidt85Andrey Cheptsov
andauthored
[Azure] Add support for H100 NVL and H200 VM series; refactor instance creation methods to cleanup failed instances (#3699)
Co-authored-by: Andrey Cheptsov <andrey.cheptsov@github.com>
1 parent 20b5296 commit a61cc26

File tree

1 file changed

+86
-8
lines changed
  • src/dstack/_internal/core/backends/azure

1 file changed

+86
-8
lines changed

src/dstack/_internal/core/backends/azure/compute.py

Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Dict, List, Optional, Tuple
77

88
from azure.core.credentials import TokenCredential
9-
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
9+
from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceNotFoundError
1010
from azure.mgmt import compute as compute_mgmt
1111
from azure.mgmt import network as network_mgmt
1212
from azure.mgmt.compute.models import (
@@ -168,7 +168,7 @@ def create_instance(
168168

169169
# TODO: Support custom availability_zones.
170170
# Currently, VMs are regional, which means they don't have zone info.
171-
vm = _launch_instance(
171+
vm = _create_instance_and_wait(
172172
compute_client=self._compute_client,
173173
subscription_id=self.config.subscription_id,
174174
location=location,
@@ -272,7 +272,7 @@ def create_gateway(
272272
)
273273
tags = azure_resources.filter_invalid_tags(tags)
274274

275-
vm = _launch_instance(
275+
vm = _create_instance_and_wait(
276276
compute_client=self._compute_client,
277277
subscription_id=self.config.subscription_id,
278278
location=configuration.region,
@@ -426,8 +426,10 @@ def get_image_name(self) -> str:
426426
r"ND(\d+)rs_v2", # NDv2-series [8xV100 32GB]
427427
r"NV(\d+)adm?s_A10_v5", # NVadsA10 v5-series [A10]
428428
r"NC(\d+)ads_A100_v4", # NC A100 v4-series [A100 80GB]
429+
r"NC(\d+)adi?s_H100_v5", # NC H100 v5-series [H100 NVL 94GB]
429430
r"ND(\d+)asr_v4", # ND A100 v4-series [8xA100 40GB]
430431
r"ND(\d+)amsr_A100_v4", # NDm A100 v4-series [8xA100 80GB]
432+
r"ND(\d+)isr_H200_v5", # ND H200 v5-series [8xH200 141GB]
431433
]
432434
_SUPPORTED_VM_SERIES_PATTERN = (
433435
"^Standard_(" + "|".join(f"({s})" for s in _SUPPORTED_VM_SERIES_PATTERNS) + ")$"
@@ -508,7 +510,7 @@ def _get_gateway_image_ref() -> ImageReference:
508510
)
509511

510512

511-
def _launch_instance(
513+
def _begin_create_instance(
512514
compute_client: compute_mgmt.ComputeManagementClient,
513515
subscription_id: str,
514516
location: str,
@@ -529,7 +531,8 @@ def _launch_instance(
529531
allocate_public_ip: bool = True,
530532
network_resource_group: Optional[str] = None,
531533
tags: Optional[Dict[str, str]] = None,
532-
) -> VirtualMachine:
534+
):
535+
"""Starts VM creation and returns immediately. The VM is created asynchronously."""
533536
if tags is None:
534537
tags = {}
535538
if network_resource_group is None:
@@ -628,18 +631,93 @@ def _launch_instance(
628631
message = e.error.message if e.error.message is not None else ""
629632
raise NoCapacityError(message)
630633
raise e
631-
vm = poller.result(timeout=600)
634+
return poller
635+
636+
637+
def _create_instance_and_wait(
638+
compute_client: compute_mgmt.ComputeManagementClient,
639+
subscription_id: str,
640+
location: str,
641+
resource_group: str,
642+
network_security_group: str,
643+
network: str,
644+
subnet: str,
645+
managed_identity_name: Optional[str],
646+
managed_identity_resource_group: Optional[str],
647+
image_reference: ImageReference,
648+
vm_size: str,
649+
instance_name: str,
650+
user_data: str,
651+
ssh_pub_keys: List[str],
652+
spot: bool,
653+
disk_size: int,
654+
computer_name: str,
655+
allocate_public_ip: bool = True,
656+
network_resource_group: Optional[str] = None,
657+
tags: Optional[Dict[str, str]] = None,
658+
) -> VirtualMachine:
659+
"""Blocking version used for gateway provisioning where IP is needed immediately."""
660+
poller = _begin_create_instance(
661+
compute_client=compute_client,
662+
subscription_id=subscription_id,
663+
location=location,
664+
resource_group=resource_group,
665+
network_security_group=network_security_group,
666+
network=network,
667+
subnet=subnet,
668+
managed_identity_name=managed_identity_name,
669+
managed_identity_resource_group=managed_identity_resource_group,
670+
image_reference=image_reference,
671+
vm_size=vm_size,
672+
instance_name=instance_name,
673+
user_data=user_data,
674+
ssh_pub_keys=ssh_pub_keys,
675+
spot=spot,
676+
disk_size=disk_size,
677+
computer_name=computer_name,
678+
allocate_public_ip=allocate_public_ip,
679+
network_resource_group=network_resource_group,
680+
tags=tags,
681+
)
682+
try:
683+
vm = poller.result(timeout=600)
684+
except HttpResponseError as e:
685+
# Azure may create a VM resource even when provisioning fails (e.g., AllocationFailed).
686+
# Clean it up to avoid orphan VMs.
687+
logger.warning(
688+
"Instance %s provisioning failed: %s. Cleaning up.",
689+
instance_name,
690+
repr(e),
691+
)
692+
_terminate_instance(
693+
compute_client=compute_client,
694+
resource_group=resource_group,
695+
instance_name=instance_name,
696+
)
697+
if e.error is not None and e.error.code in (
698+
"AllocationFailed",
699+
"OverconstrainedAllocationRequest",
700+
):
701+
raise NoCapacityError(e.error.message or str(e))
702+
raise
632703
if not poller.done():
633704
logger.error(
634-
"Timed out waiting for instance {instance_name} launch. "
635-
"The instance will be terminated."
705+
"Timed out waiting for instance %s launch. The instance will be terminated.",
706+
instance_name,
636707
)
637708
_terminate_instance(
638709
compute_client=compute_client,
639710
resource_group=resource_group,
640711
instance_name=instance_name,
641712
)
642713
raise ComputeError(f"Timed out waiting for instance {instance_name} launch")
714+
if (vm.provisioning_state or "").lower() == "failed":
715+
_terminate_instance(
716+
compute_client=compute_client,
717+
resource_group=resource_group,
718+
instance_name=instance_name,
719+
)
720+
raise NoCapacityError(f"VM {instance_name} provisioning failed")
643721
return vm
644722

645723

0 commit comments

Comments
 (0)