Skip to content

Commit 30da8a7

Browse files
committed
Fix Postgres remote query changelog number
1 parent ef86938 commit 30da8a7

3 files changed

Lines changed: 66 additions & 111 deletions

File tree

postgres/datadog_checks/postgres/remote_query.py

Lines changed: 61 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -8,37 +8,59 @@
88
import logging
99
from collections.abc import Iterable, Mapping, Sequence
1010
from dataclasses import dataclass
11-
from typing import TYPE_CHECKING, Any, Protocol
11+
from typing import TYPE_CHECKING, Any, Literal, Protocol
12+
13+
from pydantic import BaseModel, ConfigDict, Field, StrictInt, StrictStr, ValidationError, field_validator
1214

1315
if TYPE_CHECKING:
1416
from datadog_checks.postgres import PostgreSql
1517

1618
LOGGER = logging.getLogger(__name__)
1719

1820
_ALLOWED_QUERY = 'SELECT 1 AS value'
19-
_REQUEST_FIELDS = frozenset({'target', 'query', 'limits'})
20-
_TARGET_FIELDS = frozenset({'host', 'port', 'dbname'})
21-
_LIMIT_FIELDS = frozenset({'maxRows', 'maxBytes', 'timeoutMs'})
2221

2322

24-
@dataclass(frozen=True)
25-
class RemoteQueryTarget:
26-
host: str
27-
port: int
28-
dbname: str
23+
class RemoteQueryTarget(BaseModel):
24+
model_config = ConfigDict(extra='forbid', frozen=True)
2925

26+
host: StrictStr = Field(min_length=1)
27+
port: StrictInt = Field(default=5432, ge=1, le=65535)
28+
dbname: StrictStr = Field(min_length=1)
3029

31-
@dataclass(frozen=True)
32-
class RemoteQueryLimits:
33-
max_rows: int = 10
34-
max_bytes: int = 1_048_576
35-
timeout_ms: int = 5_000
30+
@field_validator('host')
31+
@classmethod
32+
def normalize_host(cls, value: str) -> str:
33+
host = value.strip().lower()
34+
if host.endswith('.'):
35+
host = host[:-1]
36+
if not host:
37+
raise ValueError('host must be a non-empty string')
38+
return host
3639

40+
@field_validator('dbname')
41+
@classmethod
42+
def validate_dbname(cls, value: str) -> str:
43+
if not value:
44+
raise ValueError('dbname must be a non-empty string')
45+
if value != value.strip():
46+
raise ValueError('dbname must not contain surrounding whitespace')
47+
return value
48+
49+
50+
class RemoteQueryLimits(BaseModel):
51+
model_config = ConfigDict(extra='forbid', frozen=True)
52+
53+
max_rows: StrictInt = Field(default=10, alias='maxRows', ge=1)
54+
max_bytes: StrictInt = Field(default=1_048_576, alias='maxBytes', ge=1)
55+
timeout_ms: StrictInt = Field(default=5_000, alias='timeoutMs', ge=1)
56+
57+
58+
class RemoteQueryRequest(BaseModel):
59+
model_config = ConfigDict(extra='forbid', frozen=True)
3760

38-
@dataclass(frozen=True)
39-
class RemoteQueryRequest:
4061
target: RemoteQueryTarget
41-
limits: RemoteQueryLimits
62+
query: Literal['SELECT 1 AS value']
63+
limits: RemoteQueryLimits = Field(default_factory=RemoteQueryLimits)
4264

4365

4466
@dataclass(frozen=True)
@@ -53,12 +75,12 @@ class PostgresCheckRegistry(Protocol):
5375
def iter_postgres_checks(self) -> Iterable['PostgreSql']: ...
5476

5577

56-
def execute_remote_query(request: Mapping[str, Any], registry: PostgresCheckRegistry) -> dict[str, Any]:
57-
request_or_error = _parse_request(request)
58-
if isinstance(request_or_error, dict):
59-
return request_or_error
78+
def execute_remote_query(request: Any, registry: PostgresCheckRegistry) -> dict[str, Any]:
79+
try:
80+
parsed_request = RemoteQueryRequest.model_validate(request)
81+
except ValidationError as e:
82+
return _validation_error(e)
6083

61-
parsed_request = request_or_error
6284
target = parsed_request.target
6385
limits = parsed_request.limits
6486

@@ -73,75 +95,10 @@ def execute_remote_query(request: Mapping[str, Any], registry: PostgresCheckRegi
7395

7496

7597
def normalize_target(target: Mapping[str, Any]) -> RemoteQueryTarget:
76-
host = target.get('host')
77-
if not isinstance(host, str) or not host.strip():
78-
raise ValueError('host must be a non-empty string')
79-
80-
dbname = target.get('dbname')
81-
if not isinstance(dbname, str) or not dbname:
82-
raise ValueError('dbname must be a non-empty string')
83-
if dbname != dbname.strip():
84-
raise ValueError('dbname must not contain surrounding whitespace')
85-
86-
return RemoteQueryTarget(
87-
host=_normalize_host(host), port=_int_in_range(target.get('port', 5432), 'port', maximum=65535), dbname=dbname
88-
)
89-
90-
91-
def _parse_request(value: Any) -> RemoteQueryRequest | dict[str, Any]:
92-
if not isinstance(value, Mapping):
93-
return _error('invalid_request', 'Remote query request must be a mapping.')
94-
95-
unknown_fields_error = _unknown_fields_error(value, _REQUEST_FIELDS, 'request')
96-
if unknown_fields_error is not None:
97-
return unknown_fields_error
98-
99-
target_or_error = _parse_target(value.get('target'))
100-
if isinstance(target_or_error, dict):
101-
return target_or_error
102-
103-
if not _is_allowed_query(value.get('query')):
104-
return _error('query_rejected', 'Only the canonical SELECT 1 proof query is allowed.')
105-
106-
limits_or_error = _parse_limits(value.get('limits', {}))
107-
if isinstance(limits_or_error, dict):
108-
return limits_or_error
109-
110-
return RemoteQueryRequest(target=target_or_error, limits=limits_or_error)
111-
112-
113-
def _parse_target(value: Any) -> RemoteQueryTarget | dict[str, Any]:
114-
if not isinstance(value, Mapping):
115-
return _error('invalid_selector', 'Target selector must be a mapping.')
116-
117-
unknown_fields_error = _unknown_fields_error(value, _TARGET_FIELDS, 'target')
118-
if unknown_fields_error is not None:
119-
return unknown_fields_error
120-
12198
try:
122-
return normalize_target(value)
123-
except ValueError as e:
124-
return _error('invalid_selector', str(e))
125-
126-
127-
def _parse_limits(value: Any) -> RemoteQueryLimits | dict[str, Any]:
128-
if value is None:
129-
value = {}
130-
if not isinstance(value, Mapping):
131-
return _error('invalid_request', 'Limits must be a mapping.')
132-
133-
unknown_fields_error = _unknown_fields_error(value, _LIMIT_FIELDS, 'limits')
134-
if unknown_fields_error is not None:
135-
return unknown_fields_error
136-
137-
try:
138-
return RemoteQueryLimits(
139-
max_rows=_int_in_range(value.get('maxRows', 10), 'maxRows'),
140-
max_bytes=_int_in_range(value.get('maxBytes', 1_048_576), 'maxBytes'),
141-
timeout_ms=_int_in_range(value.get('timeoutMs', 5_000), 'timeoutMs'),
142-
)
143-
except ValueError as e:
144-
return _error('invalid_request', str(e))
99+
return RemoteQueryTarget.model_validate(target)
100+
except ValidationError as e:
101+
raise ValueError(_validation_message(e)) from e
145102

146103

147104
def _resolve_matches(target: RemoteQueryTarget, checks: Iterable['PostgreSql']) -> list['PostgreSql']:
@@ -197,28 +154,24 @@ def _execute_select_1(check: 'PostgreSql', target: RemoteQueryTarget, limits: Re
197154
}
198155

199156

200-
def _unknown_fields_error(
201-
value: Mapping[str, Any], allowed_fields: frozenset[str], label: str
202-
) -> dict[str, Any] | None:
203-
unknown_fields = _unknown_field_names(value, allowed_fields)
204-
if unknown_fields:
205-
return _error('invalid_request', _unknown_fields_message(unknown_fields, label))
206-
return None
207-
208-
209-
def _unknown_field_names(value: Mapping[str, Any], allowed_fields: frozenset[str]) -> list[str]:
210-
return sorted(str(field) for field in value if field not in allowed_fields)
157+
def _validation_error(error: ValidationError) -> dict[str, Any]:
158+
return _error('invalid_request', _validation_message(error))
211159

212160

213-
def _unknown_fields_message(unknown_fields: list[str], label: str) -> str:
214-
field_label = 'field' if len(unknown_fields) == 1 else 'fields'
215-
return f"{label} contains unknown {field_label}: {', '.join(unknown_fields)}"
161+
def _validation_message(error: ValidationError) -> str:
162+
details = []
163+
for item in error.errors(include_input=False):
164+
location = _validation_location(item.get('loc', ()))
165+
message = item.get('msg', 'Invalid value')
166+
if location:
167+
details.append(f'{location}: {message}')
168+
else:
169+
details.append(message)
170+
return 'Invalid remote query request: {}'.format('; '.join(details))
216171

217172

218-
def _is_allowed_query(value: Any) -> bool:
219-
if not isinstance(value, str):
220-
return False
221-
return value.strip().rstrip(';').strip() == _ALLOWED_QUERY
173+
def _validation_location(location: tuple[Any, ...]) -> str:
174+
return '.'.join(str(part) for part in location)
222175

223176

224177
def _normalize_host(value: str) -> str:

postgres/tests/test_remote_query.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,14 @@ def test_execute_closed_pool_returns_target_unavailable_without_recreating_crede
244244
assert pool.requested_dbnames == []
245245

246246

247-
def test_execute_rejects_non_canonical_query_before_pool_access():
247+
@pytest.mark.parametrize('query', ['SELECT current_database()', 'SELECT 1 AS value;', ' SELECT 1 AS value'])
248+
def test_execute_rejects_non_canonical_query_before_pool_access(query):
248249
pool = FakePool()
249-
request = valid_request(query='SELECT current_database()')
250+
request = valid_request(query=query)
250251

251252
response = execute_remote_query(request, StaticPostgresCheckRegistry([make_check(pool=pool)]))
252253

253254
assert response['status'] == 'FAILED'
254-
assert response_code(response) == 'query_rejected'
255+
assert response_code(response) == 'invalid_request'
256+
assert 'query' in response['error']['message']
255257
assert pool.requested_dbnames == []

0 commit comments

Comments
 (0)