1313
1414from pydantic import BaseModel , ConfigDict , Field , StrictInt , StrictStr , ValidationError , field_validator
1515
16- RemoteQuerySql = Literal [
17- 'SELECT 1 AS value' ,
18- 'SELECT city, country FROM cities ORDER BY city' ,
19- "SELECT repeat('x', 1048576) AS payload" ,
20- "SELECT repeat('x', 2097152) AS payload" ,
21- "SELECT repeat('x', 4194304) AS payload" ,
22- "SELECT repeat('x', 8388608) AS payload" ,
23- "SELECT repeat('x', 16777216) AS payload" ,
24- "SELECT repeat('x', 33554432) AS payload" ,
25- ]
26-
2716RemoteQueryCopySql = Literal [
2817 'SELECT 1 AS value' ,
2918 'SELECT city, country FROM cities ORDER BY city' ,
@@ -73,16 +62,6 @@ def validate_dbname(cls, value: str) -> str:
7362 return value
7463
7564
76- class RemoteQueryLimits (BaseModel ):
77- """Validate the future-facing limits contract for the initial safe query slice."""
78-
79- model_config = ConfigDict (extra = 'forbid' , frozen = True )
80-
81- max_rows : StrictInt = Field (default = 10 , alias = 'maxRows' , ge = 1 )
82- max_bytes : StrictInt = Field (default = 1_048_576 , alias = 'maxBytes' , ge = 1 )
83- timeout_ms : StrictInt = Field (default = 5_000 , alias = 'timeoutMs' , ge = 1 )
84-
85-
8665class RemoteQueryCopyLimits (BaseModel ):
8766 """Validate byte-streaming limits for COPY export mode."""
8867
@@ -94,16 +73,6 @@ class RemoteQueryCopyLimits(BaseModel):
9473 timeout_ms : StrictInt = Field (default = 30_000 , alias = 'timeoutMs' , ge = 1 )
9574
9675
97- class RemoteQueryRequest (BaseModel ):
98- """Accept only exact proof queries until broader SQL execution is implemented."""
99-
100- model_config = ConfigDict (extra = 'forbid' , frozen = True )
101-
102- target : RemoteQueryTarget
103- query : RemoteQuerySql
104- limits : RemoteQueryLimits = Field (default_factory = RemoteQueryLimits )
105-
106-
10776class RemoteQueryCopyRequest (BaseModel ):
10877 """Accept only explicit COPY byte-stream export requests."""
10978
@@ -150,20 +119,6 @@ class PostgresCheckRegistry(Protocol):
150119 def iter_postgres_checks (self ) -> Iterable ['PostgreSql' ]: ...
151120
152121
153- def execute_agent_rpc_json (request_json : str | bytes | bytearray , check : 'PostgreSql' ) -> str :
154- try :
155- request = json .loads (request_json )
156- except (TypeError , ValueError ):
157- response = _error ('invalid_request' , 'Invalid remote query request: request_json must be a valid JSON object.' )
158- else :
159- if not isinstance (request , Mapping ):
160- response = _error ('invalid_request' , 'Invalid remote query request: request_json must be a JSON object.' )
161- else :
162- response = execute_remote_query (request , StaticPostgresCheckRegistry ([check ]))
163-
164- return json .dumps (response , default = str )
165-
166-
167122def execute_agent_rpc_stream_copy (
168123 request_json : str | bytes | bytearray , check : 'PostgreSql' , emit : CopyStreamEmit
169124) -> None :
@@ -227,26 +182,6 @@ def iter_agent_rpc_stream_copy_events(request: Any, registry: PostgresCheckRegis
227182 yield from _iter_copy_stream_events (matches [0 ], parsed_request , started_at )
228183
229184
230- def execute_remote_query (request : Any , registry : PostgresCheckRegistry ) -> dict [str , Any ]:
231- try :
232- parsed_request = RemoteQueryRequest .model_validate (request )
233- except ValidationError as e :
234- return _validation_error (e )
235-
236- target = parsed_request .target
237- limits = parsed_request .limits
238- query = parsed_request .query
239-
240- matches = _resolve_matches (target , registry .iter_postgres_checks ())
241- LOGGER .debug ('Remote query target match count: %d' , len (matches ))
242- if not matches :
243- return _error ('target_not_found' , 'No loaded Postgres integration instance matched target selector.' )
244- if len (matches ) > 1 :
245- return _error ('target_ambiguous' , 'More than one loaded Postgres integration instance matched target selector.' )
246-
247- return _execute_safe_query (matches [0 ], target , query , limits )
248-
249-
250185def normalize_target (target : Mapping [str , Any ]) -> RemoteQueryTarget :
251186 try :
252187 return RemoteQueryTarget .model_validate (target )
@@ -269,44 +204,6 @@ def _target_from_check(check: 'PostgreSql') -> RemoteQueryTarget | None:
269204 return None
270205
271206
272- def _execute_safe_query (
273- check : 'PostgreSql' , target : RemoteQueryTarget , query : RemoteQuerySql , limits : RemoteQueryLimits
274- ) -> dict [str , Any ]:
275- db_pool = getattr (check , 'db_pool' , None )
276- if db_pool is None :
277- return _error ('credentials_unavailable' , 'Matched Postgres check does not expose a connection pool.' )
278- if getattr (db_pool , 'is_closed' , lambda : False )():
279- return _error ('target_unavailable' , 'Matched Postgres check connection pool is closed.' , retryable = False )
280-
281- try :
282- with db_pool .get_connection (target .dbname ) as conn :
283- with conn .cursor () as cursor :
284- cursor .execute (query )
285- description = cursor .description
286- if description is None :
287- return _error ('query_failed' , 'Query did not return a result set.' )
288- raw_rows = cursor .fetchmany (limits .max_rows + 1 )
289- except RuntimeError :
290- return _error ('target_unavailable' , 'Matched Postgres check connection pool is unavailable.' , retryable = False )
291- except Exception :
292- LOGGER .exception ('Remote query execution failed' )
293- return _error ('query_failed' , 'Remote query execution failed.' )
294-
295- # max_bytes and timeout_ms are validated for the API contract but enforced in a follow-up slice.
296- truncated = len (raw_rows ) > limits .max_rows
297- response_columns = _response_columns (description , raw_rows )
298- rows = [_response_row (response_columns , row ) for row in raw_rows [: limits .max_rows ]]
299- bytes_returned = len (json .dumps ({'columns' : response_columns , 'rows' : rows }, default = str ).encode ('utf-8' ))
300-
301- return {
302- 'status' : 'SUCCEEDED' ,
303- 'columns' : response_columns ,
304- 'rows' : rows ,
305- 'truncated' : truncated ,
306- 'stats' : {'rowCount' : len (rows ), 'bytesReturned' : bytes_returned },
307- }
308-
309-
310207def _iter_copy_stream_events (
311208 check : 'PostgreSql' , request : RemoteQueryCopyRequest , started_at : float
312209) -> Iterator [CopyStreamEvent ]:
@@ -512,46 +409,6 @@ def _emit_copy_event(emit: CopyStreamEmit, event: CopyStreamEvent) -> None:
512409 emit (event .event_type , json .dumps (event .metadata , default = str ), event .payload )
513410
514411
515- def _response_columns (description : Sequence [Any ], rows : Sequence [Sequence [Any ]]) -> list [dict [str , str ]]:
516- return [
517- {'name' : _column_name (column ), 'type' : _column_type (index , rows )} for index , column in enumerate (description )
518- ]
519-
520-
521- def _column_name (column : Any ) -> str :
522- name = getattr (column , 'name' , None )
523- if name is not None :
524- return str (name )
525- return str (column [0 ])
526-
527-
528- def _column_type (index : int , rows : Sequence [Sequence [Any ]]) -> str :
529- for row in rows :
530- if row [index ] is not None :
531- return _value_type (row [index ])
532- return 'unknown'
533-
534-
535- def _value_type (value : Any ) -> str :
536- if isinstance (value , bool ):
537- return 'boolean'
538- if isinstance (value , int ):
539- return 'integer'
540- if isinstance (value , float ):
541- return 'number'
542- if isinstance (value , str ):
543- return 'string'
544- return type (value ).__name__
545-
546-
547- def _response_row (columns : Sequence [Mapping [str , str ]], row : Sequence [Any ]) -> dict [str , Any ]:
548- return {column ['name' ]: row [index ] for index , column in enumerate (columns )}
549-
550-
551- def _validation_error (error : ValidationError ) -> dict [str , Any ]:
552- return _error ('invalid_request' , _validation_message (error ))
553-
554-
555412def _validation_message (error : ValidationError ) -> str :
556413 details = []
557414 for item in error .errors (include_input = False ):
@@ -566,7 +423,3 @@ def _validation_message(error: ValidationError) -> str:
566423
567424def _validation_location (location : tuple [Any , ...]) -> str :
568425 return '.' .join (str (part ) for part in location )
569-
570-
571- def _error (code : str , message : str , retryable : bool = False ) -> dict [str , Any ]:
572- return {'status' : 'FAILED' , 'error' : {'code' : code , 'message' : message , 'retryable' : retryable }}
0 commit comments