Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/docs/concepts/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,28 @@ projects:

</div>

### CloudRift

Log into your [CloudRift :material-arrow-top-right-thin:{ .external }](https://www.cloudrift.ai/console/) console, click `API Keys` in the sidebar and click the button to create a new API key.
Comment thread
jvstme marked this conversation as resolved.
Outdated

Ensure you've created a project with CloudRift,

Comment thread
jvstme marked this conversation as resolved.
Outdated
Then proceed to configuring the backend.

<div editor-title="~/.dstack/server/config.yml">

```yaml
projects:
- name: main
backends:
- type: cloudrift
creds:
type: api_key
api_key: rift_2prgY1d0laOrf2BblTwx2B2d1zcf1zIp4tZYpj5j88qmNgz38pxNlpX3vAo
```

</div>

## On-prem servers

### SSH fleets
Expand Down
17 changes: 17 additions & 0 deletions docs/docs/reference/server/config.yml.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,23 @@ to configure [backends](../../concepts/backends.md) and other [sever-level setti
type:
required: true

##### `projects[n].backends[type=cloudrift]` { #cloudrift data-toc-label="cloudrift" }

#SCHEMA# dstack._internal.core.backends.cloudrift.models.CloudRiftBackendConfigWithCreds
overrides:
show_root_heading: false
type:
required: true
item_id_prefix: cloudrift-

###### `projects[n].backends[type=cloudrift].creds` { #cloudrift-creds data-toc-label="creds" }

#SCHEMA# dstack._internal.core.backends.cloudrift.models.CloudRiftAPIKeyCreds
overrides:
show_root_heading: false
type:
required: true

### `encryption` { #encryption data-toc-label="encryption" }

#SCHEMA# dstack._internal.server.services.config.EncryptionConfig
Expand Down
Empty file.
228 changes: 228 additions & 0 deletions src/dstack/_internal/core/backends/cloudrift/api_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
import os
import re
from typing import Any, Dict, List, Mapping, Optional, Union

import requests
from packaging import version
from requests import Response

from dstack._internal.core.errors import BackendError, BackendInvalidCredentialsError
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)


CLOUDRIFT_SERVER_ADDRESS = "https://api.cloudrift.ai"
CLOUDRIFT_API_VERSION = "2025-05-29"


class RiftClient:
def __init__(self, api_key: Optional[str] = None):
self.server_address = CLOUDRIFT_SERVER_ADDRESS
self.public_api_root = os.path.join(CLOUDRIFT_SERVER_ADDRESS, "api/v1")
self.internal_api_root = os.path.join(CLOUDRIFT_SERVER_ADDRESS, "internal")
Comment thread
jvstme marked this conversation as resolved.
Outdated
self.api_key = api_key

def validate_api_key(self) -> bool:
"""
Validates the API key by making a request to the server.
Returns True if the API key is valid, False otherwise.
"""
try:
response = self._make_request("auth/me")
if isinstance(response, dict):
return response.get("email", False)
Comment thread
jvstme marked this conversation as resolved.
Outdated
return False
except BackendInvalidCredentialsError:
return False
except Exception as e:
logger.error(f"Error validating API key: {e}")
return False

def get_instance_types(self) -> List[Dict]:
request_data = {"selector": {"ByServiceAndLocation": {"services": ["vm"]}}}
response_data = self._make_request("instance-types/list", request_data)
if isinstance(response_data, dict):
return response_data.get("instance_types", [])
return []

def list_recipies(self) -> List[Dict]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) Typo here and in a few more places below

Suggested change
def list_recipies(self) -> List[Dict]:
def list_recipes(self) -> List[Dict]:

request_data = {}
response_data = self._make_request("recipes/list", request_data)
if isinstance(response_data, dict):
return response_data.get("groups", [])
return []

def get_vm_recipies(self) -> List[Dict]:
"""
Retrieves a list of VM recipes from the CloudRift API.
Returns a list of dictionaries containing recipe information.
"""
recipe_group = self.list_recipies()
vm_recipes = []
for group in recipe_group:
tags = group.get("tags ", [])
has_vm = "vm" in tags
if group.get("name", "").lower() != "linux" and not has_vm:
continue
Comment thread
jvstme marked this conversation as resolved.
Outdated

recipes = group.get("recipes", [])
for recipe in recipes:
details = recipe.get("details", {})
if details.get("VirtualMachine", False):
vm_recipes.append(recipe)

return vm_recipes

def get_vm_image_url(self) -> str | None:
Comment thread
jvstme marked this conversation as resolved.
Outdated
recipes = self.get_vm_recipies()
ubuntu_images = []
for recipe in recipes:
has_nvidia_driver = "nvidia-driver" in recipe.get("tags", [])
if not has_nvidia_driver:
continue

recipe_name = recipe.get("name", "")
if "Ubuntu" not in recipe_name:
continue

url = recipe["details"].get("VirtualMachine", {}).get("image_url", None)
version_match = re.search(r".* (\d+\.\d+)", recipe_name)
if url and version_match and version_match.group(1):
ubuntu_version = version.parse(version_match.group(1))
ubuntu_images.append((ubuntu_version, url))

ubuntu_images.sort(key=lambda x: x[0]) # Sort by version
if ubuntu_images:
return ubuntu_images[-1][1]

return None

def deploy_instance(
self, instance_type: str, region: str, ssh_keys: List[str], cmd: str
) -> List[str]:
image_url = self.get_vm_image_url()
if not image_url:
raise BackendError("No suitable VM image found.")

request_data = {
"config": {
"VirtualMachine": {
# "cloudinit_url": "",
Comment thread
jvstme marked this conversation as resolved.
Outdated
"cloudinit_commands": cmd,
"image_url": image_url,
"ssh_key": {"PublicKeys": ssh_keys},
}
},
"selector": {
"ByInstanceTypeAndLocation": {
"datacenters": [region],
"instance_type": instance_type,
}
},
"with_public_ip": True,
}
logger.debug("Deploying instance with request data: %s", request_data)

response_data = self._make_request("instances/rent", request_data)
if isinstance(response_data, dict):
return response_data.get("instance_ids", [])
return []

def list_instances(self, instance_ids: Optional[List[str]] = None) -> List[Dict]:
request_data = {
"selector": {
"ByStatus": ["Initializing", "Active", "Deactivating"],
}
}
logger.debug("Listing instances with request data: %s", request_data)
response_data = self._make_request("instances/list", request_data)
if isinstance(response_data, dict):
return response_data.get("instances", [])

return []
Comment on lines +128 to +140
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) Unused

Suggested change
def list_instances(self, instance_ids: Optional[List[str]] = None) -> List[Dict]:
request_data = {
"selector": {
"ByStatus": ["Initializing", "Active", "Deactivating"],
}
}
logger.debug("Listing instances with request data: %s", request_data)
response_data = self._make_request("instances/list", request_data)
if isinstance(response_data, dict):
return response_data.get("instances", [])
return []


def get_instance_by_id(self, instance_id: str) -> Optional[Dict]:
request_data = {"selector": {"ById": [instance_id]}}
logger.debug("Getting instance with request data: %s", request_data)
response_data = self._make_request("instances/list", request_data)
if isinstance(response_data, dict):
instances = response_data.get("instances", [])
if isinstance(instances, list) and len(instances) > 0:
return instances[0]

return None

def is_instance_ready(self, instance_id: str) -> bool:
"""
Checks if the instance with the given ID is ready.
Returns True if the instance is ready, False otherwise.
"""
instance_info = self.get_instance_by_id(instance_id)
if instance_info:
instance_type = instance_info.get("node_mode", "")
if instance_type == "VirtualMachine":
vms = instance_info.get("virtual_machines", [])
if len(vms) > 0:
vm_ready = vms[0].get("ready", False)
return vm_ready
else:
return instance_info.get("status", "") == "Active"
return False
Comment thread
jvstme marked this conversation as resolved.
Outdated

def terminate_instance(self, instance_id: str) -> bool:
request_data = {"selector": {"ById": [instance_id]}}
logger.debug("Terminating instance with request data: %s", request_data)
response_data = self._make_request("instances/terminate", request_data)
if isinstance(response_data, dict):
info = response_data.get("terminated", [])
return len(info) > 0

return False

def _make_request(
self,
endpoint: str,
data: Optional[Mapping[str, Any]] = None,
method: str = "POST",
**kwargs,
) -> Union[Mapping[str, Any], str, Response]:
headers = {}
if self.api_key is not None:
headers["X-API-Key"] = self.api_key

version = CLOUDRIFT_API_VERSION
full_url = f"{self.public_api_root}/{endpoint}"

try:
response = requests.request(
method,
full_url,
headers=headers,
json={"version": version, "data": data},
timeout=120,
Comment thread
jvstme marked this conversation as resolved.
Outdated
**kwargs,
)

if not response.ok:
response.raise_for_status()
try:
response_json = response.json()
if isinstance(response_json, str):
return response_json
if version is not None and version < response_json["version"]:
logger.warning(
"The API version %s is lower than the server version %s. ",
version,
response_json["version"],
)
return response_json["data"]
except requests.exceptions.JSONDecodeError:
return response
except requests.HTTPError as e:
if e.response is not None and e.response.status_code in (
requests.codes.forbidden,
requests.codes.unauthorized,
):
raise BackendInvalidCredentialsError(e.response.text)
raise
16 changes: 16 additions & 0 deletions src/dstack/_internal/core/backends/cloudrift/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from dstack._internal.core.backends.base.backend import Backend
from dstack._internal.core.backends.cloudrift.compute import CloudRiftCompute
from dstack._internal.core.backends.cloudrift.models import CloudRiftConfig
from dstack._internal.core.models.backends.base import BackendType


class CloudRiftBackend(Backend):
TYPE = BackendType.CLOUDRIFT
COMPUTE_CLASS = CloudRiftCompute

def __init__(self, config: CloudRiftConfig):
self.config = config
self._compute = CloudRiftCompute(self.config)

def compute(self) -> CloudRiftCompute:
return self._compute
Loading
Loading