Skip to content

Commit e757a77

Browse files
committed
Refactor rest-plugin + polish models
1 parent a4bacac commit e757a77

7 files changed

Lines changed: 134 additions & 85 deletions

File tree

examples/plugins/example_plugin_server/app/main.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,17 @@
22

33
from fastapi import FastAPI
44

5-
from app.models import FleetSpecRequest, GatewaySpecRequest, RunSpecRequest, VolumeSpecRequest
65
from app.utils import configure_logging
6+
from dstack.plugins.builtin.models import (
7+
FleetSpecRequest,
8+
FleetSpecResponse,
9+
GatewaySpecRequest,
10+
GatewaySpecResponse,
11+
RunSpecRequest,
12+
RunSpecResponse,
13+
VolumeSpecRequest,
14+
VolumeSpecResponse,
15+
)
716

817
configure_logging()
918
logger = logging.getLogger(__name__)
@@ -12,32 +21,36 @@
1221

1322

1423
@app.post("/apply_policies/on_run_apply")
15-
async def on_run_apply(request: RunSpecRequest):
24+
async def on_run_apply(request: RunSpecRequest) -> RunSpecResponse:
1625
logger.info(
1726
f"Received run spec request from user {request.user} and project {request.project}"
1827
)
19-
return request.spec
28+
response = RunSpecResponse(spec=request.spec, error=None)
29+
return response
2030

2131

2232
@app.post("/apply_policies/on_fleet_apply")
23-
async def on_fleet_apply(request: FleetSpecRequest):
33+
async def on_fleet_apply(request: FleetSpecRequest) -> FleetSpecResponse:
2434
logger.info(
2535
f"Received fleet spec request from user {request.user} and project {request.project}"
2636
)
27-
return request.spec
37+
response = FleetSpecResponse(request.spec, error=None)
38+
return response
2839

2940

3041
@app.post("/apply_policies/on_volume_apply")
31-
async def on_volume_apply(request: VolumeSpecRequest):
42+
async def on_volume_apply(request: VolumeSpecRequest) -> VolumeSpecResponse:
3243
logger.info(
3344
f"Received volume spec request from user {request.user} and project {request.project}"
3445
)
35-
return request.spec
46+
response = VolumeSpecResponse(request.spec, error=None)
47+
return response
3648

3749

3850
@app.post("/apply_policies/on_gateway_apply")
39-
async def on_gateway_apply(request: GatewaySpecRequest):
51+
async def on_gateway_apply(request: GatewaySpecRequest) -> GatewaySpecResponse:
4052
logger.info(
4153
f"Received gateway spec request from user {request.user} and project {request.project}"
4254
)
43-
return request.spec
55+
response = GatewaySpecResponse(request.spec, error=None)
56+
return response

examples/plugins/example_plugin_server/app/models.py

Lines changed: 0 additions & 22 deletions
This file was deleted.

src/dstack/_internal/server/services/plugins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def apply_plugin_policies(user: str, project: str, spec: ApplySpec) -> ApplySpec
9696
for policy in policies:
9797
try:
9898
spec = policy.on_apply(user=user, project=project, spec=spec)
99-
except Exception as e:
99+
except ValueError as e:
100100
msg = None
101101
if len(e.args) > 0:
102102
msg = e.args[0]

src/dstack/plugins/builtin/__init__.py

Whitespace-only changes.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from typing import Generic, TypeVar
2+
3+
from pydantic import BaseModel
4+
5+
from dstack._internal.core.models.fleets import FleetSpec
6+
from dstack._internal.core.models.gateways import GatewaySpec
7+
from dstack._internal.core.models.runs import RunSpec
8+
from dstack._internal.core.models.volumes import VolumeSpec
9+
10+
SpecType = TypeVar("SpecType", RunSpec, FleetSpec, VolumeSpec, GatewaySpec)
11+
12+
13+
class SpecApplyRequest(BaseModel, Generic[SpecType]):
14+
user: str
15+
project: str
16+
spec: SpecType
17+
18+
def dict(self, *args, **kwargs):
19+
d = super().dict(*args, **kwargs)
20+
d.pop("__orig_class__", None)
21+
return d
22+
23+
24+
RunSpecRequest = SpecApplyRequest[RunSpec]
25+
FleetSpecRequest = SpecApplyRequest[FleetSpec]
26+
VolumeSpecRequest = SpecApplyRequest[VolumeSpec]
27+
GatewaySpecRequest = SpecApplyRequest[GatewaySpec]
28+
29+
30+
class SpecApplyResponse(BaseModel, Generic[SpecType]):
31+
spec: SpecType
32+
error: str | None = None
33+
34+
35+
RunSpecResponse = SpecApplyResponse[RunSpec]
36+
FleetSpecResponse = SpecApplyResponse[FleetSpec]
37+
VolumeSpecResponse = SpecApplyResponse[VolumeSpec]
38+
GatewaySpecResponse = SpecApplyResponse[GatewaySpec]

src/dstack/plugins/builtin/rest_plugin.py

Lines changed: 61 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,33 @@
11
import json
22
import os
3-
from typing import Generic, TypeVar
43

54
import requests
6-
from pydantic import BaseModel, ValidationError
5+
from pydantic import ValidationError
76

87
from dstack._internal.core.errors import ServerClientError
98
from dstack._internal.core.models.fleets import FleetSpec
109
from dstack._internal.core.models.gateways import GatewaySpec
1110
from dstack._internal.core.models.volumes import VolumeSpec
1211
from dstack.plugins import ApplyPolicy, Plugin, RunSpec, get_plugin_logger
1312
from dstack.plugins._models import ApplySpec
13+
from dstack.plugins.builtin.models import (
14+
FleetSpecRequest,
15+
FleetSpecResponse,
16+
GatewaySpecRequest,
17+
GatewaySpecResponse,
18+
RunSpecRequest,
19+
RunSpecResponse,
20+
SpecApplyRequest,
21+
SpecApplyResponse,
22+
VolumeSpecRequest,
23+
VolumeSpecResponse,
24+
)
1425

1526
logger = get_plugin_logger(__name__)
1627

1728
PLUGIN_SERVICE_URI_ENV_VAR_NAME = "DSTACK_PLUGIN_SERVICE_URI"
1829
PLUGIN_REQUEST_TIMEOUT = 8 # in seconds
1930

20-
SpecType = TypeVar("SpecType", RunSpec, FleetSpec, VolumeSpec, GatewaySpec)
21-
22-
23-
class SpecRequest(BaseModel, Generic[SpecType]):
24-
user: str
25-
project: str
26-
spec: SpecType
27-
28-
29-
RunSpecRequest = SpecRequest[RunSpec]
30-
FleetSpecRequest = SpecRequest[FleetSpec]
31-
VolumeSpecRequest = SpecRequest[VolumeSpec]
32-
GatewaySpecRequest = SpecRequest[GatewaySpec]
33-
3431

3532
class CustomApplyPolicy(ApplyPolicy):
3633
def __init__(self):
@@ -42,7 +39,12 @@ def __init__(self):
4239
)
4340
raise ServerClientError(f"{PLUGIN_SERVICE_URI_ENV_VAR_NAME} is not set")
4441

45-
def _call_plugin_service(self, spec_request: SpecRequest, endpoint: str) -> ApplySpec:
42+
def _check_request_rejected(self, response: SpecApplyResponse):
43+
if response.error is not None:
44+
logger.error(f"Plugin service rejected apply request: {response.error}")
45+
raise ServerClientError(f"Apply request rejected: {response.error}")
46+
47+
def _call_plugin_service(self, spec_request: SpecApplyRequest, endpoint: str) -> ApplySpec:
4648
response = None
4749
try:
4850
response = requests.post(
@@ -58,38 +60,58 @@ def _call_plugin_service(self, spec_request: SpecRequest, endpoint: str) -> Appl
5860
logger.error(
5961
f"Could not connect to plugin service at {self._plugin_service_uri}: %s", e
6062
)
61-
raise e
63+
raise ServerClientError(
64+
f"Could not connect to plugin service at {self._plugin_service_uri}"
65+
)
6266
except requests.RequestException as e:
6367
logger.error("Request to the plugin service failed: %s", e)
64-
if response:
65-
logger.error(f"Error response from plugin service:\n{response.text}")
66-
raise e
67-
except ValidationError as e:
68-
# Received 200 code but response body is invalid
69-
logger.exception(
70-
f"Plugin service returned invalid response:\n{response.text if response else None}"
71-
)
72-
raise e
68+
raise ServerClientError("Request to the plugin service failed")
69+
70+
def _on_apply(self, request_cls, response_cls, endpoint, user, project, spec):
71+
try:
72+
spec_request = request_cls(user=user, project=project, spec=spec)
73+
spec_json = self._call_plugin_service(spec_request, endpoint)
74+
response = response_cls(**spec_json)
75+
self._check_request_rejected(response)
76+
return response.spec
77+
except ValidationError:
78+
logger.error(f"Plugin service returned invalid response:\n{spec_json}")
79+
raise ServerClientError("Plugin service returned an invalid response")
7380

7481
def on_run_apply(self, user: str, project: str, spec: RunSpec) -> RunSpec:
75-
spec_request = RunSpecRequest(user=user, project=project, spec=spec)
76-
spec_json = self._call_plugin_service(spec_request, "/apply_policies/on_run_apply")
77-
return RunSpec(**spec_json)
82+
return self._on_apply(
83+
RunSpecRequest, RunSpecResponse, "/apply_policies/on_run_apply", user, project, spec
84+
)
7885

7986
def on_fleet_apply(self, user: str, project: str, spec: FleetSpec) -> FleetSpec:
80-
spec_request = FleetSpecRequest(user=user, project=project, spec=spec)
81-
spec_json = self._call_plugin_service(spec_request, "/apply_policies/on_fleet_apply")
82-
return FleetSpec(**spec_json)
87+
return self._on_apply(
88+
FleetSpecRequest,
89+
FleetSpecResponse,
90+
"/apply_policies/on_fleet_apply",
91+
user,
92+
project,
93+
spec,
94+
)
8395

8496
def on_volume_apply(self, user: str, project: str, spec: VolumeSpec) -> VolumeSpec:
85-
spec_request = VolumeSpecRequest(user=user, project=project, spec=spec)
86-
spec_json = self._call_plugin_service(spec_request, "/apply_policies/on_volume_apply")
87-
return VolumeSpec(**spec_json)
97+
return self._on_apply(
98+
VolumeSpecRequest,
99+
VolumeSpecResponse,
100+
"/apply_policies/on_volume_apply",
101+
user,
102+
project,
103+
spec,
104+
)
88105

89106
def on_gateway_apply(self, user: str, project: str, spec: GatewaySpec) -> GatewaySpec:
90-
spec_request = GatewaySpecRequest(user=user, project=project, spec=spec)
91-
spec_json = self._call_plugin_service(spec_request, "/apply_policies/on_gateway_apply")
92-
return GatewaySpec(**spec_json)
107+
return self._on_apply(
108+
GatewaySpecRequest,
109+
GatewaySpecResponse,
110+
"/apply_policies/on_gateway_apply",
111+
user,
112+
project,
113+
spec,
114+
)
93115

94116

95117
class RESTPlugin(Plugin):

src/tests/plugins/test_rest_plugin.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pydantic import parse_obj_as
1111
from sqlalchemy.ext.asyncio import AsyncSession
1212

13-
from dstack._internal.core.errors import ServerError
13+
from dstack._internal.core.errors import ServerClientError, ServerError
1414
from dstack._internal.core.models.backends.base import BackendType
1515
from dstack._internal.core.models.configurations import ServiceConfiguration
1616
from dstack._internal.core.models.fleets import FleetConfiguration, FleetSpec
@@ -106,34 +106,34 @@ async def test_on_run_apply_plugin_service_uri_not_set(self):
106106
@pytest.mark.parametrize(
107107
"spec", ["run_spec", "fleet_spec", "volume_spec", "gateway_spec"], indirect=True
108108
)
109-
async def test_on_run_apply_plugin_service_returns_mutated_spec(
109+
async def test_on_apply_plugin_service_returns_mutated_spec(
110110
self, test_db, user, project, spec
111111
):
112112
policy = CustomApplyPolicy()
113113
mock_response = Mock()
114-
spec_dict = spec.dict()
114+
response_dict = {"spec": spec.dict(), "error": None}
115115

116116
if isinstance(spec, (RunSpec, FleetSpec)):
117-
spec_dict["profile"]["tags"] = {"env": "test", "team": "qa"}
117+
response_dict["spec"]["profile"]["tags"] = {"env": "test", "team": "qa"}
118118
else:
119-
spec_dict["configuration_path"] = "/path/to/something"
119+
response_dict["spec"]["configuration_path"] = "/path/to/something"
120120

121-
mock_response.text = json.dumps(spec_dict)
121+
mock_response.text = json.dumps(response_dict)
122122
mock_response.raise_for_status = Mock()
123123
with mock.patch("requests.post", return_value=mock_response):
124124
result = policy.on_apply(user=user.name, project=project.name, spec=spec)
125-
assert result == type(spec)(**spec_dict)
125+
assert result == type(spec)(**response_dict["spec"])
126126

127127
@pytest.mark.asyncio
128128
@mock.patch.dict(os.environ, {PLUGIN_SERVICE_URI_ENV_VAR_NAME: "http://mock"})
129129
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
130130
@pytest.mark.parametrize(
131131
"spec", ["run_spec", "fleet_spec", "volume_spec", "gateway_spec"], indirect=True
132132
)
133-
async def test_on_run_apply_plugin_service_call_fails(self, test_db, user, project, spec):
133+
async def test_on_apply_plugin_service_call_fails(self, test_db, user, project, spec):
134134
policy = CustomApplyPolicy()
135135
with mock.patch("requests.post", side_effect=requests.RequestException("fail")):
136-
with pytest.raises(requests.RequestException):
136+
with pytest.raises(ServerClientError):
137137
policy.on_apply(user=user.name, project=project.name, spec=spec)
138138

139139
@pytest.mark.asyncio
@@ -142,14 +142,12 @@ async def test_on_run_apply_plugin_service_call_fails(self, test_db, user, proje
142142
@pytest.mark.parametrize(
143143
"spec", ["run_spec", "fleet_spec", "volume_spec", "gateway_spec"], indirect=True
144144
)
145-
async def test_on_run_apply_plugin_service_connection_fails(
146-
self, test_db, user, project, spec
147-
):
145+
async def test_on_apply_plugin_service_connection_fails(self, test_db, user, project, spec):
148146
policy = CustomApplyPolicy()
149147
with mock.patch(
150148
"requests.post", side_effect=requests.ConnectionError("Failed to connect")
151149
):
152-
with pytest.raises(requests.ConnectionError):
150+
with pytest.raises(ServerClientError):
153151
policy.on_apply(user=user.name, project=project.name, spec=spec)
154152

155153
@pytest.mark.asyncio
@@ -158,7 +156,7 @@ async def test_on_run_apply_plugin_service_connection_fails(
158156
@pytest.mark.parametrize(
159157
"spec", ["run_spec", "fleet_spec", "volume_spec", "gateway_spec"], indirect=True
160158
)
161-
async def test_on_run_apply_plugin_service_returns_invalid_spec(
159+
async def test_on_apply_plugin_service_returns_invalid_spec(
162160
self, test_db, user, project, spec
163161
):
164162
policy = CustomApplyPolicy()

0 commit comments

Comments
 (0)