Skip to content

Commit 32533ef

Browse files
committed
Additional type checks, type hints and field descriptions
1 parent db04651 commit 32533ef

3 files changed

Lines changed: 67 additions & 10 deletions

File tree

src/dstack/plugins/builtin/rest_plugin/_models.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Generic, TypeVar
1+
from typing import Generic, Optional, TypeVar
22

3-
from pydantic import BaseModel
3+
from pydantic import BaseModel, Field
4+
from typing_extensions import Annotated
45

56
from dstack._internal.core.models.fleets import FleetSpec
67
from dstack._internal.core.models.gateways import GatewaySpec
@@ -11,9 +12,9 @@
1112

1213

1314
class SpecApplyRequest(BaseModel, Generic[SpecType]):
14-
user: str
15-
project: str
16-
spec: SpecType
15+
user: Annotated[str, Field(description="The name of the user making the apply request")]
16+
project: Annotated[str, Field(description="The name of the project the request is for")]
17+
spec: Annotated[SpecType, Field(description="The spec to be applied")]
1718

1819
# Override dict() to remove __orig_class__ attribute and avoid "TypeError: Object of type _GenericAlias is not JSON serializable"
1920
# error. This issue doesn't happen though when running the code in pytest, only when running the server.
@@ -30,8 +31,15 @@ def dict(self, *args, **kwargs):
3031

3132

3233
class SpecApplyResponse(BaseModel, Generic[SpecType]):
33-
spec: SpecType
34-
error: str | None = None
34+
spec: Annotated[
35+
SpecType,
36+
Field(
37+
description="The spec to apply, original spec if error otherwise original or mutated by plugin service if approved"
38+
),
39+
]
40+
error: Annotated[
41+
Optional[str], Field(description="Error message if request is rejected", min_length=1)
42+
] = None
3543

3644

3745
RunSpecResponse = SpecApplyResponse[RunSpec]

src/dstack/plugins/builtin/rest_plugin/_plugin.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import os
3+
from typing import Type
34

45
import requests
56
from pydantic import ValidationError
@@ -25,7 +26,7 @@
2526
logger = get_plugin_logger(__name__)
2627

2728
PLUGIN_SERVICE_URI_ENV_VAR_NAME = "DSTACK_PLUGIN_SERVICE_URI"
28-
PLUGIN_REQUEST_TIMEOUT = 8 # in seconds
29+
PLUGIN_REQUEST_TIMEOUT_SEC = 8
2930

3031

3132
class CustomApplyPolicy(ApplyPolicy):
@@ -50,7 +51,7 @@ def _call_plugin_service(self, spec_request: SpecApplyRequest, endpoint: str) ->
5051
f"{self._plugin_service_uri}{endpoint}",
5152
json=spec_request.dict(),
5253
headers={"accept": "application/json", "Content-Type": "application/json"},
53-
timeout=PLUGIN_REQUEST_TIMEOUT,
54+
timeout=PLUGIN_REQUEST_TIMEOUT_SEC,
5455
)
5556
response.raise_for_status()
5657
spec_json = json.loads(response.text)
@@ -66,7 +67,15 @@ def _call_plugin_service(self, spec_request: SpecApplyRequest, endpoint: str) ->
6667
logger.error("Request to the plugin service failed: %s", e)
6768
raise ServerClientError("Request to the plugin service failed")
6869

69-
def _on_apply(self, request_cls, response_cls, endpoint, user, project, spec):
70+
def _on_apply(
71+
self,
72+
request_cls: Type[SpecApplyRequest],
73+
response_cls: Type[SpecApplyResponse],
74+
endpoint: str,
75+
user: str,
76+
project: str,
77+
spec: ApplySpec,
78+
) -> ApplySpec:
7079
try:
7180
spec_request = request_cls(user=user, project=project, spec=spec)
7281
spec_json = self._call_plugin_service(spec_request, endpoint)

src/tests/plugins/test_rest_plugin.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import os
3+
from contextlib import nullcontext as does_not_raise
34
from unittest import mock
45
from unittest.mock import Mock
56

@@ -165,3 +166,42 @@ async def test_on_apply_plugin_service_returns_invalid_spec(
165166
with mock.patch("requests.post", return_value=mock_response):
166167
with pytest.raises(ServerClientError):
167168
policy.on_apply(user.name, project=project.name, spec=spec)
169+
170+
@pytest.mark.asyncio
171+
@mock.patch.dict(os.environ, {PLUGIN_SERVICE_URI_ENV_VAR_NAME: "http://mock"})
172+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
173+
@pytest.mark.parametrize(
174+
"spec", ["run_spec", "fleet_spec", "volume_spec", "gateway_spec"], indirect=True
175+
)
176+
@pytest.mark.parametrize(
177+
("error", "expectation"),
178+
[
179+
pytest.param(None, does_not_raise(), id="error_none"),
180+
pytest.param(
181+
"",
182+
pytest.raises(
183+
ServerClientError, match="Plugin service returned an invalid response"
184+
),
185+
id="error_empty_str",
186+
),
187+
pytest.param(
188+
"validation failed",
189+
pytest.raises(
190+
ServerClientError, match="Apply request rejected: validation failed"
191+
),
192+
id="error_non_empty_str",
193+
),
194+
],
195+
)
196+
async def test_on_apply_plugin_service_error_handling(
197+
self, test_db, user, project, spec, error, expectation
198+
):
199+
policy = CustomApplyPolicy()
200+
mock_response = Mock()
201+
response_dict = {"spec": spec.dict(), "error": error}
202+
mock_response.text = json.dumps(response_dict)
203+
mock_response.raise_for_status = Mock()
204+
with mock.patch("requests.post", return_value=mock_response):
205+
with expectation:
206+
result = policy.on_apply(user=user.name, project=project.name, spec=spec)
207+
assert result == type(spec)(**response_dict["spec"])

0 commit comments

Comments
 (0)