Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion docs/docs/guides/plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,14 @@ class ExamplePolicy(ApplyPolicy):

</div>

For more information on the plugin development, see the [plugin example](https://github.com/dstackai/dstack/tree/master/examples/plugins/example_plugin).
## Built-in Plugins

### REST Plugin
`rest_plugin` is a builtin `dstack` plugin that allows writing your custom plugins as API servers, so you don't need to install plugins as Python packages.

Plugins implemented as API servers have advantages over plugins implemented as Python packages in some cases:
* No dependency conflicts with `dstack`.
* You can use any programming language.
* If you run the `dstack` server via Docker, you don't need to extend the `dstack` server image with plugins or map them via volumes.

To get started, check out the [plugin server example](/examples/plugins/example_plugin_server/README.md).
1 change: 1 addition & 0 deletions examples/plugins/example_plugin_server/.python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.11
30 changes: 30 additions & 0 deletions examples/plugins/example_plugin_server/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
## Overview

If you wish to hook up your own plugin server through `dstack` builtin `rest_plugin`, here's a basic example on how to do so.

## Steps


1. Install required dependencies for the plugin server:

```bash
uv sync
```

1. Start the plugin server locally:

```bash
fastapi dev app/main.py
```

1. Enable `rest_plugin` in `server/config.yaml`:

```yaml
plugins:
- rest_plugin
```

1. Point the `dstack` server to your plugin server:
```bash
export DSTACK_PLUGIN_SERVICE_URI=http://127.0.0.1:8000
```
Empty file.
56 changes: 56 additions & 0 deletions examples/plugins/example_plugin_server/app/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import logging

from fastapi import FastAPI

from app.utils import configure_logging
from dstack.plugins.builtin.rest_plugin import (
FleetSpecRequest,
FleetSpecResponse,
GatewaySpecRequest,
GatewaySpecResponse,
RunSpecRequest,
RunSpecResponse,
VolumeSpecRequest,
VolumeSpecResponse,
)

configure_logging()
logger = logging.getLogger(__name__)

app = FastAPI()


@app.post("/apply_policies/on_run_apply")
async def on_run_apply(request: RunSpecRequest) -> RunSpecResponse:
logger.info(
f"Received run spec request from user {request.user} and project {request.project}"
)
response = RunSpecResponse(spec=request.spec, error=None)
return response


@app.post("/apply_policies/on_fleet_apply")
async def on_fleet_apply(request: FleetSpecRequest) -> FleetSpecResponse:
logger.info(
f"Received fleet spec request from user {request.user} and project {request.project}"
)
response = FleetSpecResponse(request.spec, error=None)
return response


@app.post("/apply_policies/on_volume_apply")
async def on_volume_apply(request: VolumeSpecRequest) -> VolumeSpecResponse:
logger.info(
f"Received volume spec request from user {request.user} and project {request.project}"
)
response = VolumeSpecResponse(request.spec, error=None)
return response


@app.post("/apply_policies/on_gateway_apply")
async def on_gateway_apply(request: GatewaySpecRequest) -> GatewaySpecResponse:
logger.info(
f"Received gateway spec request from user {request.user} and project {request.project}"
)
response = GatewaySpecResponse(request.spec, error=None)
return response
7 changes: 7 additions & 0 deletions examples/plugins/example_plugin_server/app/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import logging
import os


def configure_logging():
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(level=log_level)
10 changes: 10 additions & 0 deletions examples/plugins/example_plugin_server/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[project]
name = "dstack-plugin-server"
version = "0.1.0"
description = "Example plugin server"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"fastapi[standard]>=0.115.12",
"dstack>=0.19.8"
]
91 changes: 61 additions & 30 deletions src/dstack/_internal/server/services/plugins.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
from importlib import import_module
from typing import Dict

from backports.entry_points_selectable import entry_points # backport for Python 3.9

Expand All @@ -12,50 +13,80 @@

_PLUGINS: list[Plugin] = []

_BUILTIN_PLUGINS: Dict[str, str] = {"rest_plugin": "dstack.plugins.builtin.rest_plugin:RESTPlugin"}

def load_plugins(enabled_plugins: list[str]):
_PLUGINS.clear()
plugins_entrypoints = entry_points(group="dstack.plugins")
plugins_to_load = enabled_plugins.copy()
for entrypoint in plugins_entrypoints:
if entrypoint.name not in enabled_plugins:
logger.info(
("Found not enabled plugin %s. Plugin will not be loaded."),
entrypoint.name,
)
continue

class PluginEntrypoint:
def __init__(self, name: str, import_path: str, is_builtin: bool = False):
self.name = name
self.import_path = import_path
self.is_builtin = is_builtin

def load(self):
module_path, _, class_name = self.import_path.partition(":")
try:
module_path, _, class_name = entrypoint.value.partition(":")
module = import_module(module_path)
plugin_class = getattr(module, class_name, None)
if plugin_class is None:
logger.warning(
("Failed to load plugin %s: plugin class %s not found in module %s."),
self.name,
class_name,
module_path,
)
return None
if not issubclass(plugin_class, Plugin):
logger.warning(
("Failed to load plugin %s: plugin class %s is not a subclass of Plugin."),
self.name,
class_name,
)
return None
return plugin_class()
except ImportError:
logger.warning(
(
"Failed to load plugin %s when importing %s."
" Ensure the module is on the import path."
),
entrypoint.name,
entrypoint.value,
self.name,
self.import_path,
)
continue
plugin_class = getattr(module, class_name, None)
if plugin_class is None:
logger.warning(
("Failed to load plugin %s: plugin class %s not found in module %s."),
return None


def load_plugins(enabled_plugins: list[str]):
_PLUGINS.clear()
entrypoints: dict[str, PluginEntrypoint] = {}
plugins_to_load = enabled_plugins.copy()
for entrypoint in entry_points(group="dstack.plugins"):
if entrypoint.name not in enabled_plugins:
logger.info(
("Found not enabled plugin %s. Plugin will not be loaded."),
entrypoint.name,
class_name,
module_path,
)
continue
if not issubclass(plugin_class, Plugin):
logger.warning(
("Failed to load plugin %s: plugin class %s is not a subclass of Plugin."),
entrypoint.name,
class_name,
else:
entrypoints[entrypoint.name] = PluginEntrypoint(
entrypoint.name, entrypoint.value, is_builtin=False
)
continue
plugins_to_load.remove(entrypoint.name)
_PLUGINS.append(plugin_class())
logger.info("Loaded plugin %s", entrypoint.name)

for name, import_path in _BUILTIN_PLUGINS.items():
if name not in enabled_plugins:
logger.info(
("Found not enabled builtin plugin %s. Plugin will not be loaded."),
name,
)
else:
entrypoints[name] = PluginEntrypoint(name, import_path, is_builtin=True)

for plugin_name, plugin_entrypoint in entrypoints.items():
plugin_instance = plugin_entrypoint.load()
if plugin_instance is not None:
_PLUGINS.append(plugin_instance)
plugins_to_load.remove(plugin_name)
logger.info("Loaded plugin %s", plugin_name)

if plugins_to_load:
logger.warning("Enabled plugins not found: %s", plugins_to_load)

Expand Down
Empty file.
18 changes: 18 additions & 0 deletions src/dstack/plugins/builtin/rest_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# ruff: noqa: F401
from dstack.plugins.builtin.rest_plugin._models import (
FleetSpecRequest,
FleetSpecResponse,
GatewaySpecRequest,
GatewaySpecResponse,
RunSpecRequest,
RunSpecResponse,
SpecApplyRequest,
SpecApplyResponse,
VolumeSpecRequest,
VolumeSpecResponse,
)
from dstack.plugins.builtin.rest_plugin._plugin import (
PLUGIN_SERVICE_URI_ENV_VAR_NAME,
CustomApplyPolicy,
RESTPlugin,
)
48 changes: 48 additions & 0 deletions src/dstack/plugins/builtin/rest_plugin/_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Generic, Optional, TypeVar

from pydantic import BaseModel, Field
from typing_extensions import Annotated

from dstack._internal.core.models.fleets import FleetSpec
from dstack._internal.core.models.gateways import GatewaySpec
from dstack._internal.core.models.runs import RunSpec
from dstack._internal.core.models.volumes import VolumeSpec

SpecType = TypeVar("SpecType", RunSpec, FleetSpec, VolumeSpec, GatewaySpec)


class SpecApplyRequest(BaseModel, Generic[SpecType]):
user: Annotated[str, Field(description="The name of the user making the apply request")]
project: Annotated[str, Field(description="The name of the project the request is for")]
spec: Annotated[SpecType, Field(description="The spec to be applied")]

# Override dict() to remove __orig_class__ attribute and avoid "TypeError: Object of type _GenericAlias is not JSON serializable"
# error. This issue doesn't happen though when running the code in pytest, only when running the server.
def dict(self, *args, **kwargs):
d = super().dict(*args, **kwargs)
d.pop("__orig_class__", None)
return d


RunSpecRequest = SpecApplyRequest[RunSpec]
FleetSpecRequest = SpecApplyRequest[FleetSpec]
VolumeSpecRequest = SpecApplyRequest[VolumeSpec]
GatewaySpecRequest = SpecApplyRequest[GatewaySpec]


class SpecApplyResponse(BaseModel, Generic[SpecType]):
spec: Annotated[
SpecType,
Field(
description="The spec to apply, original spec if error otherwise original or mutated by plugin service if approved"
),
]
error: Annotated[
Optional[str], Field(description="Error message if request is rejected", min_length=1)
] = None


RunSpecResponse = SpecApplyResponse[RunSpec]
FleetSpecResponse = SpecApplyResponse[FleetSpec]
VolumeSpecResponse = SpecApplyResponse[VolumeSpec]
GatewaySpecResponse = SpecApplyResponse[GatewaySpec]
Loading