Skip to content

Commit f7ef485

Browse files
authored
Support GCP A4 clusters (#3142)
This commit implements provisioning GCP A4 clusters with high-performance RoCE networking. ```shell > dstack fleet FLEET INSTANCE BACKEND RESOURCES PRICE STATUS CREATED gpu 0 gcp (us-west2) cpu=224 mem=3968GB disk=100GB B200:180GB:8 (spot) $51.552 idle 21 mins ago 1 gcp (us-west2) cpu=224 mem=3968GB disk=100GB B200:180GB:8 (spot) $51.552 idle 17 mins ago ``` To enable high-performance networking, users need to create the [appropriate networks](https://cloud.google.com/ai-hypercomputer/docs/create/create-vm#setup-network) and configure them in the backend settings. ```yaml projects: - name: main backends: - type: gcp project_id: my-project creds: type: default vpc_name: my-vpc-0 # regular, 1 subnet extra_vpcs: - my-vpc-1 # regular, 1 subnet roce_vpcs: - my-vpc-mrdma # RoCE profile, 8 subnets ``` Then apply a fleet configuration. ```yaml type: fleet nodes: 2 placement: cluster availability_zones: [us-west2-c] backends: [gcp] resources: gpu: 8:b200 ``` Each instance in the cluster will then have 10 network interfaces: - 1 regular interface in the main VPC (`default` or the one configured in `vpc_name`). - 1 regular interface in a VPC configured in `extra_vpcs`. - 8 RDMA interfaces in the VPC configured in `roce_vpcs`. Additionally, this commit optimizes the fetching and caching of subnets, so that they are fetched from the API only once, and not separately for each item in `extra_vpcs`. For some instance types, this reduces the number of API requests from 9 to 1, which cuts about 16 seconds from each offer provisioning attempt.
1 parent 85faee6 commit f7ef485

File tree

4 files changed

+118
-46
lines changed

4 files changed

+118
-46
lines changed

src/dstack/_internal/core/backends/gcp/compute.py

Lines changed: 70 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ def __init__(self, config: GCPConfig):
111111
self.resource_policies_client = compute_v1.ResourcePoliciesClient(
112112
credentials=self.credentials
113113
)
114-
self._extra_subnets_cache_lock = threading.Lock()
115-
self._extra_subnets_cache = TTLCache(maxsize=30, ttl=60)
114+
self._usable_subnets_cache_lock = threading.Lock()
115+
self._usable_subnets_cache = TTLCache(maxsize=1, ttl=120)
116116

117117
def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]:
118118
regions = get_or_error(self.config.regions)
@@ -203,12 +203,12 @@ def create_instance(
203203
disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
204204
# Choose any usable subnet in a VPC.
205205
# Configuring a specific subnet per region is not supported yet.
206-
subnetwork = _get_vpc_subnet(
207-
subnetworks_client=self.subnetworks_client,
208-
config=self.config,
206+
subnetwork = self._get_vpc_subnet(instance_offer.region)
207+
extra_subnets = self._get_extra_subnets(
209208
region=instance_offer.region,
209+
instance_type_name=instance_offer.instance.name,
210210
)
211-
extra_subnets = self._get_extra_subnets(
211+
roce_subnets = self._get_roce_subnets(
212212
region=instance_offer.region,
213213
instance_type_name=instance_offer.instance.name,
214214
)
@@ -330,6 +330,7 @@ def create_instance(
330330
network=self.config.vpc_resource_name,
331331
subnetwork=subnetwork,
332332
extra_subnetworks=extra_subnets,
333+
roce_subnetworks=roce_subnets,
333334
allocate_public_ip=allocate_public_ip,
334335
placement_policy=placement_policy,
335336
)
@@ -339,6 +340,13 @@ def create_instance(
339340
# If the request succeeds, we'll probably timeout and update_provisioning_data() will get hostname.
340341
operation = self.instances_client.insert(request=request)
341342
gcp_resources.wait_for_extended_operation(operation, timeout=30)
343+
except google.api_core.exceptions.BadRequest as e:
344+
if "Network profile only allows resource creation in location" in e.message:
345+
# A hack to find the correct RoCE VPC zone by trial and error.
346+
# Could be better to find it via the API.
347+
logger.debug("Got GCP error when provisioning a VM: %s", e)
348+
continue
349+
raise
342350
except (
343351
google.api_core.exceptions.ServiceUnavailable,
344352
google.api_core.exceptions.NotFound,
@@ -487,11 +495,7 @@ def create_gateway(
487495
)
488496
# Choose any usable subnet in a VPC.
489497
# Configuring a specific subnet per region is not supported yet.
490-
subnetwork = _get_vpc_subnet(
491-
subnetworks_client=self.subnetworks_client,
492-
config=self.config,
493-
region=configuration.region,
494-
)
498+
subnetwork = self._get_vpc_subnet(configuration.region)
495499

496500
labels = {
497501
"owner": "dstack",
@@ -793,10 +797,6 @@ def detach_volume(
793797
instance_id,
794798
)
795799

796-
@cachedmethod(
797-
cache=lambda self: self._extra_subnets_cache,
798-
lock=lambda self: self._extra_subnets_cache_lock,
799-
)
800800
def _get_extra_subnets(
801801
self,
802802
region: str,
@@ -808,15 +808,16 @@ def _get_extra_subnets(
808808
subnets_num = 8
809809
elif instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]:
810810
subnets_num = 4
811+
elif instance_type_name == "a4-highgpu-8g":
812+
subnets_num = 1 # 1 main + 1 extra + 8 RoCE
811813
else:
812814
return []
813815
extra_subnets = []
814816
for vpc_name in self.config.extra_vpcs[:subnets_num]:
815817
subnet = gcp_resources.get_vpc_subnet_or_error(
816-
subnetworks_client=self.subnetworks_client,
817-
vpc_project_id=self.config.vpc_project_id or self.config.project_id,
818818
vpc_name=vpc_name,
819819
region=region,
820+
usable_subnets=self._list_usable_subnets(),
820821
)
821822
vpc_resource_name = gcp_resources.vpc_name_to_vpc_resource_name(
822823
project_id=self.config.vpc_project_id or self.config.project_id,
@@ -825,6 +826,58 @@ def _get_extra_subnets(
825826
extra_subnets.append((vpc_resource_name, subnet))
826827
return extra_subnets
827828

829+
def _get_roce_subnets(
830+
self,
831+
region: str,
832+
instance_type_name: str,
833+
) -> List[Tuple[str, str]]:
834+
if not self.config.roce_vpcs:
835+
return []
836+
if instance_type_name == "a4-highgpu-8g":
837+
nics_num = 8
838+
else:
839+
return []
840+
roce_vpc = self.config.roce_vpcs[0] # roce_vpcs is validated to have at most 1 item
841+
subnets = gcp_resources.get_vpc_subnets(
842+
vpc_name=roce_vpc,
843+
region=region,
844+
usable_subnets=self._list_usable_subnets(),
845+
)
846+
if len(subnets) < nics_num:
847+
raise ComputeError(
848+
f"{instance_type_name} requires {nics_num} RoCE subnets,"
849+
f" but only {len(subnets)} are available in VPC {roce_vpc}"
850+
)
851+
vpc_resource_name = gcp_resources.vpc_name_to_vpc_resource_name(
852+
project_id=self.config.vpc_project_id or self.config.project_id,
853+
vpc_name=roce_vpc,
854+
)
855+
nic_subnets = []
856+
for subnet in subnets[:nics_num]:
857+
nic_subnets.append((vpc_resource_name, subnet))
858+
return nic_subnets
859+
860+
@cachedmethod(
861+
cache=lambda self: self._usable_subnets_cache,
862+
lock=lambda self: self._usable_subnets_cache_lock,
863+
)
864+
def _list_usable_subnets(self) -> list[compute_v1.UsableSubnetwork]:
865+
# To avoid hitting the `ListUsable requests per minute` system limit, we fetch all subnets
866+
# at once and cache them
867+
return gcp_resources.list_project_usable_subnets(
868+
subnetworks_client=self.subnetworks_client,
869+
project_id=self.config.vpc_project_id or self.config.project_id,
870+
)
871+
872+
def _get_vpc_subnet(self, region: str) -> Optional[str]:
873+
if self.config.vpc_name is None:
874+
return None
875+
return gcp_resources.get_vpc_subnet_or_error(
876+
vpc_name=self.config.vpc_name,
877+
region=region,
878+
usable_subnets=self._list_usable_subnets(),
879+
)
880+
828881

829882
def _supported_instances_and_zones(
830883
regions: List[str],
@@ -889,21 +942,6 @@ def _unique_instance_name(instance: InstanceType) -> str:
889942
return f"{name}-{gpu.name}-{gpu.memory_mib}"
890943

891944

892-
def _get_vpc_subnet(
893-
subnetworks_client: compute_v1.SubnetworksClient,
894-
config: GCPConfig,
895-
region: str,
896-
) -> Optional[str]:
897-
if config.vpc_name is None:
898-
return None
899-
return gcp_resources.get_vpc_subnet_or_error(
900-
subnetworks_client=subnetworks_client,
901-
vpc_project_id=config.vpc_project_id or config.project_id,
902-
vpc_name=config.vpc_name,
903-
region=region,
904-
)
905-
906-
907945
@dataclass
908946
class GCPImage:
909947
id: str

src/dstack/_internal/core/backends/gcp/configurator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,5 +202,5 @@ def _check_config_vpc(
202202
)
203203
except BackendError as e:
204204
raise ServerClientError(e.args[0])
205-
# Not checking config.extra_vpc so that users are not required to configure subnets for all regions
205+
# Not checking config.extra_vpcs and config.roce_vpcs so that users are not required to configure subnets for all regions
206206
# but only for regions they intend to use. Validation will be done on provisioning.

src/dstack/_internal/core/backends/gcp/models.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,24 @@ class GCPBackendConfig(CoreModel):
4141
Optional[List[str]],
4242
Field(
4343
description=(
44-
"The names of additional VPCs used for GPUDirect. Specify eight VPCs to maximize bandwidth."
44+
"The names of additional VPCs used for multi-NIC instances, such as those that support GPUDirect."
45+
" Specify eight VPCs to maximize bandwidth in clusters with eight-GPU instances."
4546
" Each VPC must have a subnet and a firewall rule allowing internal traffic across all subnets"
4647
)
4748
),
4849
] = None
50+
roce_vpcs: Annotated[
51+
Optional[List[str]],
52+
Field(
53+
description=(
54+
"The names of additional VPCs with the RoCE network profile."
55+
" Used for RDMA on GPU instances that support the MRDMA interface type."
56+
" A VPC should have eight subnets to maximize the bandwidth in clusters"
57+
" with eight-GPU instances."
58+
),
59+
max_items=1, # The currently supported instance types only need one VPC with eight subnets.
60+
),
61+
] = None
4962
vpc_project_id: Annotated[
5063
Optional[str],
5164
Field(description="The shared VPC hosted project ID. Required for shared VPC only"),

src/dstack/_internal/core/backends/gcp/resources.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@ def check_vpc(
5959
)
6060
for region in regions:
6161
get_vpc_subnet_or_error(
62-
subnetworks_client=subnetworks_client,
63-
vpc_project_id=vpc_project_id,
6462
vpc_name=vpc_name,
6563
region=region,
6664
usable_subnets=usable_subnets,
@@ -122,6 +120,7 @@ def create_instance_struct(
122120
network: str = "global/networks/default",
123121
subnetwork: Optional[str] = None,
124122
extra_subnetworks: Optional[List[Tuple[str, str]]] = None,
123+
roce_subnetworks: Optional[List[Tuple[str, str]]] = None,
125124
allocate_public_ip: bool = True,
126125
placement_policy: Optional[str] = None,
127126
) -> compute_v1.Instance:
@@ -133,6 +132,7 @@ def create_instance_struct(
133132
subnetwork=subnetwork,
134133
allocate_public_ip=allocate_public_ip,
135134
extra_subnetworks=extra_subnetworks,
135+
roce_subnetworks=roce_subnetworks,
136136
)
137137

138138
disk = compute_v1.AttachedDisk()
@@ -195,6 +195,7 @@ def _get_network_interfaces(
195195
subnetwork: Optional[str],
196196
allocate_public_ip: bool,
197197
extra_subnetworks: Optional[List[Tuple[str, str]]],
198+
roce_subnetworks: Optional[List[Tuple[str, str]]],
198199
) -> List[compute_v1.NetworkInterface]:
199200
network_interface = compute_v1.NetworkInterface()
200201
network_interface.network = network
@@ -222,6 +223,14 @@ def _get_network_interfaces(
222223
nic_type=compute_v1.NetworkInterface.NicType.GVNIC.name,
223224
)
224225
)
226+
for network, subnetwork in roce_subnetworks or []:
227+
network_interfaces.append(
228+
compute_v1.NetworkInterface(
229+
network=network,
230+
subnetwork=subnetwork,
231+
nic_type=compute_v1.NetworkInterface.NicType.MRDMA.name,
232+
)
233+
)
225234
return network_interfaces
226235

227236

@@ -234,29 +243,41 @@ def list_project_usable_subnets(
234243

235244

236245
def get_vpc_subnet_or_error(
237-
subnetworks_client: compute_v1.SubnetworksClient,
238-
vpc_project_id: str,
239246
vpc_name: str,
240247
region: str,
241-
usable_subnets: Optional[List[compute_v1.UsableSubnetwork]] = None,
248+
usable_subnets: list[compute_v1.UsableSubnetwork],
242249
) -> str:
243250
"""
244251
Returns resource name of any usable subnet in a given VPC
245252
(e.g. "projects/example-project/regions/europe-west4/subnetworks/example-subnet")
246253
"""
247-
if usable_subnets is None:
248-
usable_subnets = list_project_usable_subnets(subnetworks_client, vpc_project_id)
254+
vpc_subnets = get_vpc_subnets(vpc_name, region, usable_subnets)
255+
if vpc_subnets:
256+
return vpc_subnets[0]
257+
raise ComputeError(
258+
f"No usable subnetwork found in region {region} for VPC {vpc_name}."
259+
f" Ensure that VPC {vpc_name} exists and has usable subnetworks."
260+
)
261+
262+
263+
def get_vpc_subnets(
264+
vpc_name: str,
265+
region: str,
266+
usable_subnets: list[compute_v1.UsableSubnetwork],
267+
) -> list[str]:
268+
"""
269+
Returns resource names of all usable subnets in a given VPC
270+
(e.g. ["projects/example-project/regions/europe-west4/subnetworks/example-subnet"])
271+
"""
272+
result = []
249273
for subnet in usable_subnets:
250274
network_name = subnet.network.split("/")[-1]
251275
subnet_url = subnet.subnetwork
252276
subnet_resource_name = remove_prefix(subnet_url, "https://www.googleapis.com/compute/v1/")
253277
subnet_region = subnet_resource_name.split("/")[3]
254278
if network_name == vpc_name and subnet_region == region:
255-
return subnet_resource_name
256-
raise ComputeError(
257-
f"No usable subnetwork found in region {region} for VPC {vpc_name} in project {vpc_project_id}."
258-
f" Ensure that VPC {vpc_name} exists and has usable subnetworks."
259-
)
279+
result.append(subnet_resource_name)
280+
return result
260281

261282

262283
def create_runner_firewall_rules(

0 commit comments

Comments
 (0)