|
| 1 | +from dstack._internal.core.backends.gcp.compute import _supported_instances_and_zones |
| 2 | +from dstack._internal.core.backends.gcp.models import GCPConfig, GCPDefaultCreds |
| 3 | +from dstack._internal.core.models.backends.base import BackendType |
| 4 | +from dstack._internal.core.models.instances import ( |
| 5 | + Gpu, |
| 6 | + InstanceOffer, |
| 7 | + InstanceType, |
| 8 | + Resources, |
| 9 | +) |
| 10 | + |
| 11 | + |
| 12 | +def _make_offer(instance_name: str, region: str = "us-central1-a", gpus=None) -> InstanceOffer: |
| 13 | + if gpus is None: |
| 14 | + gpus = [] |
| 15 | + return InstanceOffer( |
| 16 | + backend=BackendType.GCP, |
| 17 | + instance=InstanceType( |
| 18 | + name=instance_name, |
| 19 | + resources=Resources( |
| 20 | + cpus=8, |
| 21 | + memory_mib=32768, |
| 22 | + gpus=gpus, |
| 23 | + spot=False, |
| 24 | + ), |
| 25 | + ), |
| 26 | + region=region, |
| 27 | + price=1.0, |
| 28 | + ) |
| 29 | + |
| 30 | + |
| 31 | +class TestSupportedInstancesAndZones: |
| 32 | + def test_filters_tpu_when_disabled(self): |
| 33 | + f = _supported_instances_and_zones(["us-central1"], tpu=False) |
| 34 | + offer = _make_offer( |
| 35 | + "v5litepod-8", |
| 36 | + region="us-central1-b", |
| 37 | + gpus=[Gpu(name="v5litepod", memory_mib=16384)], |
| 38 | + ) |
| 39 | + assert f(offer) is False |
| 40 | + |
| 41 | + def test_allows_single_host_tpu_when_enabled(self): |
| 42 | + f = _supported_instances_and_zones(["us-central1"], tpu=True) |
| 43 | + offer = _make_offer( |
| 44 | + "v5litepod-8", |
| 45 | + region="us-central1-b", |
| 46 | + gpus=[Gpu(name="v5litepod", memory_mib=16384)], |
| 47 | + ) |
| 48 | + assert f(offer) is True |
| 49 | + |
| 50 | + def test_filters_multi_host_tpu_when_enabled(self): |
| 51 | + f = _supported_instances_and_zones(["us-central1"], tpu=True) |
| 52 | + offer = _make_offer( |
| 53 | + "v5litepod-16", |
| 54 | + region="us-central1-b", |
| 55 | + gpus=[Gpu(name="v5litepod", memory_mib=16384)], |
| 56 | + ) |
| 57 | + assert f(offer) is False |
| 58 | + |
| 59 | + def test_allows_gpu_instances_regardless_of_tpu_flag(self): |
| 60 | + f = _supported_instances_and_zones(["us-central1"], tpu=False) |
| 61 | + offer = _make_offer( |
| 62 | + "a2-highgpu-1g", |
| 63 | + region="us-central1-b", |
| 64 | + gpus=[Gpu(name="A100", memory_mib=40960)], |
| 65 | + ) |
| 66 | + assert f(offer) is True |
| 67 | + |
| 68 | + |
| 69 | +class TestGCPConfigAllowTpu: |
| 70 | + def _make_config(self, tpu=None) -> GCPConfig: |
| 71 | + return GCPConfig( |
| 72 | + project_id="test-project", |
| 73 | + creds=GCPDefaultCreds(), |
| 74 | + tpu=tpu, |
| 75 | + ) |
| 76 | + |
| 77 | + def test_default(self): |
| 78 | + config = self._make_config(tpu=None) |
| 79 | + assert config.allow_tpu is False |
| 80 | + |
| 81 | + def test_explicit_true(self): |
| 82 | + config = self._make_config(tpu=True) |
| 83 | + assert config.allow_tpu is True |
0 commit comments