Skip to content

Commit 2db1752

Browse files
6erunSlonegg
andauthored
Add CloudRift backend (#2771)
Co-authored-by: Dmitry Trifonov <slonegg@gmail.com>
1 parent 026ba42 commit 2db1752

File tree

14 files changed

+560
-0
lines changed

14 files changed

+560
-0
lines changed

docs/docs/concepts/backends.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,28 @@ projects:
913913

914914
</div>
915915

916+
### CloudRift
917+
918+
Log into your [CloudRift :material-arrow-top-right-thin:{ .external }](https://console.cloudrift.ai/) console, click `API Keys` in the sidebar and click the button to create a new API key.
919+
920+
Ensure you've created a project with CloudRift.
921+
922+
Then proceed to configuring the backend.
923+
924+
<div editor-title="~/.dstack/server/config.yml">
925+
926+
```yaml
927+
projects:
928+
- name: main
929+
backends:
930+
- type: cloudrift
931+
creds:
932+
type: api_key
933+
api_key: rift_2prgY1d0laOrf2BblTwx2B2d1zcf1zIp4tZYpj5j88qmNgz38pxNlpX3vAo
934+
```
935+
936+
</div>
937+
916938
## On-prem servers
917939

918940
### SSH fleets

docs/docs/reference/server/config.yml.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,23 @@ to configure [backends](../../concepts/backends.md) and other [sever-level setti
315315
type:
316316
required: true
317317

318+
##### `projects[n].backends[type=cloudrift]` { #cloudrift data-toc-label="cloudrift" }
319+
320+
#SCHEMA# dstack._internal.core.backends.cloudrift.models.CloudRiftBackendConfigWithCreds
321+
overrides:
322+
show_root_heading: false
323+
type:
324+
required: true
325+
item_id_prefix: cloudrift-
326+
327+
###### `projects[n].backends[type=cloudrift].creds` { #cloudrift-creds data-toc-label="creds" }
328+
329+
#SCHEMA# dstack._internal.core.backends.cloudrift.models.CloudRiftAPIKeyCreds
330+
overrides:
331+
show_root_heading: false
332+
type:
333+
required: true
334+
318335
### `encryption` { #encryption data-toc-label="encryption" }
319336

320337
#SCHEMA# dstack._internal.server.services.config.EncryptionConfig

src/dstack/_internal/core/backends/cloudrift/__init__.py

Whitespace-only changes.
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
import os
2+
import re
3+
from typing import Any, Dict, List, Mapping, Optional, Union
4+
5+
import requests
6+
from packaging import version
7+
from requests import Response
8+
9+
from dstack._internal.core.errors import BackendError, BackendInvalidCredentialsError
10+
from dstack._internal.utils.logging import get_logger
11+
12+
logger = get_logger(__name__)
13+
14+
15+
CLOUDRIFT_SERVER_ADDRESS = "https://api.cloudrift.ai"
16+
CLOUDRIFT_API_VERSION = "2025-05-29"
17+
18+
19+
class RiftClient:
20+
def __init__(self, api_key: Optional[str] = None):
21+
self.public_api_root = os.path.join(CLOUDRIFT_SERVER_ADDRESS, "api/v1")
22+
self.api_key = api_key
23+
24+
def validate_api_key(self) -> bool:
25+
"""
26+
Validates the API key by making a request to the server.
27+
Returns True if the API key is valid, False otherwise.
28+
"""
29+
try:
30+
response = self._make_request("auth/me")
31+
if isinstance(response, dict):
32+
return "email" in response
33+
return False
34+
except BackendInvalidCredentialsError:
35+
return False
36+
except Exception as e:
37+
logger.error(f"Error validating API key: {e}")
38+
return False
39+
40+
def get_instance_types(self) -> List[Dict]:
41+
request_data = {"selector": {"ByServiceAndLocation": {"services": ["vm"]}}}
42+
response_data = self._make_request("instance-types/list", request_data)
43+
if isinstance(response_data, dict):
44+
return response_data.get("instance_types", [])
45+
return []
46+
47+
def list_recipes(self) -> List[Dict]:
48+
request_data = {}
49+
response_data = self._make_request("recipes/list", request_data)
50+
if isinstance(response_data, dict):
51+
return response_data.get("groups", [])
52+
return []
53+
54+
def get_vm_recipies(self) -> List[Dict]:
55+
"""
56+
Retrieves a list of VM recipes from the CloudRift API.
57+
Returns a list of dictionaries containing recipe information.
58+
"""
59+
recipe_group = self.list_recipes()
60+
vm_recipes = []
61+
for group in recipe_group:
62+
tags = group.get("tags", [])
63+
has_vm = "vm" in map(str.lower, tags)
64+
if group.get("name", "").lower() != "linux" or not has_vm:
65+
continue
66+
67+
recipes = group.get("recipes", [])
68+
for recipe in recipes:
69+
details = recipe.get("details", {})
70+
if details.get("VirtualMachine", False):
71+
vm_recipes.append(recipe)
72+
73+
return vm_recipes
74+
75+
def get_vm_image_url(self) -> Optional[str]:
76+
recipes = self.get_vm_recipies()
77+
ubuntu_images = []
78+
for recipe in recipes:
79+
has_nvidia_driver = "nvidia-driver" in recipe.get("tags", [])
80+
if not has_nvidia_driver:
81+
continue
82+
83+
recipe_name = recipe.get("name", "")
84+
if "Ubuntu" not in recipe_name:
85+
continue
86+
87+
url = recipe["details"].get("VirtualMachine", {}).get("image_url", None)
88+
version_match = re.search(r".* (\d+\.\d+)", recipe_name)
89+
if url and version_match and version_match.group(1):
90+
ubuntu_version = version.parse(version_match.group(1))
91+
ubuntu_images.append((ubuntu_version, url))
92+
93+
ubuntu_images.sort(key=lambda x: x[0]) # Sort by version
94+
if ubuntu_images:
95+
return ubuntu_images[-1][1]
96+
97+
return None
98+
99+
def deploy_instance(
100+
self, instance_type: str, region: str, ssh_keys: List[str], cmd: str
101+
) -> List[str]:
102+
image_url = self.get_vm_image_url()
103+
if not image_url:
104+
raise BackendError("No suitable VM image found.")
105+
106+
request_data = {
107+
"config": {
108+
"VirtualMachine": {
109+
"cloudinit_commands": cmd,
110+
"image_url": image_url,
111+
"ssh_key": {"PublicKeys": ssh_keys},
112+
}
113+
},
114+
"selector": {
115+
"ByInstanceTypeAndLocation": {
116+
"datacenters": [region],
117+
"instance_type": instance_type,
118+
}
119+
},
120+
"with_public_ip": True,
121+
}
122+
logger.debug("Deploying instance with request data: %s", request_data)
123+
124+
response_data = self._make_request("instances/rent", request_data)
125+
if isinstance(response_data, dict):
126+
return response_data.get("instance_ids", [])
127+
return []
128+
129+
def list_instances(self, instance_ids: Optional[List[str]] = None) -> List[Dict]:
130+
request_data = {
131+
"selector": {
132+
"ByStatus": ["Initializing", "Active", "Deactivating"],
133+
}
134+
}
135+
logger.debug("Listing instances with request data: %s", request_data)
136+
response_data = self._make_request("instances/list", request_data)
137+
if isinstance(response_data, dict):
138+
return response_data.get("instances", [])
139+
140+
return []
141+
142+
def get_instance_by_id(self, instance_id: str) -> Optional[Dict]:
143+
request_data = {"selector": {"ById": [instance_id]}}
144+
logger.debug("Getting instance with request data: %s", request_data)
145+
response_data = self._make_request("instances/list", request_data)
146+
if isinstance(response_data, dict):
147+
instances = response_data.get("instances", [])
148+
if isinstance(instances, list) and len(instances) > 0:
149+
return instances[0]
150+
151+
return None
152+
153+
def terminate_instance(self, instance_id: str) -> bool:
154+
request_data = {"selector": {"ById": [instance_id]}}
155+
logger.debug("Terminating instance with request data: %s", request_data)
156+
response_data = self._make_request("instances/terminate", request_data)
157+
if isinstance(response_data, dict):
158+
info = response_data.get("terminated", [])
159+
return len(info) > 0
160+
161+
return False
162+
163+
def _make_request(
164+
self,
165+
endpoint: str,
166+
data: Optional[Mapping[str, Any]] = None,
167+
method: str = "POST",
168+
**kwargs,
169+
) -> Union[Mapping[str, Any], str, Response]:
170+
headers = {}
171+
if self.api_key is not None:
172+
headers["X-API-Key"] = self.api_key
173+
174+
version = CLOUDRIFT_API_VERSION
175+
full_url = f"{self.public_api_root}/{endpoint}"
176+
177+
try:
178+
response = requests.request(
179+
method,
180+
full_url,
181+
headers=headers,
182+
json={"version": version, "data": data},
183+
timeout=15,
184+
**kwargs,
185+
)
186+
187+
if not response.ok:
188+
response.raise_for_status()
189+
try:
190+
response_json = response.json()
191+
if isinstance(response_json, str):
192+
return response_json
193+
if version is not None and version < response_json["version"]:
194+
logger.warning(
195+
"The API version %s is lower than the server version %s. ",
196+
version,
197+
response_json["version"],
198+
)
199+
return response_json["data"]
200+
except requests.exceptions.JSONDecodeError:
201+
return response
202+
except requests.HTTPError as e:
203+
if e.response is not None and e.response.status_code in (
204+
requests.codes.forbidden,
205+
requests.codes.unauthorized,
206+
):
207+
raise BackendInvalidCredentialsError(e.response.text)
208+
raise
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from dstack._internal.core.backends.base.backend import Backend
2+
from dstack._internal.core.backends.cloudrift.compute import CloudRiftCompute
3+
from dstack._internal.core.backends.cloudrift.models import CloudRiftConfig
4+
from dstack._internal.core.models.backends.base import BackendType
5+
6+
7+
class CloudRiftBackend(Backend):
8+
TYPE = BackendType.CLOUDRIFT
9+
COMPUTE_CLASS = CloudRiftCompute
10+
11+
def __init__(self, config: CloudRiftConfig):
12+
self.config = config
13+
self._compute = CloudRiftCompute(self.config)
14+
15+
def compute(self) -> CloudRiftCompute:
16+
return self._compute

0 commit comments

Comments
 (0)