Skip to content

Commit 41f5ec3

Browse files
Bihan  RanaBihan  Rana
authored andcommitted
Add DigitalOcean base class with DigitalOceanCloud and AMDDevCloud sub classes
1 parent 8fc6816 commit 41f5ec3

19 files changed

Lines changed: 376 additions & 314 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# This package contains the implementation for the AMDDevCloud backend.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from dstack._internal.core.backends.amddevcloud.compute import AMDDevCloudCompute
2+
from dstack._internal.core.backends.digitalocean_base.backend import BaseDigitalOceanBackend
3+
from dstack._internal.core.backends.digitalocean_base.models import BaseDigitalOceanConfig
4+
from dstack._internal.core.models.backends.base import BackendType
5+
6+
7+
class AMDDevCloudBackend(BaseDigitalOceanBackend):
8+
TYPE = BackendType.AMDDEVCLOUD
9+
COMPUTE_CLASS = AMDDevCloudCompute
10+
11+
def __init__(self, config: BaseDigitalOceanConfig, api_url: str):
12+
self.config = config
13+
self._compute = AMDDevCloudCompute(self.config, api_url=api_url, type=self.TYPE)
14+
15+
def compute(self) -> AMDDevCloudCompute:
16+
return self._compute
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from dstack._internal.core.backends.digitalocean_base.compute import BaseDigitalOceanCompute
2+
3+
4+
class AMDDevCloudCompute(BaseDigitalOceanCompute):
5+
pass
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from dstack._internal.core.backends.amddevcloud.backend import AMDDevCloudBackend
2+
from dstack._internal.core.backends.digitalocean_base.api_client import DigitalOceanAPIClient
3+
from dstack._internal.core.backends.digitalocean_base.backend import BaseDigitalOceanBackend
4+
from dstack._internal.core.backends.digitalocean_base.configurator import (
5+
BaseDigitalOceanConfigurator,
6+
)
7+
from dstack._internal.core.backends.digitalocean_base.models import AnyBaseDigitalOceanCreds
8+
from dstack._internal.core.models.backends.base import (
9+
BackendType,
10+
)
11+
12+
13+
class AMDDevCloudConfigurator(BaseDigitalOceanConfigurator):
14+
TYPE = BackendType.AMDDEVCLOUD
15+
BACKEND_CLASS = AMDDevCloudBackend
16+
API_URL = "https://api-amd.digitalocean.com"
17+
18+
def get_backend(self, record) -> BaseDigitalOceanBackend:
19+
config = self._get_config(record)
20+
return AMDDevCloudBackend(config=config, api_url=self.API_URL)
21+
22+
def _validate_creds(self, creds: AnyBaseDigitalOceanCreds):
23+
api_client = DigitalOceanAPIClient(creds.api_key, self.API_URL)
24+
api_client.validate_api_key()

src/dstack/_internal/core/backends/base/offers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def get_catalog_offers(
3434
provider = backend.value
3535
if backend == BackendType.LAMBDA:
3636
provider = "lambdalabs"
37+
if backend == BackendType.AMDDEVCLOUD:
38+
provider = "digitalocean"
3739
q = requirements_to_query_filter(requirements)
3840
q.provider = [provider]
3941
offers = []

src/dstack/_internal/core/backends/configurators.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55

66
_CONFIGURATOR_CLASSES: List[Type[Configurator]] = []
77

8+
try:
9+
from dstack._internal.core.backends.amddevcloud.configurator import AMDDevCloudConfigurator
10+
11+
_CONFIGURATOR_CLASSES.append(AMDDevCloudConfigurator)
12+
except ImportError:
13+
pass
814

915
try:
1016
from dstack._internal.core.backends.aws.configurator import AWSConfigurator
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
from dstack._internal.core.backends.base.backend import Backend
21
from dstack._internal.core.backends.digitalocean.compute import DigitalOceanCompute
3-
from dstack._internal.core.backends.digitalocean.models import DigitalOceanConfig
2+
from dstack._internal.core.backends.digitalocean_base.backend import BaseDigitalOceanBackend
3+
from dstack._internal.core.backends.digitalocean_base.models import BaseDigitalOceanConfig
44
from dstack._internal.core.models.backends.base import BackendType
55

66

7-
class DigitalOceanBackend(Backend):
7+
class DigitalOceanBackend(BaseDigitalOceanBackend):
88
TYPE = BackendType.DIGITALOCEAN
99
COMPUTE_CLASS = DigitalOceanCompute
1010

11-
def __init__(self, config: DigitalOceanConfig):
11+
def __init__(self, config: BaseDigitalOceanConfig, api_url: str):
1212
self.config = config
13-
self._compute = DigitalOceanCompute(self.config)
13+
self._compute = DigitalOceanCompute(self.config, api_url=api_url, type=self.TYPE)
1414

1515
def compute(self) -> DigitalOceanCompute:
1616
return self._compute
Lines changed: 3 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -1,178 +1,5 @@
1-
from typing import List, Optional
1+
from ..digitalocean_base.compute import BaseDigitalOceanCompute
22

3-
import gpuhunt
4-
from gpuhunt.providers.digitalocean import DigitalOceanProvider
53

6-
from dstack._internal.core.backends.base.backend import Compute
7-
from dstack._internal.core.backends.base.compute import (
8-
ComputeWithCreateInstanceSupport,
9-
generate_unique_instance_name,
10-
get_user_data,
11-
)
12-
from dstack._internal.core.backends.base.offers import get_catalog_offers
13-
from dstack._internal.core.backends.digitalocean.api_client import DigitalOceanAPIClient
14-
from dstack._internal.core.backends.digitalocean.models import DigitalOceanConfig
15-
from dstack._internal.core.models.backends.base import BackendType
16-
from dstack._internal.core.models.instances import (
17-
InstanceAvailability,
18-
InstanceConfiguration,
19-
InstanceOfferWithAvailability,
20-
)
21-
from dstack._internal.core.models.placement import PlacementGroup
22-
from dstack._internal.core.models.runs import JobProvisioningData, Requirements
23-
from dstack._internal.utils.logging import get_logger
24-
25-
logger = get_logger(__name__)
26-
27-
MAX_INSTANCE_NAME_LEN = 60
28-
29-
# Setup commands for DigitalOcean instances
30-
SETUP_COMMANDS = [
31-
"sudo ufw delete limit ssh",
32-
"sudo ufw allow ssh",
33-
]
34-
35-
DOCKER_INSTALL_COMMANDS = [
36-
"export DEBIAN_FRONTEND=noninteractive",
37-
"mkdir -p /etc/apt/keyrings",
38-
"curl --max-time 60 -fsSL https://download.docker.com/linux/ubuntu/gpg | gpg --dearmor -o /etc/apt/keyrings/docker.gpg",
39-
'echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" | tee /etc/apt/sources.list.d/docker.list > /dev/null',
40-
"apt-get update",
41-
"apt-get --assume-yes install docker-ce docker-ce-cli containerd.io docker-compose-plugin",
42-
]
43-
44-
45-
class DigitalOceanCompute(
46-
ComputeWithCreateInstanceSupport,
47-
Compute,
48-
):
49-
def __init__(self, config: DigitalOceanConfig):
50-
super().__init__()
51-
self.config = config
52-
self.api_client = DigitalOceanAPIClient(config.creds.api_key, config.flavor or "standard")
53-
self.catalog = gpuhunt.Catalog(balance_resources=False, auto_reload=False)
54-
self.catalog.add_provider(
55-
DigitalOceanProvider(token=config.creds.api_key, flavor=config.flavor or "standard")
56-
)
57-
# self.catalog.add_provider(
58-
# DigitalOceanProvider(token=config.creds.api_key, flavor="standard")
59-
# )
60-
61-
def get_offers(
62-
self, requirements: Optional[Requirements] = None
63-
) -> List[InstanceOfferWithAvailability]:
64-
offers = get_catalog_offers(
65-
backend=BackendType.DIGITALOCEAN,
66-
locations=self.config.regions,
67-
requirements=requirements,
68-
catalog=self.catalog,
69-
)
70-
return [
71-
InstanceOfferWithAvailability(
72-
**offer.dict(),
73-
availability=InstanceAvailability.AVAILABLE,
74-
)
75-
for offer in offers
76-
]
77-
78-
def create_instance(
79-
self,
80-
instance_offer: InstanceOfferWithAvailability,
81-
instance_config: InstanceConfiguration,
82-
placement_group: Optional[PlacementGroup],
83-
) -> JobProvisioningData:
84-
instance_name = generate_unique_instance_name(
85-
instance_config, max_length=MAX_INSTANCE_NAME_LEN
86-
)
87-
88-
project_ssh_key = instance_config.ssh_keys[0]
89-
ssh_key_id = self.api_client.get_or_create_ssh_key(
90-
name=f"dstack-{instance_config.project_name}",
91-
public_key=project_ssh_key.public,
92-
)
93-
94-
# Use the instance name directly from the offer (gpuhunt handles flavor-specific naming)
95-
size_slug = instance_offer.instance.name
96-
97-
if not instance_offer.instance.resources.gpus:
98-
backend_specific_commands = SETUP_COMMANDS + DOCKER_INSTALL_COMMANDS
99-
else:
100-
backend_specific_commands = SETUP_COMMANDS
101-
102-
# Prepare droplet configuration
103-
droplet_config = {
104-
"name": instance_name,
105-
"region": instance_offer.region,
106-
"size": size_slug,
107-
"image": self._get_image_for_instance(instance_offer),
108-
"ssh_keys": [ssh_key_id],
109-
"backups": False,
110-
"ipv6": False,
111-
"monitoring": False,
112-
"tags": [],
113-
"user_data": get_user_data(
114-
authorized_keys=instance_config.get_public_keys(),
115-
backend_specific_commands=backend_specific_commands,
116-
),
117-
}
118-
119-
droplet = self.api_client.create_droplet(droplet_config)
120-
121-
return JobProvisioningData(
122-
backend=instance_offer.backend,
123-
instance_type=instance_offer.instance,
124-
instance_id=str(droplet["id"]),
125-
hostname=None, # Will be set when droplet is active
126-
internal_ip=None,
127-
region=instance_offer.region,
128-
price=instance_offer.price,
129-
username="root",
130-
ssh_port=22,
131-
dockerized=True,
132-
ssh_proxy=None,
133-
backend_data=None,
134-
)
135-
136-
def update_provisioning_data(
137-
self,
138-
provisioning_data: JobProvisioningData,
139-
project_ssh_public_key: str,
140-
project_ssh_private_key: str,
141-
):
142-
droplet = self.api_client.get_droplet(provisioning_data.instance_id)
143-
if droplet["status"] == "active":
144-
for network in droplet["networks"]["v4"]:
145-
if network["type"] == "public":
146-
provisioning_data.hostname = network["ip_address"]
147-
break
148-
149-
def terminate_instance(
150-
self, instance_id: str, region: str, backend_data: Optional[str] = None
151-
):
152-
self.api_client.delete_droplet(instance_id)
153-
154-
def _get_image_for_instance(self, instance_offer: InstanceOfferWithAvailability) -> str:
155-
if not instance_offer.instance.resources.gpus:
156-
# No GPUs, use CPU image
157-
return "ubuntu-24-04-x64"
158-
159-
gpu_count = len(instance_offer.instance.resources.gpus)
160-
gpu_name = instance_offer.instance.resources.gpus[0].name
161-
162-
if gpu_name == "MI300X":
163-
# AMD GPU
164-
return "digitaloceanai-rocmjupyter"
165-
else:
166-
# NVIDIA GPUs - DO only supports 1 and 8 GPU configurations.
167-
# DO says for single GPU plans using GPUs other than H100s use "gpu-h100x1-base". But for x8 assuming same.
168-
# See (https://docs.digitalocean.com/products/droplets/getting-started/recommended-gpu-setup/#aiml-ready-image)
169-
if gpu_count == 8:
170-
return "gpu-h100x8-base"
171-
elif gpu_count == 1:
172-
return "gpu-h100x1-base"
173-
else:
174-
# For Unsupported GPU count - use single GPU image and log warning
175-
logger.warning(
176-
f"Unsupported NVIDIA GPU count: {gpu_count}, using single GPU image"
177-
)
178-
return "gpu-h100x1-base"
4+
class DigitalOceanCompute(BaseDigitalOceanCompute):
5+
pass
Lines changed: 14 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,27 @@
1-
import json
2-
3-
from dstack._internal.core.backends.base.configurator import (
4-
BackendRecord,
5-
Configurator,
6-
)
7-
from dstack._internal.core.backends.digitalocean.api_client import DigitalOceanAPIClient
1+
from dstack._internal.core.backends.base.configurator import BackendRecord
82
from dstack._internal.core.backends.digitalocean.backend import DigitalOceanBackend
9-
from dstack._internal.core.backends.digitalocean.models import (
10-
AnyDigitalOceanBackendConfig,
11-
AnyDigitalOceanCreds,
12-
DigitalOceanBackendConfig,
13-
DigitalOceanBackendConfigWithCreds,
14-
DigitalOceanConfig,
15-
DigitalOceanCreds,
16-
DigitalOceanStoredConfig,
3+
from dstack._internal.core.backends.digitalocean_base.api_client import DigitalOceanAPIClient
4+
from dstack._internal.core.backends.digitalocean_base.backend import BaseDigitalOceanBackend
5+
from dstack._internal.core.backends.digitalocean_base.configurator import (
6+
BaseDigitalOceanConfigurator,
7+
)
8+
from dstack._internal.core.backends.digitalocean_base.models import (
9+
AnyBaseDigitalOceanCreds,
1710
)
1811
from dstack._internal.core.models.backends.base import (
1912
BackendType,
2013
)
2114

2215

23-
class DigitalOceanConfigurator(Configurator):
16+
class DigitalOceanConfigurator(BaseDigitalOceanConfigurator):
2417
TYPE = BackendType.DIGITALOCEAN
2518
BACKEND_CLASS = DigitalOceanBackend
19+
API_URL = "https://api.digitalocean.com"
2620

27-
def validate_config(
28-
self, config: DigitalOceanBackendConfigWithCreds, default_creds_enabled: bool
29-
):
30-
self._validate_creds(config.creds, config.flavor or "standard")
31-
32-
def create_backend(
33-
self, project_name: str, config: DigitalOceanBackendConfigWithCreds
34-
) -> BackendRecord:
35-
return BackendRecord(
36-
config=DigitalOceanStoredConfig(
37-
**DigitalOceanBackendConfig.__response__.parse_obj(config).dict()
38-
).json(),
39-
auth=DigitalOceanCreds.parse_obj(config.creds).json(),
40-
)
41-
42-
def get_backend_config(
43-
self, record: BackendRecord, include_creds: bool
44-
) -> AnyDigitalOceanBackendConfig:
45-
config = self._get_config(record)
46-
if include_creds:
47-
return DigitalOceanBackendConfigWithCreds.__response__.parse_obj(config)
48-
return DigitalOceanBackendConfig.__response__.parse_obj(config)
49-
50-
def get_backend(self, record: BackendRecord) -> DigitalOceanBackend:
21+
def get_backend(self, record: BackendRecord) -> BaseDigitalOceanBackend:
5122
config = self._get_config(record)
52-
return DigitalOceanBackend(config=config)
53-
54-
def _get_config(self, record: BackendRecord) -> DigitalOceanConfig:
55-
return DigitalOceanConfig.__response__(
56-
**json.loads(record.config),
57-
creds=DigitalOceanCreds.parse_raw(record.auth),
58-
)
23+
return DigitalOceanBackend(config=config, api_url=self.API_URL)
5924

60-
def _validate_creds(self, creds: AnyDigitalOceanCreds, flavor: str):
61-
api_client = DigitalOceanAPIClient(creds.api_key, flavor)
25+
def _validate_creds(self, creds: AnyBaseDigitalOceanCreds):
26+
api_client = DigitalOceanAPIClient(creds.api_key, self.API_URL)
6227
api_client.validate_api_key()

0 commit comments

Comments
 (0)