Skip to content

Commit ce610ea

Browse files
committed
Add Postgres remote query proof executor
1 parent 48a7ac4 commit ce610ea

4 files changed

Lines changed: 534 additions & 0 deletions

File tree

postgres/changelog.d/23476.added

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add a remote query POC executor for Postgres.
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
# (C) Datadog, Inc. 2026-present
2+
# All rights reserved
3+
# Licensed under Simplified BSD License (see LICENSE)
4+
5+
from __future__ import annotations
6+
7+
import json
8+
import logging
9+
from collections.abc import Iterable, Mapping, Sequence
10+
from dataclasses import dataclass
11+
from typing import TYPE_CHECKING, Any, Protocol
12+
13+
if TYPE_CHECKING:
14+
from datadog_checks.postgres import PostgreSql
15+
16+
LOGGER = logging.getLogger(__name__)
17+
18+
_ALLOWED_QUERY = 'SELECT 1 AS value'
19+
_CREDENTIAL_FIELD_NAMES = {
20+
'access_token',
21+
'connection_string',
22+
'connectionstring',
23+
'dsn',
24+
'passwd',
25+
'password',
26+
'pwd',
27+
'refresh_token',
28+
'ssl_cert',
29+
'ssl_key',
30+
'ssl_password',
31+
'ssl_root_cert',
32+
'sslcert',
33+
'sslkey',
34+
'sslpassword',
35+
'sslrootcert',
36+
'token',
37+
'url',
38+
'user',
39+
'username',
40+
}
41+
42+
43+
@dataclass(frozen=True)
44+
class RemoteQueryTarget:
45+
host: str
46+
port: int
47+
dbname: str
48+
49+
50+
@dataclass(frozen=True)
51+
class RemoteQueryLimits:
52+
max_rows: int = 10
53+
max_bytes: int = 1_048_576
54+
timeout_ms: int = 5_000
55+
56+
57+
@dataclass(frozen=True)
58+
class StaticPostgresCheckRegistry:
59+
checks: Sequence['PostgreSql']
60+
61+
def iter_postgres_checks(self) -> Iterable['PostgreSql']:
62+
return iter(self.checks)
63+
64+
65+
class PostgresCheckRegistry(Protocol):
66+
def iter_postgres_checks(self) -> Iterable['PostgreSql']: ...
67+
68+
69+
def execute_remote_query(request: Mapping[str, Any], registry: PostgresCheckRegistry) -> dict[str, Any]:
70+
if not isinstance(request, Mapping):
71+
return _error('invalid_request', 'Remote query request must be a mapping.')
72+
73+
if _contains_credential_field(request):
74+
LOGGER.warning('Rejected remote query request containing credential-shaped fields')
75+
return _error('request_contains_credentials', 'Request must not contain datastore credential material.')
76+
77+
target_or_error = _parse_target(request.get('target'))
78+
if isinstance(target_or_error, dict):
79+
return target_or_error
80+
target = target_or_error
81+
82+
if not _is_allowed_query(request.get('query')):
83+
return _error('query_rejected', 'Only the canonical SELECT 1 proof query is allowed.')
84+
85+
limits_or_error = _parse_limits(request.get('limits', {}))
86+
if isinstance(limits_or_error, dict):
87+
return limits_or_error
88+
limits = limits_or_error
89+
90+
matches = _resolve_matches(target, registry.iter_postgres_checks())
91+
LOGGER.debug('Remote query target match count: %d', len(matches))
92+
if not matches:
93+
return _error('target_not_found', 'No loaded Postgres integration instance matched target selector.')
94+
if len(matches) > 1:
95+
return _error('target_ambiguous', 'More than one loaded Postgres integration instance matched target selector.')
96+
97+
return _execute_select_1(matches[0], target, limits)
98+
99+
100+
def normalize_target(target: Mapping[str, Any]) -> RemoteQueryTarget:
101+
host = target.get('host')
102+
if not isinstance(host, str) or not host.strip():
103+
raise ValueError('host must be a non-empty string')
104+
105+
dbname = target.get('dbname')
106+
if not isinstance(dbname, str) or not dbname:
107+
raise ValueError('dbname must be a non-empty string')
108+
if dbname != dbname.strip():
109+
raise ValueError('dbname must not contain surrounding whitespace')
110+
111+
return RemoteQueryTarget(host=_normalize_host(host), port=_normalize_port(target.get('port', 5432)), dbname=dbname)
112+
113+
114+
def _parse_target(value: Any) -> RemoteQueryTarget | dict[str, Any]:
115+
if not isinstance(value, Mapping):
116+
return _error('invalid_selector', 'Target selector must be a mapping.')
117+
118+
try:
119+
return normalize_target(value)
120+
except ValueError as e:
121+
return _error('invalid_selector', str(e))
122+
123+
124+
def _parse_limits(value: Any) -> RemoteQueryLimits | dict[str, Any]:
125+
if value is None:
126+
value = {}
127+
if not isinstance(value, Mapping):
128+
return _error('invalid_request', 'Limits must be a mapping.')
129+
130+
try:
131+
return RemoteQueryLimits(
132+
max_rows=_positive_int(value.get('maxRows', 10), 'maxRows'),
133+
max_bytes=_positive_int(value.get('maxBytes', 1_048_576), 'maxBytes'),
134+
timeout_ms=_positive_int(value.get('timeoutMs', 5_000), 'timeoutMs'),
135+
)
136+
except ValueError as e:
137+
return _error('invalid_request', str(e))
138+
139+
140+
def _resolve_matches(target: RemoteQueryTarget, checks: Iterable['PostgreSql']) -> list['PostgreSql']:
141+
matches = []
142+
for check in checks:
143+
config = getattr(check, '_config', None)
144+
if config is None:
145+
continue
146+
try:
147+
candidate = RemoteQueryTarget(
148+
host=_normalize_host(config.host),
149+
port=_normalize_port(config.port),
150+
dbname=config.dbname,
151+
)
152+
except (AttributeError, ValueError):
153+
continue
154+
if candidate == target:
155+
matches.append(check)
156+
return matches
157+
158+
159+
def _execute_select_1(check: 'PostgreSql', target: RemoteQueryTarget, limits: RemoteQueryLimits) -> dict[str, Any]:
160+
db_pool = getattr(check, 'db_pool', None)
161+
if db_pool is None:
162+
return _error('credentials_unavailable', 'Matched Postgres check does not expose a connection pool.')
163+
if getattr(db_pool, 'is_closed', lambda: False)():
164+
return _error('target_unavailable', 'Matched Postgres check connection pool is closed.', retryable=False)
165+
166+
try:
167+
with db_pool.get_connection(target.dbname) as conn:
168+
with conn.cursor() as cursor:
169+
cursor.execute(_ALLOWED_QUERY)
170+
if cursor.description is None:
171+
return _error('query_failed', 'Query did not return a result set.')
172+
columns = [_column_name(column) for column in cursor.description]
173+
raw_rows = cursor.fetchmany(limits.max_rows + 1)
174+
except RuntimeError:
175+
return _error('target_unavailable', 'Matched Postgres check connection pool is unavailable.', retryable=False)
176+
except Exception:
177+
LOGGER.exception('Remote query execution failed')
178+
return _error('query_failed', 'Remote query execution failed.')
179+
180+
truncated = len(raw_rows) > limits.max_rows
181+
rows = [_row_to_dict(columns, row) for row in raw_rows[: limits.max_rows]]
182+
response_columns = [{'name': name, 'type': _infer_type(rows, name)} for name in columns]
183+
bytes_returned = len(json.dumps({'columns': response_columns, 'rows': rows}, default=str).encode('utf-8'))
184+
185+
return {
186+
'status': 'SUCCEEDED',
187+
'columns': response_columns,
188+
'rows': rows,
189+
'truncated': truncated,
190+
'stats': {'rowCount': len(rows), 'bytesReturned': bytes_returned},
191+
}
192+
193+
194+
def _contains_credential_field(value: Any) -> bool:
195+
if isinstance(value, Mapping):
196+
for key, nested_value in value.items():
197+
if str(key).lower() in _CREDENTIAL_FIELD_NAMES:
198+
return True
199+
if _contains_credential_field(nested_value):
200+
return True
201+
elif isinstance(value, list | tuple):
202+
return any(_contains_credential_field(item) for item in value)
203+
return False
204+
205+
206+
def _is_allowed_query(value: Any) -> bool:
207+
if not isinstance(value, str):
208+
return False
209+
return value.strip().rstrip(';').strip() == _ALLOWED_QUERY
210+
211+
212+
def _normalize_host(value: str) -> str:
213+
host = value.strip().lower()
214+
if host.endswith('.'):
215+
host = host[:-1]
216+
if not host:
217+
raise ValueError('host must be a non-empty string')
218+
return host
219+
220+
221+
def _normalize_port(value: Any) -> int:
222+
if isinstance(value, bool):
223+
raise ValueError('port must be an integer')
224+
if isinstance(value, int):
225+
port = value
226+
elif isinstance(value, str):
227+
if not value.isdigit():
228+
raise ValueError('port must be an integer')
229+
port = int(value)
230+
else:
231+
raise ValueError('port must be an integer')
232+
233+
if port <= 0 or port > 65535:
234+
raise ValueError('port must be between 1 and 65535')
235+
return port
236+
237+
238+
def _positive_int(value: Any, field: str) -> int:
239+
if isinstance(value, bool):
240+
raise ValueError(f'{field} must be a positive integer')
241+
if isinstance(value, int):
242+
number = value
243+
elif isinstance(value, str) and value.isdigit():
244+
number = int(value)
245+
else:
246+
raise ValueError(f'{field} must be a positive integer')
247+
if number <= 0:
248+
raise ValueError(f'{field} must be a positive integer')
249+
return number
250+
251+
252+
def _column_name(column: Any) -> str:
253+
name = getattr(column, 'name', None)
254+
if name is not None:
255+
return str(name)
256+
return str(column[0])
257+
258+
259+
def _row_to_dict(columns: list[str], row: Any) -> dict[str, Any]:
260+
if isinstance(row, Mapping):
261+
return {column: row[column] for column in columns}
262+
return dict(zip(columns, row))
263+
264+
265+
def _infer_type(rows: list[dict[str, Any]], column: str) -> str:
266+
for row in rows:
267+
value = row.get(column)
268+
if isinstance(value, bool):
269+
return 'boolean'
270+
if isinstance(value, int):
271+
return 'integer'
272+
if isinstance(value, float):
273+
return 'number'
274+
if value is not None:
275+
return 'string'
276+
return 'unknown'
277+
278+
279+
def _error(code: str, message: str, retryable: bool = False) -> dict[str, Any]:
280+
return {'status': 'FAILED', 'error': {'code': code, 'message': message, 'retryable': retryable}}

0 commit comments

Comments
 (0)