11import json
22import os
3- from typing import Generic , TypeVar
43
54import requests
6- from pydantic import BaseModel , ValidationError
5+ from pydantic import ValidationError
76
87from dstack ._internal .core .errors import ServerClientError
98from dstack ._internal .core .models .fleets import FleetSpec
109from dstack ._internal .core .models .gateways import GatewaySpec
1110from dstack ._internal .core .models .volumes import VolumeSpec
1211from dstack .plugins import ApplyPolicy , Plugin , RunSpec , get_plugin_logger
1312from dstack .plugins ._models import ApplySpec
13+ from dstack .plugins .builtin .models import (
14+ FleetSpecRequest ,
15+ FleetSpecResponse ,
16+ GatewaySpecRequest ,
17+ GatewaySpecResponse ,
18+ RunSpecRequest ,
19+ RunSpecResponse ,
20+ SpecApplyRequest ,
21+ SpecApplyResponse ,
22+ VolumeSpecRequest ,
23+ VolumeSpecResponse ,
24+ )
1425
1526logger = get_plugin_logger (__name__ )
1627
1728PLUGIN_SERVICE_URI_ENV_VAR_NAME = "DSTACK_PLUGIN_SERVICE_URI"
1829PLUGIN_REQUEST_TIMEOUT = 8 # in seconds
1930
20- SpecType = TypeVar ("SpecType" , RunSpec , FleetSpec , VolumeSpec , GatewaySpec )
21-
22-
23- class SpecRequest (BaseModel , Generic [SpecType ]):
24- user : str
25- project : str
26- spec : SpecType
27-
28-
29- RunSpecRequest = SpecRequest [RunSpec ]
30- FleetSpecRequest = SpecRequest [FleetSpec ]
31- VolumeSpecRequest = SpecRequest [VolumeSpec ]
32- GatewaySpecRequest = SpecRequest [GatewaySpec ]
33-
3431
3532class CustomApplyPolicy (ApplyPolicy ):
3633 def __init__ (self ):
@@ -42,7 +39,12 @@ def __init__(self):
4239 )
4340 raise ServerClientError (f"{ PLUGIN_SERVICE_URI_ENV_VAR_NAME } is not set" )
4441
45- def _call_plugin_service (self , spec_request : SpecRequest , endpoint : str ) -> ApplySpec :
42+ def _check_request_rejected (self , response : SpecApplyResponse ):
43+ if response .error is not None :
44+ logger .error (f"Plugin service rejected apply request: { response .error } " )
45+ raise ServerClientError (f"Apply request rejected: { response .error } " )
46+
47+ def _call_plugin_service (self , spec_request : SpecApplyRequest , endpoint : str ) -> ApplySpec :
4648 response = None
4749 try :
4850 response = requests .post (
@@ -58,38 +60,58 @@ def _call_plugin_service(self, spec_request: SpecRequest, endpoint: str) -> Appl
5860 logger .error (
5961 f"Could not connect to plugin service at { self ._plugin_service_uri } : %s" , e
6062 )
61- raise e
63+ raise ServerClientError (
64+ f"Could not connect to plugin service at { self ._plugin_service_uri } "
65+ )
6266 except requests .RequestException as e :
6367 logger .error ("Request to the plugin service failed: %s" , e )
64- if response :
65- logger .error (f"Error response from plugin service:\n { response .text } " )
66- raise e
67- except ValidationError as e :
68- # Received 200 code but response body is invalid
69- logger .exception (
70- f"Plugin service returned invalid response:\n { response .text if response else None } "
71- )
72- raise e
68+ raise ServerClientError ("Request to the plugin service failed" )
69+
70+ def _on_apply (self , request_cls , response_cls , endpoint , user , project , spec ):
71+ try :
72+ spec_request = request_cls (user = user , project = project , spec = spec )
73+ spec_json = self ._call_plugin_service (spec_request , endpoint )
74+ response = response_cls (** spec_json )
75+ self ._check_request_rejected (response )
76+ return response .spec
77+ except ValidationError :
78+ logger .error (f"Plugin service returned invalid response:\n { spec_json } " )
79+ raise ServerClientError ("Plugin service returned an invalid response" )
7380
7481 def on_run_apply (self , user : str , project : str , spec : RunSpec ) -> RunSpec :
75- spec_request = RunSpecRequest ( user = user , project = project , spec = spec )
76- spec_json = self . _call_plugin_service ( spec_request , "/apply_policies/on_run_apply" )
77- return RunSpec ( ** spec_json )
82+ return self . _on_apply (
83+ RunSpecRequest , RunSpecResponse , "/apply_policies/on_run_apply" , user , project , spec
84+ )
7885
7986 def on_fleet_apply (self , user : str , project : str , spec : FleetSpec ) -> FleetSpec :
80- spec_request = FleetSpecRequest (user = user , project = project , spec = spec )
81- spec_json = self ._call_plugin_service (spec_request , "/apply_policies/on_fleet_apply" )
82- return FleetSpec (** spec_json )
87+ return self ._on_apply (
88+ FleetSpecRequest ,
89+ FleetSpecResponse ,
90+ "/apply_policies/on_fleet_apply" ,
91+ user ,
92+ project ,
93+ spec ,
94+ )
8395
8496 def on_volume_apply (self , user : str , project : str , spec : VolumeSpec ) -> VolumeSpec :
85- spec_request = VolumeSpecRequest (user = user , project = project , spec = spec )
86- spec_json = self ._call_plugin_service (spec_request , "/apply_policies/on_volume_apply" )
87- return VolumeSpec (** spec_json )
97+ return self ._on_apply (
98+ VolumeSpecRequest ,
99+ VolumeSpecResponse ,
100+ "/apply_policies/on_volume_apply" ,
101+ user ,
102+ project ,
103+ spec ,
104+ )
88105
89106 def on_gateway_apply (self , user : str , project : str , spec : GatewaySpec ) -> GatewaySpec :
90- spec_request = GatewaySpecRequest (user = user , project = project , spec = spec )
91- spec_json = self ._call_plugin_service (spec_request , "/apply_policies/on_gateway_apply" )
92- return GatewaySpec (** spec_json )
107+ return self ._on_apply (
108+ GatewaySpecRequest ,
109+ GatewaySpecResponse ,
110+ "/apply_policies/on_gateway_apply" ,
111+ user ,
112+ project ,
113+ spec ,
114+ )
93115
94116
95117class RESTPlugin (Plugin ):
0 commit comments