From a756cf489c98bac0642080952c93d981cd0e6d9d Mon Sep 17 00:00:00 2001 From: Nadine Handal Date: Tue, 13 May 2025 14:26:10 -0400 Subject: [PATCH 1/9] Add REST plugin --- pyproject.toml | 6 ++ src/plugins/rest_plugin/README.md | 1 + src/plugins/rest_plugin/__init__.py | 0 src/plugins/rest_plugin/pyproject.toml | 14 +++ src/plugins/rest_plugin/src/__init__.py | 0 src/plugins/rest_plugin/src/rest_plugin.py | 63 +++++++++++++ src/tests/plugins/__init__.py | 0 src/tests/plugins/test_rest_plugin.py | 102 +++++++++++++++++++++ 8 files changed, 186 insertions(+) create mode 100644 src/plugins/rest_plugin/README.md create mode 100644 src/plugins/rest_plugin/__init__.py create mode 100644 src/plugins/rest_plugin/pyproject.toml create mode 100644 src/plugins/rest_plugin/src/__init__.py create mode 100644 src/plugins/rest_plugin/src/rest_plugin.py create mode 100644 src/tests/plugins/__init__.py create mode 100644 src/tests/plugins/test_rest_plugin.py diff --git a/pyproject.toml b/pyproject.toml index fac0183e28..c8b7dba702 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,9 @@ pattern = '\s*|]*>\s*|\s*|]*>\s*|\ replacement = '' ignore-case = true +[tool.uv.workspace] +members = ["src/plugins/rest_plugin"] + [dependency-groups] dev = [ "build>=1.2.2.post1", @@ -176,3 +179,6 @@ nebius = [ all = [ "dstack[gateway,server,aws,azure,gcp,datacrunch,kubernetes,lambda,nebius,oci]", ] + +[project.entry-points."dstack.plugins"] +rest_plugin = "plugins.rest_plugin.src.rest_plugin:RESTPlugin" diff --git a/src/plugins/rest_plugin/README.md b/src/plugins/rest_plugin/README.md new file mode 100644 index 0000000000..c2c8c48456 --- /dev/null +++ b/src/plugins/rest_plugin/README.md @@ -0,0 +1 @@ +[TODO] \ No newline at end of file diff --git a/src/plugins/rest_plugin/__init__.py b/src/plugins/rest_plugin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/plugins/rest_plugin/pyproject.toml b/src/plugins/rest_plugin/pyproject.toml new file mode 100644 index 0000000000..aee85ec7e2 --- /dev/null +++ b/src/plugins/rest_plugin/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "rest-plugin" +version = "0.1.0" +description = "A dstack plugin that enables validation and mutation of run specifications via REST API" +readme = "README.md" +requires-python = ">=3.9" +dependencies = [] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src"] diff --git a/src/plugins/rest_plugin/src/__init__.py b/src/plugins/rest_plugin/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/plugins/rest_plugin/src/rest_plugin.py b/src/plugins/rest_plugin/src/rest_plugin.py new file mode 100644 index 0000000000..2381590ad1 --- /dev/null +++ b/src/plugins/rest_plugin/src/rest_plugin.py @@ -0,0 +1,63 @@ +import json +import os +import pydantic +import requests +from dstack._internal.core.errors import ServerError +from dstack._internal.core.models.fleets import FleetSpec +from dstack._internal.core.models.gateways import GatewaySpec +from dstack._internal.core.models.volumes import VolumeSpec +from dstack.plugins import ApplyPolicy, Plugin, RunSpec, get_plugin_logger +from dstack.plugins._models import ApplySpec + +logger = get_plugin_logger(__name__) + +PLUGIN_SERVICE_URI_ENV_VAR_NAME = "DSTACK_PLUGIN_SERVICE_URI" + +class PreApplyPolicy(ApplyPolicy): + def __init__(self): + self._plugin_service_uri = os.getenv(PLUGIN_SERVICE_URI_ENV_VAR_NAME) + if not self._plugin_service_uri: + logger.error(f"Cannot create policy as {PLUGIN_SERVICE_URI_ENV_VAR_NAME} is not set") + raise ServerError(f"{PLUGIN_SERVICE_URI_ENV_VAR_NAME} is not set") + + def _call_plugin_service(self, user: str, project: str, spec: ApplySpec, endpoint: str) -> ApplySpec: + # Make request to plugin service with run params + params = { + "user": user, + "project": project, + "spec": spec.json() + } + response = None + try: + response = requests.post(f"{self._plugin_service_uri}/{endpoint}", json=json.dumps(params)) + response.raise_for_status() + spec_json = json.loads(response.text) + spec = RunSpec(**spec_json) + except requests.RequestException as e: + logger.error("Failed to call plugin service: %s", e) + if response: + logger.error(f"Error response from plugin service:\n{response.text}") + logger.info("Returning original run spec") + return spec + except pydantic.ValidationError as e: + # TODO: check response error status and report if plugin service rejected request as invalid + logger.exception(f"Plugin service returned invalid response:\n{response.text if response else None}") + logger.info("Returning original run spec") + return spec + logger.info(f"Using RunSpec from plugin service:\n{spec}") + return spec + def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec: + return self._call_plugin_service(user, project, spec, '/runs/pre_apply') + + def on_fleet_apply(self, user: str, project: str, spec: FleetSpec) -> FleetSpec: + return self._call_plugin_service(user, project, spec, '/fleets/pre_apply') + + def on_volume_apply(self, user: str, project: str, spec: VolumeSpec) -> VolumeSpec: + return self._call_plugin_service(user, project, spec, '/volumes/pre_apply') + + def on_gateway_apply(self, user: str, project: str, spec: GatewaySpec) -> GatewaySpec: + return self._call_plugin_service(user, project, spec, '/gateways/pre_apply') + +class RESTPlugin(Plugin): + def get_apply_policies(self) -> list[ApplyPolicy]: + return [PreApplyPolicy()] diff --git a/src/tests/plugins/__init__.py b/src/tests/plugins/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/tests/plugins/test_rest_plugin.py b/src/tests/plugins/test_rest_plugin.py new file mode 100644 index 0000000000..79beb57088 --- /dev/null +++ b/src/tests/plugins/test_rest_plugin.py @@ -0,0 +1,102 @@ +from dstack._internal.core.errors import ServerError +from dstack._internal.server.models import ProjectModel, UserModel +from plugins.rest_plugin.src.rest_plugin import PreApplyPolicy, PLUGIN_SERVICE_URI_ENV_VAR_NAME +import pytest +from sqlalchemy.ext.asyncio import AsyncSession +from pydantic import parse_obj_as +import os +import json +import requests +from unittest.mock import Mock + +from dstack._internal.core.models.runs import RunSpec +from dstack._internal.core.models.configurations import ServiceConfiguration +from dstack._internal.core.models.profiles import Profile +from dstack._internal.core.models.resources import Range +from dstack._internal.server.testing.common import ( + create_project, + create_user, + create_repo, + get_run_spec, +) +from dstack._internal.server.testing.conf import session, test_db # noqa: F401 +from dstack._internal.server.services import encryption as encryption # import for side-effect +import pytest_asyncio +from unittest import mock + + +async def create_run_spec( + session: AsyncSession, + project: ProjectModel, + replicas: str = 1, +) -> RunSpec: + repo = await create_repo(session=session, project_id=project.id) + run_name = "test-run" + profile = Profile(name="test-profile") + spec = get_run_spec( + repo_id=repo.name, + run_name=run_name, + profile=profile, + configuration=ServiceConfiguration( + commands=["echo hello"], + port=8000, + replicas=parse_obj_as(Range[int], replicas) + ), + ) + return spec + +@pytest_asyncio.fixture +async def project(session): + return await create_project(session=session) + +@pytest_asyncio.fixture +async def user(session): + return await create_user(session=session) + +@pytest_asyncio.fixture +async def run_spec(session, project): + return await create_run_spec(session=session, project=project) + + +class TestRESTPlugin: + @pytest.mark.asyncio + async def test_on_run_apply_plugin_service_uri_not_set(self): + with pytest.raises(ServerError): + policy = PreApplyPolicy() + + @pytest.mark.asyncio + @mock.patch.dict(os.environ, {PLUGIN_SERVICE_URI_ENV_VAR_NAME: "http://mock"}) + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_on_run_apply_plugin_service_returns_mutated_spec(self, test_db, user, project, run_spec): + policy = PreApplyPolicy() + mock_response = Mock() + run_spec_dict = run_spec.dict() + run_spec_dict["profile"]["tags"] = {"env": "test", "team": "qa"} + mock_response.text = json.dumps(run_spec_dict) + mock_response.raise_for_status = Mock() + with mock.patch("requests.post", return_value=mock_response): + result = policy.on_apply(user=user.name, project=project.name, spec=run_spec) + assert result == RunSpec(**run_spec_dict) + + @pytest.mark.asyncio + @mock.patch.dict(os.environ, {PLUGIN_SERVICE_URI_ENV_VAR_NAME: "http://mock"}) + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_on_run_apply_plugin_service_call_fails(self, test_db, user, project, run_spec): + policy = PreApplyPolicy() + with mock.patch("requests.post", side_effect=requests.RequestException("fail")): + result = policy.on_apply(user=user.name, project=project.name, spec=run_spec) + assert result == run_spec + + @pytest.mark.asyncio + @mock.patch.dict(os.environ, {PLUGIN_SERVICE_URI_ENV_VAR_NAME: "http://mock"}) + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_on_run_apply_plugin_service_returns_invalid_spec(self, test_db, user, project, run_spec): + policy = PreApplyPolicy() + mock_response = Mock() + mock_response.text = json.dumps({"invalid-key": "abc"}) + mock_response.raise_for_status = Mock() + with mock.patch("requests.post", return_value=mock_response): + result = policy.on_apply(user.name, project=project.name, spec=run_spec) + # return original run spec + assert result == run_spec + \ No newline at end of file From fa46f3a4fe1dd632f02039b8cfaf78a6bf24b23b Mon Sep 17 00:00:00 2001 From: Nadine Handal Date: Tue, 20 May 2025 11:16:26 -0400 Subject: [PATCH 2/9] Change rest-plugin to a builtin plugin --- pyproject.toml | 6 - .../_internal/server/services/plugins.py | 93 ++++++++---- src/dstack/plugins/builtin/rest_plugin.py | 97 ++++++++++++ src/plugins/rest_plugin/README.md | 1 - src/plugins/rest_plugin/__init__.py | 0 src/plugins/rest_plugin/pyproject.toml | 14 -- src/plugins/rest_plugin/src/__init__.py | 0 src/plugins/rest_plugin/src/rest_plugin.py | 63 -------- .../_internal/server/services/test_plugins.py | 106 +++++++++---- src/tests/conftest.py | 2 + src/tests/plugins/test_rest_plugin.py | 142 +++++++++++++----- 11 files changed, 345 insertions(+), 179 deletions(-) create mode 100644 src/dstack/plugins/builtin/rest_plugin.py delete mode 100644 src/plugins/rest_plugin/README.md delete mode 100644 src/plugins/rest_plugin/__init__.py delete mode 100644 src/plugins/rest_plugin/pyproject.toml delete mode 100644 src/plugins/rest_plugin/src/__init__.py delete mode 100644 src/plugins/rest_plugin/src/rest_plugin.py diff --git a/pyproject.toml b/pyproject.toml index c8b7dba702..fac0183e28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,9 +75,6 @@ pattern = '\s*|]*>\s*|\s*|]*>\s*|\ replacement = '' ignore-case = true -[tool.uv.workspace] -members = ["src/plugins/rest_plugin"] - [dependency-groups] dev = [ "build>=1.2.2.post1", @@ -179,6 +176,3 @@ nebius = [ all = [ "dstack[gateway,server,aws,azure,gcp,datacrunch,kubernetes,lambda,nebius,oci]", ] - -[project.entry-points."dstack.plugins"] -rest_plugin = "plugins.rest_plugin.src.rest_plugin:RESTPlugin" diff --git a/src/dstack/_internal/server/services/plugins.py b/src/dstack/_internal/server/services/plugins.py index a8e5be8a05..a9db27cfb1 100644 --- a/src/dstack/_internal/server/services/plugins.py +++ b/src/dstack/_internal/server/services/plugins.py @@ -1,5 +1,6 @@ import itertools from importlib import import_module +from typing import Dict from backports.entry_points_selectable import entry_points # backport for Python 3.9 @@ -12,50 +13,80 @@ _PLUGINS: list[Plugin] = [] +_BUILTIN_PLUGINS: Dict[str, str] = {"rest_plugin": "dstack.plugins.builtin.rest_plugin:RESTPlugin"} -def load_plugins(enabled_plugins: list[str]): - _PLUGINS.clear() - plugins_entrypoints = entry_points(group="dstack.plugins") - plugins_to_load = enabled_plugins.copy() - for entrypoint in plugins_entrypoints: - if entrypoint.name not in enabled_plugins: - logger.info( - ("Found not enabled plugin %s. Plugin will not be loaded."), - entrypoint.name, - ) - continue + +class PluginEntrypoint: + def __init__(self, name: str, import_path: str, is_builtin: bool = False): + self.name = name + self.import_path = import_path + self.is_builtin = is_builtin + + def load(self): + module_path, _, class_name = self.import_path.partition(":") try: - module_path, _, class_name = entrypoint.value.partition(":") module = import_module(module_path) + plugin_class = getattr(module, class_name, None) + if plugin_class is None: + logger.warning( + ("Failed to load plugin %s: plugin class %s not found in module %s."), + self.name, + class_name, + module_path, + ) + return None + if not issubclass(plugin_class, Plugin): + logger.warning( + ("Failed to load plugin %s: plugin class %s is not a subclass of Plugin."), + self.name, + class_name, + ) + return None + return plugin_class() except ImportError: logger.warning( ( "Failed to load plugin %s when importing %s." " Ensure the module is on the import path." ), - entrypoint.name, - entrypoint.value, + self.name, + self.import_path, ) - continue - plugin_class = getattr(module, class_name, None) - if plugin_class is None: - logger.warning( - ("Failed to load plugin %s: plugin class %s not found in module %s."), + return None + + +def load_plugins(enabled_plugins: list[str]): + _PLUGINS.clear() + entrypoints: dict[str, PluginEntrypoint] = {} + plugins_to_load = enabled_plugins.copy() + for entrypoint in entry_points(group="dstack.plugins"): + if entrypoint.name not in enabled_plugins: + logger.info( + ("Found not enabled plugin %s. Plugin will not be loaded."), entrypoint.name, - class_name, - module_path, ) continue - if not issubclass(plugin_class, Plugin): - logger.warning( - ("Failed to load plugin %s: plugin class %s is not a subclass of Plugin."), - entrypoint.name, - class_name, + else: + entrypoints[entrypoint.name] = PluginEntrypoint( + entrypoint.name, entrypoint.value, is_builtin=False ) - continue - plugins_to_load.remove(entrypoint.name) - _PLUGINS.append(plugin_class()) - logger.info("Loaded plugin %s", entrypoint.name) + + for name, import_path in _BUILTIN_PLUGINS.items(): + if name not in enabled_plugins: + logger.info( + ("Found not enabled builtin plugin %s. Plugin will not be loaded."), + name, + ) + else: + entrypoints[name] = PluginEntrypoint(name, import_path, is_builtin=True) + + for plugin_name, plugin_entrypoint in entrypoints.items(): + plugin_instance = plugin_entrypoint.load() + if plugin_instance is not None: + _PLUGINS.append(plugin_instance) + plugins_to_load.remove(plugin_name) + logger.info("Loaded plugin %s", plugin_name) + if plugins_to_load: logger.warning("Enabled plugins not found: %s", plugins_to_load) @@ -65,7 +96,7 @@ def apply_plugin_policies(user: str, project: str, spec: ApplySpec) -> ApplySpec for policy in policies: try: spec = policy.on_apply(user=user, project=project, spec=spec) - except ValueError as e: + except Exception as e: msg = None if len(e.args) > 0: msg = e.args[0] diff --git a/src/dstack/plugins/builtin/rest_plugin.py b/src/dstack/plugins/builtin/rest_plugin.py new file mode 100644 index 0000000000..d246d0e367 --- /dev/null +++ b/src/dstack/plugins/builtin/rest_plugin.py @@ -0,0 +1,97 @@ +import json +import os +from typing import Generic, TypeVar + +import requests +from pydantic import BaseModel, ValidationError + +from dstack._internal.core.errors import ServerClientError +from dstack._internal.core.models.fleets import FleetSpec +from dstack._internal.core.models.gateways import GatewaySpec +from dstack._internal.core.models.volumes import VolumeSpec +from dstack.plugins import ApplyPolicy, Plugin, RunSpec, get_plugin_logger +from dstack.plugins._models import ApplySpec + +logger = get_plugin_logger(__name__) + +PLUGIN_SERVICE_URI_ENV_VAR_NAME = "DSTACK_PLUGIN_SERVICE_URI" +PLUGIN_REQUEST_TIMEOUT = 8 # in seconds + +SpecType = TypeVar("SpecType", RunSpec, FleetSpec, VolumeSpec, GatewaySpec) + + +class SpecRequest(BaseModel, Generic[SpecType]): + user: str + project: str + spec: SpecType + + +RunSpecRequest = SpecRequest[RunSpec] +FleetSpecRequest = SpecRequest[FleetSpec] +VolumeSpecRequest = SpecRequest[VolumeSpec] +GatewaySpecRequest = SpecRequest[GatewaySpec] + + +class CustomApplyPolicy(ApplyPolicy): + def __init__(self): + self._plugin_service_uri = os.getenv(PLUGIN_SERVICE_URI_ENV_VAR_NAME) + logger.info(f"Found plugin service at {self._plugin_service_uri}") + if not self._plugin_service_uri: + logger.error( + f"Cannot create policy because {PLUGIN_SERVICE_URI_ENV_VAR_NAME} is not set" + ) + raise ServerClientError(f"{PLUGIN_SERVICE_URI_ENV_VAR_NAME} is not set") + + def _call_plugin_service(self, spec_request: SpecRequest, endpoint: str) -> ApplySpec: + response = None + try: + response = requests.post( + f"{self._plugin_service_uri}{endpoint}", + json=spec_request.dict(), + headers={"accept": "application/json", "Content-Type": "application/json"}, + timeout=PLUGIN_REQUEST_TIMEOUT, + ) + response.raise_for_status() + spec_json = json.loads(response.text) + return spec_json + except requests.exceptions.ConnectionError as e: + logger.error( + f"Could not connect to plugin service at {self._plugin_service_uri}: %s", e + ) + raise e + except requests.RequestException as e: + logger.error("Request to the plugin service failed: %s", e) + if response: + logger.error(f"Error response from plugin service:\n{response.text}") + raise e + except ValidationError as e: + # Received 200 code but response body is invalid + logger.exception( + f"Plugin service returned invalid response:\n{response.text if response else None}" + ) + raise e + + def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec: + spec_request = RunSpecRequest(user=user, project=project, spec=spec) + spec_json = self._call_plugin_service(spec_request, "/apply_policies/on_run_apply") + return RunSpec(**spec_json) + + def on_fleet_apply(self, user: str, project: str, spec: FleetSpec) -> FleetSpec: + spec_request = FleetSpecRequest(user=user, project=project, spec=spec) + spec_json = self._call_plugin_service(spec_request, "/apply_policies/on_fleet_apply") + return FleetSpec(**spec_json) + + def on_volume_apply(self, user: str, project: str, spec: VolumeSpec) -> VolumeSpec: + spec_request = VolumeSpecRequest(user=user, project=project, spec=spec) + spec_json = self._call_plugin_service(spec_request, "/apply_policies/on_volume_apply") + return VolumeSpec(**spec_json) + + def on_gateway_apply(self, user: str, project: str, spec: GatewaySpec) -> GatewaySpec: + spec_request = GatewaySpecRequest(user=user, project=project, spec=spec) + spec_json = self._call_plugin_service(spec_request, "/apply_policies/on_gateway_apply") + return GatewaySpec(**spec_json) + + +class RESTPlugin(Plugin): + def get_apply_policies(self) -> list[ApplyPolicy]: + return [CustomApplyPolicy()] diff --git a/src/plugins/rest_plugin/README.md b/src/plugins/rest_plugin/README.md deleted file mode 100644 index c2c8c48456..0000000000 --- a/src/plugins/rest_plugin/README.md +++ /dev/null @@ -1 +0,0 @@ -[TODO] \ No newline at end of file diff --git a/src/plugins/rest_plugin/__init__.py b/src/plugins/rest_plugin/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/plugins/rest_plugin/pyproject.toml b/src/plugins/rest_plugin/pyproject.toml deleted file mode 100644 index aee85ec7e2..0000000000 --- a/src/plugins/rest_plugin/pyproject.toml +++ /dev/null @@ -1,14 +0,0 @@ -[project] -name = "rest-plugin" -version = "0.1.0" -description = "A dstack plugin that enables validation and mutation of run specifications via REST API" -readme = "README.md" -requires-python = ">=3.9" -dependencies = [] - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["src"] diff --git a/src/plugins/rest_plugin/src/__init__.py b/src/plugins/rest_plugin/src/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/plugins/rest_plugin/src/rest_plugin.py b/src/plugins/rest_plugin/src/rest_plugin.py deleted file mode 100644 index 2381590ad1..0000000000 --- a/src/plugins/rest_plugin/src/rest_plugin.py +++ /dev/null @@ -1,63 +0,0 @@ -import json -import os -import pydantic -import requests -from dstack._internal.core.errors import ServerError -from dstack._internal.core.models.fleets import FleetSpec -from dstack._internal.core.models.gateways import GatewaySpec -from dstack._internal.core.models.volumes import VolumeSpec -from dstack.plugins import ApplyPolicy, Plugin, RunSpec, get_plugin_logger -from dstack.plugins._models import ApplySpec - -logger = get_plugin_logger(__name__) - -PLUGIN_SERVICE_URI_ENV_VAR_NAME = "DSTACK_PLUGIN_SERVICE_URI" - -class PreApplyPolicy(ApplyPolicy): - def __init__(self): - self._plugin_service_uri = os.getenv(PLUGIN_SERVICE_URI_ENV_VAR_NAME) - if not self._plugin_service_uri: - logger.error(f"Cannot create policy as {PLUGIN_SERVICE_URI_ENV_VAR_NAME} is not set") - raise ServerError(f"{PLUGIN_SERVICE_URI_ENV_VAR_NAME} is not set") - - def _call_plugin_service(self, user: str, project: str, spec: ApplySpec, endpoint: str) -> ApplySpec: - # Make request to plugin service with run params - params = { - "user": user, - "project": project, - "spec": spec.json() - } - response = None - try: - response = requests.post(f"{self._plugin_service_uri}/{endpoint}", json=json.dumps(params)) - response.raise_for_status() - spec_json = json.loads(response.text) - spec = RunSpec(**spec_json) - except requests.RequestException as e: - logger.error("Failed to call plugin service: %s", e) - if response: - logger.error(f"Error response from plugin service:\n{response.text}") - logger.info("Returning original run spec") - return spec - except pydantic.ValidationError as e: - # TODO: check response error status and report if plugin service rejected request as invalid - logger.exception(f"Plugin service returned invalid response:\n{response.text if response else None}") - logger.info("Returning original run spec") - return spec - logger.info(f"Using RunSpec from plugin service:\n{spec}") - return spec - def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec: - return self._call_plugin_service(user, project, spec, '/runs/pre_apply') - - def on_fleet_apply(self, user: str, project: str, spec: FleetSpec) -> FleetSpec: - return self._call_plugin_service(user, project, spec, '/fleets/pre_apply') - - def on_volume_apply(self, user: str, project: str, spec: VolumeSpec) -> VolumeSpec: - return self._call_plugin_service(user, project, spec, '/volumes/pre_apply') - - def on_gateway_apply(self, user: str, project: str, spec: GatewaySpec) -> GatewaySpec: - return self._call_plugin_service(user, project, spec, '/gateways/pre_apply') - -class RESTPlugin(Plugin): - def get_apply_policies(self) -> list[ApplyPolicy]: - return [PreApplyPolicy()] diff --git a/src/tests/_internal/server/services/test_plugins.py b/src/tests/_internal/server/services/test_plugins.py index 460764c4d6..ca5f0bfac6 100644 --- a/src/tests/_internal/server/services/test_plugins.py +++ b/src/tests/_internal/server/services/test_plugins.py @@ -1,4 +1,5 @@ import logging +from importlib import import_module from importlib.metadata import EntryPoint from unittest.mock import MagicMock, patch @@ -6,6 +7,7 @@ from dstack._internal.server.services.plugins import _PLUGINS, load_plugins from dstack.plugins import Plugin +from dstack.plugins.builtin.rest_plugin import RESTPlugin class DummyPlugin1(Plugin): @@ -30,55 +32,105 @@ def clear_plugins(): class TestLoadPlugins: @patch("dstack._internal.server.services.plugins.entry_points") @patch("dstack._internal.server.services.plugins.import_module") - def test_load_single_plugin(self, mock_import_module, mock_entry_points, caplog): + @pytest.mark.parametrize( + ["plugin_name", "plugin_module_path", "plugin_class"], + [ + ("plugin1", "dummy.plugins", DummyPlugin1), + ("rest_plugin", "dstack.plugins.builtin.rest_plugin", RESTPlugin), + ], + ) + def test_load_single_plugin( + self, + mock_import_module, + mock_entry_points, + caplog, + plugin_name, + plugin_module_path, + plugin_class, + ): mock_entry_points.return_value = [ EntryPoint( - name="plugin1", - value="dummy.plugins:DummyPlugin1", + name=plugin_name, + value=f"{plugin_module_path}:{plugin_class.__name__}", group="dstack.plugins", ) ] mock_module = MagicMock() - mock_module.DummyPlugin1 = DummyPlugin1 - mock_import_module.return_value = mock_module + setattr(mock_module, plugin_class.__name__, plugin_class) + # if it's a built-in plugin, do the real import + mock_import_module.side_effect = ( + lambda module_path: import_module(module_path) + if module_path.startswith("dstack.plugins.builtin") + else mock_module + ) with caplog.at_level(logging.INFO): - load_plugins(["plugin1"]) + load_plugins([plugin_name]) assert len(_PLUGINS) == 1 - assert isinstance(_PLUGINS[0], DummyPlugin1) + assert isinstance(_PLUGINS[0], plugin_class) mock_entry_points.assert_called_once_with(group="dstack.plugins") - mock_import_module.assert_called_once_with("dummy.plugins") - assert "Loaded plugin plugin1" in caplog.text + mock_import_module.assert_called_once_with(plugin_module_path) + assert f"Loaded plugin {plugin_name}" in caplog.text @patch("dstack._internal.server.services.plugins.entry_points") @patch("dstack._internal.server.services.plugins.import_module") - def test_load_multiple_plugins(self, mock_import_module, mock_entry_points, caplog): - mock_entry_points.return_value = [ - EntryPoint( - name="plugin1", - value="dummy.plugins:DummyPlugin1", - group="dstack.plugins", + @pytest.mark.parametrize( + ["plugin_names", "plugin_module_paths", "plugin_classes"], + [ + ( + ["plugin1", "plugin2"], + ["dummy.plugins", "dummy.plugins"], + [DummyPlugin1, DummyPlugin2], + ), + ( + ["plugin1", "plugin2", "rest_plugin"], + ["dummy.plugins", "dummy.plugins", "dstack.plugins.builtin.rest_plugin"], + [DummyPlugin1, DummyPlugin2, RESTPlugin], ), + ], + ids=["multiple_plugins_without_builtin_plugin", "multiple_plugins_with_builtin_plugin"], + ) + def test_load_multiple_plugins( + self, + mock_import_module, + mock_entry_points, + caplog, + plugin_names, + plugin_module_paths, + plugin_classes, + ): + mock_entry_points.return_value = [ EntryPoint( - name="plugin2", - value="dummy.plugins:DummyPlugin2", + name=plugin_name, + value=f"{plugin_module_path}:{plugin_class.__name__}", group="dstack.plugins", - ), + ) + for plugin_name, plugin_module_path, plugin_class in zip( + plugin_names, plugin_module_paths, plugin_classes + ) ] mock_module = MagicMock() - mock_module.DummyPlugin1 = DummyPlugin1 - mock_module.DummyPlugin2 = DummyPlugin2 - mock_import_module.return_value = mock_module + + for plugin_class, plugin_module_path in zip(plugin_classes, plugin_module_paths): + if not plugin_module_path.startswith("dstack.plugins.builtin"): + setattr(mock_module, plugin_class.__name__, plugin_class) + + mock_import_module.side_effect = ( + lambda module_path: import_module(module_path) + if module_path.startswith("dstack.plugins.builtin") + else mock_module + ) with caplog.at_level(logging.INFO): - load_plugins(["plugin1", "plugin2"]) + load_plugins(plugin_names) + + assert len(_PLUGINS) == len(plugin_names) + for i, plugin_class in enumerate(plugin_classes): + assert isinstance(_PLUGINS[i], plugin_class) - assert len(_PLUGINS) == 2 - assert isinstance(_PLUGINS[0], DummyPlugin1) - assert isinstance(_PLUGINS[1], DummyPlugin2) - assert "Loaded plugin plugin1" in caplog.text - assert "Loaded plugin plugin2" in caplog.text + for plugin_name in plugin_names: + assert f"Loaded plugin {plugin_name}" in caplog.text @patch("dstack._internal.server.services.plugins.entry_points") @patch("dstack._internal.server.services.plugins.import_module") diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 9b0ef039a6..28d2e010a9 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -2,6 +2,8 @@ import pytest +from dstack._internal.server.testing.conf import session, test_db # noqa: F401 + def pytest_configure(config): config.addinivalue_line("markers", "ui: mark test as testing UI to run only with --runui") diff --git a/src/tests/plugins/test_rest_plugin.py b/src/tests/plugins/test_rest_plugin.py index 79beb57088..979ccdc55d 100644 --- a/src/tests/plugins/test_rest_plugin.py +++ b/src/tests/plugins/test_rest_plugin.py @@ -1,28 +1,35 @@ -from dstack._internal.core.errors import ServerError -from dstack._internal.server.models import ProjectModel, UserModel -from plugins.rest_plugin.src.rest_plugin import PreApplyPolicy, PLUGIN_SERVICE_URI_ENV_VAR_NAME -import pytest -from sqlalchemy.ext.asyncio import AsyncSession -from pydantic import parse_obj_as -import os import json -import requests +import os +from unittest import mock from unittest.mock import Mock -from dstack._internal.core.models.runs import RunSpec +import pydantic +import pytest +import pytest_asyncio +import requests +from pydantic import parse_obj_as +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import ServerError +from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import ServiceConfiguration +from dstack._internal.core.models.fleets import FleetConfiguration, FleetSpec +from dstack._internal.core.models.gateways import GatewayConfiguration, GatewaySpec from dstack._internal.core.models.profiles import Profile from dstack._internal.core.models.resources import Range +from dstack._internal.core.models.runs import RunSpec +from dstack._internal.core.models.volumes import VolumeSpec +from dstack._internal.server.models import ProjectModel +from dstack._internal.server.services import encryption as encryption from dstack._internal.server.testing.common import ( create_project, - create_user, create_repo, + create_user, + get_fleet_spec, get_run_spec, + get_volume_configuration, ) -from dstack._internal.server.testing.conf import session, test_db # noqa: F401 -from dstack._internal.server.services import encryption as encryption # import for side-effect -import pytest_asyncio -from unittest import mock +from dstack.plugins.builtin.rest_plugin import PLUGIN_SERVICE_URI_ENV_VAR_NAME, CustomApplyPolicy async def create_run_spec( @@ -38,65 +45,126 @@ async def create_run_spec( run_name=run_name, profile=profile, configuration=ServiceConfiguration( - commands=["echo hello"], - port=8000, - replicas=parse_obj_as(Range[int], replicas) + commands=["echo hello"], port=8000, replicas=parse_obj_as(Range[int], replicas) ), ) return spec + +async def create_fleet_spec(): + name = "test-fleet-spec" + fleet_conf = FleetConfiguration(name=name) + return get_fleet_spec(conf=fleet_conf) + + +async def create_volume_spec(): + return VolumeSpec(configuration=get_volume_configuration()) + + +async def create_gateway_spec(): + configuration = GatewayConfiguration( + name="test-gateway-config", + backend=BackendType.AWS, + region="us-central", + ) + return GatewaySpec(configuration=configuration) + + @pytest_asyncio.fixture async def project(session): return await create_project(session=session) + @pytest_asyncio.fixture async def user(session): return await create_user(session=session) + @pytest_asyncio.fixture -async def run_spec(session, project): - return await create_run_spec(session=session, project=project) +async def spec(request, session, project): + if request.param == "run_spec": + return await create_run_spec(session, project) + elif request.param == "fleet_spec": + return await create_fleet_spec() + elif request.param == "volume_spec": + return await create_volume_spec() + elif request.param == "gateway_spec": + return await create_gateway_spec() + else: + raise ValueError(f"Unknown spec fixture: {request.param}") class TestRESTPlugin: @pytest.mark.asyncio async def test_on_run_apply_plugin_service_uri_not_set(self): with pytest.raises(ServerError): - policy = PreApplyPolicy() + CustomApplyPolicy() @pytest.mark.asyncio @mock.patch.dict(os.environ, {PLUGIN_SERVICE_URI_ENV_VAR_NAME: "http://mock"}) @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_on_run_apply_plugin_service_returns_mutated_spec(self, test_db, user, project, run_spec): - policy = PreApplyPolicy() + @pytest.mark.parametrize( + "spec", ["run_spec", "fleet_spec", "volume_spec", "gateway_spec"], indirect=True + ) + async def test_on_run_apply_plugin_service_returns_mutated_spec( + self, test_db, user, project, spec + ): + policy = CustomApplyPolicy() mock_response = Mock() - run_spec_dict = run_spec.dict() - run_spec_dict["profile"]["tags"] = {"env": "test", "team": "qa"} - mock_response.text = json.dumps(run_spec_dict) + spec_dict = spec.dict() + + if isinstance(spec, (RunSpec, FleetSpec)): + spec_dict["profile"]["tags"] = {"env": "test", "team": "qa"} + else: + spec_dict["configuration_path"] = "/path/to/something" + + mock_response.text = json.dumps(spec_dict) mock_response.raise_for_status = Mock() with mock.patch("requests.post", return_value=mock_response): - result = policy.on_apply(user=user.name, project=project.name, spec=run_spec) - assert result == RunSpec(**run_spec_dict) + result = policy.on_apply(user=user.name, project=project.name, spec=spec) + assert result == type(spec)(**spec_dict) @pytest.mark.asyncio @mock.patch.dict(os.environ, {PLUGIN_SERVICE_URI_ENV_VAR_NAME: "http://mock"}) @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_on_run_apply_plugin_service_call_fails(self, test_db, user, project, run_spec): - policy = PreApplyPolicy() + @pytest.mark.parametrize( + "spec", ["run_spec", "fleet_spec", "volume_spec", "gateway_spec"], indirect=True + ) + async def test_on_run_apply_plugin_service_call_fails(self, test_db, user, project, spec): + policy = CustomApplyPolicy() with mock.patch("requests.post", side_effect=requests.RequestException("fail")): - result = policy.on_apply(user=user.name, project=project.name, spec=run_spec) - assert result == run_spec + with pytest.raises(requests.RequestException): + policy.on_apply(user=user.name, project=project.name, spec=spec) + + @pytest.mark.asyncio + @mock.patch.dict(os.environ, {PLUGIN_SERVICE_URI_ENV_VAR_NAME: "http://mock"}) + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.parametrize( + "spec", ["run_spec", "fleet_spec", "volume_spec", "gateway_spec"], indirect=True + ) + async def test_on_run_apply_plugin_service_connection_fails( + self, test_db, user, project, spec + ): + policy = CustomApplyPolicy() + with mock.patch( + "requests.post", side_effect=requests.ConnectionError("Failed to connect") + ): + with pytest.raises(requests.ConnectionError): + policy.on_apply(user=user.name, project=project.name, spec=spec) @pytest.mark.asyncio @mock.patch.dict(os.environ, {PLUGIN_SERVICE_URI_ENV_VAR_NAME: "http://mock"}) @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_on_run_apply_plugin_service_returns_invalid_spec(self, test_db, user, project, run_spec): - policy = PreApplyPolicy() + @pytest.mark.parametrize( + "spec", ["run_spec", "fleet_spec", "volume_spec", "gateway_spec"], indirect=True + ) + async def test_on_run_apply_plugin_service_returns_invalid_spec( + self, test_db, user, project, spec + ): + policy = CustomApplyPolicy() mock_response = Mock() mock_response.text = json.dumps({"invalid-key": "abc"}) mock_response.raise_for_status = Mock() with mock.patch("requests.post", return_value=mock_response): - result = policy.on_apply(user.name, project=project.name, spec=run_spec) - # return original run spec - assert result == run_spec - \ No newline at end of file + with pytest.raises(pydantic.ValidationError): + policy.on_apply(user.name, project=project.name, spec=spec) From dc1796995ab0a2e58cae5e10c15ba4d8f01c9213 Mon Sep 17 00:00:00 2001 From: Nadine Handal Date: Tue, 20 May 2025 12:39:56 -0400 Subject: [PATCH 3/9] Add plugin server example --- .../example_plugin_server/.python-version | 1 + .../plugins/example_plugin_server/README.md | 30 +++++++++++++ .../example_plugin_server/app/__init__.py | 0 .../plugins/example_plugin_server/app/main.py | 43 +++++++++++++++++++ .../example_plugin_server/app/models.py | 22 ++++++++++ .../example_plugin_server/app/utils.py | 7 +++ .../example_plugin_server/pyproject.toml | 10 +++++ 7 files changed, 113 insertions(+) create mode 100644 examples/plugins/example_plugin_server/.python-version create mode 100644 examples/plugins/example_plugin_server/README.md create mode 100644 examples/plugins/example_plugin_server/app/__init__.py create mode 100644 examples/plugins/example_plugin_server/app/main.py create mode 100644 examples/plugins/example_plugin_server/app/models.py create mode 100644 examples/plugins/example_plugin_server/app/utils.py create mode 100644 examples/plugins/example_plugin_server/pyproject.toml diff --git a/examples/plugins/example_plugin_server/.python-version b/examples/plugins/example_plugin_server/.python-version new file mode 100644 index 0000000000..2c0733315e --- /dev/null +++ b/examples/plugins/example_plugin_server/.python-version @@ -0,0 +1 @@ +3.11 diff --git a/examples/plugins/example_plugin_server/README.md b/examples/plugins/example_plugin_server/README.md new file mode 100644 index 0000000000..ce96dfd8f0 --- /dev/null +++ b/examples/plugins/example_plugin_server/README.md @@ -0,0 +1,30 @@ +## Overview + +If you wish to hook up your own plugin server through `dstack` builtin ` rest-plugin`, here's a basic example on how to do so. + +## Steps + + +1. Install required dependencies for the plugin server: + + ```bash + uv sync + ``` + +1. Start the plugin server locally: + + ```bash + fastapi dev app/main.py + ``` + +1. Enable `rest-plugin` in `dstack` `server/config.yaml`: + + ```yaml + plugins: + - rest_plugin + ``` + +1. Point the `dstack` server to your plugin server: + ```bash + export DSTACK_PLUGIN_SERVICE_URI=http://127.0.0.1:8000 + ``` diff --git a/examples/plugins/example_plugin_server/app/__init__.py b/examples/plugins/example_plugin_server/app/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/plugins/example_plugin_server/app/main.py b/examples/plugins/example_plugin_server/app/main.py new file mode 100644 index 0000000000..f00cd04429 --- /dev/null +++ b/examples/plugins/example_plugin_server/app/main.py @@ -0,0 +1,43 @@ +import logging + +from fastapi import FastAPI + +from app.models import FleetSpecRequest, GatewaySpecRequest, RunSpecRequest, VolumeSpecRequest +from app.utils import configure_logging + +configure_logging() +logger = logging.getLogger(__name__) + +app = FastAPI() + + +@app.post("/apply_policies/on_run_apply") +async def on_run_apply(request: RunSpecRequest): + logger.info( + f"Received run spec request from user {request.user} and project {request.project}" + ) + return request.spec + + +@app.post("/apply_policies/on_fleet_apply") +async def on_fleet_apply(request: FleetSpecRequest): + logger.info( + f"Received fleet spec request from user {request.user} and project {request.project}" + ) + return request.spec + + +@app.post("/apply_policies/on_volume_apply") +async def on_volume_apply(request: VolumeSpecRequest): + logger.info( + f"Received volume spec request from user {request.user} and project {request.project}" + ) + return request.spec + + +@app.post("/apply_policies/on_gateway_apply") +async def on_gateway_apply(request: GatewaySpecRequest): + logger.info( + f"Received gateway spec request from user {request.user} and project {request.project}" + ) + return request.spec diff --git a/examples/plugins/example_plugin_server/app/models.py b/examples/plugins/example_plugin_server/app/models.py new file mode 100644 index 0000000000..d15dc290d5 --- /dev/null +++ b/examples/plugins/example_plugin_server/app/models.py @@ -0,0 +1,22 @@ +from typing import Generic, TypeVar + +from pydantic import BaseModel + +from dstack._internal.core.models.fleets import FleetSpec +from dstack._internal.core.models.gateways import GatewaySpec +from dstack._internal.core.models.runs import RunSpec +from dstack._internal.core.models.volumes import VolumeSpec + +SpecType = TypeVar("SpecType", RunSpec, FleetSpec, VolumeSpec, GatewaySpec) + + +class SpecRequest(BaseModel, Generic[SpecType]): + user: str + project: str + spec: SpecType + + +RunSpecRequest = SpecRequest[RunSpec] +FleetSpecRequest = SpecRequest[FleetSpec] +VolumeSpecRequest = SpecRequest[VolumeSpec] +GatewaySpecRequest = SpecRequest[GatewaySpec] diff --git a/examples/plugins/example_plugin_server/app/utils.py b/examples/plugins/example_plugin_server/app/utils.py new file mode 100644 index 0000000000..b07406682f --- /dev/null +++ b/examples/plugins/example_plugin_server/app/utils.py @@ -0,0 +1,7 @@ +import logging +import os + + +def configure_logging(): + log_level = os.getenv("LOG_LEVEL", "INFO").upper() + logging.basicConfig(level=log_level) diff --git a/examples/plugins/example_plugin_server/pyproject.toml b/examples/plugins/example_plugin_server/pyproject.toml new file mode 100644 index 0000000000..abf5c63dba --- /dev/null +++ b/examples/plugins/example_plugin_server/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dstack-plugin-server" +version = "0.1.0" +description = "Example plugin server" +readme = "README.md" +requires-python = ">=3.11" +dependencies = [ + "fastapi[standard]>=0.115.12", + "dstack>=0.19.8" +] From a4bacac531d9206339d7b95764768866c3a981b7 Mon Sep 17 00:00:00 2001 From: Nadine Handal Date: Tue, 20 May 2025 14:56:41 -0400 Subject: [PATCH 4/9] Document rest-plugin in guides --- docs/docs/guides/plugins.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/docs/guides/plugins.md b/docs/docs/guides/plugins.md index ae8c287b95..2f662d553e 100644 --- a/docs/docs/guides/plugins.md +++ b/docs/docs/guides/plugins.md @@ -113,4 +113,9 @@ class ExamplePolicy(ApplyPolicy): +## Built-in Plugins + +### REST Plugin +If you'd like to apply custom policies within your organization, you can set up your own plugin API server and integrate it with `dstack` via the `rest-plugin`. To get started, check out the [plugijn server example](/examples/plugins/example_plugin_server/README.md). + For more information on the plugin development, see the [plugin example](https://github.com/dstackai/dstack/tree/master/examples/plugins/example_plugin). From e757a77789072cf8c810e9c47e60bc1e9e396d84 Mon Sep 17 00:00:00 2001 From: Nadine Handal Date: Tue, 20 May 2025 22:01:55 -0400 Subject: [PATCH 5/9] Refactor rest-plugin + polish models --- .../plugins/example_plugin_server/app/main.py | 31 ++++-- .../example_plugin_server/app/models.py | 22 ---- .../_internal/server/services/plugins.py | 2 +- src/dstack/plugins/builtin/__init__.py | 0 src/dstack/plugins/builtin/models.py | 38 +++++++ src/dstack/plugins/builtin/rest_plugin.py | 100 +++++++++++------- src/tests/plugins/test_rest_plugin.py | 26 +++-- 7 files changed, 134 insertions(+), 85 deletions(-) delete mode 100644 examples/plugins/example_plugin_server/app/models.py create mode 100644 src/dstack/plugins/builtin/__init__.py create mode 100644 src/dstack/plugins/builtin/models.py diff --git a/examples/plugins/example_plugin_server/app/main.py b/examples/plugins/example_plugin_server/app/main.py index f00cd04429..17456aedd6 100644 --- a/examples/plugins/example_plugin_server/app/main.py +++ b/examples/plugins/example_plugin_server/app/main.py @@ -2,8 +2,17 @@ from fastapi import FastAPI -from app.models import FleetSpecRequest, GatewaySpecRequest, RunSpecRequest, VolumeSpecRequest from app.utils import configure_logging +from dstack.plugins.builtin.models import ( + FleetSpecRequest, + FleetSpecResponse, + GatewaySpecRequest, + GatewaySpecResponse, + RunSpecRequest, + RunSpecResponse, + VolumeSpecRequest, + VolumeSpecResponse, +) configure_logging() logger = logging.getLogger(__name__) @@ -12,32 +21,36 @@ @app.post("/apply_policies/on_run_apply") -async def on_run_apply(request: RunSpecRequest): +async def on_run_apply(request: RunSpecRequest) -> RunSpecResponse: logger.info( f"Received run spec request from user {request.user} and project {request.project}" ) - return request.spec + response = RunSpecResponse(spec=request.spec, error=None) + return response @app.post("/apply_policies/on_fleet_apply") -async def on_fleet_apply(request: FleetSpecRequest): +async def on_fleet_apply(request: FleetSpecRequest) -> FleetSpecResponse: logger.info( f"Received fleet spec request from user {request.user} and project {request.project}" ) - return request.spec + response = FleetSpecResponse(request.spec, error=None) + return response @app.post("/apply_policies/on_volume_apply") -async def on_volume_apply(request: VolumeSpecRequest): +async def on_volume_apply(request: VolumeSpecRequest) -> VolumeSpecResponse: logger.info( f"Received volume spec request from user {request.user} and project {request.project}" ) - return request.spec + response = VolumeSpecResponse(request.spec, error=None) + return response @app.post("/apply_policies/on_gateway_apply") -async def on_gateway_apply(request: GatewaySpecRequest): +async def on_gateway_apply(request: GatewaySpecRequest) -> GatewaySpecResponse: logger.info( f"Received gateway spec request from user {request.user} and project {request.project}" ) - return request.spec + response = GatewaySpecResponse(request.spec, error=None) + return response diff --git a/examples/plugins/example_plugin_server/app/models.py b/examples/plugins/example_plugin_server/app/models.py deleted file mode 100644 index d15dc290d5..0000000000 --- a/examples/plugins/example_plugin_server/app/models.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Generic, TypeVar - -from pydantic import BaseModel - -from dstack._internal.core.models.fleets import FleetSpec -from dstack._internal.core.models.gateways import GatewaySpec -from dstack._internal.core.models.runs import RunSpec -from dstack._internal.core.models.volumes import VolumeSpec - -SpecType = TypeVar("SpecType", RunSpec, FleetSpec, VolumeSpec, GatewaySpec) - - -class SpecRequest(BaseModel, Generic[SpecType]): - user: str - project: str - spec: SpecType - - -RunSpecRequest = SpecRequest[RunSpec] -FleetSpecRequest = SpecRequest[FleetSpec] -VolumeSpecRequest = SpecRequest[VolumeSpec] -GatewaySpecRequest = SpecRequest[GatewaySpec] diff --git a/src/dstack/_internal/server/services/plugins.py b/src/dstack/_internal/server/services/plugins.py index a9db27cfb1..99699ef731 100644 --- a/src/dstack/_internal/server/services/plugins.py +++ b/src/dstack/_internal/server/services/plugins.py @@ -96,7 +96,7 @@ def apply_plugin_policies(user: str, project: str, spec: ApplySpec) -> ApplySpec for policy in policies: try: spec = policy.on_apply(user=user, project=project, spec=spec) - except Exception as e: + except ValueError as e: msg = None if len(e.args) > 0: msg = e.args[0] diff --git a/src/dstack/plugins/builtin/__init__.py b/src/dstack/plugins/builtin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/dstack/plugins/builtin/models.py b/src/dstack/plugins/builtin/models.py new file mode 100644 index 0000000000..f809938658 --- /dev/null +++ b/src/dstack/plugins/builtin/models.py @@ -0,0 +1,38 @@ +from typing import Generic, TypeVar + +from pydantic import BaseModel + +from dstack._internal.core.models.fleets import FleetSpec +from dstack._internal.core.models.gateways import GatewaySpec +from dstack._internal.core.models.runs import RunSpec +from dstack._internal.core.models.volumes import VolumeSpec + +SpecType = TypeVar("SpecType", RunSpec, FleetSpec, VolumeSpec, GatewaySpec) + + +class SpecApplyRequest(BaseModel, Generic[SpecType]): + user: str + project: str + spec: SpecType + + def dict(self, *args, **kwargs): + d = super().dict(*args, **kwargs) + d.pop("__orig_class__", None) + return d + + +RunSpecRequest = SpecApplyRequest[RunSpec] +FleetSpecRequest = SpecApplyRequest[FleetSpec] +VolumeSpecRequest = SpecApplyRequest[VolumeSpec] +GatewaySpecRequest = SpecApplyRequest[GatewaySpec] + + +class SpecApplyResponse(BaseModel, Generic[SpecType]): + spec: SpecType + error: str | None = None + + +RunSpecResponse = SpecApplyResponse[RunSpec] +FleetSpecResponse = SpecApplyResponse[FleetSpec] +VolumeSpecResponse = SpecApplyResponse[VolumeSpec] +GatewaySpecResponse = SpecApplyResponse[GatewaySpec] diff --git a/src/dstack/plugins/builtin/rest_plugin.py b/src/dstack/plugins/builtin/rest_plugin.py index d246d0e367..d746b0f96f 100644 --- a/src/dstack/plugins/builtin/rest_plugin.py +++ b/src/dstack/plugins/builtin/rest_plugin.py @@ -1,9 +1,8 @@ import json import os -from typing import Generic, TypeVar import requests -from pydantic import BaseModel, ValidationError +from pydantic import ValidationError from dstack._internal.core.errors import ServerClientError from dstack._internal.core.models.fleets import FleetSpec @@ -11,26 +10,24 @@ from dstack._internal.core.models.volumes import VolumeSpec from dstack.plugins import ApplyPolicy, Plugin, RunSpec, get_plugin_logger from dstack.plugins._models import ApplySpec +from dstack.plugins.builtin.models import ( + FleetSpecRequest, + FleetSpecResponse, + GatewaySpecRequest, + GatewaySpecResponse, + RunSpecRequest, + RunSpecResponse, + SpecApplyRequest, + SpecApplyResponse, + VolumeSpecRequest, + VolumeSpecResponse, +) logger = get_plugin_logger(__name__) PLUGIN_SERVICE_URI_ENV_VAR_NAME = "DSTACK_PLUGIN_SERVICE_URI" PLUGIN_REQUEST_TIMEOUT = 8 # in seconds -SpecType = TypeVar("SpecType", RunSpec, FleetSpec, VolumeSpec, GatewaySpec) - - -class SpecRequest(BaseModel, Generic[SpecType]): - user: str - project: str - spec: SpecType - - -RunSpecRequest = SpecRequest[RunSpec] -FleetSpecRequest = SpecRequest[FleetSpec] -VolumeSpecRequest = SpecRequest[VolumeSpec] -GatewaySpecRequest = SpecRequest[GatewaySpec] - class CustomApplyPolicy(ApplyPolicy): def __init__(self): @@ -42,7 +39,12 @@ def __init__(self): ) raise ServerClientError(f"{PLUGIN_SERVICE_URI_ENV_VAR_NAME} is not set") - def _call_plugin_service(self, spec_request: SpecRequest, endpoint: str) -> ApplySpec: + def _check_request_rejected(self, response: SpecApplyResponse): + if response.error is not None: + logger.error(f"Plugin service rejected apply request: {response.error}") + raise ServerClientError(f"Apply request rejected: {response.error}") + + def _call_plugin_service(self, spec_request: SpecApplyRequest, endpoint: str) -> ApplySpec: response = None try: response = requests.post( @@ -58,38 +60,58 @@ def _call_plugin_service(self, spec_request: SpecRequest, endpoint: str) -> Appl logger.error( f"Could not connect to plugin service at {self._plugin_service_uri}: %s", e ) - raise e + raise ServerClientError( + f"Could not connect to plugin service at {self._plugin_service_uri}" + ) except requests.RequestException as e: logger.error("Request to the plugin service failed: %s", e) - if response: - logger.error(f"Error response from plugin service:\n{response.text}") - raise e - except ValidationError as e: - # Received 200 code but response body is invalid - logger.exception( - f"Plugin service returned invalid response:\n{response.text if response else None}" - ) - raise e + raise ServerClientError("Request to the plugin service failed") + + def _on_apply(self, request_cls, response_cls, endpoint, user, project, spec): + try: + spec_request = request_cls(user=user, project=project, spec=spec) + spec_json = self._call_plugin_service(spec_request, endpoint) + response = response_cls(**spec_json) + self._check_request_rejected(response) + return response.spec + except ValidationError: + logger.error(f"Plugin service returned invalid response:\n{spec_json}") + raise ServerClientError("Plugin service returned an invalid response") def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec: - spec_request = RunSpecRequest(user=user, project=project, spec=spec) - spec_json = self._call_plugin_service(spec_request, "/apply_policies/on_run_apply") - return RunSpec(**spec_json) + return self._on_apply( + RunSpecRequest, RunSpecResponse, "/apply_policies/on_run_apply", user, project, spec + ) def on_fleet_apply(self, user: str, project: str, spec: FleetSpec) -> FleetSpec: - spec_request = FleetSpecRequest(user=user, project=project, spec=spec) - spec_json = self._call_plugin_service(spec_request, "/apply_policies/on_fleet_apply") - return FleetSpec(**spec_json) + return self._on_apply( + FleetSpecRequest, + FleetSpecResponse, + "/apply_policies/on_fleet_apply", + user, + project, + spec, + ) def on_volume_apply(self, user: str, project: str, spec: VolumeSpec) -> VolumeSpec: - spec_request = VolumeSpecRequest(user=user, project=project, spec=spec) - spec_json = self._call_plugin_service(spec_request, "/apply_policies/on_volume_apply") - return VolumeSpec(**spec_json) + return self._on_apply( + VolumeSpecRequest, + VolumeSpecResponse, + "/apply_policies/on_volume_apply", + user, + project, + spec, + ) def on_gateway_apply(self, user: str, project: str, spec: GatewaySpec) -> GatewaySpec: - spec_request = GatewaySpecRequest(user=user, project=project, spec=spec) - spec_json = self._call_plugin_service(spec_request, "/apply_policies/on_gateway_apply") - return GatewaySpec(**spec_json) + return self._on_apply( + GatewaySpecRequest, + GatewaySpecResponse, + "/apply_policies/on_gateway_apply", + user, + project, + spec, + ) class RESTPlugin(Plugin): diff --git a/src/tests/plugins/test_rest_plugin.py b/src/tests/plugins/test_rest_plugin.py index 979ccdc55d..5a99f19c00 100644 --- a/src/tests/plugins/test_rest_plugin.py +++ b/src/tests/plugins/test_rest_plugin.py @@ -10,7 +10,7 @@ from pydantic import parse_obj_as from sqlalchemy.ext.asyncio import AsyncSession -from dstack._internal.core.errors import ServerError +from dstack._internal.core.errors import ServerClientError, ServerError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import ServiceConfiguration from dstack._internal.core.models.fleets import FleetConfiguration, FleetSpec @@ -106,23 +106,23 @@ async def test_on_run_apply_plugin_service_uri_not_set(self): @pytest.mark.parametrize( "spec", ["run_spec", "fleet_spec", "volume_spec", "gateway_spec"], indirect=True ) - async def test_on_run_apply_plugin_service_returns_mutated_spec( + async def test_on_apply_plugin_service_returns_mutated_spec( self, test_db, user, project, spec ): policy = CustomApplyPolicy() mock_response = Mock() - spec_dict = spec.dict() + response_dict = {"spec": spec.dict(), "error": None} if isinstance(spec, (RunSpec, FleetSpec)): - spec_dict["profile"]["tags"] = {"env": "test", "team": "qa"} + response_dict["spec"]["profile"]["tags"] = {"env": "test", "team": "qa"} else: - spec_dict["configuration_path"] = "/path/to/something" + response_dict["spec"]["configuration_path"] = "/path/to/something" - mock_response.text = json.dumps(spec_dict) + mock_response.text = json.dumps(response_dict) mock_response.raise_for_status = Mock() with mock.patch("requests.post", return_value=mock_response): result = policy.on_apply(user=user.name, project=project.name, spec=spec) - assert result == type(spec)(**spec_dict) + assert result == type(spec)(**response_dict["spec"]) @pytest.mark.asyncio @mock.patch.dict(os.environ, {PLUGIN_SERVICE_URI_ENV_VAR_NAME: "http://mock"}) @@ -130,10 +130,10 @@ async def test_on_run_apply_plugin_service_returns_mutated_spec( @pytest.mark.parametrize( "spec", ["run_spec", "fleet_spec", "volume_spec", "gateway_spec"], indirect=True ) - async def test_on_run_apply_plugin_service_call_fails(self, test_db, user, project, spec): + async def test_on_apply_plugin_service_call_fails(self, test_db, user, project, spec): policy = CustomApplyPolicy() with mock.patch("requests.post", side_effect=requests.RequestException("fail")): - with pytest.raises(requests.RequestException): + with pytest.raises(ServerClientError): policy.on_apply(user=user.name, project=project.name, spec=spec) @pytest.mark.asyncio @@ -142,14 +142,12 @@ async def test_on_run_apply_plugin_service_call_fails(self, test_db, user, proje @pytest.mark.parametrize( "spec", ["run_spec", "fleet_spec", "volume_spec", "gateway_spec"], indirect=True ) - async def test_on_run_apply_plugin_service_connection_fails( - self, test_db, user, project, spec - ): + async def test_on_apply_plugin_service_connection_fails(self, test_db, user, project, spec): policy = CustomApplyPolicy() with mock.patch( "requests.post", side_effect=requests.ConnectionError("Failed to connect") ): - with pytest.raises(requests.ConnectionError): + with pytest.raises(ServerClientError): policy.on_apply(user=user.name, project=project.name, spec=spec) @pytest.mark.asyncio @@ -158,7 +156,7 @@ async def test_on_run_apply_plugin_service_connection_fails( @pytest.mark.parametrize( "spec", ["run_spec", "fleet_spec", "volume_spec", "gateway_spec"], indirect=True ) - async def test_on_run_apply_plugin_service_returns_invalid_spec( + async def test_on_apply_plugin_service_returns_invalid_spec( self, test_db, user, project, spec ): policy = CustomApplyPolicy() From f9036f560987554fbbeeabcefc6898566f303be3 Mon Sep 17 00:00:00 2001 From: Nadine Handal Date: Wed, 21 May 2025 11:48:20 -0400 Subject: [PATCH 6/9] Restructure rest_plugin modules --- .../plugins/example_plugin_server/app/main.py | 2 +- .../plugins/builtin/rest_plugin/__init__.py | 18 ++++++++++++++++++ .../{models.py => rest_plugin/_models.py} | 2 ++ .../{rest_plugin.py => rest_plugin/_plugin.py} | 5 ++--- src/tests/plugins/test_rest_plugin.py | 3 +-- 5 files changed, 24 insertions(+), 6 deletions(-) create mode 100644 src/dstack/plugins/builtin/rest_plugin/__init__.py rename src/dstack/plugins/builtin/{models.py => rest_plugin/_models.py} (82%) rename src/dstack/plugins/builtin/{rest_plugin.py => rest_plugin/_plugin.py} (96%) diff --git a/examples/plugins/example_plugin_server/app/main.py b/examples/plugins/example_plugin_server/app/main.py index 17456aedd6..f20d6090cc 100644 --- a/examples/plugins/example_plugin_server/app/main.py +++ b/examples/plugins/example_plugin_server/app/main.py @@ -3,7 +3,7 @@ from fastapi import FastAPI from app.utils import configure_logging -from dstack.plugins.builtin.models import ( +from dstack.plugins.builtin.rest_plugin import ( FleetSpecRequest, FleetSpecResponse, GatewaySpecRequest, diff --git a/src/dstack/plugins/builtin/rest_plugin/__init__.py b/src/dstack/plugins/builtin/rest_plugin/__init__.py new file mode 100644 index 0000000000..4d5e0fe14a --- /dev/null +++ b/src/dstack/plugins/builtin/rest_plugin/__init__.py @@ -0,0 +1,18 @@ +# ruff: noqa: F401 +from dstack.plugins.builtin.rest_plugin._models import ( + FleetSpecRequest, + FleetSpecResponse, + GatewaySpecRequest, + GatewaySpecResponse, + RunSpecRequest, + RunSpecResponse, + SpecApplyRequest, + SpecApplyResponse, + VolumeSpecRequest, + VolumeSpecResponse, +) +from dstack.plugins.builtin.rest_plugin._plugin import ( + PLUGIN_SERVICE_URI_ENV_VAR_NAME, + CustomApplyPolicy, + RESTPlugin, +) diff --git a/src/dstack/plugins/builtin/models.py b/src/dstack/plugins/builtin/rest_plugin/_models.py similarity index 82% rename from src/dstack/plugins/builtin/models.py rename to src/dstack/plugins/builtin/rest_plugin/_models.py index f809938658..02847d2032 100644 --- a/src/dstack/plugins/builtin/models.py +++ b/src/dstack/plugins/builtin/rest_plugin/_models.py @@ -15,6 +15,8 @@ class SpecApplyRequest(BaseModel, Generic[SpecType]): project: str spec: SpecType + # Override dict() to remove __orig_class__ attribute and avoid "TypeError: Object of type _GenericAlias is not JSON serializable" + # error. This issue doesn't happen though when running the code in pytest, only when running the server. def dict(self, *args, **kwargs): d = super().dict(*args, **kwargs) d.pop("__orig_class__", None) diff --git a/src/dstack/plugins/builtin/rest_plugin.py b/src/dstack/plugins/builtin/rest_plugin/_plugin.py similarity index 96% rename from src/dstack/plugins/builtin/rest_plugin.py rename to src/dstack/plugins/builtin/rest_plugin/_plugin.py index d746b0f96f..6979b7e9f9 100644 --- a/src/dstack/plugins/builtin/rest_plugin.py +++ b/src/dstack/plugins/builtin/rest_plugin/_plugin.py @@ -8,9 +8,8 @@ from dstack._internal.core.models.fleets import FleetSpec from dstack._internal.core.models.gateways import GatewaySpec from dstack._internal.core.models.volumes import VolumeSpec -from dstack.plugins import ApplyPolicy, Plugin, RunSpec, get_plugin_logger -from dstack.plugins._models import ApplySpec -from dstack.plugins.builtin.models import ( +from dstack.plugins import ApplyPolicy, ApplySpec, Plugin, RunSpec, get_plugin_logger +from dstack.plugins.builtin.rest_plugin import ( FleetSpecRequest, FleetSpecResponse, GatewaySpecRequest, diff --git a/src/tests/plugins/test_rest_plugin.py b/src/tests/plugins/test_rest_plugin.py index 5a99f19c00..7660d053de 100644 --- a/src/tests/plugins/test_rest_plugin.py +++ b/src/tests/plugins/test_rest_plugin.py @@ -3,7 +3,6 @@ from unittest import mock from unittest.mock import Mock -import pydantic import pytest import pytest_asyncio import requests @@ -164,5 +163,5 @@ async def test_on_apply_plugin_service_returns_invalid_spec( mock_response.text = json.dumps({"invalid-key": "abc"}) mock_response.raise_for_status = Mock() with mock.patch("requests.post", return_value=mock_response): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ServerClientError): policy.on_apply(user.name, project=project.name, spec=spec) From db046512c7b3d78af004d610f556e7799b83ecdf Mon Sep 17 00:00:00 2001 From: Nadine Handal Date: Wed, 21 May 2025 14:06:31 -0400 Subject: [PATCH 7/9] Doc updates --- docs/docs/guides/plugins.md | 9 +++++++-- examples/plugins/example_plugin_server/README.md | 4 ++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/docs/docs/guides/plugins.md b/docs/docs/guides/plugins.md index 2f662d553e..ac87aa10bc 100644 --- a/docs/docs/guides/plugins.md +++ b/docs/docs/guides/plugins.md @@ -116,6 +116,11 @@ class ExamplePolicy(ApplyPolicy): ## Built-in Plugins ### REST Plugin -If you'd like to apply custom policies within your organization, you can set up your own plugin API server and integrate it with `dstack` via the `rest-plugin`. To get started, check out the [plugijn server example](/examples/plugins/example_plugin_server/README.md). +`rest_plugin` is a builtin `dstack` plugin that allows writing your custom plugins as API servers, so you don't need to install plugins as Python packages. -For more information on the plugin development, see the [plugin example](https://github.com/dstackai/dstack/tree/master/examples/plugins/example_plugin). +Plugins implemented as API servers have advantages over plugins implemented as Python packages in some cases: +* No dependency conflicts with `dstack`. +* You can use any programming language. +* If you run the `dstack` server via Docker, you don't need to extend the `dstack` server image with plugins or map them via volumes. + +To get started, check out the [plugin server example](/examples/plugins/example_plugin_server/README.md). diff --git a/examples/plugins/example_plugin_server/README.md b/examples/plugins/example_plugin_server/README.md index ce96dfd8f0..840184d593 100644 --- a/examples/plugins/example_plugin_server/README.md +++ b/examples/plugins/example_plugin_server/README.md @@ -1,6 +1,6 @@ ## Overview -If you wish to hook up your own plugin server through `dstack` builtin ` rest-plugin`, here's a basic example on how to do so. +If you wish to hook up your own plugin server through `dstack` builtin `rest_plugin`, here's a basic example on how to do so. ## Steps @@ -17,7 +17,7 @@ If you wish to hook up your own plugin server through `dstack` builtin ` rest-pl fastapi dev app/main.py ``` -1. Enable `rest-plugin` in `dstack` `server/config.yaml`: +1. Enable `rest_plugin` in `server/config.yaml`: ```yaml plugins: From 32533efd03fe194c8717f528be4d659c8eeabb44 Mon Sep 17 00:00:00 2001 From: Nadine Handal Date: Thu, 22 May 2025 14:21:05 -0400 Subject: [PATCH 8/9] Additional type checks, type hints and field descriptions --- .../plugins/builtin/rest_plugin/_models.py | 22 ++++++---- .../plugins/builtin/rest_plugin/_plugin.py | 15 +++++-- src/tests/plugins/test_rest_plugin.py | 40 +++++++++++++++++++ 3 files changed, 67 insertions(+), 10 deletions(-) diff --git a/src/dstack/plugins/builtin/rest_plugin/_models.py b/src/dstack/plugins/builtin/rest_plugin/_models.py index 02847d2032..ee3a042464 100644 --- a/src/dstack/plugins/builtin/rest_plugin/_models.py +++ b/src/dstack/plugins/builtin/rest_plugin/_models.py @@ -1,6 +1,7 @@ -from typing import Generic, TypeVar +from typing import Generic, Optional, TypeVar -from pydantic import BaseModel +from pydantic import BaseModel, Field +from typing_extensions import Annotated from dstack._internal.core.models.fleets import FleetSpec from dstack._internal.core.models.gateways import GatewaySpec @@ -11,9 +12,9 @@ class SpecApplyRequest(BaseModel, Generic[SpecType]): - user: str - project: str - spec: SpecType + user: Annotated[str, Field(description="The name of the user making the apply request")] + project: Annotated[str, Field(description="The name of the project the request is for")] + spec: Annotated[SpecType, Field(description="The spec to be applied")] # Override dict() to remove __orig_class__ attribute and avoid "TypeError: Object of type _GenericAlias is not JSON serializable" # error. This issue doesn't happen though when running the code in pytest, only when running the server. @@ -30,8 +31,15 @@ def dict(self, *args, **kwargs): class SpecApplyResponse(BaseModel, Generic[SpecType]): - spec: SpecType - error: str | None = None + spec: Annotated[ + SpecType, + Field( + description="The spec to apply, original spec if error otherwise original or mutated by plugin service if approved" + ), + ] + error: Annotated[ + Optional[str], Field(description="Error message if request is rejected", min_length=1) + ] = None RunSpecResponse = SpecApplyResponse[RunSpec] diff --git a/src/dstack/plugins/builtin/rest_plugin/_plugin.py b/src/dstack/plugins/builtin/rest_plugin/_plugin.py index 6979b7e9f9..48735d15ff 100644 --- a/src/dstack/plugins/builtin/rest_plugin/_plugin.py +++ b/src/dstack/plugins/builtin/rest_plugin/_plugin.py @@ -1,5 +1,6 @@ import json import os +from typing import Type import requests from pydantic import ValidationError @@ -25,7 +26,7 @@ logger = get_plugin_logger(__name__) PLUGIN_SERVICE_URI_ENV_VAR_NAME = "DSTACK_PLUGIN_SERVICE_URI" -PLUGIN_REQUEST_TIMEOUT = 8 # in seconds +PLUGIN_REQUEST_TIMEOUT_SEC = 8 class CustomApplyPolicy(ApplyPolicy): @@ -50,7 +51,7 @@ def _call_plugin_service(self, spec_request: SpecApplyRequest, endpoint: str) -> f"{self._plugin_service_uri}{endpoint}", json=spec_request.dict(), headers={"accept": "application/json", "Content-Type": "application/json"}, - timeout=PLUGIN_REQUEST_TIMEOUT, + timeout=PLUGIN_REQUEST_TIMEOUT_SEC, ) response.raise_for_status() spec_json = json.loads(response.text) @@ -66,7 +67,15 @@ def _call_plugin_service(self, spec_request: SpecApplyRequest, endpoint: str) -> logger.error("Request to the plugin service failed: %s", e) raise ServerClientError("Request to the plugin service failed") - def _on_apply(self, request_cls, response_cls, endpoint, user, project, spec): + def _on_apply( + self, + request_cls: Type[SpecApplyRequest], + response_cls: Type[SpecApplyResponse], + endpoint: str, + user: str, + project: str, + spec: ApplySpec, + ) -> ApplySpec: try: spec_request = request_cls(user=user, project=project, spec=spec) spec_json = self._call_plugin_service(spec_request, endpoint) diff --git a/src/tests/plugins/test_rest_plugin.py b/src/tests/plugins/test_rest_plugin.py index 7660d053de..0b725ff95b 100644 --- a/src/tests/plugins/test_rest_plugin.py +++ b/src/tests/plugins/test_rest_plugin.py @@ -1,5 +1,6 @@ import json import os +from contextlib import nullcontext as does_not_raise from unittest import mock from unittest.mock import Mock @@ -165,3 +166,42 @@ async def test_on_apply_plugin_service_returns_invalid_spec( with mock.patch("requests.post", return_value=mock_response): with pytest.raises(ServerClientError): policy.on_apply(user.name, project=project.name, spec=spec) + + @pytest.mark.asyncio + @mock.patch.dict(os.environ, {PLUGIN_SERVICE_URI_ENV_VAR_NAME: "http://mock"}) + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.parametrize( + "spec", ["run_spec", "fleet_spec", "volume_spec", "gateway_spec"], indirect=True + ) + @pytest.mark.parametrize( + ("error", "expectation"), + [ + pytest.param(None, does_not_raise(), id="error_none"), + pytest.param( + "", + pytest.raises( + ServerClientError, match="Plugin service returned an invalid response" + ), + id="error_empty_str", + ), + pytest.param( + "validation failed", + pytest.raises( + ServerClientError, match="Apply request rejected: validation failed" + ), + id="error_non_empty_str", + ), + ], + ) + async def test_on_apply_plugin_service_error_handling( + self, test_db, user, project, spec, error, expectation + ): + policy = CustomApplyPolicy() + mock_response = Mock() + response_dict = {"spec": spec.dict(), "error": error} + mock_response.text = json.dumps(response_dict) + mock_response.raise_for_status = Mock() + with mock.patch("requests.post", return_value=mock_response): + with expectation: + result = policy.on_apply(user=user.name, project=project.name, spec=spec) + assert result == type(spec)(**response_dict["spec"]) From 6f1f90ec00d12963ff524eb6a9772eb4c3a44108 Mon Sep 17 00:00:00 2001 From: Nadine Handal Date: Fri, 23 May 2025 10:14:18 -0400 Subject: [PATCH 9/9] Unskip postgres tests --- src/tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 28d2e010a9..202d59058f 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -2,7 +2,7 @@ import pytest -from dstack._internal.server.testing.conf import session, test_db # noqa: F401 +from dstack._internal.server.testing.conf import postgres_container, session, test_db # noqa: F401 def pytest_configure(config):