Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 62 additions & 31 deletions src/dstack/_internal/server/services/plugins.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
from importlib import import_module
from typing import Dict

from backports.entry_points_selectable import entry_points # backport for Python 3.9

Expand All @@ -12,50 +13,80 @@

_PLUGINS: list[Plugin] = []

_BUILTIN_PLUGINS: Dict[str, str] = {"rest_plugin": "dstack.plugins.builtin.rest_plugin:RESTPlugin"}

def load_plugins(enabled_plugins: list[str]):
_PLUGINS.clear()
plugins_entrypoints = entry_points(group="dstack.plugins")
plugins_to_load = enabled_plugins.copy()
for entrypoint in plugins_entrypoints:
if entrypoint.name not in enabled_plugins:
logger.info(
("Found not enabled plugin %s. Plugin will not be loaded."),
entrypoint.name,
)
continue

class PluginEntrypoint:
def __init__(self, name: str, import_path: str, is_builtin: bool = False):
self.name = name
self.import_path = import_path
self.is_builtin = is_builtin

def load(self):
module_path, _, class_name = self.import_path.partition(":")
try:
module_path, _, class_name = entrypoint.value.partition(":")
module = import_module(module_path)
plugin_class = getattr(module, class_name, None)
if plugin_class is None:
logger.warning(
("Failed to load plugin %s: plugin class %s not found in module %s."),
self.name,
class_name,
module_path,
)
return None
if not issubclass(plugin_class, Plugin):
logger.warning(
("Failed to load plugin %s: plugin class %s is not a subclass of Plugin."),
self.name,
class_name,
)
return None
return plugin_class()
except ImportError:
logger.warning(
(
"Failed to load plugin %s when importing %s."
" Ensure the module is on the import path."
),
entrypoint.name,
entrypoint.value,
self.name,
self.import_path,
)
continue
plugin_class = getattr(module, class_name, None)
if plugin_class is None:
logger.warning(
("Failed to load plugin %s: plugin class %s not found in module %s."),
return None


def load_plugins(enabled_plugins: list[str]):
_PLUGINS.clear()
entrypoints: dict[str, PluginEntrypoint] = {}
plugins_to_load = enabled_plugins.copy()
for entrypoint in entry_points(group="dstack.plugins"):
if entrypoint.name not in enabled_plugins:
logger.info(
("Found not enabled plugin %s. Plugin will not be loaded."),
entrypoint.name,
class_name,
module_path,
)
continue
if not issubclass(plugin_class, Plugin):
logger.warning(
("Failed to load plugin %s: plugin class %s is not a subclass of Plugin."),
entrypoint.name,
class_name,
else:
entrypoints[entrypoint.name] = PluginEntrypoint(
entrypoint.name, entrypoint.value, is_builtin=False
)
continue
plugins_to_load.remove(entrypoint.name)
_PLUGINS.append(plugin_class())
logger.info("Loaded plugin %s", entrypoint.name)

for name, import_path in _BUILTIN_PLUGINS.items():
if name not in enabled_plugins:
logger.info(
("Found not enabled builtin plugin %s. Plugin will not be loaded."),
name,
)
else:
entrypoints[name] = PluginEntrypoint(name, import_path, is_builtin=True)

for plugin_name, plugin_entrypoint in entrypoints.items():
plugin_instance = plugin_entrypoint.load()
if plugin_instance is not None:
_PLUGINS.append(plugin_instance)
plugins_to_load.remove(plugin_name)
logger.info("Loaded plugin %s", plugin_name)

if plugins_to_load:
logger.warning("Enabled plugins not found: %s", plugins_to_load)

Expand All @@ -65,7 +96,7 @@ def apply_plugin_policies(user: str, project: str, spec: ApplySpec) -> ApplySpec
for policy in policies:
try:
spec = policy.on_apply(user=user, project=project, spec=spec)
except ValueError as e:
except Exception as e:
msg = None
if len(e.args) > 0:
msg = e.args[0]
Expand Down
97 changes: 97 additions & 0 deletions src/dstack/plugins/builtin/rest_plugin.py
Comment thread
Nadine-H marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import json
import os
from typing import Generic, TypeVar

import requests
from pydantic import BaseModel, ValidationError

from dstack._internal.core.errors import ServerClientError
from dstack._internal.core.models.fleets import FleetSpec
from dstack._internal.core.models.gateways import GatewaySpec
from dstack._internal.core.models.volumes import VolumeSpec
from dstack.plugins import ApplyPolicy, Plugin, RunSpec, get_plugin_logger
from dstack.plugins._models import ApplySpec
Comment thread
Nadine-H marked this conversation as resolved.
Outdated

logger = get_plugin_logger(__name__)

PLUGIN_SERVICE_URI_ENV_VAR_NAME = "DSTACK_PLUGIN_SERVICE_URI"
PLUGIN_REQUEST_TIMEOUT = 8 # in seconds

SpecType = TypeVar("SpecType", RunSpec, FleetSpec, VolumeSpec, GatewaySpec)


class SpecRequest(BaseModel, Generic[SpecType]):
user: str
project: str
spec: SpecType


RunSpecRequest = SpecRequest[RunSpec]
FleetSpecRequest = SpecRequest[FleetSpec]
VolumeSpecRequest = SpecRequest[VolumeSpec]
GatewaySpecRequest = SpecRequest[GatewaySpec]


class CustomApplyPolicy(ApplyPolicy):
Comment thread
Nadine-H marked this conversation as resolved.
Outdated
def __init__(self):
self._plugin_service_uri = os.getenv(PLUGIN_SERVICE_URI_ENV_VAR_NAME)
logger.info(f"Found plugin service at {self._plugin_service_uri}")
if not self._plugin_service_uri:
logger.error(
f"Cannot create policy because {PLUGIN_SERVICE_URI_ENV_VAR_NAME} is not set"
)
raise ServerClientError(f"{PLUGIN_SERVICE_URI_ENV_VAR_NAME} is not set")

def _call_plugin_service(self, spec_request: SpecRequest, endpoint: str) -> ApplySpec:
response = None
try:
response = requests.post(
f"{self._plugin_service_uri}{endpoint}",
json=spec_request.dict(),
headers={"accept": "application/json", "Content-Type": "application/json"},
timeout=PLUGIN_REQUEST_TIMEOUT,
)
response.raise_for_status()
spec_json = json.loads(response.text)
Comment thread
Nadine-H marked this conversation as resolved.
Outdated
return spec_json
except requests.exceptions.ConnectionError as e:
logger.error(
f"Could not connect to plugin service at {self._plugin_service_uri}: %s", e
)
raise e
except requests.RequestException as e:
logger.error("Request to the plugin service failed: %s", e)
if response:
logger.error(f"Error response from plugin service:\n{response.text}")
raise e
except ValidationError as e:
# Received 200 code but response body is invalid
logger.exception(
f"Plugin service returned invalid response:\n{response.text if response else None}"
)
raise e

def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec:
spec_request = RunSpecRequest(user=user, project=project, spec=spec)
spec_json = self._call_plugin_service(spec_request, "/apply_policies/on_run_apply")
return RunSpec(**spec_json)

def on_fleet_apply(self, user: str, project: str, spec: FleetSpec) -> FleetSpec:
spec_request = FleetSpecRequest(user=user, project=project, spec=spec)
spec_json = self._call_plugin_service(spec_request, "/apply_policies/on_fleet_apply")
return FleetSpec(**spec_json)

def on_volume_apply(self, user: str, project: str, spec: VolumeSpec) -> VolumeSpec:
spec_request = VolumeSpecRequest(user=user, project=project, spec=spec)
spec_json = self._call_plugin_service(spec_request, "/apply_policies/on_volume_apply")
return VolumeSpec(**spec_json)

def on_gateway_apply(self, user: str, project: str, spec: GatewaySpec) -> GatewaySpec:
spec_request = GatewaySpecRequest(user=user, project=project, spec=spec)
spec_json = self._call_plugin_service(spec_request, "/apply_policies/on_gateway_apply")
return GatewaySpec(**spec_json)


class RESTPlugin(Plugin):
def get_apply_policies(self) -> list[ApplyPolicy]:
return [CustomApplyPolicy()]
106 changes: 79 additions & 27 deletions src/tests/_internal/server/services/test_plugins.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
from importlib import import_module
from importlib.metadata import EntryPoint
from unittest.mock import MagicMock, patch

import pytest

from dstack._internal.server.services.plugins import _PLUGINS, load_plugins
from dstack.plugins import Plugin
from dstack.plugins.builtin.rest_plugin import RESTPlugin


class DummyPlugin1(Plugin):
Expand All @@ -30,55 +32,105 @@ def clear_plugins():
class TestLoadPlugins:
@patch("dstack._internal.server.services.plugins.entry_points")
@patch("dstack._internal.server.services.plugins.import_module")
def test_load_single_plugin(self, mock_import_module, mock_entry_points, caplog):
@pytest.mark.parametrize(
["plugin_name", "plugin_module_path", "plugin_class"],
[
("plugin1", "dummy.plugins", DummyPlugin1),
("rest_plugin", "dstack.plugins.builtin.rest_plugin", RESTPlugin),
],
)
def test_load_single_plugin(
self,
mock_import_module,
mock_entry_points,
caplog,
plugin_name,
plugin_module_path,
plugin_class,
):
mock_entry_points.return_value = [
EntryPoint(
name="plugin1",
value="dummy.plugins:DummyPlugin1",
name=plugin_name,
value=f"{plugin_module_path}:{plugin_class.__name__}",
group="dstack.plugins",
)
]
mock_module = MagicMock()
mock_module.DummyPlugin1 = DummyPlugin1
mock_import_module.return_value = mock_module
setattr(mock_module, plugin_class.__name__, plugin_class)
# if it's a built-in plugin, do the real import
mock_import_module.side_effect = (
lambda module_path: import_module(module_path)
if module_path.startswith("dstack.plugins.builtin")
else mock_module
)

with caplog.at_level(logging.INFO):
load_plugins(["plugin1"])
load_plugins([plugin_name])

assert len(_PLUGINS) == 1
assert isinstance(_PLUGINS[0], DummyPlugin1)
assert isinstance(_PLUGINS[0], plugin_class)
mock_entry_points.assert_called_once_with(group="dstack.plugins")
mock_import_module.assert_called_once_with("dummy.plugins")
assert "Loaded plugin plugin1" in caplog.text
mock_import_module.assert_called_once_with(plugin_module_path)
assert f"Loaded plugin {plugin_name}" in caplog.text

@patch("dstack._internal.server.services.plugins.entry_points")
@patch("dstack._internal.server.services.plugins.import_module")
def test_load_multiple_plugins(self, mock_import_module, mock_entry_points, caplog):
mock_entry_points.return_value = [
EntryPoint(
name="plugin1",
value="dummy.plugins:DummyPlugin1",
group="dstack.plugins",
@pytest.mark.parametrize(
["plugin_names", "plugin_module_paths", "plugin_classes"],
[
(
["plugin1", "plugin2"],
["dummy.plugins", "dummy.plugins"],
[DummyPlugin1, DummyPlugin2],
),
(
["plugin1", "plugin2", "rest_plugin"],
["dummy.plugins", "dummy.plugins", "dstack.plugins.builtin.rest_plugin"],
[DummyPlugin1, DummyPlugin2, RESTPlugin],
),
],
ids=["multiple_plugins_without_builtin_plugin", "multiple_plugins_with_builtin_plugin"],
)
def test_load_multiple_plugins(
self,
mock_import_module,
mock_entry_points,
caplog,
plugin_names,
plugin_module_paths,
plugin_classes,
):
mock_entry_points.return_value = [
EntryPoint(
name="plugin2",
value="dummy.plugins:DummyPlugin2",
name=plugin_name,
value=f"{plugin_module_path}:{plugin_class.__name__}",
group="dstack.plugins",
),
)
for plugin_name, plugin_module_path, plugin_class in zip(
plugin_names, plugin_module_paths, plugin_classes
)
]
mock_module = MagicMock()
mock_module.DummyPlugin1 = DummyPlugin1
mock_module.DummyPlugin2 = DummyPlugin2
mock_import_module.return_value = mock_module

for plugin_class, plugin_module_path in zip(plugin_classes, plugin_module_paths):
if not plugin_module_path.startswith("dstack.plugins.builtin"):
setattr(mock_module, plugin_class.__name__, plugin_class)

mock_import_module.side_effect = (
lambda module_path: import_module(module_path)
if module_path.startswith("dstack.plugins.builtin")
else mock_module
)

with caplog.at_level(logging.INFO):
load_plugins(["plugin1", "plugin2"])
load_plugins(plugin_names)

assert len(_PLUGINS) == len(plugin_names)
for i, plugin_class in enumerate(plugin_classes):
assert isinstance(_PLUGINS[i], plugin_class)

assert len(_PLUGINS) == 2
assert isinstance(_PLUGINS[0], DummyPlugin1)
assert isinstance(_PLUGINS[1], DummyPlugin2)
assert "Loaded plugin plugin1" in caplog.text
assert "Loaded plugin plugin2" in caplog.text
for plugin_name in plugin_names:
assert f"Loaded plugin {plugin_name}" in caplog.text

@patch("dstack._internal.server.services.plugins.entry_points")
@patch("dstack._internal.server.services.plugins.import_module")
Expand Down
2 changes: 2 additions & 0 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import pytest

from dstack._internal.server.testing.conf import session, test_db # noqa: F401


def pytest_configure(config):
config.addinivalue_line("markers", "ui: mark test as testing UI to run only with --runui")
Expand Down
Empty file added src/tests/plugins/__init__.py
Empty file.
Loading