88import logging
99from collections .abc import Iterable , Mapping , Sequence
1010from dataclasses import dataclass
11- from typing import TYPE_CHECKING , Any , Protocol
11+ from typing import TYPE_CHECKING , Any , Literal , Protocol
12+
13+ from pydantic import BaseModel , ConfigDict , Field , StrictInt , StrictStr , ValidationError , field_validator
1214
1315if TYPE_CHECKING :
1416 from datadog_checks .postgres import PostgreSql
1517
1618LOGGER = logging .getLogger (__name__ )
1719
1820_ALLOWED_QUERY = 'SELECT 1 AS value'
19- _REQUEST_FIELDS = frozenset ({'target' , 'query' , 'limits' })
20- _TARGET_FIELDS = frozenset ({'host' , 'port' , 'dbname' })
21- _LIMIT_FIELDS = frozenset ({'maxRows' , 'maxBytes' , 'timeoutMs' })
2221
2322
24- @dataclass (frozen = True )
25- class RemoteQueryTarget :
26- host : str
27- port : int
28- dbname : str
23+ class RemoteQueryTarget (BaseModel ):
24+ model_config = ConfigDict (extra = 'forbid' , frozen = True )
2925
26+ host : StrictStr = Field (min_length = 1 )
27+ port : StrictInt = Field (default = 5432 , ge = 1 , le = 65535 )
28+ dbname : StrictStr = Field (min_length = 1 )
3029
31- @dataclass (frozen = True )
32- class RemoteQueryLimits :
33- max_rows : int = 10
34- max_bytes : int = 1_048_576
35- timeout_ms : int = 5_000
30+ @field_validator ('host' )
31+ @classmethod
32+ def normalize_host (cls , value : str ) -> str :
33+ host = value .strip ().lower ()
34+ if host .endswith ('.' ):
35+ host = host [:- 1 ]
36+ if not host :
37+ raise ValueError ('host must be a non-empty string' )
38+ return host
3639
40+ @field_validator ('dbname' )
41+ @classmethod
42+ def validate_dbname (cls , value : str ) -> str :
43+ if not value :
44+ raise ValueError ('dbname must be a non-empty string' )
45+ if value != value .strip ():
46+ raise ValueError ('dbname must not contain surrounding whitespace' )
47+ return value
48+
49+
50+ class RemoteQueryLimits (BaseModel ):
51+ model_config = ConfigDict (extra = 'forbid' , frozen = True )
52+
53+ max_rows : StrictInt = Field (default = 10 , alias = 'maxRows' , ge = 1 )
54+ max_bytes : StrictInt = Field (default = 1_048_576 , alias = 'maxBytes' , ge = 1 )
55+ timeout_ms : StrictInt = Field (default = 5_000 , alias = 'timeoutMs' , ge = 1 )
56+
57+
58+ class RemoteQueryRequest (BaseModel ):
59+ model_config = ConfigDict (extra = 'forbid' , frozen = True )
3760
38- @dataclass (frozen = True )
39- class RemoteQueryRequest :
4061 target : RemoteQueryTarget
41- limits : RemoteQueryLimits
62+ query : Literal ['SELECT 1 AS value' ]
63+ limits : RemoteQueryLimits = Field (default_factory = RemoteQueryLimits )
4264
4365
4466@dataclass (frozen = True )
@@ -53,12 +75,12 @@ class PostgresCheckRegistry(Protocol):
5375 def iter_postgres_checks (self ) -> Iterable ['PostgreSql' ]: ...
5476
5577
56- def execute_remote_query (request : Mapping [str , Any ], registry : PostgresCheckRegistry ) -> dict [str , Any ]:
57- request_or_error = _parse_request (request )
58- if isinstance (request_or_error , dict ):
59- return request_or_error
78+ def execute_remote_query (request : Any , registry : PostgresCheckRegistry ) -> dict [str , Any ]:
79+ try :
80+ parsed_request = RemoteQueryRequest .model_validate (request )
81+ except ValidationError as e :
82+ return _validation_error (e )
6083
61- parsed_request = request_or_error
6284 target = parsed_request .target
6385 limits = parsed_request .limits
6486
@@ -73,75 +95,10 @@ def execute_remote_query(request: Mapping[str, Any], registry: PostgresCheckRegi
7395
7496
7597def normalize_target (target : Mapping [str , Any ]) -> RemoteQueryTarget :
76- host = target .get ('host' )
77- if not isinstance (host , str ) or not host .strip ():
78- raise ValueError ('host must be a non-empty string' )
79-
80- dbname = target .get ('dbname' )
81- if not isinstance (dbname , str ) or not dbname :
82- raise ValueError ('dbname must be a non-empty string' )
83- if dbname != dbname .strip ():
84- raise ValueError ('dbname must not contain surrounding whitespace' )
85-
86- return RemoteQueryTarget (
87- host = _normalize_host (host ), port = _int_in_range (target .get ('port' , 5432 ), 'port' , maximum = 65535 ), dbname = dbname
88- )
89-
90-
91- def _parse_request (value : Any ) -> RemoteQueryRequest | dict [str , Any ]:
92- if not isinstance (value , Mapping ):
93- return _error ('invalid_request' , 'Remote query request must be a mapping.' )
94-
95- unknown_fields_error = _unknown_fields_error (value , _REQUEST_FIELDS , 'request' )
96- if unknown_fields_error is not None :
97- return unknown_fields_error
98-
99- target_or_error = _parse_target (value .get ('target' ))
100- if isinstance (target_or_error , dict ):
101- return target_or_error
102-
103- if not _is_allowed_query (value .get ('query' )):
104- return _error ('query_rejected' , 'Only the canonical SELECT 1 proof query is allowed.' )
105-
106- limits_or_error = _parse_limits (value .get ('limits' , {}))
107- if isinstance (limits_or_error , dict ):
108- return limits_or_error
109-
110- return RemoteQueryRequest (target = target_or_error , limits = limits_or_error )
111-
112-
113- def _parse_target (value : Any ) -> RemoteQueryTarget | dict [str , Any ]:
114- if not isinstance (value , Mapping ):
115- return _error ('invalid_selector' , 'Target selector must be a mapping.' )
116-
117- unknown_fields_error = _unknown_fields_error (value , _TARGET_FIELDS , 'target' )
118- if unknown_fields_error is not None :
119- return unknown_fields_error
120-
12198 try :
122- return normalize_target (value )
123- except ValueError as e :
124- return _error ('invalid_selector' , str (e ))
125-
126-
127- def _parse_limits (value : Any ) -> RemoteQueryLimits | dict [str , Any ]:
128- if value is None :
129- value = {}
130- if not isinstance (value , Mapping ):
131- return _error ('invalid_request' , 'Limits must be a mapping.' )
132-
133- unknown_fields_error = _unknown_fields_error (value , _LIMIT_FIELDS , 'limits' )
134- if unknown_fields_error is not None :
135- return unknown_fields_error
136-
137- try :
138- return RemoteQueryLimits (
139- max_rows = _int_in_range (value .get ('maxRows' , 10 ), 'maxRows' ),
140- max_bytes = _int_in_range (value .get ('maxBytes' , 1_048_576 ), 'maxBytes' ),
141- timeout_ms = _int_in_range (value .get ('timeoutMs' , 5_000 ), 'timeoutMs' ),
142- )
143- except ValueError as e :
144- return _error ('invalid_request' , str (e ))
99+ return RemoteQueryTarget .model_validate (target )
100+ except ValidationError as e :
101+ raise ValueError (_validation_message (e )) from e
145102
146103
147104def _resolve_matches (target : RemoteQueryTarget , checks : Iterable ['PostgreSql' ]) -> list ['PostgreSql' ]:
@@ -197,28 +154,24 @@ def _execute_select_1(check: 'PostgreSql', target: RemoteQueryTarget, limits: Re
197154 }
198155
199156
200- def _unknown_fields_error (
201- value : Mapping [str , Any ], allowed_fields : frozenset [str ], label : str
202- ) -> dict [str , Any ] | None :
203- unknown_fields = _unknown_field_names (value , allowed_fields )
204- if unknown_fields :
205- return _error ('invalid_request' , _unknown_fields_message (unknown_fields , label ))
206- return None
207-
208-
209- def _unknown_field_names (value : Mapping [str , Any ], allowed_fields : frozenset [str ]) -> list [str ]:
210- return sorted (str (field ) for field in value if field not in allowed_fields )
157+ def _validation_error (error : ValidationError ) -> dict [str , Any ]:
158+ return _error ('invalid_request' , _validation_message (error ))
211159
212160
213- def _unknown_fields_message (unknown_fields : list [str ], label : str ) -> str :
214- field_label = 'field' if len (unknown_fields ) == 1 else 'fields'
215- return f"{ label } contains unknown { field_label } : { ', ' .join (unknown_fields )} "
161+ def _validation_message (error : ValidationError ) -> str :
162+ details = []
163+ for item in error .errors (include_input = False ):
164+ location = _validation_location (item .get ('loc' , ()))
165+ message = item .get ('msg' , 'Invalid value' )
166+ if location :
167+ details .append (f'{ location } : { message } ' )
168+ else :
169+ details .append (message )
170+ return 'Invalid remote query request: {}' .format ('; ' .join (details ))
216171
217172
218- def _is_allowed_query (value : Any ) -> bool :
219- if not isinstance (value , str ):
220- return False
221- return value .strip ().rstrip (';' ).strip () == _ALLOWED_QUERY
173+ def _validation_location (location : tuple [Any , ...]) -> str :
174+ return '.' .join (str (part ) for part in location )
222175
223176
224177def _normalize_host (value : str ) -> str :
0 commit comments