Skip to content

Commit 8fc6816

Browse files
Bihan  RanaBihan  Rana
authored andcommitted
Add DigitalOcean Backend
1 parent 38e66bc commit 8fc6816

10 files changed

Lines changed: 428 additions & 1 deletion

File tree

src/dstack/_internal/core/backends/configurators.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@
4747
except ImportError:
4848
pass
4949

50+
try:
51+
from dstack._internal.core.backends.digitalocean.configurator import (
52+
DigitalOceanConfigurator,
53+
)
54+
55+
_CONFIGURATOR_CLASSES.append(DigitalOceanConfigurator)
56+
except ImportError:
57+
pass
58+
5059
try:
5160
from dstack._internal.core.backends.gcp.configurator import GCPConfigurator
5261

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# DigitalOcean backend for dstack
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from typing import Any, Dict, List, Optional
2+
3+
import requests
4+
5+
from dstack._internal.core.backends.base.configurator import raise_invalid_credentials_error
6+
from dstack._internal.core.errors import BackendRateLimitExceededError
7+
from dstack._internal.utils.logging import get_logger
8+
9+
logger = get_logger(__name__)
10+
11+
# DigitalOcean API endpoints
12+
STANDARD_CLOUD_API_URL = "https://api.digitalocean.com/v2"
13+
AMD_CLOUD_API_URL = "https://api-amd.digitalocean.com/v2"
14+
15+
16+
class DigitalOceanAPIClient:
17+
def __init__(self, api_key: str, flavor: str = "standard"):
18+
self.api_key = api_key
19+
self.flavor = flavor
20+
self.base_url = self._get_base_url()
21+
22+
def _get_base_url(self) -> str:
23+
if self.flavor == "amd":
24+
return AMD_CLOUD_API_URL
25+
return STANDARD_CLOUD_API_URL
26+
27+
def validate_api_key(self) -> bool:
28+
try:
29+
response = self._make_request("GET", "/account")
30+
response.raise_for_status()
31+
return True
32+
except requests.HTTPError as e:
33+
status = e.response.status_code
34+
if status == 401:
35+
raise_invalid_credentials_error(
36+
fields=[["creds", "api_key"]], details="Invaild API key"
37+
)
38+
raise e
39+
40+
def list_ssh_keys(self) -> List[Dict[str, Any]]:
41+
response = self._make_request("GET", "/account/keys")
42+
response.raise_for_status()
43+
return response.json()["ssh_keys"]
44+
45+
def create_ssh_key(self, name: str, public_key: str) -> Dict[str, Any]:
46+
payload = {"name": name, "public_key": public_key}
47+
response = self._make_request("POST", "/account/keys", json=payload)
48+
response.raise_for_status()
49+
return response.json()["ssh_key"]
50+
51+
def get_or_create_ssh_key(self, name: str, public_key: str) -> int:
52+
ssh_keys = self.list_ssh_keys()
53+
for ssh_key in ssh_keys:
54+
if ssh_key["public_key"].strip() == public_key.strip():
55+
return ssh_key["id"]
56+
57+
ssh_key = self.create_ssh_key(name, public_key)
58+
return ssh_key["id"]
59+
60+
def create_droplet(self, droplet_config: Dict[str, Any]) -> Dict[str, Any]:
61+
response = self._make_request("POST", "/droplets", json=droplet_config)
62+
response.raise_for_status()
63+
return response.json()["droplet"]
64+
65+
def get_droplet(self, droplet_id: str) -> Dict[str, Any]:
66+
response = self._make_request("GET", f"/droplets/{droplet_id}")
67+
response.raise_for_status()
68+
return response.json()["droplet"]
69+
70+
def delete_droplet(self, droplet_id: str) -> None:
71+
response = self._make_request("DELETE", f"/droplets/{droplet_id}")
72+
if response.status_code == 404:
73+
logger.debug("DigitalOcean droplet %s not found", droplet_id)
74+
return
75+
response.raise_for_status()
76+
77+
def _make_request(
78+
self, method: str, endpoint: str, json: Optional[Dict[str, Any]] = None, timeout: int = 30
79+
) -> requests.Response:
80+
url = f"{self.base_url}{endpoint}"
81+
headers = {
82+
"Content-Type": "application/json",
83+
"Authorization": f"Bearer {self.api_key}",
84+
}
85+
86+
response = requests.request(
87+
method=method,
88+
url=url,
89+
headers=headers,
90+
json=json,
91+
timeout=timeout,
92+
)
93+
94+
if response.status_code == 429:
95+
raise BackendRateLimitExceededError("API rate limit exceeded.")
96+
97+
return response
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from dstack._internal.core.backends.base.backend import Backend
2+
from dstack._internal.core.backends.digitalocean.compute import DigitalOceanCompute
3+
from dstack._internal.core.backends.digitalocean.models import DigitalOceanConfig
4+
from dstack._internal.core.models.backends.base import BackendType
5+
6+
7+
class DigitalOceanBackend(Backend):
8+
TYPE = BackendType.DIGITALOCEAN
9+
COMPUTE_CLASS = DigitalOceanCompute
10+
11+
def __init__(self, config: DigitalOceanConfig):
12+
self.config = config
13+
self._compute = DigitalOceanCompute(self.config)
14+
15+
def compute(self) -> DigitalOceanCompute:
16+
return self._compute
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
from typing import List, Optional
2+
3+
import gpuhunt
4+
from gpuhunt.providers.digitalocean import DigitalOceanProvider
5+
6+
from dstack._internal.core.backends.base.backend import Compute
7+
from dstack._internal.core.backends.base.compute import (
8+
ComputeWithCreateInstanceSupport,
9+
generate_unique_instance_name,
10+
get_user_data,
11+
)
12+
from dstack._internal.core.backends.base.offers import get_catalog_offers
13+
from dstack._internal.core.backends.digitalocean.api_client import DigitalOceanAPIClient
14+
from dstack._internal.core.backends.digitalocean.models import DigitalOceanConfig
15+
from dstack._internal.core.models.backends.base import BackendType
16+
from dstack._internal.core.models.instances import (
17+
InstanceAvailability,
18+
InstanceConfiguration,
19+
InstanceOfferWithAvailability,
20+
)
21+
from dstack._internal.core.models.placement import PlacementGroup
22+
from dstack._internal.core.models.runs import JobProvisioningData, Requirements
23+
from dstack._internal.utils.logging import get_logger
24+
25+
logger = get_logger(__name__)
26+
27+
MAX_INSTANCE_NAME_LEN = 60
28+
29+
# Setup commands for DigitalOcean instances
30+
SETUP_COMMANDS = [
31+
"sudo ufw delete limit ssh",
32+
"sudo ufw allow ssh",
33+
]
34+
35+
DOCKER_INSTALL_COMMANDS = [
36+
"export DEBIAN_FRONTEND=noninteractive",
37+
"mkdir -p /etc/apt/keyrings",
38+
"curl --max-time 60 -fsSL https://download.docker.com/linux/ubuntu/gpg | gpg --dearmor -o /etc/apt/keyrings/docker.gpg",
39+
'echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu $(lsb_release -cs) stable" | tee /etc/apt/sources.list.d/docker.list > /dev/null',
40+
"apt-get update",
41+
"apt-get --assume-yes install docker-ce docker-ce-cli containerd.io docker-compose-plugin",
42+
]
43+
44+
45+
class DigitalOceanCompute(
46+
ComputeWithCreateInstanceSupport,
47+
Compute,
48+
):
49+
def __init__(self, config: DigitalOceanConfig):
50+
super().__init__()
51+
self.config = config
52+
self.api_client = DigitalOceanAPIClient(config.creds.api_key, config.flavor or "standard")
53+
self.catalog = gpuhunt.Catalog(balance_resources=False, auto_reload=False)
54+
self.catalog.add_provider(
55+
DigitalOceanProvider(token=config.creds.api_key, flavor=config.flavor or "standard")
56+
)
57+
# self.catalog.add_provider(
58+
# DigitalOceanProvider(token=config.creds.api_key, flavor="standard")
59+
# )
60+
61+
def get_offers(
62+
self, requirements: Optional[Requirements] = None
63+
) -> List[InstanceOfferWithAvailability]:
64+
offers = get_catalog_offers(
65+
backend=BackendType.DIGITALOCEAN,
66+
locations=self.config.regions,
67+
requirements=requirements,
68+
catalog=self.catalog,
69+
)
70+
return [
71+
InstanceOfferWithAvailability(
72+
**offer.dict(),
73+
availability=InstanceAvailability.AVAILABLE,
74+
)
75+
for offer in offers
76+
]
77+
78+
def create_instance(
79+
self,
80+
instance_offer: InstanceOfferWithAvailability,
81+
instance_config: InstanceConfiguration,
82+
placement_group: Optional[PlacementGroup],
83+
) -> JobProvisioningData:
84+
instance_name = generate_unique_instance_name(
85+
instance_config, max_length=MAX_INSTANCE_NAME_LEN
86+
)
87+
88+
project_ssh_key = instance_config.ssh_keys[0]
89+
ssh_key_id = self.api_client.get_or_create_ssh_key(
90+
name=f"dstack-{instance_config.project_name}",
91+
public_key=project_ssh_key.public,
92+
)
93+
94+
# Use the instance name directly from the offer (gpuhunt handles flavor-specific naming)
95+
size_slug = instance_offer.instance.name
96+
97+
if not instance_offer.instance.resources.gpus:
98+
backend_specific_commands = SETUP_COMMANDS + DOCKER_INSTALL_COMMANDS
99+
else:
100+
backend_specific_commands = SETUP_COMMANDS
101+
102+
# Prepare droplet configuration
103+
droplet_config = {
104+
"name": instance_name,
105+
"region": instance_offer.region,
106+
"size": size_slug,
107+
"image": self._get_image_for_instance(instance_offer),
108+
"ssh_keys": [ssh_key_id],
109+
"backups": False,
110+
"ipv6": False,
111+
"monitoring": False,
112+
"tags": [],
113+
"user_data": get_user_data(
114+
authorized_keys=instance_config.get_public_keys(),
115+
backend_specific_commands=backend_specific_commands,
116+
),
117+
}
118+
119+
droplet = self.api_client.create_droplet(droplet_config)
120+
121+
return JobProvisioningData(
122+
backend=instance_offer.backend,
123+
instance_type=instance_offer.instance,
124+
instance_id=str(droplet["id"]),
125+
hostname=None, # Will be set when droplet is active
126+
internal_ip=None,
127+
region=instance_offer.region,
128+
price=instance_offer.price,
129+
username="root",
130+
ssh_port=22,
131+
dockerized=True,
132+
ssh_proxy=None,
133+
backend_data=None,
134+
)
135+
136+
def update_provisioning_data(
137+
self,
138+
provisioning_data: JobProvisioningData,
139+
project_ssh_public_key: str,
140+
project_ssh_private_key: str,
141+
):
142+
droplet = self.api_client.get_droplet(provisioning_data.instance_id)
143+
if droplet["status"] == "active":
144+
for network in droplet["networks"]["v4"]:
145+
if network["type"] == "public":
146+
provisioning_data.hostname = network["ip_address"]
147+
break
148+
149+
def terminate_instance(
150+
self, instance_id: str, region: str, backend_data: Optional[str] = None
151+
):
152+
self.api_client.delete_droplet(instance_id)
153+
154+
def _get_image_for_instance(self, instance_offer: InstanceOfferWithAvailability) -> str:
155+
if not instance_offer.instance.resources.gpus:
156+
# No GPUs, use CPU image
157+
return "ubuntu-24-04-x64"
158+
159+
gpu_count = len(instance_offer.instance.resources.gpus)
160+
gpu_name = instance_offer.instance.resources.gpus[0].name
161+
162+
if gpu_name == "MI300X":
163+
# AMD GPU
164+
return "digitaloceanai-rocmjupyter"
165+
else:
166+
# NVIDIA GPUs - DO only supports 1 and 8 GPU configurations.
167+
# DO says for single GPU plans using GPUs other than H100s use "gpu-h100x1-base". But for x8 assuming same.
168+
# See (https://docs.digitalocean.com/products/droplets/getting-started/recommended-gpu-setup/#aiml-ready-image)
169+
if gpu_count == 8:
170+
return "gpu-h100x8-base"
171+
elif gpu_count == 1:
172+
return "gpu-h100x1-base"
173+
else:
174+
# For Unsupported GPU count - use single GPU image and log warning
175+
logger.warning(
176+
f"Unsupported NVIDIA GPU count: {gpu_count}, using single GPU image"
177+
)
178+
return "gpu-h100x1-base"
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import json
2+
3+
from dstack._internal.core.backends.base.configurator import (
4+
BackendRecord,
5+
Configurator,
6+
)
7+
from dstack._internal.core.backends.digitalocean.api_client import DigitalOceanAPIClient
8+
from dstack._internal.core.backends.digitalocean.backend import DigitalOceanBackend
9+
from dstack._internal.core.backends.digitalocean.models import (
10+
AnyDigitalOceanBackendConfig,
11+
AnyDigitalOceanCreds,
12+
DigitalOceanBackendConfig,
13+
DigitalOceanBackendConfigWithCreds,
14+
DigitalOceanConfig,
15+
DigitalOceanCreds,
16+
DigitalOceanStoredConfig,
17+
)
18+
from dstack._internal.core.models.backends.base import (
19+
BackendType,
20+
)
21+
22+
23+
class DigitalOceanConfigurator(Configurator):
24+
TYPE = BackendType.DIGITALOCEAN
25+
BACKEND_CLASS = DigitalOceanBackend
26+
27+
def validate_config(
28+
self, config: DigitalOceanBackendConfigWithCreds, default_creds_enabled: bool
29+
):
30+
self._validate_creds(config.creds, config.flavor or "standard")
31+
32+
def create_backend(
33+
self, project_name: str, config: DigitalOceanBackendConfigWithCreds
34+
) -> BackendRecord:
35+
return BackendRecord(
36+
config=DigitalOceanStoredConfig(
37+
**DigitalOceanBackendConfig.__response__.parse_obj(config).dict()
38+
).json(),
39+
auth=DigitalOceanCreds.parse_obj(config.creds).json(),
40+
)
41+
42+
def get_backend_config(
43+
self, record: BackendRecord, include_creds: bool
44+
) -> AnyDigitalOceanBackendConfig:
45+
config = self._get_config(record)
46+
if include_creds:
47+
return DigitalOceanBackendConfigWithCreds.__response__.parse_obj(config)
48+
return DigitalOceanBackendConfig.__response__.parse_obj(config)
49+
50+
def get_backend(self, record: BackendRecord) -> DigitalOceanBackend:
51+
config = self._get_config(record)
52+
return DigitalOceanBackend(config=config)
53+
54+
def _get_config(self, record: BackendRecord) -> DigitalOceanConfig:
55+
return DigitalOceanConfig.__response__(
56+
**json.loads(record.config),
57+
creds=DigitalOceanCreds.parse_raw(record.auth),
58+
)
59+
60+
def _validate_creds(self, creds: AnyDigitalOceanCreds, flavor: str):
61+
api_client = DigitalOceanAPIClient(creds.api_key, flavor)
62+
api_client.validate_api_key()

0 commit comments

Comments
 (0)