1+ from dstack ._internal .core .errors import ServerError
2+ from dstack ._internal .server .models import ProjectModel , UserModel
3+ from plugins .rest_plugin .src .rest_plugin import PreApplyPolicy , PLUGIN_SERVICE_URI_ENV_VAR_NAME
4+ import pytest
5+ from sqlalchemy .ext .asyncio import AsyncSession
6+ from pydantic import parse_obj_as
7+ import os
8+ import json
9+ import requests
10+ from unittest .mock import Mock
11+
12+ from dstack ._internal .core .models .runs import RunSpec
13+ from dstack ._internal .core .models .configurations import ServiceConfiguration
14+ from dstack ._internal .core .models .profiles import Profile
15+ from dstack ._internal .core .models .resources import Range
16+ from dstack ._internal .server .testing .common import (
17+ create_project ,
18+ create_user ,
19+ create_repo ,
20+ get_run_spec ,
21+ )
22+ from dstack ._internal .server .testing .conf import session , test_db # noqa: F401
23+ from dstack ._internal .server .services import encryption as encryption # import for side-effect
24+ import pytest_asyncio
25+ from unittest import mock
26+
27+
28+ async def create_run_spec (
29+ session : AsyncSession ,
30+ project : ProjectModel ,
31+ replicas : str = 1 ,
32+ ) -> RunSpec :
33+ repo = await create_repo (session = session , project_id = project .id )
34+ run_name = "test-run"
35+ profile = Profile (name = "test-profile" )
36+ spec = get_run_spec (
37+ repo_id = repo .name ,
38+ run_name = run_name ,
39+ profile = profile ,
40+ configuration = ServiceConfiguration (
41+ commands = ["echo hello" ],
42+ port = 8000 ,
43+ replicas = parse_obj_as (Range [int ], replicas )
44+ ),
45+ )
46+ return spec
47+
48+ @pytest_asyncio .fixture
49+ async def project (session ):
50+ return await create_project (session = session )
51+
52+ @pytest_asyncio .fixture
53+ async def user (session ):
54+ return await create_user (session = session )
55+
56+ @pytest_asyncio .fixture
57+ async def run_spec (session , project ):
58+ return await create_run_spec (session = session , project = project )
59+
60+
61+ class TestRESTPlugin :
62+ @pytest .mark .asyncio
63+ async def test_on_run_apply_plugin_service_uri_not_set (self ):
64+ with pytest .raises (ServerError ):
65+ policy = PreApplyPolicy ()
66+
67+ @pytest .mark .asyncio
68+ @mock .patch .dict (os .environ , {PLUGIN_SERVICE_URI_ENV_VAR_NAME : "http://mock" })
69+ @pytest .mark .parametrize ("test_db" , ["sqlite" , "postgres" ], indirect = True )
70+ async def test_on_run_apply_plugin_service_returns_mutated_spec (self , test_db , user , project , run_spec ):
71+ policy = PreApplyPolicy ()
72+ mock_response = Mock ()
73+ run_spec_dict = run_spec .dict ()
74+ run_spec_dict ["profile" ]["tags" ] = {"env" : "test" , "team" : "qa" }
75+ mock_response .text = json .dumps (run_spec_dict )
76+ mock_response .raise_for_status = Mock ()
77+ with mock .patch ("requests.post" , return_value = mock_response ):
78+ result = policy .on_apply (user = user .name , project = project .name , spec = run_spec )
79+ assert result == RunSpec (** run_spec_dict )
80+
81+ @pytest .mark .asyncio
82+ @mock .patch .dict (os .environ , {PLUGIN_SERVICE_URI_ENV_VAR_NAME : "http://mock" })
83+ @pytest .mark .parametrize ("test_db" , ["sqlite" , "postgres" ], indirect = True )
84+ async def test_on_run_apply_plugin_service_call_fails (self , test_db , user , project , run_spec ):
85+ policy = PreApplyPolicy ()
86+ with mock .patch ("requests.post" , side_effect = requests .RequestException ("fail" )):
87+ result = policy .on_apply (user = user .name , project = project .name , spec = run_spec )
88+ assert result == run_spec
89+
90+ @pytest .mark .asyncio
91+ @mock .patch .dict (os .environ , {PLUGIN_SERVICE_URI_ENV_VAR_NAME : "http://mock" })
92+ @pytest .mark .parametrize ("test_db" , ["sqlite" , "postgres" ], indirect = True )
93+ async def test_on_run_apply_plugin_service_returns_invalid_spec (self , test_db , user , project , run_spec ):
94+ policy = PreApplyPolicy ()
95+ mock_response = Mock ()
96+ mock_response .text = json .dumps ({"invalid-key" : "abc" })
97+ mock_response .raise_for_status = Mock ()
98+ with mock .patch ("requests.post" , return_value = mock_response ):
99+ result = policy .on_apply (user .name , project = project .name , spec = run_spec )
100+ # return original run spec
101+ assert result == run_spec
102+
0 commit comments