Skip to content

Commit a3562cd

Browse files
Add tpu option to GCP backend to gate TPU provisioning
TPU offers are now excluded by default for GCP backends. To enable TPU provisioning, set `tpu: true` in the backend config. This follows the same pattern as RunPod's `community_cloud` option. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent f934bb4 commit a3562cd

4 files changed

Lines changed: 130 additions & 4 deletions

File tree

docs/docs/concepts/backends.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,29 @@ gcloud projects list --format="json(projectId)"
631631
Using private subnets assumes that both the `dstack` server and users can access the configured VPC's private subnets.
632632
Additionally, [Cloud NAT](https://cloud.google.com/nat/docs/overview) must be configured to provide access to external resources for provisioned instances.
633633

634+
??? info "TPU"
635+
By default, `dstack` does not include TPU offers.
636+
To enable TPU provisioning, set `tpu` to `true` in the backend settings.
637+
638+
<div editor-title="~/.dstack/server/config.yml">
639+
640+
```yaml
641+
projects:
642+
- name: main
643+
backends:
644+
- type: gcp
645+
project_id: gcp-project-id
646+
creds:
647+
type: default
648+
649+
tpu: true
650+
```
651+
652+
</div>
653+
654+
Make sure the required TPU permissions and the `serviceAccountUser` role are granted
655+
(see "Required permissions" above).
656+
634657
### Lambda
635658

636659
Log into your [Lambda Cloud](https://lambdalabs.com/service/gpu-cloud) account, click API keys in the sidebar, and then click the `Generate API key`

src/dstack/_internal/core/backends/gcp/compute.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability
135135
regions = get_or_error(self.config.regions)
136136
offers = get_catalog_offers(
137137
backend=BackendType.GCP,
138-
extra_filter=_supported_instances_and_zones(regions),
138+
extra_filter=_supported_instances_and_zones(regions, tpu=self.config.allow_tpu),
139139
)
140140
quotas: Dict[str, Dict[str, float]] = defaultdict(dict)
141141
for region in self.regions_client.list(project=self.config.project_id):
@@ -989,14 +989,17 @@ def _find_reservation(self, configured_name: str) -> dict[str, compute_v1.Reserv
989989

990990
def _supported_instances_and_zones(
991991
regions: List[str],
992+
tpu: bool = False,
992993
) -> Optional[Callable[[InstanceOffer], bool]]:
993994
def _filter(offer: InstanceOffer) -> bool:
994995
# strip zone
995996
if offer.region[:-2] not in regions:
996997
return False
997-
# remove multi-host TPUs for initial release
998-
if _is_tpu(offer.instance.name) and not _is_single_host_tpu(offer.instance.name):
999-
return False
998+
if _is_tpu(offer.instance.name):
999+
if not tpu:
1000+
return False
1001+
if not _is_single_host_tpu(offer.instance.name):
1002+
return False
10001003
for family in [
10011004
"m4-",
10021005
"c4-",

src/dstack/_internal/core/backends/gcp/models.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from dstack._internal.core.backends.base.models import fill_data
66
from dstack._internal.core.models.common import CoreModel
77

8+
GCP_TPU_DEFAULT = False
9+
810

911
class GCPServiceAccountCreds(CoreModel):
1012
type: Annotated[Literal["service_account"], Field(description="The type of credentials")] = (
@@ -89,6 +91,15 @@ class GCPBackendConfig(CoreModel):
8991
description="The tags (labels) that will be assigned to resources created by `dstack`"
9092
),
9193
] = None
94+
tpu: Annotated[
95+
Optional[bool],
96+
Field(
97+
description=(
98+
"Whether TPU offers can be used for provisioning."
99+
f" Defaults to `{str(GCP_TPU_DEFAULT).lower()}`"
100+
)
101+
),
102+
] = None
92103
preview_features: Annotated[
93104
Optional[List[Literal["g4"]]],
94105
Field(
@@ -143,6 +154,12 @@ class GCPStoredConfig(GCPBackendConfig):
143154
class GCPConfig(GCPStoredConfig):
144155
creds: AnyGCPCreds
145156

157+
@property
158+
def allow_tpu(self) -> bool:
159+
if self.tpu is not None:
160+
return self.tpu
161+
return GCP_TPU_DEFAULT
162+
146163
@property
147164
def allocate_public_ips(self) -> bool:
148165
if self.public_ips is not None:
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from dstack._internal.core.backends.gcp.compute import _supported_instances_and_zones
2+
from dstack._internal.core.backends.gcp.models import GCPConfig, GCPDefaultCreds
3+
from dstack._internal.core.models.backends.base import BackendType
4+
from dstack._internal.core.models.instances import (
5+
Gpu,
6+
InstanceOffer,
7+
InstanceType,
8+
Resources,
9+
)
10+
11+
12+
def _make_offer(instance_name: str, region: str = "us-central1-a", gpus=None) -> InstanceOffer:
13+
if gpus is None:
14+
gpus = []
15+
return InstanceOffer(
16+
backend=BackendType.GCP,
17+
instance=InstanceType(
18+
name=instance_name,
19+
resources=Resources(
20+
cpus=8,
21+
memory_mib=32768,
22+
gpus=gpus,
23+
spot=False,
24+
),
25+
),
26+
region=region,
27+
price=1.0,
28+
)
29+
30+
31+
class TestSupportedInstancesAndZones:
32+
def test_filters_tpu_when_disabled(self):
33+
f = _supported_instances_and_zones(["us-central1"], tpu=False)
34+
offer = _make_offer(
35+
"v5litepod-8",
36+
region="us-central1-b",
37+
gpus=[Gpu(name="v5litepod", memory_mib=16384)],
38+
)
39+
assert f(offer) is False
40+
41+
def test_allows_single_host_tpu_when_enabled(self):
42+
f = _supported_instances_and_zones(["us-central1"], tpu=True)
43+
offer = _make_offer(
44+
"v5litepod-8",
45+
region="us-central1-b",
46+
gpus=[Gpu(name="v5litepod", memory_mib=16384)],
47+
)
48+
assert f(offer) is True
49+
50+
def test_filters_multi_host_tpu_when_enabled(self):
51+
f = _supported_instances_and_zones(["us-central1"], tpu=True)
52+
offer = _make_offer(
53+
"v5litepod-16",
54+
region="us-central1-b",
55+
gpus=[Gpu(name="v5litepod", memory_mib=16384)],
56+
)
57+
assert f(offer) is False
58+
59+
def test_allows_gpu_instances_regardless_of_tpu_flag(self):
60+
f = _supported_instances_and_zones(["us-central1"], tpu=False)
61+
offer = _make_offer(
62+
"a2-highgpu-1g",
63+
region="us-central1-b",
64+
gpus=[Gpu(name="A100", memory_mib=40960)],
65+
)
66+
assert f(offer) is True
67+
68+
69+
class TestGCPConfigAllowTpu:
70+
def _make_config(self, tpu=None) -> GCPConfig:
71+
return GCPConfig(
72+
project_id="test-project",
73+
creds=GCPDefaultCreds(),
74+
tpu=tpu,
75+
)
76+
77+
def test_default(self):
78+
config = self._make_config(tpu=None)
79+
assert config.allow_tpu is False
80+
81+
def test_explicit_true(self):
82+
config = self._make_config(tpu=True)
83+
assert config.allow_tpu is True

0 commit comments

Comments
 (0)