diff --git a/docs/docs/guides/plugins.md b/docs/docs/guides/plugins.md index ae8c287b95..ac87aa10bc 100644 --- a/docs/docs/guides/plugins.md +++ b/docs/docs/guides/plugins.md @@ -113,4 +113,14 @@ class ExamplePolicy(ApplyPolicy): -For more information on the plugin development, see the [plugin example](https://github.com/dstackai/dstack/tree/master/examples/plugins/example_plugin). +## Built-in Plugins + +### REST Plugin +`rest_plugin` is a builtin `dstack` plugin that allows writing your custom plugins as API servers, so you don't need to install plugins as Python packages. + +Plugins implemented as API servers have advantages over plugins implemented as Python packages in some cases: +* No dependency conflicts with `dstack`. +* You can use any programming language. +* If you run the `dstack` server via Docker, you don't need to extend the `dstack` server image with plugins or map them via volumes. + +To get started, check out the [plugin server example](/examples/plugins/example_plugin_server/README.md). diff --git a/examples/plugins/example_plugin_server/.python-version b/examples/plugins/example_plugin_server/.python-version new file mode 100644 index 0000000000..2c0733315e --- /dev/null +++ b/examples/plugins/example_plugin_server/.python-version @@ -0,0 +1 @@ +3.11 diff --git a/examples/plugins/example_plugin_server/README.md b/examples/plugins/example_plugin_server/README.md new file mode 100644 index 0000000000..840184d593 --- /dev/null +++ b/examples/plugins/example_plugin_server/README.md @@ -0,0 +1,30 @@ +## Overview + +If you wish to hook up your own plugin server through `dstack` builtin `rest_plugin`, here's a basic example on how to do so. + +## Steps + + +1. Install required dependencies for the plugin server: + + ```bash + uv sync + ``` + +1. Start the plugin server locally: + + ```bash + fastapi dev app/main.py + ``` + +1. Enable `rest_plugin` in `server/config.yaml`: + + ```yaml + plugins: + - rest_plugin + ``` + +1. Point the `dstack` server to your plugin server: + ```bash + export DSTACK_PLUGIN_SERVICE_URI=http://127.0.0.1:8000 + ``` diff --git a/examples/plugins/example_plugin_server/app/__init__.py b/examples/plugins/example_plugin_server/app/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/plugins/example_plugin_server/app/main.py b/examples/plugins/example_plugin_server/app/main.py new file mode 100644 index 0000000000..f20d6090cc --- /dev/null +++ b/examples/plugins/example_plugin_server/app/main.py @@ -0,0 +1,56 @@ +import logging + +from fastapi import FastAPI + +from app.utils import configure_logging +from dstack.plugins.builtin.rest_plugin import ( + FleetSpecRequest, + FleetSpecResponse, + GatewaySpecRequest, + GatewaySpecResponse, + RunSpecRequest, + RunSpecResponse, + VolumeSpecRequest, + VolumeSpecResponse, +) + +configure_logging() +logger = logging.getLogger(__name__) + +app = FastAPI() + + +@app.post("/apply_policies/on_run_apply") +async def on_run_apply(request: RunSpecRequest) -> RunSpecResponse: + logger.info( + f"Received run spec request from user {request.user} and project {request.project}" + ) + response = RunSpecResponse(spec=request.spec, error=None) + return response + + +@app.post("/apply_policies/on_fleet_apply") +async def on_fleet_apply(request: FleetSpecRequest) -> FleetSpecResponse: + logger.info( + f"Received fleet spec request from user {request.user} and project {request.project}" + ) + response = FleetSpecResponse(request.spec, error=None) + return response + + +@app.post("/apply_policies/on_volume_apply") +async def on_volume_apply(request: VolumeSpecRequest) -> VolumeSpecResponse: + logger.info( + f"Received volume spec request from user {request.user} and project {request.project}" + ) + response = VolumeSpecResponse(request.spec, error=None) + return response + + +@app.post("/apply_policies/on_gateway_apply") +async def on_gateway_apply(request: GatewaySpecRequest) -> GatewaySpecResponse: + logger.info( + f"Received gateway spec request from user {request.user} and project {request.project}" + ) + response = GatewaySpecResponse(request.spec, error=None) + return response diff --git a/examples/plugins/example_plugin_server/app/utils.py b/examples/plugins/example_plugin_server/app/utils.py new file mode 100644 index 0000000000..b07406682f --- /dev/null +++ b/examples/plugins/example_plugin_server/app/utils.py @@ -0,0 +1,7 @@ +import logging +import os + + +def configure_logging(): + log_level = os.getenv("LOG_LEVEL", "INFO").upper() + logging.basicConfig(level=log_level) diff --git a/examples/plugins/example_plugin_server/pyproject.toml b/examples/plugins/example_plugin_server/pyproject.toml new file mode 100644 index 0000000000..abf5c63dba --- /dev/null +++ b/examples/plugins/example_plugin_server/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dstack-plugin-server" +version = "0.1.0" +description = "Example plugin server" +readme = "README.md" +requires-python = ">=3.11" +dependencies = [ + "fastapi[standard]>=0.115.12", + "dstack>=0.19.8" +] diff --git a/src/dstack/_internal/server/services/plugins.py b/src/dstack/_internal/server/services/plugins.py index a8e5be8a05..99699ef731 100644 --- a/src/dstack/_internal/server/services/plugins.py +++ b/src/dstack/_internal/server/services/plugins.py @@ -1,5 +1,6 @@ import itertools from importlib import import_module +from typing import Dict from backports.entry_points_selectable import entry_points # backport for Python 3.9 @@ -12,50 +13,80 @@ _PLUGINS: list[Plugin] = [] +_BUILTIN_PLUGINS: Dict[str, str] = {"rest_plugin": "dstack.plugins.builtin.rest_plugin:RESTPlugin"} -def load_plugins(enabled_plugins: list[str]): - _PLUGINS.clear() - plugins_entrypoints = entry_points(group="dstack.plugins") - plugins_to_load = enabled_plugins.copy() - for entrypoint in plugins_entrypoints: - if entrypoint.name not in enabled_plugins: - logger.info( - ("Found not enabled plugin %s. Plugin will not be loaded."), - entrypoint.name, - ) - continue + +class PluginEntrypoint: + def __init__(self, name: str, import_path: str, is_builtin: bool = False): + self.name = name + self.import_path = import_path + self.is_builtin = is_builtin + + def load(self): + module_path, _, class_name = self.import_path.partition(":") try: - module_path, _, class_name = entrypoint.value.partition(":") module = import_module(module_path) + plugin_class = getattr(module, class_name, None) + if plugin_class is None: + logger.warning( + ("Failed to load plugin %s: plugin class %s not found in module %s."), + self.name, + class_name, + module_path, + ) + return None + if not issubclass(plugin_class, Plugin): + logger.warning( + ("Failed to load plugin %s: plugin class %s is not a subclass of Plugin."), + self.name, + class_name, + ) + return None + return plugin_class() except ImportError: logger.warning( ( "Failed to load plugin %s when importing %s." " Ensure the module is on the import path." ), - entrypoint.name, - entrypoint.value, + self.name, + self.import_path, ) - continue - plugin_class = getattr(module, class_name, None) - if plugin_class is None: - logger.warning( - ("Failed to load plugin %s: plugin class %s not found in module %s."), + return None + + +def load_plugins(enabled_plugins: list[str]): + _PLUGINS.clear() + entrypoints: dict[str, PluginEntrypoint] = {} + plugins_to_load = enabled_plugins.copy() + for entrypoint in entry_points(group="dstack.plugins"): + if entrypoint.name not in enabled_plugins: + logger.info( + ("Found not enabled plugin %s. Plugin will not be loaded."), entrypoint.name, - class_name, - module_path, ) continue - if not issubclass(plugin_class, Plugin): - logger.warning( - ("Failed to load plugin %s: plugin class %s is not a subclass of Plugin."), - entrypoint.name, - class_name, + else: + entrypoints[entrypoint.name] = PluginEntrypoint( + entrypoint.name, entrypoint.value, is_builtin=False ) - continue - plugins_to_load.remove(entrypoint.name) - _PLUGINS.append(plugin_class()) - logger.info("Loaded plugin %s", entrypoint.name) + + for name, import_path in _BUILTIN_PLUGINS.items(): + if name not in enabled_plugins: + logger.info( + ("Found not enabled builtin plugin %s. Plugin will not be loaded."), + name, + ) + else: + entrypoints[name] = PluginEntrypoint(name, import_path, is_builtin=True) + + for plugin_name, plugin_entrypoint in entrypoints.items(): + plugin_instance = plugin_entrypoint.load() + if plugin_instance is not None: + _PLUGINS.append(plugin_instance) + plugins_to_load.remove(plugin_name) + logger.info("Loaded plugin %s", plugin_name) + if plugins_to_load: logger.warning("Enabled plugins not found: %s", plugins_to_load) diff --git a/src/dstack/plugins/builtin/__init__.py b/src/dstack/plugins/builtin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/dstack/plugins/builtin/rest_plugin/__init__.py b/src/dstack/plugins/builtin/rest_plugin/__init__.py new file mode 100644 index 0000000000..4d5e0fe14a --- /dev/null +++ b/src/dstack/plugins/builtin/rest_plugin/__init__.py @@ -0,0 +1,18 @@ +# ruff: noqa: F401 +from dstack.plugins.builtin.rest_plugin._models import ( + FleetSpecRequest, + FleetSpecResponse, + GatewaySpecRequest, + GatewaySpecResponse, + RunSpecRequest, + RunSpecResponse, + SpecApplyRequest, + SpecApplyResponse, + VolumeSpecRequest, + VolumeSpecResponse, +) +from dstack.plugins.builtin.rest_plugin._plugin import ( + PLUGIN_SERVICE_URI_ENV_VAR_NAME, + CustomApplyPolicy, + RESTPlugin, +) diff --git a/src/dstack/plugins/builtin/rest_plugin/_models.py b/src/dstack/plugins/builtin/rest_plugin/_models.py new file mode 100644 index 0000000000..ee3a042464 --- /dev/null +++ b/src/dstack/plugins/builtin/rest_plugin/_models.py @@ -0,0 +1,48 @@ +from typing import Generic, Optional, TypeVar + +from pydantic import BaseModel, Field +from typing_extensions import Annotated + +from dstack._internal.core.models.fleets import FleetSpec +from dstack._internal.core.models.gateways import GatewaySpec +from dstack._internal.core.models.runs import RunSpec +from dstack._internal.core.models.volumes import VolumeSpec + +SpecType = TypeVar("SpecType", RunSpec, FleetSpec, VolumeSpec, GatewaySpec) + + +class SpecApplyRequest(BaseModel, Generic[SpecType]): + user: Annotated[str, Field(description="The name of the user making the apply request")] + project: Annotated[str, Field(description="The name of the project the request is for")] + spec: Annotated[SpecType, Field(description="The spec to be applied")] + + # Override dict() to remove __orig_class__ attribute and avoid "TypeError: Object of type _GenericAlias is not JSON serializable" + # error. This issue doesn't happen though when running the code in pytest, only when running the server. + def dict(self, *args, **kwargs): + d = super().dict(*args, **kwargs) + d.pop("__orig_class__", None) + return d + + +RunSpecRequest = SpecApplyRequest[RunSpec] +FleetSpecRequest = SpecApplyRequest[FleetSpec] +VolumeSpecRequest = SpecApplyRequest[VolumeSpec] +GatewaySpecRequest = SpecApplyRequest[GatewaySpec] + + +class SpecApplyResponse(BaseModel, Generic[SpecType]): + spec: Annotated[ + SpecType, + Field( + description="The spec to apply, original spec if error otherwise original or mutated by plugin service if approved" + ), + ] + error: Annotated[ + Optional[str], Field(description="Error message if request is rejected", min_length=1) + ] = None + + +RunSpecResponse = SpecApplyResponse[RunSpec] +FleetSpecResponse = SpecApplyResponse[FleetSpec] +VolumeSpecResponse = SpecApplyResponse[VolumeSpec] +GatewaySpecResponse = SpecApplyResponse[GatewaySpec] diff --git a/src/dstack/plugins/builtin/rest_plugin/_plugin.py b/src/dstack/plugins/builtin/rest_plugin/_plugin.py new file mode 100644 index 0000000000..48735d15ff --- /dev/null +++ b/src/dstack/plugins/builtin/rest_plugin/_plugin.py @@ -0,0 +1,127 @@ +import json +import os +from typing import Type + +import requests +from pydantic import ValidationError + +from dstack._internal.core.errors import ServerClientError +from dstack._internal.core.models.fleets import FleetSpec +from dstack._internal.core.models.gateways import GatewaySpec +from dstack._internal.core.models.volumes import VolumeSpec +from dstack.plugins import ApplyPolicy, ApplySpec, Plugin, RunSpec, get_plugin_logger +from dstack.plugins.builtin.rest_plugin import ( + FleetSpecRequest, + FleetSpecResponse, + GatewaySpecRequest, + GatewaySpecResponse, + RunSpecRequest, + RunSpecResponse, + SpecApplyRequest, + SpecApplyResponse, + VolumeSpecRequest, + VolumeSpecResponse, +) + +logger = get_plugin_logger(__name__) + +PLUGIN_SERVICE_URI_ENV_VAR_NAME = "DSTACK_PLUGIN_SERVICE_URI" +PLUGIN_REQUEST_TIMEOUT_SEC = 8 + + +class CustomApplyPolicy(ApplyPolicy): + def __init__(self): + self._plugin_service_uri = os.getenv(PLUGIN_SERVICE_URI_ENV_VAR_NAME) + logger.info(f"Found plugin service at {self._plugin_service_uri}") + if not self._plugin_service_uri: + logger.error( + f"Cannot create policy because {PLUGIN_SERVICE_URI_ENV_VAR_NAME} is not set" + ) + raise ServerClientError(f"{PLUGIN_SERVICE_URI_ENV_VAR_NAME} is not set") + + def _check_request_rejected(self, response: SpecApplyResponse): + if response.error is not None: + logger.error(f"Plugin service rejected apply request: {response.error}") + raise ServerClientError(f"Apply request rejected: {response.error}") + + def _call_plugin_service(self, spec_request: SpecApplyRequest, endpoint: str) -> ApplySpec: + response = None + try: + response = requests.post( + f"{self._plugin_service_uri}{endpoint}", + json=spec_request.dict(), + headers={"accept": "application/json", "Content-Type": "application/json"}, + timeout=PLUGIN_REQUEST_TIMEOUT_SEC, + ) + response.raise_for_status() + spec_json = json.loads(response.text) + return spec_json + except requests.exceptions.ConnectionError as e: + logger.error( + f"Could not connect to plugin service at {self._plugin_service_uri}: %s", e + ) + raise ServerClientError( + f"Could not connect to plugin service at {self._plugin_service_uri}" + ) + except requests.RequestException as e: + logger.error("Request to the plugin service failed: %s", e) + raise ServerClientError("Request to the plugin service failed") + + def _on_apply( + self, + request_cls: Type[SpecApplyRequest], + response_cls: Type[SpecApplyResponse], + endpoint: str, + user: str, + project: str, + spec: ApplySpec, + ) -> ApplySpec: + try: + spec_request = request_cls(user=user, project=project, spec=spec) + spec_json = self._call_plugin_service(spec_request, endpoint) + response = response_cls(**spec_json) + self._check_request_rejected(response) + return response.spec + except ValidationError: + logger.error(f"Plugin service returned invalid response:\n{spec_json}") + raise ServerClientError("Plugin service returned an invalid response") + + def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec: + return self._on_apply( + RunSpecRequest, RunSpecResponse, "/apply_policies/on_run_apply", user, project, spec + ) + + def on_fleet_apply(self, user: str, project: str, spec: FleetSpec) -> FleetSpec: + return self._on_apply( + FleetSpecRequest, + FleetSpecResponse, + "/apply_policies/on_fleet_apply", + user, + project, + spec, + ) + + def on_volume_apply(self, user: str, project: str, spec: VolumeSpec) -> VolumeSpec: + return self._on_apply( + VolumeSpecRequest, + VolumeSpecResponse, + "/apply_policies/on_volume_apply", + user, + project, + spec, + ) + + def on_gateway_apply(self, user: str, project: str, spec: GatewaySpec) -> GatewaySpec: + return self._on_apply( + GatewaySpecRequest, + GatewaySpecResponse, + "/apply_policies/on_gateway_apply", + user, + project, + spec, + ) + + +class RESTPlugin(Plugin): + def get_apply_policies(self) -> list[ApplyPolicy]: + return [CustomApplyPolicy()] diff --git a/src/tests/_internal/server/services/test_plugins.py b/src/tests/_internal/server/services/test_plugins.py index 460764c4d6..ca5f0bfac6 100644 --- a/src/tests/_internal/server/services/test_plugins.py +++ b/src/tests/_internal/server/services/test_plugins.py @@ -1,4 +1,5 @@ import logging +from importlib import import_module from importlib.metadata import EntryPoint from unittest.mock import MagicMock, patch @@ -6,6 +7,7 @@ from dstack._internal.server.services.plugins import _PLUGINS, load_plugins from dstack.plugins import Plugin +from dstack.plugins.builtin.rest_plugin import RESTPlugin class DummyPlugin1(Plugin): @@ -30,55 +32,105 @@ def clear_plugins(): class TestLoadPlugins: @patch("dstack._internal.server.services.plugins.entry_points") @patch("dstack._internal.server.services.plugins.import_module") - def test_load_single_plugin(self, mock_import_module, mock_entry_points, caplog): + @pytest.mark.parametrize( + ["plugin_name", "plugin_module_path", "plugin_class"], + [ + ("plugin1", "dummy.plugins", DummyPlugin1), + ("rest_plugin", "dstack.plugins.builtin.rest_plugin", RESTPlugin), + ], + ) + def test_load_single_plugin( + self, + mock_import_module, + mock_entry_points, + caplog, + plugin_name, + plugin_module_path, + plugin_class, + ): mock_entry_points.return_value = [ EntryPoint( - name="plugin1", - value="dummy.plugins:DummyPlugin1", + name=plugin_name, + value=f"{plugin_module_path}:{plugin_class.__name__}", group="dstack.plugins", ) ] mock_module = MagicMock() - mock_module.DummyPlugin1 = DummyPlugin1 - mock_import_module.return_value = mock_module + setattr(mock_module, plugin_class.__name__, plugin_class) + # if it's a built-in plugin, do the real import + mock_import_module.side_effect = ( + lambda module_path: import_module(module_path) + if module_path.startswith("dstack.plugins.builtin") + else mock_module + ) with caplog.at_level(logging.INFO): - load_plugins(["plugin1"]) + load_plugins([plugin_name]) assert len(_PLUGINS) == 1 - assert isinstance(_PLUGINS[0], DummyPlugin1) + assert isinstance(_PLUGINS[0], plugin_class) mock_entry_points.assert_called_once_with(group="dstack.plugins") - mock_import_module.assert_called_once_with("dummy.plugins") - assert "Loaded plugin plugin1" in caplog.text + mock_import_module.assert_called_once_with(plugin_module_path) + assert f"Loaded plugin {plugin_name}" in caplog.text @patch("dstack._internal.server.services.plugins.entry_points") @patch("dstack._internal.server.services.plugins.import_module") - def test_load_multiple_plugins(self, mock_import_module, mock_entry_points, caplog): - mock_entry_points.return_value = [ - EntryPoint( - name="plugin1", - value="dummy.plugins:DummyPlugin1", - group="dstack.plugins", + @pytest.mark.parametrize( + ["plugin_names", "plugin_module_paths", "plugin_classes"], + [ + ( + ["plugin1", "plugin2"], + ["dummy.plugins", "dummy.plugins"], + [DummyPlugin1, DummyPlugin2], + ), + ( + ["plugin1", "plugin2", "rest_plugin"], + ["dummy.plugins", "dummy.plugins", "dstack.plugins.builtin.rest_plugin"], + [DummyPlugin1, DummyPlugin2, RESTPlugin], ), + ], + ids=["multiple_plugins_without_builtin_plugin", "multiple_plugins_with_builtin_plugin"], + ) + def test_load_multiple_plugins( + self, + mock_import_module, + mock_entry_points, + caplog, + plugin_names, + plugin_module_paths, + plugin_classes, + ): + mock_entry_points.return_value = [ EntryPoint( - name="plugin2", - value="dummy.plugins:DummyPlugin2", + name=plugin_name, + value=f"{plugin_module_path}:{plugin_class.__name__}", group="dstack.plugins", - ), + ) + for plugin_name, plugin_module_path, plugin_class in zip( + plugin_names, plugin_module_paths, plugin_classes + ) ] mock_module = MagicMock() - mock_module.DummyPlugin1 = DummyPlugin1 - mock_module.DummyPlugin2 = DummyPlugin2 - mock_import_module.return_value = mock_module + + for plugin_class, plugin_module_path in zip(plugin_classes, plugin_module_paths): + if not plugin_module_path.startswith("dstack.plugins.builtin"): + setattr(mock_module, plugin_class.__name__, plugin_class) + + mock_import_module.side_effect = ( + lambda module_path: import_module(module_path) + if module_path.startswith("dstack.plugins.builtin") + else mock_module + ) with caplog.at_level(logging.INFO): - load_plugins(["plugin1", "plugin2"]) + load_plugins(plugin_names) + + assert len(_PLUGINS) == len(plugin_names) + for i, plugin_class in enumerate(plugin_classes): + assert isinstance(_PLUGINS[i], plugin_class) - assert len(_PLUGINS) == 2 - assert isinstance(_PLUGINS[0], DummyPlugin1) - assert isinstance(_PLUGINS[1], DummyPlugin2) - assert "Loaded plugin plugin1" in caplog.text - assert "Loaded plugin plugin2" in caplog.text + for plugin_name in plugin_names: + assert f"Loaded plugin {plugin_name}" in caplog.text @patch("dstack._internal.server.services.plugins.entry_points") @patch("dstack._internal.server.services.plugins.import_module") diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 9b0ef039a6..202d59058f 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -2,6 +2,8 @@ import pytest +from dstack._internal.server.testing.conf import postgres_container, session, test_db # noqa: F401 + def pytest_configure(config): config.addinivalue_line("markers", "ui: mark test as testing UI to run only with --runui") diff --git a/src/tests/plugins/__init__.py b/src/tests/plugins/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/tests/plugins/test_rest_plugin.py b/src/tests/plugins/test_rest_plugin.py new file mode 100644 index 0000000000..0b725ff95b --- /dev/null +++ b/src/tests/plugins/test_rest_plugin.py @@ -0,0 +1,207 @@ +import json +import os +from contextlib import nullcontext as does_not_raise +from unittest import mock +from unittest.mock import Mock + +import pytest +import pytest_asyncio +import requests +from pydantic import parse_obj_as +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import ServerClientError, ServerError +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.configurations import ServiceConfiguration +from dstack._internal.core.models.fleets import FleetConfiguration, FleetSpec +from dstack._internal.core.models.gateways import GatewayConfiguration, GatewaySpec +from dstack._internal.core.models.profiles import Profile +from dstack._internal.core.models.resources import Range +from dstack._internal.core.models.runs import RunSpec +from dstack._internal.core.models.volumes import VolumeSpec +from dstack._internal.server.models import ProjectModel +from dstack._internal.server.services import encryption as encryption +from dstack._internal.server.testing.common import ( + create_project, + create_repo, + create_user, + get_fleet_spec, + get_run_spec, + get_volume_configuration, +) +from dstack.plugins.builtin.rest_plugin import PLUGIN_SERVICE_URI_ENV_VAR_NAME, CustomApplyPolicy + + +async def create_run_spec( + session: AsyncSession, + project: ProjectModel, + replicas: str = 1, +) -> RunSpec: + repo = await create_repo(session=session, project_id=project.id) + run_name = "test-run" + profile = Profile(name="test-profile") + spec = get_run_spec( + repo_id=repo.name, + run_name=run_name, + profile=profile, + configuration=ServiceConfiguration( + commands=["echo hello"], port=8000, replicas=parse_obj_as(Range[int], replicas) + ), + ) + return spec + + +async def create_fleet_spec(): + name = "test-fleet-spec" + fleet_conf = FleetConfiguration(name=name) + return get_fleet_spec(conf=fleet_conf) + + +async def create_volume_spec(): + return VolumeSpec(configuration=get_volume_configuration()) + + +async def create_gateway_spec(): + configuration = GatewayConfiguration( + name="test-gateway-config", + backend=BackendType.AWS, + region="us-central", + ) + return GatewaySpec(configuration=configuration) + + +@pytest_asyncio.fixture +async def project(session): + return await create_project(session=session) + + +@pytest_asyncio.fixture +async def user(session): + return await create_user(session=session) + + +@pytest_asyncio.fixture +async def spec(request, session, project): + if request.param == "run_spec": + return await create_run_spec(session, project) + elif request.param == "fleet_spec": + return await create_fleet_spec() + elif request.param == "volume_spec": + return await create_volume_spec() + elif request.param == "gateway_spec": + return await create_gateway_spec() + else: + raise ValueError(f"Unknown spec fixture: {request.param}") + + +class TestRESTPlugin: + @pytest.mark.asyncio + async def test_on_run_apply_plugin_service_uri_not_set(self): + with pytest.raises(ServerError): + CustomApplyPolicy() + + @pytest.mark.asyncio + @mock.patch.dict(os.environ, {PLUGIN_SERVICE_URI_ENV_VAR_NAME: "http://mock"}) + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.parametrize( + "spec", ["run_spec", "fleet_spec", "volume_spec", "gateway_spec"], indirect=True + ) + async def test_on_apply_plugin_service_returns_mutated_spec( + self, test_db, user, project, spec + ): + policy = CustomApplyPolicy() + mock_response = Mock() + response_dict = {"spec": spec.dict(), "error": None} + + if isinstance(spec, (RunSpec, FleetSpec)): + response_dict["spec"]["profile"]["tags"] = {"env": "test", "team": "qa"} + else: + response_dict["spec"]["configuration_path"] = "/path/to/something" + + mock_response.text = json.dumps(response_dict) + mock_response.raise_for_status = Mock() + with mock.patch("requests.post", return_value=mock_response): + result = policy.on_apply(user=user.name, project=project.name, spec=spec) + assert result == type(spec)(**response_dict["spec"]) + + @pytest.mark.asyncio + @mock.patch.dict(os.environ, {PLUGIN_SERVICE_URI_ENV_VAR_NAME: "http://mock"}) + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.parametrize( + "spec", ["run_spec", "fleet_spec", "volume_spec", "gateway_spec"], indirect=True + ) + async def test_on_apply_plugin_service_call_fails(self, test_db, user, project, spec): + policy = CustomApplyPolicy() + with mock.patch("requests.post", side_effect=requests.RequestException("fail")): + with pytest.raises(ServerClientError): + policy.on_apply(user=user.name, project=project.name, spec=spec) + + @pytest.mark.asyncio + @mock.patch.dict(os.environ, {PLUGIN_SERVICE_URI_ENV_VAR_NAME: "http://mock"}) + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.parametrize( + "spec", ["run_spec", "fleet_spec", "volume_spec", "gateway_spec"], indirect=True + ) + async def test_on_apply_plugin_service_connection_fails(self, test_db, user, project, spec): + policy = CustomApplyPolicy() + with mock.patch( + "requests.post", side_effect=requests.ConnectionError("Failed to connect") + ): + with pytest.raises(ServerClientError): + policy.on_apply(user=user.name, project=project.name, spec=spec) + + @pytest.mark.asyncio + @mock.patch.dict(os.environ, {PLUGIN_SERVICE_URI_ENV_VAR_NAME: "http://mock"}) + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.parametrize( + "spec", ["run_spec", "fleet_spec", "volume_spec", "gateway_spec"], indirect=True + ) + async def test_on_apply_plugin_service_returns_invalid_spec( + self, test_db, user, project, spec + ): + policy = CustomApplyPolicy() + mock_response = Mock() + mock_response.text = json.dumps({"invalid-key": "abc"}) + mock_response.raise_for_status = Mock() + with mock.patch("requests.post", return_value=mock_response): + with pytest.raises(ServerClientError): + policy.on_apply(user.name, project=project.name, spec=spec) + + @pytest.mark.asyncio + @mock.patch.dict(os.environ, {PLUGIN_SERVICE_URI_ENV_VAR_NAME: "http://mock"}) + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.parametrize( + "spec", ["run_spec", "fleet_spec", "volume_spec", "gateway_spec"], indirect=True + ) + @pytest.mark.parametrize( + ("error", "expectation"), + [ + pytest.param(None, does_not_raise(), id="error_none"), + pytest.param( + "", + pytest.raises( + ServerClientError, match="Plugin service returned an invalid response" + ), + id="error_empty_str", + ), + pytest.param( + "validation failed", + pytest.raises( + ServerClientError, match="Apply request rejected: validation failed" + ), + id="error_non_empty_str", + ), + ], + ) + async def test_on_apply_plugin_service_error_handling( + self, test_db, user, project, spec, error, expectation + ): + policy = CustomApplyPolicy() + mock_response = Mock() + response_dict = {"spec": spec.dict(), "error": error} + mock_response.text = json.dumps(response_dict) + mock_response.raise_for_status = Mock() + with mock.patch("requests.post", return_value=mock_response): + with expectation: + result = policy.on_apply(user=user.name, project=project.name, spec=spec) + assert result == type(spec)(**response_dict["spec"])