From 998be180066277b93b83eebf4ac8dea8f6f91baf Mon Sep 17 00:00:00 2001 From: nubtron Date: Tue, 28 Apr 2026 15:25:33 +0000 Subject: [PATCH 01/11] Add Postgres remote query proof executor --- postgres/changelog.d/23476.added | 1 + .../datadog_checks/postgres/remote_query.py | 302 ++++++++++++++++++ postgres/tests/test_remote_query.py | 243 ++++++++++++++ .../tests/test_remote_query_integration.py | 33 ++ 4 files changed, 579 insertions(+) create mode 100644 postgres/changelog.d/23476.added create mode 100644 postgres/datadog_checks/postgres/remote_query.py create mode 100644 postgres/tests/test_remote_query.py create mode 100644 postgres/tests/test_remote_query_integration.py diff --git a/postgres/changelog.d/23476.added b/postgres/changelog.d/23476.added new file mode 100644 index 0000000000000..2c901bd4b2e33 --- /dev/null +++ b/postgres/changelog.d/23476.added @@ -0,0 +1 @@ +Add a remote query POC executor for Postgres. \ No newline at end of file diff --git a/postgres/datadog_checks/postgres/remote_query.py b/postgres/datadog_checks/postgres/remote_query.py new file mode 100644 index 0000000000000..506713ba843be --- /dev/null +++ b/postgres/datadog_checks/postgres/remote_query.py @@ -0,0 +1,302 @@ +# (C) Datadog, Inc. 2026-present +# All rights reserved +# Licensed under Simplified BSD License (see LICENSE) + +from __future__ import annotations + +import json +import logging +from collections.abc import Iterable, Mapping, Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Protocol + +if TYPE_CHECKING: + from datadog_checks.postgres import PostgreSql + +LOGGER = logging.getLogger(__name__) + +_ALLOWED_QUERY = 'SELECT 1 AS value' +_REQUEST_FIELDS = frozenset({'target', 'query', 'limits'}) +_TARGET_FIELDS = frozenset({'host', 'port', 'dbname'}) +_LIMIT_FIELDS = frozenset({'maxRows', 'maxBytes', 'timeoutMs'}) + + +class UnknownFieldsError(ValueError): + pass + + +@dataclass(frozen=True) +class RemoteQueryTarget: + host: str + port: int + dbname: str + + +@dataclass(frozen=True) +class RemoteQueryLimits: + max_rows: int = 10 + max_bytes: int = 1_048_576 + timeout_ms: int = 5_000 + + +@dataclass(frozen=True) +class RemoteQueryRequest: + target: RemoteQueryTarget + query: str + limits: RemoteQueryLimits + + +@dataclass(frozen=True) +class StaticPostgresCheckRegistry: + checks: Sequence['PostgreSql'] + + def iter_postgres_checks(self) -> Iterable['PostgreSql']: + return iter(self.checks) + + +class PostgresCheckRegistry(Protocol): + def iter_postgres_checks(self) -> Iterable['PostgreSql']: ... + + +def execute_remote_query(request: Mapping[str, Any], registry: PostgresCheckRegistry) -> dict[str, Any]: + request_or_error = _parse_request(request) + if isinstance(request_or_error, dict): + return request_or_error + + parsed_request = request_or_error + target = parsed_request.target + limits = parsed_request.limits + + matches = _resolve_matches(target, registry.iter_postgres_checks()) + LOGGER.debug('Remote query target match count: %d', len(matches)) + if not matches: + return _error('target_not_found', 'No loaded Postgres integration instance matched target selector.') + if len(matches) > 1: + return _error('target_ambiguous', 'More than one loaded Postgres integration instance matched target selector.') + + return _execute_select_1(matches[0], target, limits) + + +def normalize_target(target: Mapping[str, Any]) -> RemoteQueryTarget: + _reject_unknown_fields(target, _TARGET_FIELDS, 'target') + + host = target.get('host') + if not isinstance(host, str) or not host.strip(): + raise ValueError('host must be a non-empty string') + + dbname = target.get('dbname') + if not isinstance(dbname, str) or not dbname: + raise ValueError('dbname must be a non-empty string') + if dbname != dbname.strip(): + raise ValueError('dbname must not contain surrounding whitespace') + + return RemoteQueryTarget(host=_normalize_host(host), port=_normalize_port(target.get('port', 5432)), dbname=dbname) + + +def _parse_request(value: Any) -> RemoteQueryRequest | dict[str, Any]: + if not isinstance(value, Mapping): + return _error('invalid_request', 'Remote query request must be a mapping.') + + unknown_fields_error = _unknown_fields_error(value, _REQUEST_FIELDS, 'request') + if unknown_fields_error is not None: + return unknown_fields_error + + target_or_error = _parse_target(value.get('target')) + if isinstance(target_or_error, dict): + return target_or_error + + if not _is_allowed_query(value.get('query')): + return _error('query_rejected', 'Only the canonical SELECT 1 proof query is allowed.') + + limits_or_error = _parse_limits(value.get('limits', {})) + if isinstance(limits_or_error, dict): + return limits_or_error + + return RemoteQueryRequest(target=target_or_error, query=_ALLOWED_QUERY, limits=limits_or_error) + + +def _parse_target(value: Any) -> RemoteQueryTarget | dict[str, Any]: + if not isinstance(value, Mapping): + return _error('invalid_selector', 'Target selector must be a mapping.') + + try: + return normalize_target(value) + except UnknownFieldsError as e: + return _error('invalid_request', str(e)) + except ValueError as e: + return _error('invalid_selector', str(e)) + + +def _parse_limits(value: Any) -> RemoteQueryLimits | dict[str, Any]: + if value is None: + value = {} + if not isinstance(value, Mapping): + return _error('invalid_request', 'Limits must be a mapping.') + + unknown_fields_error = _unknown_fields_error(value, _LIMIT_FIELDS, 'limits') + if unknown_fields_error is not None: + return unknown_fields_error + + try: + return RemoteQueryLimits( + max_rows=_positive_int(value.get('maxRows', 10), 'maxRows'), + max_bytes=_positive_int(value.get('maxBytes', 1_048_576), 'maxBytes'), + timeout_ms=_positive_int(value.get('timeoutMs', 5_000), 'timeoutMs'), + ) + except ValueError as e: + return _error('invalid_request', str(e)) + + +def _resolve_matches(target: RemoteQueryTarget, checks: Iterable['PostgreSql']) -> list['PostgreSql']: + matches = [] + for check in checks: + config = getattr(check, '_config', None) + if config is None: + continue + try: + candidate = RemoteQueryTarget( + host=_normalize_host(config.host), + port=_normalize_port(config.port), + dbname=config.dbname, + ) + except (AttributeError, ValueError): + continue + if candidate == target: + matches.append(check) + return matches + + +def _execute_select_1(check: 'PostgreSql', target: RemoteQueryTarget, limits: RemoteQueryLimits) -> dict[str, Any]: + db_pool = getattr(check, 'db_pool', None) + if db_pool is None: + return _error('credentials_unavailable', 'Matched Postgres check does not expose a connection pool.') + if getattr(db_pool, 'is_closed', lambda: False)(): + return _error('target_unavailable', 'Matched Postgres check connection pool is closed.', retryable=False) + + try: + with db_pool.get_connection(target.dbname) as conn: + with conn.cursor() as cursor: + cursor.execute(_ALLOWED_QUERY) + if cursor.description is None: + return _error('query_failed', 'Query did not return a result set.') + columns = [_column_name(column) for column in cursor.description] + raw_rows = cursor.fetchmany(limits.max_rows + 1) + except RuntimeError: + return _error('target_unavailable', 'Matched Postgres check connection pool is unavailable.', retryable=False) + except Exception: + LOGGER.exception('Remote query execution failed') + return _error('query_failed', 'Remote query execution failed.') + + truncated = len(raw_rows) > limits.max_rows + rows = [_row_to_dict(columns, row) for row in raw_rows[: limits.max_rows]] + response_columns = [{'name': name, 'type': _infer_type(rows, name)} for name in columns] + bytes_returned = len(json.dumps({'columns': response_columns, 'rows': rows}, default=str).encode('utf-8')) + + return { + 'status': 'SUCCEEDED', + 'columns': response_columns, + 'rows': rows, + 'truncated': truncated, + 'stats': {'rowCount': len(rows), 'bytesReturned': bytes_returned}, + } + + +def _reject_unknown_fields(value: Mapping[str, Any], allowed_fields: frozenset[str], label: str) -> None: + unknown_fields = _unknown_field_names(value, allowed_fields) + if unknown_fields: + raise UnknownFieldsError(_unknown_fields_message(unknown_fields, label)) + + +def _unknown_fields_error( + value: Mapping[str, Any], allowed_fields: frozenset[str], label: str +) -> dict[str, Any] | None: + unknown_fields = _unknown_field_names(value, allowed_fields) + if unknown_fields: + return _error('invalid_request', _unknown_fields_message(unknown_fields, label)) + return None + + +def _unknown_field_names(value: Mapping[str, Any], allowed_fields: frozenset[str]) -> list[str]: + return sorted(str(field) for field in value if field not in allowed_fields) + + +def _unknown_fields_message(unknown_fields: list[str], label: str) -> str: + field_label = 'field' if len(unknown_fields) == 1 else 'fields' + return f"{label} contains unknown {field_label}: {', '.join(unknown_fields)}" + + +def _is_allowed_query(value: Any) -> bool: + if not isinstance(value, str): + return False + return value.strip().rstrip(';').strip() == _ALLOWED_QUERY + + +def _normalize_host(value: str) -> str: + host = value.strip().lower() + if host.endswith('.'): + host = host[:-1] + if not host: + raise ValueError('host must be a non-empty string') + return host + + +def _normalize_port(value: Any) -> int: + if isinstance(value, bool): + raise ValueError('port must be an integer') + if isinstance(value, int): + port = value + elif isinstance(value, str): + if not value.isdigit(): + raise ValueError('port must be an integer') + port = int(value) + else: + raise ValueError('port must be an integer') + + if port <= 0 or port > 65535: + raise ValueError('port must be between 1 and 65535') + return port + + +def _positive_int(value: Any, field: str) -> int: + if isinstance(value, bool): + raise ValueError(f'{field} must be a positive integer') + if isinstance(value, int): + number = value + elif isinstance(value, str) and value.isdigit(): + number = int(value) + else: + raise ValueError(f'{field} must be a positive integer') + if number <= 0: + raise ValueError(f'{field} must be a positive integer') + return number + + +def _column_name(column: Any) -> str: + name = getattr(column, 'name', None) + if name is not None: + return str(name) + return str(column[0]) + + +def _row_to_dict(columns: list[str], row: Any) -> dict[str, Any]: + if isinstance(row, Mapping): + return {column: row[column] for column in columns} + return dict(zip(columns, row)) + + +def _infer_type(rows: list[dict[str, Any]], column: str) -> str: + for row in rows: + value = row.get(column) + if isinstance(value, bool): + return 'boolean' + if isinstance(value, int): + return 'integer' + if isinstance(value, float): + return 'number' + if value is not None: + return 'string' + return 'unknown' + + +def _error(code: str, message: str, retryable: bool = False) -> dict[str, Any]: + return {'status': 'FAILED', 'error': {'code': code, 'message': message, 'retryable': retryable}} diff --git a/postgres/tests/test_remote_query.py b/postgres/tests/test_remote_query.py new file mode 100644 index 0000000000000..e44258e58cdcf --- /dev/null +++ b/postgres/tests/test_remote_query.py @@ -0,0 +1,243 @@ +# (C) Datadog, Inc. 2026-present +# All rights reserved +# Licensed under Simplified BSD License (see LICENSE) + +from contextlib import contextmanager +from types import SimpleNamespace + +import pytest + +from datadog_checks.postgres.remote_query import ( + StaticPostgresCheckRegistry, + execute_remote_query, + normalize_target, +) + + +class FakePool: + def __init__(self, rows=None, description=None, closed=False): + self.rows = rows or [(1,)] + self.description = description or [SimpleNamespace(name='value')] + self.closed = closed + self.requested_dbnames = [] + + def is_closed(self): + return self.closed + + @contextmanager + def get_connection(self, dbname): + self.requested_dbnames.append(dbname) + yield FakeConnection(self.rows, self.description) + + +class FakeConnection: + def __init__(self, rows, description): + self.rows = rows + self.description = description + + @contextmanager + def cursor(self): + yield FakeCursor(self.rows, self.description) + + +class FakeCursor: + def __init__(self, rows, description): + self.rows = rows + self.description = description + self.executed = None + + def execute(self, query): + self.executed = query + + def fetchmany(self, size): + return self.rows[:size] + + +def make_check(host='localhost', port='5432', dbname='datadog_test', pool=None, **metadata): + check = SimpleNamespace( + _config=SimpleNamespace(host=host, port=port, dbname=dbname, **metadata), + db_pool=pool or FakePool(), + ) + check.execute_query_raw = pytest.fail + check._run_query_scope = pytest.fail + check.data_observability = SimpleNamespace(run_job=pytest.fail) + return check + + +def valid_request(host='LOCALHOST.', port=5432, dbname='datadog_test', **extra): + request = { + 'target': {'host': host, 'port': port, 'dbname': dbname}, + 'query': 'SELECT 1 AS value', + 'limits': {'maxRows': 10, 'maxBytes': 1048576, 'timeoutMs': 5000}, + } + request.update(extra) + return request + + +class ExplodingRegistry: + def iter_postgres_checks(self): + pytest.fail('registry must not be iterated') + + +def response_code(response): + return response['error']['code'] + + +def test_normalize_target_trims_lowercases_host_and_removes_one_trailing_dot(): + target = normalize_target({'host': ' Example.INTERNAL. ', 'port': '5432', 'dbname': 'postgres'}) + + assert target.host == 'example.internal' + assert target.port == 5432 + assert target.dbname == 'postgres' + + +def test_normalize_target_defaults_missing_port_to_5432(): + target = normalize_target({'host': 'localhost', 'dbname': 'postgres'}) + + assert target.port == 5432 + + +@pytest.mark.parametrize('port', [True, 'abc', '0', 0, -1, 65536, None]) +def test_normalize_target_rejects_invalid_port_values(port): + with pytest.raises(ValueError): + normalize_target({'host': 'localhost', 'port': port, 'dbname': 'postgres'}) + + +@pytest.mark.parametrize( + 'target', + [ + {'host': '', 'port': 5432, 'dbname': 'postgres'}, + {'host': ' ', 'port': 5432, 'dbname': 'postgres'}, + {'host': 'localhost', 'port': 5432, 'dbname': ''}, + {'host': 'localhost', 'port': 5432, 'dbname': ' postgres '}, + ], +) +def test_normalize_target_rejects_empty_host_or_dbname(target): + with pytest.raises(ValueError): + normalize_target(target) + + +@pytest.mark.parametrize('field', ['extra', 'password']) +def test_rejects_unknown_request_fields_before_resolution(caplog, field): + request = valid_request(**{field: 'SECRET_DO_NOT_LOG'}) + + response = execute_remote_query(request, ExplodingRegistry()) + + assert response['status'] == 'FAILED' + assert response_code(response) == 'invalid_request' + assert field in response['error']['message'] + assert 'SECRET_DO_NOT_LOG' not in str(response) + assert 'SECRET_DO_NOT_LOG' not in caplog.text + + +def test_rejects_unknown_target_fields_before_resolution(): + request = valid_request() + request['target']['password'] = 'SECRET_DO_NOT_LOG' + + response = execute_remote_query(request, ExplodingRegistry()) + + assert response['status'] == 'FAILED' + assert response_code(response) == 'invalid_request' + assert 'password' in response['error']['message'] + assert 'SECRET_DO_NOT_LOG' not in str(response) + + +def test_rejects_unknown_limits_fields_before_resolution(): + request = valid_request() + request['limits']['password'] = 'SECRET_DO_NOT_LOG' + + response = execute_remote_query(request, ExplodingRegistry()) + + assert response['status'] == 'FAILED' + assert response_code(response) == 'invalid_request' + assert 'password' in response['error']['message'] + assert 'SECRET_DO_NOT_LOG' not in str(response) + + +def test_resolve_matches_exact_host_port_dbname_from_check_config(): + pool = FakePool() + check = make_check(host='localhost', port='5432', dbname='datadog_test', pool=pool) + + response = execute_remote_query(valid_request(host='LOCALHOST.', port='5432'), StaticPostgresCheckRegistry([check])) + + assert response['status'] == 'SUCCEEDED' + assert response['rows'] == [{'value': 1}] + assert pool.requested_dbnames == ['datadog_test'] + + +def test_resolve_requires_dbname_match_even_when_host_and_port_match(): + pool = FakePool() + check = make_check(host='localhost', port='5432', dbname='datadog_test', pool=pool) + + response = execute_remote_query(valid_request(dbname='postgres'), StaticPostgresCheckRegistry([check])) + + assert response['status'] == 'FAILED' + assert response_code(response) == 'target_not_found' + assert pool.requested_dbnames == [] + + +def test_resolve_ignores_metadata_identity_matches(): + pool = FakePool() + check = make_check( + host='configured.internal', + port='5432', + dbname='datadog_test', + pool=pool, + reported_hostname='reported.internal', + database_identifier='reported.internal', + ) + + response = execute_remote_query(valid_request(host='reported.internal'), StaticPostgresCheckRegistry([check])) + + assert response['status'] == 'FAILED' + assert response_code(response) == 'target_not_found' + assert pool.requested_dbnames == [] + + +def test_resolve_fails_ambiguous_duplicate_configs(): + first_pool = FakePool() + second_pool = FakePool() + checks = [make_check(pool=first_pool), make_check(pool=second_pool)] + + response = execute_remote_query(valid_request(), StaticPostgresCheckRegistry(checks)) + + assert response['status'] == 'FAILED' + assert response_code(response) == 'target_ambiguous' + assert first_pool.requested_dbnames == [] + assert second_pool.requested_dbnames == [] + + +def test_execute_sets_truncated_when_more_than_max_rows_returned(): + pool = FakePool(rows=[(1,), (2,)]) + check = make_check(pool=pool) + request = valid_request() + request['limits']['maxRows'] = 1 + + response = execute_remote_query(request, StaticPostgresCheckRegistry([check])) + + assert response['status'] == 'SUCCEEDED' + assert response['rows'] == [{'value': 1}] + assert response['truncated'] is True + assert response['stats']['rowCount'] == 1 + + +def test_execute_closed_pool_returns_target_unavailable_without_recreating_credentials(): + pool = FakePool(closed=True) + check = make_check(pool=pool) + + response = execute_remote_query(valid_request(), StaticPostgresCheckRegistry([check])) + + assert response['status'] == 'FAILED' + assert response_code(response) == 'target_unavailable' + assert pool.requested_dbnames == [] + + +def test_execute_rejects_non_canonical_query_before_pool_access(): + pool = FakePool() + request = valid_request(query='SELECT current_database()') + + response = execute_remote_query(request, StaticPostgresCheckRegistry([make_check(pool=pool)])) + + assert response['status'] == 'FAILED' + assert response_code(response) == 'query_rejected' + assert pool.requested_dbnames == [] diff --git a/postgres/tests/test_remote_query_integration.py b/postgres/tests/test_remote_query_integration.py new file mode 100644 index 0000000000000..173d5fd9fddaf --- /dev/null +++ b/postgres/tests/test_remote_query_integration.py @@ -0,0 +1,33 @@ +# (C) Datadog, Inc. 2026-present +# All rights reserved +# Licensed under Simplified BSD License (see LICENSE) + +import json + +import pytest + +from datadog_checks.postgres.remote_query import StaticPostgresCheckRegistry, execute_remote_query + + +@pytest.mark.integration +@pytest.mark.usefixtures('dd_environment') +def test_agent_local_executor_select_1_reuses_integration_check_credentials(integration_check, pg_instance): + check = integration_check(pg_instance) + request = { + 'target': { + 'host': pg_instance['host'], + 'port': int(pg_instance['port']), + 'dbname': pg_instance['dbname'], + }, + 'query': 'SELECT 1 AS value', + 'limits': {'maxRows': 10, 'maxBytes': 1048576, 'timeoutMs': 5000}, + } + + response = execute_remote_query(request, StaticPostgresCheckRegistry([check])) + + assert response['status'] == 'SUCCEEDED' + assert response['columns'][0]['name'] == 'value' + assert response['rows'] == [{'value': 1}] + assert response['truncated'] is False + assert response['stats']['rowCount'] == 1 + assert 'password' not in json.dumps(request).lower() From aa72bd46a8bf9b4c13fe311953c75d369d1d3f6b Mon Sep 17 00:00:00 2001 From: nubtron Date: Tue, 28 Apr 2026 16:22:13 +0000 Subject: [PATCH 02/11] Simplify Postgres remote query typed request parsing --- .../datadog_checks/postgres/remote_query.py | 70 ++++++------------- postgres/tests/test_remote_query.py | 26 +++++-- 2 files changed, 39 insertions(+), 57 deletions(-) diff --git a/postgres/datadog_checks/postgres/remote_query.py b/postgres/datadog_checks/postgres/remote_query.py index 506713ba843be..f0eadc01da50a 100644 --- a/postgres/datadog_checks/postgres/remote_query.py +++ b/postgres/datadog_checks/postgres/remote_query.py @@ -21,10 +21,6 @@ _LIMIT_FIELDS = frozenset({'maxRows', 'maxBytes', 'timeoutMs'}) -class UnknownFieldsError(ValueError): - pass - - @dataclass(frozen=True) class RemoteQueryTarget: host: str @@ -42,7 +38,6 @@ class RemoteQueryLimits: @dataclass(frozen=True) class RemoteQueryRequest: target: RemoteQueryTarget - query: str limits: RemoteQueryLimits @@ -78,8 +73,6 @@ def execute_remote_query(request: Mapping[str, Any], registry: PostgresCheckRegi def normalize_target(target: Mapping[str, Any]) -> RemoteQueryTarget: - _reject_unknown_fields(target, _TARGET_FIELDS, 'target') - host = target.get('host') if not isinstance(host, str) or not host.strip(): raise ValueError('host must be a non-empty string') @@ -90,7 +83,9 @@ def normalize_target(target: Mapping[str, Any]) -> RemoteQueryTarget: if dbname != dbname.strip(): raise ValueError('dbname must not contain surrounding whitespace') - return RemoteQueryTarget(host=_normalize_host(host), port=_normalize_port(target.get('port', 5432)), dbname=dbname) + return RemoteQueryTarget( + host=_normalize_host(host), port=_int_in_range(target.get('port', 5432), 'port', maximum=65535), dbname=dbname + ) def _parse_request(value: Any) -> RemoteQueryRequest | dict[str, Any]: @@ -112,17 +107,19 @@ def _parse_request(value: Any) -> RemoteQueryRequest | dict[str, Any]: if isinstance(limits_or_error, dict): return limits_or_error - return RemoteQueryRequest(target=target_or_error, query=_ALLOWED_QUERY, limits=limits_or_error) + return RemoteQueryRequest(target=target_or_error, limits=limits_or_error) def _parse_target(value: Any) -> RemoteQueryTarget | dict[str, Any]: if not isinstance(value, Mapping): return _error('invalid_selector', 'Target selector must be a mapping.') + unknown_fields_error = _unknown_fields_error(value, _TARGET_FIELDS, 'target') + if unknown_fields_error is not None: + return unknown_fields_error + try: return normalize_target(value) - except UnknownFieldsError as e: - return _error('invalid_request', str(e)) except ValueError as e: return _error('invalid_selector', str(e)) @@ -139,9 +136,9 @@ def _parse_limits(value: Any) -> RemoteQueryLimits | dict[str, Any]: try: return RemoteQueryLimits( - max_rows=_positive_int(value.get('maxRows', 10), 'maxRows'), - max_bytes=_positive_int(value.get('maxBytes', 1_048_576), 'maxBytes'), - timeout_ms=_positive_int(value.get('timeoutMs', 5_000), 'timeoutMs'), + max_rows=_int_in_range(value.get('maxRows', 10), 'maxRows'), + max_bytes=_int_in_range(value.get('maxBytes', 1_048_576), 'maxBytes'), + timeout_ms=_int_in_range(value.get('timeoutMs', 5_000), 'timeoutMs'), ) except ValueError as e: return _error('invalid_request', str(e)) @@ -156,7 +153,7 @@ def _resolve_matches(target: RemoteQueryTarget, checks: Iterable['PostgreSql']) try: candidate = RemoteQueryTarget( host=_normalize_host(config.host), - port=_normalize_port(config.port), + port=_int_in_range(config.port, 'port', maximum=65535), dbname=config.dbname, ) except (AttributeError, ValueError): @@ -201,12 +198,6 @@ def _execute_select_1(check: 'PostgreSql', target: RemoteQueryTarget, limits: Re } -def _reject_unknown_fields(value: Mapping[str, Any], allowed_fields: frozenset[str], label: str) -> None: - unknown_fields = _unknown_field_names(value, allowed_fields) - if unknown_fields: - raise UnknownFieldsError(_unknown_fields_message(unknown_fields, label)) - - def _unknown_fields_error( value: Mapping[str, Any], allowed_fields: frozenset[str], label: str ) -> dict[str, Any] | None: @@ -240,35 +231,14 @@ def _normalize_host(value: str) -> str: return host -def _normalize_port(value: Any) -> int: - if isinstance(value, bool): - raise ValueError('port must be an integer') - if isinstance(value, int): - port = value - elif isinstance(value, str): - if not value.isdigit(): - raise ValueError('port must be an integer') - port = int(value) - else: - raise ValueError('port must be an integer') - - if port <= 0 or port > 65535: - raise ValueError('port must be between 1 and 65535') - return port - - -def _positive_int(value: Any, field: str) -> int: - if isinstance(value, bool): - raise ValueError(f'{field} must be a positive integer') - if isinstance(value, int): - number = value - elif isinstance(value, str) and value.isdigit(): - number = int(value) - else: - raise ValueError(f'{field} must be a positive integer') - if number <= 0: - raise ValueError(f'{field} must be a positive integer') - return number +def _int_in_range(value: Any, field: str, *, minimum: int = 1, maximum: int | None = None) -> int: + if not isinstance(value, int) or isinstance(value, bool): + raise ValueError(f'{field} must be an integer') + if value < minimum or (maximum is not None and value > maximum): + if maximum is None: + raise ValueError(f'{field} must be greater than or equal to {minimum}') + raise ValueError(f'{field} must be between {minimum} and {maximum}') + return value def _column_name(column: Any) -> str: diff --git a/postgres/tests/test_remote_query.py b/postgres/tests/test_remote_query.py index e44258e58cdcf..b034b742c5d95 100644 --- a/postgres/tests/test_remote_query.py +++ b/postgres/tests/test_remote_query.py @@ -53,7 +53,7 @@ def fetchmany(self, size): return self.rows[:size] -def make_check(host='localhost', port='5432', dbname='datadog_test', pool=None, **metadata): +def make_check(host='localhost', port=5432, dbname='datadog_test', pool=None, **metadata): check = SimpleNamespace( _config=SimpleNamespace(host=host, port=port, dbname=dbname, **metadata), db_pool=pool or FakePool(), @@ -84,7 +84,7 @@ def response_code(response): def test_normalize_target_trims_lowercases_host_and_removes_one_trailing_dot(): - target = normalize_target({'host': ' Example.INTERNAL. ', 'port': '5432', 'dbname': 'postgres'}) + target = normalize_target({'host': ' Example.INTERNAL. ', 'port': 5432, 'dbname': 'postgres'}) assert target.host == 'example.internal' assert target.port == 5432 @@ -97,7 +97,7 @@ def test_normalize_target_defaults_missing_port_to_5432(): assert target.port == 5432 -@pytest.mark.parametrize('port', [True, 'abc', '0', 0, -1, 65536, None]) +@pytest.mark.parametrize('port', [True, '5432', 'abc', '0', 0, -1, 65536, None]) def test_normalize_target_rejects_invalid_port_values(port): with pytest.raises(ValueError): normalize_target({'host': 'localhost', 'port': port, 'dbname': 'postgres'}) @@ -154,11 +154,23 @@ def test_rejects_unknown_limits_fields_before_resolution(): assert 'SECRET_DO_NOT_LOG' not in str(response) +@pytest.mark.parametrize('field', ['maxRows', 'maxBytes', 'timeoutMs']) +def test_rejects_string_limit_values_before_resolution(field): + request = valid_request() + request['limits'][field] = '10' + + response = execute_remote_query(request, ExplodingRegistry()) + + assert response['status'] == 'FAILED' + assert response_code(response) == 'invalid_request' + assert field in response['error']['message'] + + def test_resolve_matches_exact_host_port_dbname_from_check_config(): pool = FakePool() - check = make_check(host='localhost', port='5432', dbname='datadog_test', pool=pool) + check = make_check(host='localhost', port=5432, dbname='datadog_test', pool=pool) - response = execute_remote_query(valid_request(host='LOCALHOST.', port='5432'), StaticPostgresCheckRegistry([check])) + response = execute_remote_query(valid_request(host='LOCALHOST.', port=5432), StaticPostgresCheckRegistry([check])) assert response['status'] == 'SUCCEEDED' assert response['rows'] == [{'value': 1}] @@ -167,7 +179,7 @@ def test_resolve_matches_exact_host_port_dbname_from_check_config(): def test_resolve_requires_dbname_match_even_when_host_and_port_match(): pool = FakePool() - check = make_check(host='localhost', port='5432', dbname='datadog_test', pool=pool) + check = make_check(host='localhost', port=5432, dbname='datadog_test', pool=pool) response = execute_remote_query(valid_request(dbname='postgres'), StaticPostgresCheckRegistry([check])) @@ -180,7 +192,7 @@ def test_resolve_ignores_metadata_identity_matches(): pool = FakePool() check = make_check( host='configured.internal', - port='5432', + port=5432, dbname='datadog_test', pool=pool, reported_hostname='reported.internal', From dda18ae29523cd895e67f2f704b3e34c223ba491 Mon Sep 17 00:00:00 2001 From: nubtron Date: Tue, 28 Apr 2026 16:28:00 +0000 Subject: [PATCH 03/11] Simplify Postgres remote query result shaping --- .../datadog_checks/postgres/remote_query.py | 32 ++----------------- 1 file changed, 2 insertions(+), 30 deletions(-) diff --git a/postgres/datadog_checks/postgres/remote_query.py b/postgres/datadog_checks/postgres/remote_query.py index f0eadc01da50a..50309805528bf 100644 --- a/postgres/datadog_checks/postgres/remote_query.py +++ b/postgres/datadog_checks/postgres/remote_query.py @@ -176,7 +176,6 @@ def _execute_select_1(check: 'PostgreSql', target: RemoteQueryTarget, limits: Re cursor.execute(_ALLOWED_QUERY) if cursor.description is None: return _error('query_failed', 'Query did not return a result set.') - columns = [_column_name(column) for column in cursor.description] raw_rows = cursor.fetchmany(limits.max_rows + 1) except RuntimeError: return _error('target_unavailable', 'Matched Postgres check connection pool is unavailable.', retryable=False) @@ -185,8 +184,8 @@ def _execute_select_1(check: 'PostgreSql', target: RemoteQueryTarget, limits: Re return _error('query_failed', 'Remote query execution failed.') truncated = len(raw_rows) > limits.max_rows - rows = [_row_to_dict(columns, row) for row in raw_rows[: limits.max_rows]] - response_columns = [{'name': name, 'type': _infer_type(rows, name)} for name in columns] + rows = [{'value': row[0]} for row in raw_rows[: limits.max_rows]] + response_columns = [{'name': 'value', 'type': 'integer'}] bytes_returned = len(json.dumps({'columns': response_columns, 'rows': rows}, default=str).encode('utf-8')) return { @@ -241,32 +240,5 @@ def _int_in_range(value: Any, field: str, *, minimum: int = 1, maximum: int | No return value -def _column_name(column: Any) -> str: - name = getattr(column, 'name', None) - if name is not None: - return str(name) - return str(column[0]) - - -def _row_to_dict(columns: list[str], row: Any) -> dict[str, Any]: - if isinstance(row, Mapping): - return {column: row[column] for column in columns} - return dict(zip(columns, row)) - - -def _infer_type(rows: list[dict[str, Any]], column: str) -> str: - for row in rows: - value = row.get(column) - if isinstance(value, bool): - return 'boolean' - if isinstance(value, int): - return 'integer' - if isinstance(value, float): - return 'number' - if value is not None: - return 'string' - return 'unknown' - - def _error(code: str, message: str, retryable: bool = False) -> dict[str, Any]: return {'status': 'FAILED', 'error': {'code': code, 'message': message, 'retryable': retryable}} From c3390f5dd881d8c61a71a57a0a4a0d1073a9bdbf Mon Sep 17 00:00:00 2001 From: nubtron Date: Wed, 29 Apr 2026 14:16:48 +0000 Subject: [PATCH 04/11] Regenerate labeler config --- .github/workflows/config/labeler.yml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/config/labeler.yml b/.github/workflows/config/labeler.yml index e440a28f2f028..63f7249b4c2e1 100644 --- a/.github/workflows/config/labeler.yml +++ b/.github/workflows/config/labeler.yml @@ -890,10 +890,6 @@ integration/langchain: - changed-files: - any-glob-to-any-file: - langchain/**/* -integration/lparstats: -- changed-files: - - any-glob-to-any-file: - - lparstats/**/* integration/lastpass: - changed-files: - any-glob-to-any-file: @@ -918,6 +914,10 @@ integration/litellm: - changed-files: - any-glob-to-any-file: - litellm/**/* +integration/lparstats: +- changed-files: + - any-glob-to-any-file: + - lparstats/**/* integration/lustre: - changed-files: - any-glob-to-any-file: @@ -1002,10 +1002,6 @@ integration/nagios: - changed-files: - any-glob-to-any-file: - nagios/**/* -integration/nifi: -- changed-files: - - any-glob-to-any-file: - - nifi/**/* integration/network: - changed-files: - any-glob-to-any-file: @@ -1026,6 +1022,10 @@ integration/nginx_ingress_controller: - changed-files: - any-glob-to-any-file: - nginx_ingress_controller/**/* +integration/nifi: +- changed-files: + - any-glob-to-any-file: + - nifi/**/* integration/ntp: - changed-files: - any-glob-to-any-file: From cc5dbb48de8f5819789d24770fa2aef54f6e9eb3 Mon Sep 17 00:00:00 2001 From: nubtron Date: Wed, 29 Apr 2026 14:16:52 +0000 Subject: [PATCH 05/11] Update Postgres remote query license headers --- postgres/datadog_checks/postgres/remote_query.py | 2 +- postgres/tests/test_remote_query.py | 2 +- postgres/tests/test_remote_query_integration.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/postgres/datadog_checks/postgres/remote_query.py b/postgres/datadog_checks/postgres/remote_query.py index 50309805528bf..10c2317316fab 100644 --- a/postgres/datadog_checks/postgres/remote_query.py +++ b/postgres/datadog_checks/postgres/remote_query.py @@ -1,6 +1,6 @@ # (C) Datadog, Inc. 2026-present # All rights reserved -# Licensed under Simplified BSD License (see LICENSE) +# Licensed under a 3-clause BSD style license (see LICENSE) from __future__ import annotations diff --git a/postgres/tests/test_remote_query.py b/postgres/tests/test_remote_query.py index b034b742c5d95..0ddda99c50761 100644 --- a/postgres/tests/test_remote_query.py +++ b/postgres/tests/test_remote_query.py @@ -1,6 +1,6 @@ # (C) Datadog, Inc. 2026-present # All rights reserved -# Licensed under Simplified BSD License (see LICENSE) +# Licensed under a 3-clause BSD style license (see LICENSE) from contextlib import contextmanager from types import SimpleNamespace diff --git a/postgres/tests/test_remote_query_integration.py b/postgres/tests/test_remote_query_integration.py index 173d5fd9fddaf..2ad003cc2ffd0 100644 --- a/postgres/tests/test_remote_query_integration.py +++ b/postgres/tests/test_remote_query_integration.py @@ -1,6 +1,6 @@ # (C) Datadog, Inc. 2026-present # All rights reserved -# Licensed under Simplified BSD License (see LICENSE) +# Licensed under a 3-clause BSD style license (see LICENSE) import json From 92ac25caeb57c24419c8f04d560babf445b2096d Mon Sep 17 00:00:00 2001 From: nubtron Date: Thu, 30 Apr 2026 08:03:26 +0000 Subject: [PATCH 06/11] Fix Postgres remote query changelog number --- .../changelog.d/{23476.added => 23499.added} | 0 .../datadog_checks/postgres/remote_query.py | 169 +++++++----------- postgres/tests/test_remote_query.py | 8 +- 3 files changed, 66 insertions(+), 111 deletions(-) rename postgres/changelog.d/{23476.added => 23499.added} (100%) diff --git a/postgres/changelog.d/23476.added b/postgres/changelog.d/23499.added similarity index 100% rename from postgres/changelog.d/23476.added rename to postgres/changelog.d/23499.added diff --git a/postgres/datadog_checks/postgres/remote_query.py b/postgres/datadog_checks/postgres/remote_query.py index 10c2317316fab..c283103362a8d 100644 --- a/postgres/datadog_checks/postgres/remote_query.py +++ b/postgres/datadog_checks/postgres/remote_query.py @@ -8,7 +8,9 @@ import logging from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Any, Literal, Protocol + +from pydantic import BaseModel, ConfigDict, Field, StrictInt, StrictStr, ValidationError, field_validator if TYPE_CHECKING: from datadog_checks.postgres import PostgreSql @@ -16,29 +18,49 @@ LOGGER = logging.getLogger(__name__) _ALLOWED_QUERY = 'SELECT 1 AS value' -_REQUEST_FIELDS = frozenset({'target', 'query', 'limits'}) -_TARGET_FIELDS = frozenset({'host', 'port', 'dbname'}) -_LIMIT_FIELDS = frozenset({'maxRows', 'maxBytes', 'timeoutMs'}) -@dataclass(frozen=True) -class RemoteQueryTarget: - host: str - port: int - dbname: str +class RemoteQueryTarget(BaseModel): + model_config = ConfigDict(extra='forbid', frozen=True) + host: StrictStr = Field(min_length=1) + port: StrictInt = Field(default=5432, ge=1, le=65535) + dbname: StrictStr = Field(min_length=1) -@dataclass(frozen=True) -class RemoteQueryLimits: - max_rows: int = 10 - max_bytes: int = 1_048_576 - timeout_ms: int = 5_000 + @field_validator('host') + @classmethod + def normalize_host(cls, value: str) -> str: + host = value.strip().lower() + if host.endswith('.'): + host = host[:-1] + if not host: + raise ValueError('host must be a non-empty string') + return host + @field_validator('dbname') + @classmethod + def validate_dbname(cls, value: str) -> str: + if not value: + raise ValueError('dbname must be a non-empty string') + if value != value.strip(): + raise ValueError('dbname must not contain surrounding whitespace') + return value + + +class RemoteQueryLimits(BaseModel): + model_config = ConfigDict(extra='forbid', frozen=True) + + max_rows: StrictInt = Field(default=10, alias='maxRows', ge=1) + max_bytes: StrictInt = Field(default=1_048_576, alias='maxBytes', ge=1) + timeout_ms: StrictInt = Field(default=5_000, alias='timeoutMs', ge=1) + + +class RemoteQueryRequest(BaseModel): + model_config = ConfigDict(extra='forbid', frozen=True) -@dataclass(frozen=True) -class RemoteQueryRequest: target: RemoteQueryTarget - limits: RemoteQueryLimits + query: Literal['SELECT 1 AS value'] + limits: RemoteQueryLimits = Field(default_factory=RemoteQueryLimits) @dataclass(frozen=True) @@ -53,12 +75,12 @@ class PostgresCheckRegistry(Protocol): def iter_postgres_checks(self) -> Iterable['PostgreSql']: ... -def execute_remote_query(request: Mapping[str, Any], registry: PostgresCheckRegistry) -> dict[str, Any]: - request_or_error = _parse_request(request) - if isinstance(request_or_error, dict): - return request_or_error +def execute_remote_query(request: Any, registry: PostgresCheckRegistry) -> dict[str, Any]: + try: + parsed_request = RemoteQueryRequest.model_validate(request) + except ValidationError as e: + return _validation_error(e) - parsed_request = request_or_error target = parsed_request.target limits = parsed_request.limits @@ -73,75 +95,10 @@ def execute_remote_query(request: Mapping[str, Any], registry: PostgresCheckRegi def normalize_target(target: Mapping[str, Any]) -> RemoteQueryTarget: - host = target.get('host') - if not isinstance(host, str) or not host.strip(): - raise ValueError('host must be a non-empty string') - - dbname = target.get('dbname') - if not isinstance(dbname, str) or not dbname: - raise ValueError('dbname must be a non-empty string') - if dbname != dbname.strip(): - raise ValueError('dbname must not contain surrounding whitespace') - - return RemoteQueryTarget( - host=_normalize_host(host), port=_int_in_range(target.get('port', 5432), 'port', maximum=65535), dbname=dbname - ) - - -def _parse_request(value: Any) -> RemoteQueryRequest | dict[str, Any]: - if not isinstance(value, Mapping): - return _error('invalid_request', 'Remote query request must be a mapping.') - - unknown_fields_error = _unknown_fields_error(value, _REQUEST_FIELDS, 'request') - if unknown_fields_error is not None: - return unknown_fields_error - - target_or_error = _parse_target(value.get('target')) - if isinstance(target_or_error, dict): - return target_or_error - - if not _is_allowed_query(value.get('query')): - return _error('query_rejected', 'Only the canonical SELECT 1 proof query is allowed.') - - limits_or_error = _parse_limits(value.get('limits', {})) - if isinstance(limits_or_error, dict): - return limits_or_error - - return RemoteQueryRequest(target=target_or_error, limits=limits_or_error) - - -def _parse_target(value: Any) -> RemoteQueryTarget | dict[str, Any]: - if not isinstance(value, Mapping): - return _error('invalid_selector', 'Target selector must be a mapping.') - - unknown_fields_error = _unknown_fields_error(value, _TARGET_FIELDS, 'target') - if unknown_fields_error is not None: - return unknown_fields_error - try: - return normalize_target(value) - except ValueError as e: - return _error('invalid_selector', str(e)) - - -def _parse_limits(value: Any) -> RemoteQueryLimits | dict[str, Any]: - if value is None: - value = {} - if not isinstance(value, Mapping): - return _error('invalid_request', 'Limits must be a mapping.') - - unknown_fields_error = _unknown_fields_error(value, _LIMIT_FIELDS, 'limits') - if unknown_fields_error is not None: - return unknown_fields_error - - try: - return RemoteQueryLimits( - max_rows=_int_in_range(value.get('maxRows', 10), 'maxRows'), - max_bytes=_int_in_range(value.get('maxBytes', 1_048_576), 'maxBytes'), - timeout_ms=_int_in_range(value.get('timeoutMs', 5_000), 'timeoutMs'), - ) - except ValueError as e: - return _error('invalid_request', str(e)) + return RemoteQueryTarget.model_validate(target) + except ValidationError as e: + raise ValueError(_validation_message(e)) from e def _resolve_matches(target: RemoteQueryTarget, checks: Iterable['PostgreSql']) -> list['PostgreSql']: @@ -197,28 +154,24 @@ def _execute_select_1(check: 'PostgreSql', target: RemoteQueryTarget, limits: Re } -def _unknown_fields_error( - value: Mapping[str, Any], allowed_fields: frozenset[str], label: str -) -> dict[str, Any] | None: - unknown_fields = _unknown_field_names(value, allowed_fields) - if unknown_fields: - return _error('invalid_request', _unknown_fields_message(unknown_fields, label)) - return None - - -def _unknown_field_names(value: Mapping[str, Any], allowed_fields: frozenset[str]) -> list[str]: - return sorted(str(field) for field in value if field not in allowed_fields) +def _validation_error(error: ValidationError) -> dict[str, Any]: + return _error('invalid_request', _validation_message(error)) -def _unknown_fields_message(unknown_fields: list[str], label: str) -> str: - field_label = 'field' if len(unknown_fields) == 1 else 'fields' - return f"{label} contains unknown {field_label}: {', '.join(unknown_fields)}" +def _validation_message(error: ValidationError) -> str: + details = [] + for item in error.errors(include_input=False): + location = _validation_location(item.get('loc', ())) + message = item.get('msg', 'Invalid value') + if location: + details.append(f'{location}: {message}') + else: + details.append(message) + return 'Invalid remote query request: {}'.format('; '.join(details)) -def _is_allowed_query(value: Any) -> bool: - if not isinstance(value, str): - return False - return value.strip().rstrip(';').strip() == _ALLOWED_QUERY +def _validation_location(location: tuple[Any, ...]) -> str: + return '.'.join(str(part) for part in location) def _normalize_host(value: str) -> str: diff --git a/postgres/tests/test_remote_query.py b/postgres/tests/test_remote_query.py index 0ddda99c50761..3d9831166c67d 100644 --- a/postgres/tests/test_remote_query.py +++ b/postgres/tests/test_remote_query.py @@ -244,12 +244,14 @@ def test_execute_closed_pool_returns_target_unavailable_without_recreating_crede assert pool.requested_dbnames == [] -def test_execute_rejects_non_canonical_query_before_pool_access(): +@pytest.mark.parametrize('query', ['SELECT current_database()', 'SELECT 1 AS value;', ' SELECT 1 AS value']) +def test_execute_rejects_non_canonical_query_before_pool_access(query): pool = FakePool() - request = valid_request(query='SELECT current_database()') + request = valid_request(query=query) response = execute_remote_query(request, StaticPostgresCheckRegistry([make_check(pool=pool)])) assert response['status'] == 'FAILED' - assert response_code(response) == 'query_rejected' + assert response_code(response) == 'invalid_request' + assert 'query' in response['error']['message'] assert pool.requested_dbnames == [] From f3b3d43707a6b2f21774a4b9d1d5d70241687e56 Mon Sep 17 00:00:00 2001 From: nubtron Date: Sat, 2 May 2026 18:17:55 +0000 Subject: [PATCH 07/11] Allow fixture table remote query --- .../datadog_checks/postgres/remote_query.py | 128 +++++++----- postgres/tests/test_remote_query.py | 182 ++++++++++++++---- .../tests/test_remote_query_integration.py | 36 +++- 3 files changed, 259 insertions(+), 87 deletions(-) diff --git a/postgres/datadog_checks/postgres/remote_query.py b/postgres/datadog_checks/postgres/remote_query.py index c283103362a8d..b05734addd316 100644 --- a/postgres/datadog_checks/postgres/remote_query.py +++ b/postgres/datadog_checks/postgres/remote_query.py @@ -12,13 +12,13 @@ from pydantic import BaseModel, ConfigDict, Field, StrictInt, StrictStr, ValidationError, field_validator +RemoteQuerySql = Literal['SELECT 1 AS value', 'SELECT city, country FROM cities ORDER BY city'] + if TYPE_CHECKING: from datadog_checks.postgres import PostgreSql LOGGER = logging.getLogger(__name__) -_ALLOWED_QUERY = 'SELECT 1 AS value' - class RemoteQueryTarget(BaseModel): model_config = ConfigDict(extra='forbid', frozen=True) @@ -48,6 +48,8 @@ def validate_dbname(cls, value: str) -> str: class RemoteQueryLimits(BaseModel): + """Validate the future-facing limits contract for the initial safe query slice.""" + model_config = ConfigDict(extra='forbid', frozen=True) max_rows: StrictInt = Field(default=10, alias='maxRows', ge=1) @@ -56,10 +58,12 @@ class RemoteQueryLimits(BaseModel): class RemoteQueryRequest(BaseModel): + """Accept only exact proof queries until broader SQL execution is implemented.""" + model_config = ConfigDict(extra='forbid', frozen=True) target: RemoteQueryTarget - query: Literal['SELECT 1 AS value'] + query: RemoteQuerySql limits: RemoteQueryLimits = Field(default_factory=RemoteQueryLimits) @@ -75,6 +79,20 @@ class PostgresCheckRegistry(Protocol): def iter_postgres_checks(self) -> Iterable['PostgreSql']: ... +def execute_agent_rpc_json(request_json: str | bytes | bytearray, check: 'PostgreSql') -> str: + try: + request = json.loads(request_json) + except (TypeError, ValueError): + response = _error('invalid_request', 'Invalid remote query request: request_json must be a valid JSON object.') + else: + if not isinstance(request, Mapping): + response = _error('invalid_request', 'Invalid remote query request: request_json must be a JSON object.') + else: + response = execute_remote_query(request, StaticPostgresCheckRegistry([check])) + + return json.dumps(response, default=str) + + def execute_remote_query(request: Any, registry: PostgresCheckRegistry) -> dict[str, Any]: try: parsed_request = RemoteQueryRequest.model_validate(request) @@ -83,6 +101,7 @@ def execute_remote_query(request: Any, registry: PostgresCheckRegistry) -> dict[ target = parsed_request.target limits = parsed_request.limits + query = parsed_request.query matches = _resolve_matches(target, registry.iter_postgres_checks()) LOGGER.debug('Remote query target match count: %d', len(matches)) @@ -91,7 +110,7 @@ def execute_remote_query(request: Any, registry: PostgresCheckRegistry) -> dict[ if len(matches) > 1: return _error('target_ambiguous', 'More than one loaded Postgres integration instance matched target selector.') - return _execute_select_1(matches[0], target, limits) + return _execute_safe_query(matches[0], target, query, limits) def normalize_target(target: Mapping[str, Any]) -> RemoteQueryTarget: @@ -102,25 +121,23 @@ def normalize_target(target: Mapping[str, Any]) -> RemoteQueryTarget: def _resolve_matches(target: RemoteQueryTarget, checks: Iterable['PostgreSql']) -> list['PostgreSql']: - matches = [] - for check in checks: - config = getattr(check, '_config', None) - if config is None: - continue - try: - candidate = RemoteQueryTarget( - host=_normalize_host(config.host), - port=_int_in_range(config.port, 'port', maximum=65535), - dbname=config.dbname, - ) - except (AttributeError, ValueError): - continue - if candidate == target: - matches.append(check) - return matches - - -def _execute_select_1(check: 'PostgreSql', target: RemoteQueryTarget, limits: RemoteQueryLimits) -> dict[str, Any]: + return [check for check in checks if _target_from_check(check) == target] + + +def _target_from_check(check: 'PostgreSql') -> RemoteQueryTarget | None: + config = getattr(check, '_config', None) + if config is None: + return None + + try: + return RemoteQueryTarget(host=config.host, port=config.port, dbname=config.dbname) + except (AttributeError, ValidationError): + return None + + +def _execute_safe_query( + check: 'PostgreSql', target: RemoteQueryTarget, query: RemoteQuerySql, limits: RemoteQueryLimits +) -> dict[str, Any]: db_pool = getattr(check, 'db_pool', None) if db_pool is None: return _error('credentials_unavailable', 'Matched Postgres check does not expose a connection pool.') @@ -130,8 +147,9 @@ def _execute_select_1(check: 'PostgreSql', target: RemoteQueryTarget, limits: Re try: with db_pool.get_connection(target.dbname) as conn: with conn.cursor() as cursor: - cursor.execute(_ALLOWED_QUERY) - if cursor.description is None: + cursor.execute(query) + description = cursor.description + if description is None: return _error('query_failed', 'Query did not return a result set.') raw_rows = cursor.fetchmany(limits.max_rows + 1) except RuntimeError: @@ -140,9 +158,10 @@ def _execute_select_1(check: 'PostgreSql', target: RemoteQueryTarget, limits: Re LOGGER.exception('Remote query execution failed') return _error('query_failed', 'Remote query execution failed.') + # max_bytes and timeout_ms are validated for the API contract but enforced in a follow-up slice. truncated = len(raw_rows) > limits.max_rows - rows = [{'value': row[0]} for row in raw_rows[: limits.max_rows]] - response_columns = [{'name': 'value', 'type': 'integer'}] + response_columns = _response_columns(description, raw_rows) + rows = [_response_row(response_columns, row) for row in raw_rows[: limits.max_rows]] bytes_returned = len(json.dumps({'columns': response_columns, 'rows': rows}, default=str).encode('utf-8')) return { @@ -154,6 +173,42 @@ def _execute_select_1(check: 'PostgreSql', target: RemoteQueryTarget, limits: Re } +def _response_columns(description: Sequence[Any], rows: Sequence[Sequence[Any]]) -> list[dict[str, str]]: + return [ + {'name': _column_name(column), 'type': _column_type(index, rows)} for index, column in enumerate(description) + ] + + +def _column_name(column: Any) -> str: + name = getattr(column, 'name', None) + if name is not None: + return str(name) + return str(column[0]) + + +def _column_type(index: int, rows: Sequence[Sequence[Any]]) -> str: + for row in rows: + if row[index] is not None: + return _value_type(row[index]) + return 'unknown' + + +def _value_type(value: Any) -> str: + if isinstance(value, bool): + return 'boolean' + if isinstance(value, int): + return 'integer' + if isinstance(value, float): + return 'number' + if isinstance(value, str): + return 'string' + return type(value).__name__ + + +def _response_row(columns: Sequence[Mapping[str, str]], row: Sequence[Any]) -> dict[str, Any]: + return {column['name']: row[index] for index, column in enumerate(columns)} + + def _validation_error(error: ValidationError) -> dict[str, Any]: return _error('invalid_request', _validation_message(error)) @@ -174,24 +229,5 @@ def _validation_location(location: tuple[Any, ...]) -> str: return '.'.join(str(part) for part in location) -def _normalize_host(value: str) -> str: - host = value.strip().lower() - if host.endswith('.'): - host = host[:-1] - if not host: - raise ValueError('host must be a non-empty string') - return host - - -def _int_in_range(value: Any, field: str, *, minimum: int = 1, maximum: int | None = None) -> int: - if not isinstance(value, int) or isinstance(value, bool): - raise ValueError(f'{field} must be an integer') - if value < minimum or (maximum is not None and value > maximum): - if maximum is None: - raise ValueError(f'{field} must be greater than or equal to {minimum}') - raise ValueError(f'{field} must be between {minimum} and {maximum}') - return value - - def _error(code: str, message: str, retryable: bool = False) -> dict[str, Any]: return {'status': 'FAILED', 'error': {'code': code, 'message': message, 'retryable': retryable}} diff --git a/postgres/tests/test_remote_query.py b/postgres/tests/test_remote_query.py index 3d9831166c67d..7e1352de2f50c 100644 --- a/postgres/tests/test_remote_query.py +++ b/postgres/tests/test_remote_query.py @@ -2,6 +2,7 @@ # All rights reserved # Licensed under a 3-clause BSD style license (see LICENSE) +import json from contextlib import contextmanager from types import SimpleNamespace @@ -9,6 +10,7 @@ from datadog_checks.postgres.remote_query import ( StaticPostgresCheckRegistry, + execute_agent_rpc_json, execute_remote_query, normalize_target, ) @@ -54,10 +56,13 @@ def fetchmany(self, size): def make_check(host='localhost', port=5432, dbname='datadog_test', pool=None, **metadata): - check = SimpleNamespace( + return SimpleNamespace( _config=SimpleNamespace(host=host, port=port, dbname=dbname, **metadata), db_pool=pool or FakePool(), ) + + +def block_existing_query_helpers(check): check.execute_query_raw = pytest.fail check._run_query_scope = pytest.fail check.data_observability = SimpleNamespace(run_job=pytest.fail) @@ -79,8 +84,18 @@ def iter_postgres_checks(self): pytest.fail('registry must not be iterated') -def response_code(response): - return response['error']['code'] +def assert_failed(response, code, message_contains=None): + assert response['status'] == 'FAILED' + assert response['error']['code'] == code + if message_contains is not None: + assert message_contains in response['error']['message'] + + +def execute_agent_rpc_response(request_json, check): + response_json = execute_agent_rpc_json(request_json, check) + + assert isinstance(response_json, str) + return json.loads(response_json) def test_normalize_target_trims_lowercases_host_and_removes_one_trailing_dot(): @@ -123,9 +138,7 @@ def test_rejects_unknown_request_fields_before_resolution(caplog, field): response = execute_remote_query(request, ExplodingRegistry()) - assert response['status'] == 'FAILED' - assert response_code(response) == 'invalid_request' - assert field in response['error']['message'] + assert_failed(response, 'invalid_request', field) assert 'SECRET_DO_NOT_LOG' not in str(response) assert 'SECRET_DO_NOT_LOG' not in caplog.text @@ -136,9 +149,7 @@ def test_rejects_unknown_target_fields_before_resolution(): response = execute_remote_query(request, ExplodingRegistry()) - assert response['status'] == 'FAILED' - assert response_code(response) == 'invalid_request' - assert 'password' in response['error']['message'] + assert_failed(response, 'invalid_request', 'password') assert 'SECRET_DO_NOT_LOG' not in str(response) @@ -148,9 +159,7 @@ def test_rejects_unknown_limits_fields_before_resolution(): response = execute_remote_query(request, ExplodingRegistry()) - assert response['status'] == 'FAILED' - assert response_code(response) == 'invalid_request' - assert 'password' in response['error']['message'] + assert_failed(response, 'invalid_request', 'password') assert 'SECRET_DO_NOT_LOG' not in str(response) @@ -161,9 +170,76 @@ def test_rejects_string_limit_values_before_resolution(field): response = execute_remote_query(request, ExplodingRegistry()) - assert response['status'] == 'FAILED' - assert response_code(response) == 'invalid_request' - assert field in response['error']['message'] + assert_failed(response, 'invalid_request', field) + + +@pytest.mark.parametrize( + 'request_json', + [ + json.dumps(valid_request()), + json.dumps(valid_request()).encode(), + bytearray(json.dumps(valid_request()), 'utf-8'), + ], +) +def test_agent_rpc_json_accepts_json_request_text_and_live_check(request_json): + pool = FakePool() + check = make_check(pool=pool) + + response = execute_agent_rpc_response(request_json, check) + + assert response['status'] == 'SUCCEEDED' + assert response['rows'] == [{'value': 1}] + assert pool.requested_dbnames == ['datadog_test'] + + +@pytest.mark.parametrize('request_json', ['{"password": "SECRET_DO_NOT_LOG"', b'\xff']) +def test_agent_rpc_json_rejects_malformed_json_without_echoing_input(caplog, request_json): + pool = FakePool() + + response = execute_agent_rpc_response(request_json, make_check(pool=pool)) + + assert_failed(response, 'invalid_request', 'request_json') + assert 'SECRET_DO_NOT_LOG' not in str(response) + assert 'SECRET_DO_NOT_LOG' not in caplog.text + assert pool.requested_dbnames == [] + + +@pytest.mark.parametrize('request_json', ['[]', 'null', '"SECRET_DO_NOT_LOG"', '1']) +def test_agent_rpc_json_rejects_non_object_json_without_echoing_input(request_json): + pool = FakePool() + + response = execute_agent_rpc_response(request_json, make_check(pool=pool)) + + assert_failed(response, 'invalid_request', 'JSON object') + assert 'SECRET_DO_NOT_LOG' not in str(response) + assert pool.requested_dbnames == [] + + +def test_agent_rpc_json_reuses_strict_validation_for_request_shape(): + pool = FakePool() + request = valid_request(password='SECRET_DO_NOT_LOG') + + response = execute_agent_rpc_response(json.dumps(request), make_check(pool=pool)) + + assert_failed(response, 'invalid_request', 'password') + assert 'SECRET_DO_NOT_LOG' not in str(response) + assert pool.requested_dbnames == [] + + +def test_agent_rpc_json_uses_only_supplied_live_check_for_target_matching(): + matching_pool = FakePool() + non_matching_pool = FakePool() + request_json = json.dumps(valid_request(host='configured.internal')) + + response = execute_agent_rpc_response(request_json, make_check(host='localhost', pool=non_matching_pool)) + + assert_failed(response, 'target_not_found') + assert non_matching_pool.requested_dbnames == [] + + response = execute_agent_rpc_response(request_json, make_check(host='configured.internal', pool=matching_pool)) + + assert response['status'] == 'SUCCEEDED' + assert matching_pool.requested_dbnames == ['datadog_test'] def test_resolve_matches_exact_host_port_dbname_from_check_config(): @@ -173,18 +249,60 @@ def test_resolve_matches_exact_host_port_dbname_from_check_config(): response = execute_remote_query(valid_request(host='LOCALHOST.', port=5432), StaticPostgresCheckRegistry([check])) assert response['status'] == 'SUCCEEDED' + assert response['columns'] == [{'name': 'value', 'type': 'integer'}] assert response['rows'] == [{'value': 1}] assert pool.requested_dbnames == ['datadog_test'] +def test_execute_accepts_fixture_table_query_and_serializes_result_rows(): + pool = FakePool( + rows=[('Beautiful city of lights', 'France'), ('New York', 'USA')], + description=[SimpleNamespace(name='city'), SimpleNamespace(name='country')], + ) + check = make_check(pool=pool) + + response = execute_remote_query( + valid_request(query='SELECT city, country FROM cities ORDER BY city'), StaticPostgresCheckRegistry([check]) + ) + + assert response['status'] == 'SUCCEEDED' + assert response['columns'] == [{'name': 'city', 'type': 'string'}, {'name': 'country', 'type': 'string'}] + assert response['rows'] == [ + {'city': 'Beautiful city of lights', 'country': 'France'}, + {'city': 'New York', 'country': 'USA'}, + ] + assert response['truncated'] is False + assert response['stats']['rowCount'] == 2 + assert pool.requested_dbnames == ['datadog_test'] + + +@pytest.mark.parametrize( + 'query', + [ + 'SELECT current_database()', + 'SELECT 1 AS value;', + ' SELECT 1 AS value', + 'SELECT city, country FROM cities ORDER BY city;', + 'SELECT country, city FROM cities ORDER BY city', + ], +) +def test_execute_rejects_non_canonical_query_before_pool_access(query): + pool = FakePool() + request = valid_request(query=query) + + response = execute_remote_query(request, StaticPostgresCheckRegistry([make_check(pool=pool)])) + + assert_failed(response, 'invalid_request', 'query') + assert pool.requested_dbnames == [] + + def test_resolve_requires_dbname_match_even_when_host_and_port_match(): pool = FakePool() check = make_check(host='localhost', port=5432, dbname='datadog_test', pool=pool) response = execute_remote_query(valid_request(dbname='postgres'), StaticPostgresCheckRegistry([check])) - assert response['status'] == 'FAILED' - assert response_code(response) == 'target_not_found' + assert_failed(response, 'target_not_found') assert pool.requested_dbnames == [] @@ -201,8 +319,7 @@ def test_resolve_ignores_metadata_identity_matches(): response = execute_remote_query(valid_request(host='reported.internal'), StaticPostgresCheckRegistry([check])) - assert response['status'] == 'FAILED' - assert response_code(response) == 'target_not_found' + assert_failed(response, 'target_not_found') assert pool.requested_dbnames == [] @@ -213,8 +330,7 @@ def test_resolve_fails_ambiguous_duplicate_configs(): response = execute_remote_query(valid_request(), StaticPostgresCheckRegistry(checks)) - assert response['status'] == 'FAILED' - assert response_code(response) == 'target_ambiguous' + assert_failed(response, 'target_ambiguous') assert first_pool.requested_dbnames == [] assert second_pool.requested_dbnames == [] @@ -233,25 +349,21 @@ def test_execute_sets_truncated_when_more_than_max_rows_returned(): assert response['stats']['rowCount'] == 1 -def test_execute_closed_pool_returns_target_unavailable_without_recreating_credentials(): - pool = FakePool(closed=True) - check = make_check(pool=pool) +def test_execute_uses_connection_pool_not_existing_query_helpers(): + pool = FakePool() + check = block_existing_query_helpers(make_check(pool=pool)) response = execute_remote_query(valid_request(), StaticPostgresCheckRegistry([check])) - assert response['status'] == 'FAILED' - assert response_code(response) == 'target_unavailable' - assert pool.requested_dbnames == [] + assert response['status'] == 'SUCCEEDED' + assert pool.requested_dbnames == ['datadog_test'] -@pytest.mark.parametrize('query', ['SELECT current_database()', 'SELECT 1 AS value;', ' SELECT 1 AS value']) -def test_execute_rejects_non_canonical_query_before_pool_access(query): - pool = FakePool() - request = valid_request(query=query) +def test_execute_closed_pool_returns_target_unavailable_without_recreating_credentials(): + pool = FakePool(closed=True) + check = make_check(pool=pool) - response = execute_remote_query(request, StaticPostgresCheckRegistry([make_check(pool=pool)])) + response = execute_remote_query(valid_request(), StaticPostgresCheckRegistry([check])) - assert response['status'] == 'FAILED' - assert response_code(response) == 'invalid_request' - assert 'query' in response['error']['message'] + assert_failed(response, 'target_unavailable') assert pool.requested_dbnames == [] diff --git a/postgres/tests/test_remote_query_integration.py b/postgres/tests/test_remote_query_integration.py index 2ad003cc2ffd0..05cb065d67274 100644 --- a/postgres/tests/test_remote_query_integration.py +++ b/postgres/tests/test_remote_query_integration.py @@ -9,20 +9,24 @@ from datadog_checks.postgres.remote_query import StaticPostgresCheckRegistry, execute_remote_query -@pytest.mark.integration -@pytest.mark.usefixtures('dd_environment') -def test_agent_local_executor_select_1_reuses_integration_check_credentials(integration_check, pg_instance): - check = integration_check(pg_instance) - request = { +def remote_query_request(pg_instance: dict[str, object], query: str) -> dict[str, object]: + return { 'target': { 'host': pg_instance['host'], 'port': int(pg_instance['port']), 'dbname': pg_instance['dbname'], }, - 'query': 'SELECT 1 AS value', + 'query': query, 'limits': {'maxRows': 10, 'maxBytes': 1048576, 'timeoutMs': 5000}, } + +@pytest.mark.integration +@pytest.mark.usefixtures('dd_environment') +def test_agent_local_executor_select_1_reuses_integration_check_credentials(integration_check, pg_instance): + check = integration_check(pg_instance) + request = remote_query_request(pg_instance, 'SELECT 1 AS value') + response = execute_remote_query(request, StaticPostgresCheckRegistry([check])) assert response['status'] == 'SUCCEEDED' @@ -31,3 +35,23 @@ def test_agent_local_executor_select_1_reuses_integration_check_credentials(inte assert response['truncated'] is False assert response['stats']['rowCount'] == 1 assert 'password' not in json.dumps(request).lower() + + +@pytest.mark.integration +@pytest.mark.usefixtures('dd_environment') +def test_agent_local_executor_fixture_table_query_returns_city_rows(integration_check, pg_instance): + bob_instance = dict(pg_instance, username='bob', password='bob') + check = integration_check(bob_instance) + request = remote_query_request(bob_instance, 'SELECT city, country FROM cities ORDER BY city') + + response = execute_remote_query(request, StaticPostgresCheckRegistry([check])) + + assert response['status'] == 'SUCCEEDED' + assert response['columns'] == [{'name': 'city', 'type': 'string'}, {'name': 'country', 'type': 'string'}] + assert response['rows'] == [ + {'city': 'Beautiful city of lights', 'country': 'France'}, + {'city': 'New York', 'country': 'USA'}, + ] + assert response['truncated'] is False + assert response['stats']['rowCount'] == 2 + assert 'password' not in json.dumps(request).lower() From da2b81c34a43c9d1ab4b7cb301c426f107a7999f Mon Sep 17 00:00:00 2001 From: nubtron Date: Sat, 2 May 2026 18:43:16 +0000 Subject: [PATCH 08/11] Allow large remote query payload proofs --- .../datadog_checks/postgres/remote_query.py | 11 ++++++++++- postgres/tests/test_remote_query.py | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/postgres/datadog_checks/postgres/remote_query.py b/postgres/datadog_checks/postgres/remote_query.py index b05734addd316..c63dc3ee374a0 100644 --- a/postgres/datadog_checks/postgres/remote_query.py +++ b/postgres/datadog_checks/postgres/remote_query.py @@ -12,7 +12,16 @@ from pydantic import BaseModel, ConfigDict, Field, StrictInt, StrictStr, ValidationError, field_validator -RemoteQuerySql = Literal['SELECT 1 AS value', 'SELECT city, country FROM cities ORDER BY city'] +RemoteQuerySql = Literal[ + 'SELECT 1 AS value', + 'SELECT city, country FROM cities ORDER BY city', + "SELECT repeat('x', 1048576) AS payload", + "SELECT repeat('x', 2097152) AS payload", + "SELECT repeat('x', 4194304) AS payload", + "SELECT repeat('x', 8388608) AS payload", + "SELECT repeat('x', 16777216) AS payload", + "SELECT repeat('x', 33554432) AS payload", +] if TYPE_CHECKING: from datadog_checks.postgres import PostgreSql diff --git a/postgres/tests/test_remote_query.py b/postgres/tests/test_remote_query.py index 7e1352de2f50c..12add6a3ff64e 100644 --- a/postgres/tests/test_remote_query.py +++ b/postgres/tests/test_remote_query.py @@ -276,6 +276,24 @@ def test_execute_accepts_fixture_table_query_and_serializes_result_rows(): assert pool.requested_dbnames == ['datadog_test'] +@pytest.mark.parametrize('size', [1048576, 2097152, 4194304, 8388608, 16777216, 33554432]) +def test_execute_accepts_large_payload_proof_queries_and_serializes_result_rows(size): + pool = FakePool(rows=[('x' * size,)], description=[SimpleNamespace(name='payload')]) + check = make_check(pool=pool) + + response = execute_remote_query( + valid_request(query=f"SELECT repeat('x', {size}) AS payload"), StaticPostgresCheckRegistry([check]) + ) + + assert response['status'] == 'SUCCEEDED' + assert response['columns'] == [{'name': 'payload', 'type': 'string'}] + assert len(response['rows']) == 1 + assert len(response['rows'][0]['payload']) == size + assert response['truncated'] is False + assert response['stats']['rowCount'] == 1 + assert pool.requested_dbnames == ['datadog_test'] + + @pytest.mark.parametrize( 'query', [ From 4f481b0722c8ee24b87798d9cc63554288c53366 Mon Sep 17 00:00:00 2001 From: nubtron Date: Sun, 3 May 2026 09:40:02 +0000 Subject: [PATCH 09/11] Add Postgres remote query COPY streaming executor --- .../datadog_checks/postgres/remote_query.py | 308 +++++++++++++++++- postgres/tests/test_remote_query.py | 189 ++++++++++- .../tests/test_remote_query_integration.py | 67 +++- 3 files changed, 554 insertions(+), 10 deletions(-) diff --git a/postgres/datadog_checks/postgres/remote_query.py b/postgres/datadog_checks/postgres/remote_query.py index c63dc3ee374a0..57fb5ce51dde1 100644 --- a/postgres/datadog_checks/postgres/remote_query.py +++ b/postgres/datadog_checks/postgres/remote_query.py @@ -6,7 +6,8 @@ import json import logging -from collections.abc import Iterable, Mapping, Sequence +import time +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal, Protocol @@ -23,6 +24,22 @@ "SELECT repeat('x', 33554432) AS payload", ] +RemoteQueryCopySql = Literal[ + 'SELECT 1 AS value', + 'SELECT city, country FROM cities ORDER BY city', + "SELECT repeat('x', 1048576) AS payload", + "SELECT repeat('x', 2097152) AS payload", + "SELECT repeat('x', 4194304) AS payload", + "SELECT repeat('x', 8388608) AS payload", + "SELECT repeat('x', 16777216) AS payload", + "SELECT repeat('x', 33554432) AS payload", + "SELECT i, repeat('x', 1000) AS payload FROM generate_series(1, 3000) AS i", +] + +CopyStreamFormat = Literal['csv'] +CopyStreamEvent = Mapping[str, Any] +CopyStreamEmit = Callable[[CopyStreamEvent], None] + if TYPE_CHECKING: from datadog_checks.postgres import PostgreSql @@ -66,6 +83,17 @@ class RemoteQueryLimits(BaseModel): timeout_ms: StrictInt = Field(default=5_000, alias='timeoutMs', ge=1) +class RemoteQueryCopyLimits(BaseModel): + """Validate byte-streaming limits for COPY export mode.""" + + model_config = ConfigDict(extra='forbid', frozen=True) + + chunk_bytes: StrictInt = Field(default=1_048_576, alias='chunkBytes', ge=1) + max_bytes: StrictInt = Field(default=64 * 1_048_576, alias='maxBytes', ge=1) + max_row_bytes: StrictInt = Field(default=8 * 1_048_576, alias='maxRowBytes', ge=1) + timeout_ms: StrictInt = Field(default=30_000, alias='timeoutMs', ge=1) + + class RemoteQueryRequest(BaseModel): """Accept only exact proof queries until broader SQL execution is implemented.""" @@ -76,6 +104,18 @@ class RemoteQueryRequest(BaseModel): limits: RemoteQueryLimits = Field(default_factory=RemoteQueryLimits) +class RemoteQueryCopyRequest(BaseModel): + """Accept only explicit COPY byte-stream export requests.""" + + model_config = ConfigDict(extra='forbid', frozen=True) + + operation: Literal['copy_stream'] = Field(alias='operation') + target: RemoteQueryTarget + query: RemoteQueryCopySql + format: CopyStreamFormat = 'csv' + limits: RemoteQueryCopyLimits = Field(default_factory=RemoteQueryCopyLimits) + + @dataclass(frozen=True) class StaticPostgresCheckRegistry: checks: Sequence['PostgreSql'] @@ -84,6 +124,21 @@ def iter_postgres_checks(self) -> Iterable['PostgreSql']: return iter(self.checks) +@dataclass(frozen=True) +class _CopyStreamState: + sequence: int = 0 + chunks_emitted: int = 0 + bytes_emitted: int = 0 + + +class _CopyStreamFailure(Exception): + def __init__(self, code: str, message: str, retryable: bool = False): + self.code = code + self.message = message + self.retryable = retryable + super().__init__(message) + + class PostgresCheckRegistry(Protocol): def iter_postgres_checks(self) -> Iterable['PostgreSql']: ... @@ -102,6 +157,65 @@ def execute_agent_rpc_json(request_json: str | bytes | bytearray, check: 'Postgr return json.dumps(response, default=str) +def execute_agent_rpc_stream_copy( + request_json: str | bytes | bytearray, check: 'PostgreSql', emit: CopyStreamEmit +) -> None: + """Execute an explicit COPY byte-stream request and emit chunk events.""" + try: + request = json.loads(request_json) + except (TypeError, ValueError): + emit( + _stream_failed_event( + 'invalid_request', 'Invalid remote query request: request_json must be a valid JSON object.' + ) + ) + return + + if not isinstance(request, Mapping): + emit( + _stream_failed_event('invalid_request', 'Invalid remote query request: request_json must be a JSON object.') + ) + return + + events = iter_agent_rpc_stream_copy_events(request, StaticPostgresCheckRegistry([check])) + try: + for event in events: + emit(event) + except BaseException: + events.close() + raise + + +def iter_agent_rpc_stream_copy_events(request: Any, registry: PostgresCheckRegistry) -> Iterator[dict[str, Any]]: + """Yield COPY byte-stream events for unit tests and callback adaptation.""" + started_at = time.monotonic() + try: + parsed_request = RemoteQueryCopyRequest.model_validate(request) + except ValidationError as e: + yield _stream_failed_event('invalid_request', _validation_message(e), elapsed_ms=_elapsed_ms(started_at)) + return + + target = parsed_request.target + matches = _resolve_matches(target, registry.iter_postgres_checks()) + LOGGER.debug('Remote query COPY stream target match count: %d', len(matches)) + if not matches: + yield _stream_failed_event( + 'target_not_found', + 'No loaded Postgres integration instance matched target selector.', + elapsed_ms=_elapsed_ms(started_at), + ) + return + if len(matches) > 1: + yield _stream_failed_event( + 'target_ambiguous', + 'More than one loaded Postgres integration instance matched target selector.', + elapsed_ms=_elapsed_ms(started_at), + ) + return + + yield from _iter_copy_stream_events(matches[0], parsed_request, started_at) + + def execute_remote_query(request: Any, registry: PostgresCheckRegistry) -> dict[str, Any]: try: parsed_request = RemoteQueryRequest.model_validate(request) @@ -182,6 +296,198 @@ def _execute_safe_query( } +def _iter_copy_stream_events( + check: 'PostgreSql', request: RemoteQueryCopyRequest, started_at: float +) -> Iterator[dict[str, Any]]: + db_pool = getattr(check, 'db_pool', None) + if db_pool is None: + yield _stream_failed_event( + 'credentials_unavailable', + 'Matched Postgres check does not expose a connection pool.', + elapsed_ms=_elapsed_ms(started_at), + ) + return + if getattr(db_pool, 'is_closed', lambda: False)(): + yield _stream_failed_event( + 'target_unavailable', + 'Matched Postgres check connection pool is closed.', + retryable=False, + elapsed_ms=_elapsed_ms(started_at), + ) + return + + yield { + 'type': 'metadata', + 'status': 'STARTED', + 'format': request.format, + 'operation': request.operation, + 'chunkBytes': request.limits.chunk_bytes, + 'maxBytes': request.limits.max_bytes, + 'maxRowBytes': request.limits.max_row_bytes, + } + + state = _CopyStreamState() + error: _CopyStreamFailure | None = None + try: + for event, next_state in _copy_stream_data_events(check, request, state, started_at): + state = next_state + yield event + except _CopyStreamFailure as e: + error = e + except RuntimeError: + error = _CopyStreamFailure( + 'target_unavailable', 'Matched Postgres check connection pool is unavailable.', retryable=False + ) + except Exception: + LOGGER.exception('Remote query COPY stream execution failed') + error = _CopyStreamFailure('query_failed', 'Remote query COPY stream execution failed.') + + if error is not None: + yield _stream_failed_event( + error.code, + error.message, + retryable=error.retryable, + stats=_copy_stream_stats(state, started_at, request.format), + ) + return + + yield { + 'type': 'final', + 'status': 'SUCCEEDED', + 'stats': _copy_stream_stats(state, started_at, request.format), + } + + +def _copy_stream_data_events( + check: 'PostgreSql', request: RemoteQueryCopyRequest, state: _CopyStreamState, started_at: float +) -> Iterator[tuple[dict[str, Any], _CopyStreamState]]: + limits = request.limits + deadline = started_at + (limits.timeout_ms / 1000) + copy_sql = _copy_stdout_sql(request.query, request.format) + pending = bytearray() + + with check.db_pool.get_connection(request.target.dbname) as conn: + with conn.cursor() as cursor: + previous_statement_timeout = _set_statement_timeout(cursor, limits.timeout_ms) + try: + with cursor.copy(copy_sql) as copy: + for block in copy: + _raise_if_timed_out(deadline) + block_view = memoryview(block) + if len(block_view) > limits.max_row_bytes: + raise _CopyStreamFailure( + 'max_row_bytes_exceeded', + 'COPY stream row exceeded maxRowBytes; psycopg exposes COPY data at row granularity.', + ) + + offset = 0 + while offset < len(block_view): + _raise_if_timed_out(deadline) + remaining_allowed = limits.max_bytes - state.bytes_emitted - len(pending) + if remaining_allowed <= 0: + raise _CopyStreamFailure('max_bytes_exceeded', 'COPY stream exceeded maxBytes.') + + remaining_chunk = limits.chunk_bytes - len(pending) + take = min(remaining_chunk, remaining_allowed, len(block_view) - offset) + pending.extend(block_view[offset : offset + take]) + offset += take + + if len(pending) >= limits.chunk_bytes: + event, state = _copy_data_event(pending, state) + pending.clear() + yield event, state + + if offset < len(block_view) and state.bytes_emitted + len(pending) >= limits.max_bytes: + if pending: + event, state = _copy_data_event(pending, state) + pending.clear() + yield event, state + raise _CopyStreamFailure('max_bytes_exceeded', 'COPY stream exceeded maxBytes.') + + if pending: + event, state = _copy_data_event(pending, state) + pending.clear() + yield event, state + finally: + _restore_statement_timeout(cursor, previous_statement_timeout) + + +def _copy_stdout_sql(query: str, stream_format: CopyStreamFormat) -> str: + # CSV is the only initial COPY byte-stream format. It is compact, standard, and preserves raw COPY bytes. + if stream_format != 'csv': + raise _CopyStreamFailure('invalid_request', 'Unsupported COPY stream format.') + return f'COPY ({query}) TO STDOUT WITH (FORMAT CSV)' + + +def _set_statement_timeout(cursor: Any, timeout_ms: int) -> str | None: + previous_statement_timeout = None + try: + cursor.execute('SHOW statement_timeout') + row = cursor.fetchone() + previous_statement_timeout = row[0] if row else None + cursor.execute('SET statement_timeout = %s', (timeout_ms,)) + except Exception: + LOGGER.debug('Unable to scope statement_timeout for remote query COPY stream', exc_info=True) + return previous_statement_timeout + + +def _restore_statement_timeout(cursor: Any, previous_statement_timeout: str | None) -> None: + if previous_statement_timeout is None: + return + try: + cursor.execute('SET statement_timeout = %s', (previous_statement_timeout,)) + except Exception: + LOGGER.debug('Unable to restore statement_timeout after remote query COPY stream', exc_info=True) + + +def _raise_if_timed_out(deadline: float) -> None: + if time.monotonic() > deadline: + raise _CopyStreamFailure('timeout', 'COPY stream exceeded timeoutMs.', retryable=True) + + +def _copy_data_event(data: bytearray, state: _CopyStreamState) -> tuple[dict[str, Any], _CopyStreamState]: + payload = bytes(data) + event = {'type': 'data', 'sequence': state.sequence, 'data': payload, 'bytes': len(payload)} + next_state = _CopyStreamState( + sequence=state.sequence + 1, + chunks_emitted=state.chunks_emitted + 1, + bytes_emitted=state.bytes_emitted + len(payload), + ) + return event, next_state + + +def _copy_stream_stats(state: _CopyStreamState, started_at: float, stream_format: CopyStreamFormat) -> dict[str, Any]: + return { + 'format': stream_format, + 'bytesEmitted': state.bytes_emitted, + 'chunksEmitted': state.chunks_emitted, + 'elapsedMs': _elapsed_ms(started_at), + } + + +def _elapsed_ms(started_at: float) -> int: + return max(0, int((time.monotonic() - started_at) * 1000)) + + +def _stream_failed_event( + code: str, + message: str, + retryable: bool = False, + stats: Mapping[str, Any] | None = None, + elapsed_ms: int | None = None, +) -> dict[str, Any]: + event = { + 'type': 'final', + 'status': 'FAILED', + 'error': {'code': code, 'message': message, 'retryable': retryable}, + } + if stats is not None: + event['stats'] = dict(stats) + elif elapsed_ms is not None: + event['stats'] = {'elapsedMs': elapsed_ms} + return event + + def _response_columns(description: Sequence[Any], rows: Sequence[Sequence[Any]]) -> list[dict[str, str]]: return [ {'name': _column_name(column), 'type': _column_type(index, rows)} for index, column in enumerate(description) diff --git a/postgres/tests/test_remote_query.py b/postgres/tests/test_remote_query.py index 12add6a3ff64e..372f5aa64f7e1 100644 --- a/postgres/tests/test_remote_query.py +++ b/postgres/tests/test_remote_query.py @@ -11,17 +11,21 @@ from datadog_checks.postgres.remote_query import ( StaticPostgresCheckRegistry, execute_agent_rpc_json, + execute_agent_rpc_stream_copy, execute_remote_query, + iter_agent_rpc_stream_copy_events, normalize_target, ) class FakePool: - def __init__(self, rows=None, description=None, closed=False): + def __init__(self, rows=None, description=None, closed=False, copy_blocks=None): self.rows = rows or [(1,)] self.description = description or [SimpleNamespace(name='value')] self.closed = closed + self.copy_blocks = copy_blocks or [] self.requested_dbnames = [] + self.closed_copies = 0 def is_closed(self): return self.closed @@ -29,31 +33,57 @@ def is_closed(self): @contextmanager def get_connection(self, dbname): self.requested_dbnames.append(dbname) - yield FakeConnection(self.rows, self.description) + yield FakeConnection(self.rows, self.description, self.copy_blocks, self) class FakeConnection: - def __init__(self, rows, description): + def __init__(self, rows, description, copy_blocks, pool): self.rows = rows self.description = description + self.copy_blocks = copy_blocks + self.pool = pool @contextmanager def cursor(self): - yield FakeCursor(self.rows, self.description) + yield FakeCursor(self.rows, self.description, self.copy_blocks, self.pool) class FakeCursor: - def __init__(self, rows, description): + def __init__(self, rows, description, copy_blocks, pool): self.rows = rows self.description = description - self.executed = None + self.copy_blocks = copy_blocks + self.pool = pool + self.executed = [] - def execute(self, query): - self.executed = query + def execute(self, query, params=None): + self.executed.append((query, params)) + + def fetchone(self): + return ('0',) def fetchmany(self, size): return self.rows[:size] + def copy(self, query): + self.executed.append((query, None)) + return FakeCopy(self.copy_blocks, self.pool) + + +class FakeCopy: + def __init__(self, blocks, pool): + self.blocks = blocks + self.pool = pool + + def __enter__(self): + return self + + def __exit__(self, *args): + self.pool.closed_copies += 1 + + def __iter__(self): + return iter(self.blocks) + def make_check(host='localhost', port=5432, dbname='datadog_test', pool=None, **metadata): return SimpleNamespace( @@ -385,3 +415,146 @@ def test_execute_closed_pool_returns_target_unavailable_without_recreating_crede assert_failed(response, 'target_unavailable') assert pool.requested_dbnames == [] + + +def valid_copy_request(host='LOCALHOST.', port=5432, dbname='datadog_test', **extra): + request = { + 'operation': 'copy_stream', + 'target': {'host': host, 'port': port, 'dbname': dbname}, + 'query': 'SELECT 1 AS value', + 'format': 'csv', + 'limits': {'chunkBytes': 8, 'maxBytes': 64, 'maxRowBytes': 32, 'timeoutMs': 5000}, + } + request.update(extra) + return request + + +def collect_copy_events(request, check): + return list(iter_agent_rpc_stream_copy_events(request, StaticPostgresCheckRegistry([check]))) + + +def test_copy_stream_requires_explicit_operation_before_pool_access(): + pool = FakePool(copy_blocks=[b'1\n']) + request = valid_copy_request() + request.pop('operation') + + events = collect_copy_events(request, make_check(pool=pool)) + + assert events[-1]['status'] == 'FAILED' + assert events[-1]['error']['code'] == 'invalid_request' + assert 'operation' in events[-1]['error']['message'] + assert pool.requested_dbnames == [] + + +def test_copy_stream_rejects_unknown_fields_without_echoing_secrets(caplog): + pool = FakePool(copy_blocks=[b'1\n']) + request = valid_copy_request(password='SECRET_DO_NOT_LOG') + + events = collect_copy_events(request, make_check(pool=pool)) + + assert events[-1]['status'] == 'FAILED' + assert events[-1]['error']['code'] == 'invalid_request' + assert 'password' in events[-1]['error']['message'] + assert 'SECRET_DO_NOT_LOG' not in str(events) + assert 'SECRET_DO_NOT_LOG' not in caplog.text + assert pool.requested_dbnames == [] + + +def test_copy_stream_rejects_non_copy_allowlisted_queries_before_pool_access(): + pool = FakePool(copy_blocks=[b'1\n']) + request = valid_copy_request(query='SELECT current_database()') + + events = collect_copy_events(request, make_check(pool=pool)) + + assert events[-1]['status'] == 'FAILED' + assert events[-1]['error']['code'] == 'invalid_request' + assert 'query' in events[-1]['error']['message'] + assert pool.requested_dbnames == [] + + +def test_copy_stream_uses_connection_pool_and_emits_chunked_copy_bytes(): + pool = FakePool(copy_blocks=[b'abc', b'defgh', b'ijklmnop', b'qr']) + check = block_existing_query_helpers(make_check(pool=pool)) + + events = collect_copy_events(valid_copy_request(), check) + + assert events[0]['type'] == 'metadata' + assert events[0]['operation'] == 'copy_stream' + assert events[0]['format'] == 'csv' + data_events = [event for event in events if event['type'] == 'data'] + assert [event['sequence'] for event in data_events] == [0, 1, 2] + assert [event['data'] for event in data_events] == [b'abcdefgh', b'ijklmnop', b'qr'] + assert [event['bytes'] for event in data_events] == [8, 8, 2] + assert events[-1]['type'] == 'final' + assert events[-1]['status'] == 'SUCCEEDED' + assert events[-1]['stats']['bytesEmitted'] == 18 + assert events[-1]['stats']['chunksEmitted'] == 3 + assert pool.requested_dbnames == ['datadog_test'] + assert pool.closed_copies == 1 + + +def test_copy_stream_enforces_max_bytes_without_exceeding_limit(): + pool = FakePool(copy_blocks=[b'abcdefgh', b'ijklmnop']) + request = valid_copy_request(limits={'chunkBytes': 8, 'maxBytes': 10, 'maxRowBytes': 32, 'timeoutMs': 5000}) + + events = collect_copy_events(request, make_check(pool=pool)) + + data_events = [event for event in events if event['type'] == 'data'] + assert [event['data'] for event in data_events] == [b'abcdefgh', b'ij'] + assert sum(event['bytes'] for event in data_events) == 10 + assert events[-1]['status'] == 'FAILED' + assert events[-1]['error']['code'] == 'max_bytes_exceeded' + assert events[-1]['stats']['bytesEmitted'] == 10 + assert pool.closed_copies == 1 + + +def test_copy_stream_enforces_max_row_bytes_after_copy_block_arrives(): + pool = FakePool(copy_blocks=[b'abc', b'x' * 33]) + + events = collect_copy_events(valid_copy_request(), make_check(pool=pool)) + + assert [event['data'] for event in events if event['type'] == 'data'] == [] + assert events[-1]['status'] == 'FAILED' + assert events[-1]['error']['code'] == 'max_row_bytes_exceeded' + assert 'row granularity' in events[-1]['error']['message'] + assert pool.closed_copies == 1 + + +def test_agent_rpc_stream_copy_adapts_iterator_to_callback(): + pool = FakePool(copy_blocks=[b'1\n']) + events = [] + + execute_agent_rpc_stream_copy(json.dumps(valid_copy_request()), make_check(pool=pool), events.append) + + assert [event['type'] for event in events] == ['metadata', 'data', 'final'] + assert events[1]['data'] == b'1\n' + assert events[-1]['status'] == 'SUCCEEDED' + + +def test_agent_rpc_stream_copy_rejects_malformed_json_without_echoing_input(caplog): + pool = FakePool(copy_blocks=[b'1\n']) + events = [] + + execute_agent_rpc_stream_copy('{"password": "SECRET_DO_NOT_LOG"', make_check(pool=pool), events.append) + + assert events[-1]['status'] == 'FAILED' + assert events[-1]['error']['code'] == 'invalid_request' + assert 'SECRET_DO_NOT_LOG' not in str(events) + assert 'SECRET_DO_NOT_LOG' not in caplog.text + assert pool.requested_dbnames == [] + + +def test_agent_rpc_stream_copy_closes_copy_when_callback_raises(): + pool = FakePool(copy_blocks=[b'12345678', b'abcdef']) + events = [] + + def emit(event): + events.append(event) + if event['type'] == 'data': + raise RuntimeError('stop streaming') + + with pytest.raises(RuntimeError, match='stop streaming'): + execute_agent_rpc_stream_copy(json.dumps(valid_copy_request()), make_check(pool=pool), emit) + + assert [event['type'] for event in events] == ['metadata', 'data'] + assert pool.closed_copies == 1 diff --git a/postgres/tests/test_remote_query_integration.py b/postgres/tests/test_remote_query_integration.py index 05cb065d67274..cbc10a377eeb1 100644 --- a/postgres/tests/test_remote_query_integration.py +++ b/postgres/tests/test_remote_query_integration.py @@ -6,7 +6,11 @@ import pytest -from datadog_checks.postgres.remote_query import StaticPostgresCheckRegistry, execute_remote_query +from datadog_checks.postgres.remote_query import ( + StaticPostgresCheckRegistry, + execute_remote_query, + iter_agent_rpc_stream_copy_events, +) def remote_query_request(pg_instance: dict[str, object], query: str) -> dict[str, object]: @@ -21,6 +25,20 @@ def remote_query_request(pg_instance: dict[str, object], query: str) -> dict[str } +def remote_query_copy_request(pg_instance: dict[str, object], query: str, limits: dict[str, int]) -> dict[str, object]: + return { + 'operation': 'copy_stream', + 'target': { + 'host': pg_instance['host'], + 'port': int(pg_instance['port']), + 'dbname': pg_instance['dbname'], + }, + 'query': query, + 'format': 'csv', + 'limits': limits, + } + + @pytest.mark.integration @pytest.mark.usefixtures('dd_environment') def test_agent_local_executor_select_1_reuses_integration_check_credentials(integration_check, pg_instance): @@ -55,3 +73,50 @@ def test_agent_local_executor_fixture_table_query_returns_city_rows(integration_ assert response['truncated'] is False assert response['stats']['rowCount'] == 2 assert 'password' not in json.dumps(request).lower() + + +@pytest.mark.integration +@pytest.mark.usefixtures('dd_environment') +def test_agent_local_copy_stream_fixture_table_query_emits_csv_chunks(integration_check, pg_instance): + bob_instance = dict(pg_instance, username='bob', password='bob') + check = integration_check(bob_instance) + request = remote_query_copy_request( + bob_instance, + 'SELECT city, country FROM cities ORDER BY city', + {'chunkBytes': 16, 'maxBytes': 1024, 'maxRowBytes': 128, 'timeoutMs': 5000}, + ) + + events = list(iter_agent_rpc_stream_copy_events(request, StaticPostgresCheckRegistry([check]))) + + data = b''.join(event['data'] for event in events if event['type'] == 'data') + assert events[0]['type'] == 'metadata' + assert events[0]['format'] == 'csv' + assert events[-1]['status'] == 'SUCCEEDED' + assert b'Beautiful city of lights,France\n' in data + assert b'New York,USA\n' in data + assert events[-1]['stats']['bytesEmitted'] == len(data) + assert all(event['bytes'] <= 16 for event in events if event['type'] == 'data') + assert 'password' not in json.dumps(request).lower() + + +@pytest.mark.integration +@pytest.mark.usefixtures('dd_environment') +def test_agent_local_copy_stream_enforces_max_row_bytes_and_reuses_connection(integration_check, pg_instance): + check = integration_check(pg_instance) + request = remote_query_copy_request( + pg_instance, + "SELECT repeat('x', 1048576) AS payload", + {'chunkBytes': 1024, 'maxBytes': 2 * 1048576, 'maxRowBytes': 1024, 'timeoutMs': 5000}, + ) + + events = list(iter_agent_rpc_stream_copy_events(request, StaticPostgresCheckRegistry([check]))) + + assert [event for event in events if event['type'] == 'data'] == [] + assert events[-1]['status'] == 'FAILED' + assert events[-1]['error']['code'] == 'max_row_bytes_exceeded' + + response = execute_remote_query( + remote_query_request(pg_instance, 'SELECT 1 AS value'), StaticPostgresCheckRegistry([check]) + ) + assert response['status'] == 'SUCCEEDED' + assert response['rows'] == [{'value': 1}] From 497027c9ee05329612705b79a34b753ff66fb54c Mon Sep 17 00:00:00 2001 From: nubtron Date: Sun, 3 May 2026 13:56:32 +0000 Subject: [PATCH 10/11] Make Postgres remote query COPY stream binary-safe --- .../datadog_checks/postgres/remote_query.py | 98 +++++++++------ postgres/tests/test_remote_query.py | 116 +++++++++++------- .../tests/test_remote_query_integration.py | 54 ++++++-- 3 files changed, 177 insertions(+), 91 deletions(-) diff --git a/postgres/datadog_checks/postgres/remote_query.py b/postgres/datadog_checks/postgres/remote_query.py index 57fb5ce51dde1..9a8d7decde16b 100644 --- a/postgres/datadog_checks/postgres/remote_query.py +++ b/postgres/datadog_checks/postgres/remote_query.py @@ -27,6 +27,7 @@ RemoteQueryCopySql = Literal[ 'SELECT 1 AS value', 'SELECT city, country FROM cities ORDER BY city', + "SELECT decode('00ff80', 'hex') AS payload", "SELECT repeat('x', 1048576) AS payload", "SELECT repeat('x', 2097152) AS payload", "SELECT repeat('x', 4194304) AS payload", @@ -36,9 +37,8 @@ "SELECT i, repeat('x', 1000) AS payload FROM generate_series(1, 3000) AS i", ] -CopyStreamFormat = Literal['csv'] -CopyStreamEvent = Mapping[str, Any] -CopyStreamEmit = Callable[[CopyStreamEvent], None] +CopyStreamFormat = Literal['csv', 'binary'] +CopyStreamEmit = Callable[[str, str, bytes], None] if TYPE_CHECKING: from datadog_checks.postgres import PostgreSql @@ -131,6 +131,13 @@ class _CopyStreamState: bytes_emitted: int = 0 +@dataclass(frozen=True) +class CopyStreamEvent: + event_type: str + metadata: Mapping[str, Any] + payload: bytes = b'' + + class _CopyStreamFailure(Exception): def __init__(self, code: str, message: str, retryable: bool = False): self.code = code @@ -164,29 +171,33 @@ def execute_agent_rpc_stream_copy( try: request = json.loads(request_json) except (TypeError, ValueError): - emit( + _emit_copy_event( + emit, _stream_failed_event( 'invalid_request', 'Invalid remote query request: request_json must be a valid JSON object.' - ) + ), ) return if not isinstance(request, Mapping): - emit( - _stream_failed_event('invalid_request', 'Invalid remote query request: request_json must be a JSON object.') + _emit_copy_event( + emit, + _stream_failed_event( + 'invalid_request', 'Invalid remote query request: request_json must be a JSON object.' + ), ) return events = iter_agent_rpc_stream_copy_events(request, StaticPostgresCheckRegistry([check])) try: for event in events: - emit(event) + _emit_copy_event(emit, event) except BaseException: events.close() raise -def iter_agent_rpc_stream_copy_events(request: Any, registry: PostgresCheckRegistry) -> Iterator[dict[str, Any]]: +def iter_agent_rpc_stream_copy_events(request: Any, registry: PostgresCheckRegistry) -> Iterator[CopyStreamEvent]: """Yield COPY byte-stream events for unit tests and callback adaptation.""" started_at = time.monotonic() try: @@ -298,7 +309,7 @@ def _execute_safe_query( def _iter_copy_stream_events( check: 'PostgreSql', request: RemoteQueryCopyRequest, started_at: float -) -> Iterator[dict[str, Any]]: +) -> Iterator[CopyStreamEvent]: db_pool = getattr(check, 'db_pool', None) if db_pool is None: yield _stream_failed_event( @@ -316,15 +327,17 @@ def _iter_copy_stream_events( ) return - yield { - 'type': 'metadata', - 'status': 'STARTED', - 'format': request.format, - 'operation': request.operation, - 'chunkBytes': request.limits.chunk_bytes, - 'maxBytes': request.limits.max_bytes, - 'maxRowBytes': request.limits.max_row_bytes, - } + yield CopyStreamEvent( + 'metadata', + { + 'status': 'STARTED', + 'format': request.format, + 'operation': request.operation, + 'chunkBytes': request.limits.chunk_bytes, + 'maxBytes': request.limits.max_bytes, + 'maxRowBytes': request.limits.max_row_bytes, + }, + ) state = _CopyStreamState() error: _CopyStreamFailure | None = None @@ -351,16 +364,15 @@ def _iter_copy_stream_events( ) return - yield { - 'type': 'final', - 'status': 'SUCCEEDED', - 'stats': _copy_stream_stats(state, started_at, request.format), - } + yield CopyStreamEvent( + 'final', + {'status': 'SUCCEEDED', 'stats': _copy_stream_stats(state, started_at, request.format)}, + ) def _copy_stream_data_events( check: 'PostgreSql', request: RemoteQueryCopyRequest, state: _CopyStreamState, started_at: float -) -> Iterator[tuple[dict[str, Any], _CopyStreamState]]: +) -> Iterator[tuple[CopyStreamEvent, _CopyStreamState]]: limits = request.limits deadline = started_at + (limits.timeout_ms / 1000) copy_sql = _copy_stdout_sql(request.query, request.format) @@ -413,10 +425,11 @@ def _copy_stream_data_events( def _copy_stdout_sql(query: str, stream_format: CopyStreamFormat) -> str: - # CSV is the only initial COPY byte-stream format. It is compact, standard, and preserves raw COPY bytes. - if stream_format != 'csv': - raise _CopyStreamFailure('invalid_request', 'Unsupported COPY stream format.') - return f'COPY ({query}) TO STDOUT WITH (FORMAT CSV)' + if stream_format == 'csv': + return f'COPY ({query}) TO STDOUT WITH (FORMAT CSV)' + if stream_format == 'binary': + return f'COPY ({query}) TO STDOUT WITH (FORMAT BINARY)' + raise _CopyStreamFailure('invalid_request', 'Unsupported COPY stream format.') def _set_statement_timeout(cursor: Any, timeout_ms: int) -> str | None: @@ -445,9 +458,17 @@ def _raise_if_timed_out(deadline: float) -> None: raise _CopyStreamFailure('timeout', 'COPY stream exceeded timeoutMs.', retryable=True) -def _copy_data_event(data: bytearray, state: _CopyStreamState) -> tuple[dict[str, Any], _CopyStreamState]: +def _copy_data_event(data: bytearray, state: _CopyStreamState) -> tuple[CopyStreamEvent, _CopyStreamState]: payload = bytes(data) - event = {'type': 'data', 'sequence': state.sequence, 'data': payload, 'bytes': len(payload)} + event = CopyStreamEvent( + 'data', + { + 'sequence': state.sequence, + 'offset': state.bytes_emitted, + 'bytes': len(payload), + }, + payload, + ) next_state = _CopyStreamState( sequence=state.sequence + 1, chunks_emitted=state.chunks_emitted + 1, @@ -475,17 +496,20 @@ def _stream_failed_event( retryable: bool = False, stats: Mapping[str, Any] | None = None, elapsed_ms: int | None = None, -) -> dict[str, Any]: - event = { - 'type': 'final', +) -> CopyStreamEvent: + metadata = { 'status': 'FAILED', 'error': {'code': code, 'message': message, 'retryable': retryable}, } if stats is not None: - event['stats'] = dict(stats) + metadata['stats'] = dict(stats) elif elapsed_ms is not None: - event['stats'] = {'elapsedMs': elapsed_ms} - return event + metadata['stats'] = {'elapsedMs': elapsed_ms} + return CopyStreamEvent('error', metadata) + + +def _emit_copy_event(emit: CopyStreamEmit, event: CopyStreamEvent) -> None: + emit(event.event_type, json.dumps(event.metadata, default=str), event.payload) def _response_columns(description: Sequence[Any], rows: Sequence[Sequence[Any]]) -> list[dict[str, str]]: diff --git a/postgres/tests/test_remote_query.py b/postgres/tests/test_remote_query.py index 372f5aa64f7e1..910c30285cbec 100644 --- a/postgres/tests/test_remote_query.py +++ b/postgres/tests/test_remote_query.py @@ -433,6 +433,14 @@ def collect_copy_events(request, check): return list(iter_agent_rpc_stream_copy_events(request, StaticPostgresCheckRegistry([check]))) +def event_metadata(event): + return event.metadata + + +def event_payload(event): + return event.payload + + def test_copy_stream_requires_explicit_operation_before_pool_access(): pool = FakePool(copy_blocks=[b'1\n']) request = valid_copy_request() @@ -440,9 +448,9 @@ def test_copy_stream_requires_explicit_operation_before_pool_access(): events = collect_copy_events(request, make_check(pool=pool)) - assert events[-1]['status'] == 'FAILED' - assert events[-1]['error']['code'] == 'invalid_request' - assert 'operation' in events[-1]['error']['message'] + assert event_metadata(events[-1])['status'] == 'FAILED' + assert event_metadata(events[-1])['error']['code'] == 'invalid_request' + assert 'operation' in event_metadata(events[-1])['error']['message'] assert pool.requested_dbnames == [] @@ -452,9 +460,9 @@ def test_copy_stream_rejects_unknown_fields_without_echoing_secrets(caplog): events = collect_copy_events(request, make_check(pool=pool)) - assert events[-1]['status'] == 'FAILED' - assert events[-1]['error']['code'] == 'invalid_request' - assert 'password' in events[-1]['error']['message'] + assert event_metadata(events[-1])['status'] == 'FAILED' + assert event_metadata(events[-1])['error']['code'] == 'invalid_request' + assert 'password' in event_metadata(events[-1])['error']['message'] assert 'SECRET_DO_NOT_LOG' not in str(events) assert 'SECRET_DO_NOT_LOG' not in caplog.text assert pool.requested_dbnames == [] @@ -466,9 +474,9 @@ def test_copy_stream_rejects_non_copy_allowlisted_queries_before_pool_access(): events = collect_copy_events(request, make_check(pool=pool)) - assert events[-1]['status'] == 'FAILED' - assert events[-1]['error']['code'] == 'invalid_request' - assert 'query' in events[-1]['error']['message'] + assert event_metadata(events[-1])['status'] == 'FAILED' + assert event_metadata(events[-1])['error']['code'] == 'invalid_request' + assert 'query' in event_metadata(events[-1])['error']['message'] assert pool.requested_dbnames == [] @@ -478,33 +486,46 @@ def test_copy_stream_uses_connection_pool_and_emits_chunked_copy_bytes(): events = collect_copy_events(valid_copy_request(), check) - assert events[0]['type'] == 'metadata' - assert events[0]['operation'] == 'copy_stream' - assert events[0]['format'] == 'csv' - data_events = [event for event in events if event['type'] == 'data'] - assert [event['sequence'] for event in data_events] == [0, 1, 2] - assert [event['data'] for event in data_events] == [b'abcdefgh', b'ijklmnop', b'qr'] - assert [event['bytes'] for event in data_events] == [8, 8, 2] - assert events[-1]['type'] == 'final' - assert events[-1]['status'] == 'SUCCEEDED' - assert events[-1]['stats']['bytesEmitted'] == 18 - assert events[-1]['stats']['chunksEmitted'] == 3 + assert events[0].event_type == 'metadata' + assert event_metadata(events[0])['operation'] == 'copy_stream' + assert event_metadata(events[0])['format'] == 'csv' + data_events = [event for event in events if event.event_type == 'data'] + assert [event_metadata(event)['sequence'] for event in data_events] == [0, 1, 2] + assert [event_metadata(event)['offset'] for event in data_events] == [0, 8, 16] + assert [event_payload(event) for event in data_events] == [b'abcdefgh', b'ijklmnop', b'qr'] + assert [event_metadata(event)['bytes'] for event in data_events] == [8, 8, 2] + assert events[-1].event_type == 'final' + assert event_metadata(events[-1])['status'] == 'SUCCEEDED' + assert event_metadata(events[-1])['stats']['bytesEmitted'] == 18 + assert event_metadata(events[-1])['stats']['chunksEmitted'] == 3 assert pool.requested_dbnames == ['datadog_test'] assert pool.closed_copies == 1 +def test_copy_stream_data_event_payload_preserves_arbitrary_bytes(): + arbitrary_bytes = b'\x00\xff\x80abc\n' + pool = FakePool(copy_blocks=[arbitrary_bytes]) + + events = collect_copy_events(valid_copy_request(), make_check(pool=pool)) + + data_events = [event for event in events if event.event_type == 'data'] + assert len(data_events) == 1 + assert event_payload(data_events[0]) == arbitrary_bytes + assert isinstance(event_payload(data_events[0]), bytes) + + def test_copy_stream_enforces_max_bytes_without_exceeding_limit(): pool = FakePool(copy_blocks=[b'abcdefgh', b'ijklmnop']) request = valid_copy_request(limits={'chunkBytes': 8, 'maxBytes': 10, 'maxRowBytes': 32, 'timeoutMs': 5000}) events = collect_copy_events(request, make_check(pool=pool)) - data_events = [event for event in events if event['type'] == 'data'] - assert [event['data'] for event in data_events] == [b'abcdefgh', b'ij'] - assert sum(event['bytes'] for event in data_events) == 10 - assert events[-1]['status'] == 'FAILED' - assert events[-1]['error']['code'] == 'max_bytes_exceeded' - assert events[-1]['stats']['bytesEmitted'] == 10 + data_events = [event for event in events if event.event_type == 'data'] + assert [event_payload(event) for event in data_events] == [b'abcdefgh', b'ij'] + assert sum(event_metadata(event)['bytes'] for event in data_events) == 10 + assert event_metadata(events[-1])['status'] == 'FAILED' + assert event_metadata(events[-1])['error']['code'] == 'max_bytes_exceeded' + assert event_metadata(events[-1])['stats']['bytesEmitted'] == 10 assert pool.closed_copies == 1 @@ -513,32 +534,41 @@ def test_copy_stream_enforces_max_row_bytes_after_copy_block_arrives(): events = collect_copy_events(valid_copy_request(), make_check(pool=pool)) - assert [event['data'] for event in events if event['type'] == 'data'] == [] - assert events[-1]['status'] == 'FAILED' - assert events[-1]['error']['code'] == 'max_row_bytes_exceeded' - assert 'row granularity' in events[-1]['error']['message'] + assert [event_payload(event) for event in events if event.event_type == 'data'] == [] + assert event_metadata(events[-1])['status'] == 'FAILED' + assert event_metadata(events[-1])['error']['code'] == 'max_row_bytes_exceeded' + assert 'row granularity' in event_metadata(events[-1])['error']['message'] assert pool.closed_copies == 1 -def test_agent_rpc_stream_copy_adapts_iterator_to_callback(): - pool = FakePool(copy_blocks=[b'1\n']) +def test_agent_rpc_stream_copy_adapts_iterator_to_binary_safe_callback(): + arbitrary_bytes = b'\x00\xff\x80abc\n' + pool = FakePool(copy_blocks=[arbitrary_bytes]) events = [] - execute_agent_rpc_stream_copy(json.dumps(valid_copy_request()), make_check(pool=pool), events.append) + execute_agent_rpc_stream_copy( + json.dumps(valid_copy_request()), make_check(pool=pool), lambda *event: events.append(event) + ) - assert [event['type'] for event in events] == ['metadata', 'data', 'final'] - assert events[1]['data'] == b'1\n' - assert events[-1]['status'] == 'SUCCEEDED' + assert [event[0] for event in events] == ['metadata', 'data', 'final'] + assert json.loads(events[1][1])['bytes'] == len(arbitrary_bytes) + assert events[1][2] == arbitrary_bytes + assert isinstance(events[1][2], bytes) + assert json.loads(events[-1][1])['status'] == 'SUCCEEDED' def test_agent_rpc_stream_copy_rejects_malformed_json_without_echoing_input(caplog): pool = FakePool(copy_blocks=[b'1\n']) events = [] - execute_agent_rpc_stream_copy('{"password": "SECRET_DO_NOT_LOG"', make_check(pool=pool), events.append) + execute_agent_rpc_stream_copy( + '{"password": "SECRET_DO_NOT_LOG"', make_check(pool=pool), lambda *event: events.append(event) + ) - assert events[-1]['status'] == 'FAILED' - assert events[-1]['error']['code'] == 'invalid_request' + metadata = json.loads(events[-1][1]) + assert events[-1][0] == 'error' + assert metadata['status'] == 'FAILED' + assert metadata['error']['code'] == 'invalid_request' assert 'SECRET_DO_NOT_LOG' not in str(events) assert 'SECRET_DO_NOT_LOG' not in caplog.text assert pool.requested_dbnames == [] @@ -548,13 +578,13 @@ def test_agent_rpc_stream_copy_closes_copy_when_callback_raises(): pool = FakePool(copy_blocks=[b'12345678', b'abcdef']) events = [] - def emit(event): - events.append(event) - if event['type'] == 'data': + def emit(event_type, metadata_json, payload): + events.append((event_type, metadata_json, payload)) + if event_type == 'data': raise RuntimeError('stop streaming') with pytest.raises(RuntimeError, match='stop streaming'): execute_agent_rpc_stream_copy(json.dumps(valid_copy_request()), make_check(pool=pool), emit) - assert [event['type'] for event in events] == ['metadata', 'data'] + assert [event[0] for event in events] == ['metadata', 'data'] assert pool.closed_copies == 1 diff --git a/postgres/tests/test_remote_query_integration.py b/postgres/tests/test_remote_query_integration.py index cbc10a377eeb1..78dd12f64179c 100644 --- a/postgres/tests/test_remote_query_integration.py +++ b/postgres/tests/test_remote_query_integration.py @@ -25,7 +25,9 @@ def remote_query_request(pg_instance: dict[str, object], query: str) -> dict[str } -def remote_query_copy_request(pg_instance: dict[str, object], query: str, limits: dict[str, int]) -> dict[str, object]: +def remote_query_copy_request( + pg_instance: dict[str, object], query: str, limits: dict[str, int], stream_format: str = 'csv' +) -> dict[str, object]: return { 'operation': 'copy_stream', 'target': { @@ -34,11 +36,19 @@ def remote_query_copy_request(pg_instance: dict[str, object], query: str, limits 'dbname': pg_instance['dbname'], }, 'query': query, - 'format': 'csv', + 'format': stream_format, 'limits': limits, } +def event_metadata(event): + return event.metadata + + +def event_payload(event): + return event.payload + + @pytest.mark.integration @pytest.mark.usefixtures('dd_environment') def test_agent_local_executor_select_1_reuses_integration_check_credentials(integration_check, pg_instance): @@ -88,17 +98,39 @@ def test_agent_local_copy_stream_fixture_table_query_emits_csv_chunks(integratio events = list(iter_agent_rpc_stream_copy_events(request, StaticPostgresCheckRegistry([check]))) - data = b''.join(event['data'] for event in events if event['type'] == 'data') - assert events[0]['type'] == 'metadata' - assert events[0]['format'] == 'csv' - assert events[-1]['status'] == 'SUCCEEDED' + data_events = [event for event in events if event.event_type == 'data'] + data = b''.join(event_payload(event) for event in data_events) + assert events[0].event_type == 'metadata' + assert event_metadata(events[0])['format'] == 'csv' + assert event_metadata(events[-1])['status'] == 'SUCCEEDED' assert b'Beautiful city of lights,France\n' in data assert b'New York,USA\n' in data - assert events[-1]['stats']['bytesEmitted'] == len(data) - assert all(event['bytes'] <= 16 for event in events if event['type'] == 'data') + assert event_metadata(events[-1])['stats']['bytesEmitted'] == len(data) + assert all(event_metadata(event)['bytes'] <= 16 for event in data_events) assert 'password' not in json.dumps(request).lower() +@pytest.mark.integration +@pytest.mark.usefixtures('dd_environment') +def test_agent_local_copy_stream_binary_format_preserves_non_text_bytes(integration_check, pg_instance): + check = integration_check(pg_instance) + request = remote_query_copy_request( + pg_instance, + "SELECT decode('00ff80', 'hex') AS payload", + {'chunkBytes': 1024, 'maxBytes': 4096, 'maxRowBytes': 4096, 'timeoutMs': 5000}, + stream_format='binary', + ) + + events = list(iter_agent_rpc_stream_copy_events(request, StaticPostgresCheckRegistry([check]))) + + data = b''.join(event_payload(event) for event in events if event.event_type == 'data') + assert events[0].event_type == 'metadata' + assert event_metadata(events[0])['format'] == 'binary' + assert event_metadata(events[-1])['status'] == 'SUCCEEDED' + assert b'PGCOPY\n\xff\r\n\x00' in data + assert b'\x00\xff\x80' in data + + @pytest.mark.integration @pytest.mark.usefixtures('dd_environment') def test_agent_local_copy_stream_enforces_max_row_bytes_and_reuses_connection(integration_check, pg_instance): @@ -111,9 +143,9 @@ def test_agent_local_copy_stream_enforces_max_row_bytes_and_reuses_connection(in events = list(iter_agent_rpc_stream_copy_events(request, StaticPostgresCheckRegistry([check]))) - assert [event for event in events if event['type'] == 'data'] == [] - assert events[-1]['status'] == 'FAILED' - assert events[-1]['error']['code'] == 'max_row_bytes_exceeded' + assert [event for event in events if event.event_type == 'data'] == [] + assert event_metadata(events[-1])['status'] == 'FAILED' + assert event_metadata(events[-1])['error']['code'] == 'max_row_bytes_exceeded' response = execute_remote_query( remote_query_request(pg_instance, 'SELECT 1 AS value'), StaticPostgresCheckRegistry([check]) From 1742f6f97b4bff4b8505cd1ea2246415e6b2d52a Mon Sep 17 00:00:00 2001 From: nubtron Date: Sun, 3 May 2026 14:46:44 +0000 Subject: [PATCH 11/11] Remove inline Postgres remote query mode --- .../datadog_checks/postgres/remote_query.py | 147 ------- postgres/tests/test_remote_query.py | 409 ++++++------------ .../tests/test_remote_query_integration.py | 77 ++-- 3 files changed, 166 insertions(+), 467 deletions(-) diff --git a/postgres/datadog_checks/postgres/remote_query.py b/postgres/datadog_checks/postgres/remote_query.py index 9a8d7decde16b..11603a518a1b5 100644 --- a/postgres/datadog_checks/postgres/remote_query.py +++ b/postgres/datadog_checks/postgres/remote_query.py @@ -13,17 +13,6 @@ from pydantic import BaseModel, ConfigDict, Field, StrictInt, StrictStr, ValidationError, field_validator -RemoteQuerySql = Literal[ - 'SELECT 1 AS value', - 'SELECT city, country FROM cities ORDER BY city', - "SELECT repeat('x', 1048576) AS payload", - "SELECT repeat('x', 2097152) AS payload", - "SELECT repeat('x', 4194304) AS payload", - "SELECT repeat('x', 8388608) AS payload", - "SELECT repeat('x', 16777216) AS payload", - "SELECT repeat('x', 33554432) AS payload", -] - RemoteQueryCopySql = Literal[ 'SELECT 1 AS value', 'SELECT city, country FROM cities ORDER BY city', @@ -73,16 +62,6 @@ def validate_dbname(cls, value: str) -> str: return value -class RemoteQueryLimits(BaseModel): - """Validate the future-facing limits contract for the initial safe query slice.""" - - model_config = ConfigDict(extra='forbid', frozen=True) - - max_rows: StrictInt = Field(default=10, alias='maxRows', ge=1) - max_bytes: StrictInt = Field(default=1_048_576, alias='maxBytes', ge=1) - timeout_ms: StrictInt = Field(default=5_000, alias='timeoutMs', ge=1) - - class RemoteQueryCopyLimits(BaseModel): """Validate byte-streaming limits for COPY export mode.""" @@ -94,16 +73,6 @@ class RemoteQueryCopyLimits(BaseModel): timeout_ms: StrictInt = Field(default=30_000, alias='timeoutMs', ge=1) -class RemoteQueryRequest(BaseModel): - """Accept only exact proof queries until broader SQL execution is implemented.""" - - model_config = ConfigDict(extra='forbid', frozen=True) - - target: RemoteQueryTarget - query: RemoteQuerySql - limits: RemoteQueryLimits = Field(default_factory=RemoteQueryLimits) - - class RemoteQueryCopyRequest(BaseModel): """Accept only explicit COPY byte-stream export requests.""" @@ -150,20 +119,6 @@ class PostgresCheckRegistry(Protocol): def iter_postgres_checks(self) -> Iterable['PostgreSql']: ... -def execute_agent_rpc_json(request_json: str | bytes | bytearray, check: 'PostgreSql') -> str: - try: - request = json.loads(request_json) - except (TypeError, ValueError): - response = _error('invalid_request', 'Invalid remote query request: request_json must be a valid JSON object.') - else: - if not isinstance(request, Mapping): - response = _error('invalid_request', 'Invalid remote query request: request_json must be a JSON object.') - else: - response = execute_remote_query(request, StaticPostgresCheckRegistry([check])) - - return json.dumps(response, default=str) - - def execute_agent_rpc_stream_copy( request_json: str | bytes | bytearray, check: 'PostgreSql', emit: CopyStreamEmit ) -> None: @@ -227,26 +182,6 @@ def iter_agent_rpc_stream_copy_events(request: Any, registry: PostgresCheckRegis yield from _iter_copy_stream_events(matches[0], parsed_request, started_at) -def execute_remote_query(request: Any, registry: PostgresCheckRegistry) -> dict[str, Any]: - try: - parsed_request = RemoteQueryRequest.model_validate(request) - except ValidationError as e: - return _validation_error(e) - - target = parsed_request.target - limits = parsed_request.limits - query = parsed_request.query - - matches = _resolve_matches(target, registry.iter_postgres_checks()) - LOGGER.debug('Remote query target match count: %d', len(matches)) - if not matches: - return _error('target_not_found', 'No loaded Postgres integration instance matched target selector.') - if len(matches) > 1: - return _error('target_ambiguous', 'More than one loaded Postgres integration instance matched target selector.') - - return _execute_safe_query(matches[0], target, query, limits) - - def normalize_target(target: Mapping[str, Any]) -> RemoteQueryTarget: try: return RemoteQueryTarget.model_validate(target) @@ -269,44 +204,6 @@ def _target_from_check(check: 'PostgreSql') -> RemoteQueryTarget | None: return None -def _execute_safe_query( - check: 'PostgreSql', target: RemoteQueryTarget, query: RemoteQuerySql, limits: RemoteQueryLimits -) -> dict[str, Any]: - db_pool = getattr(check, 'db_pool', None) - if db_pool is None: - return _error('credentials_unavailable', 'Matched Postgres check does not expose a connection pool.') - if getattr(db_pool, 'is_closed', lambda: False)(): - return _error('target_unavailable', 'Matched Postgres check connection pool is closed.', retryable=False) - - try: - with db_pool.get_connection(target.dbname) as conn: - with conn.cursor() as cursor: - cursor.execute(query) - description = cursor.description - if description is None: - return _error('query_failed', 'Query did not return a result set.') - raw_rows = cursor.fetchmany(limits.max_rows + 1) - except RuntimeError: - return _error('target_unavailable', 'Matched Postgres check connection pool is unavailable.', retryable=False) - except Exception: - LOGGER.exception('Remote query execution failed') - return _error('query_failed', 'Remote query execution failed.') - - # max_bytes and timeout_ms are validated for the API contract but enforced in a follow-up slice. - truncated = len(raw_rows) > limits.max_rows - response_columns = _response_columns(description, raw_rows) - rows = [_response_row(response_columns, row) for row in raw_rows[: limits.max_rows]] - bytes_returned = len(json.dumps({'columns': response_columns, 'rows': rows}, default=str).encode('utf-8')) - - return { - 'status': 'SUCCEEDED', - 'columns': response_columns, - 'rows': rows, - 'truncated': truncated, - 'stats': {'rowCount': len(rows), 'bytesReturned': bytes_returned}, - } - - def _iter_copy_stream_events( check: 'PostgreSql', request: RemoteQueryCopyRequest, started_at: float ) -> Iterator[CopyStreamEvent]: @@ -512,46 +409,6 @@ def _emit_copy_event(emit: CopyStreamEmit, event: CopyStreamEvent) -> None: emit(event.event_type, json.dumps(event.metadata, default=str), event.payload) -def _response_columns(description: Sequence[Any], rows: Sequence[Sequence[Any]]) -> list[dict[str, str]]: - return [ - {'name': _column_name(column), 'type': _column_type(index, rows)} for index, column in enumerate(description) - ] - - -def _column_name(column: Any) -> str: - name = getattr(column, 'name', None) - if name is not None: - return str(name) - return str(column[0]) - - -def _column_type(index: int, rows: Sequence[Sequence[Any]]) -> str: - for row in rows: - if row[index] is not None: - return _value_type(row[index]) - return 'unknown' - - -def _value_type(value: Any) -> str: - if isinstance(value, bool): - return 'boolean' - if isinstance(value, int): - return 'integer' - if isinstance(value, float): - return 'number' - if isinstance(value, str): - return 'string' - return type(value).__name__ - - -def _response_row(columns: Sequence[Mapping[str, str]], row: Sequence[Any]) -> dict[str, Any]: - return {column['name']: row[index] for index, column in enumerate(columns)} - - -def _validation_error(error: ValidationError) -> dict[str, Any]: - return _error('invalid_request', _validation_message(error)) - - def _validation_message(error: ValidationError) -> str: details = [] for item in error.errors(include_input=False): @@ -566,7 +423,3 @@ def _validation_message(error: ValidationError) -> str: def _validation_location(location: tuple[Any, ...]) -> str: return '.'.join(str(part) for part in location) - - -def _error(code: str, message: str, retryable: bool = False) -> dict[str, Any]: - return {'status': 'FAILED', 'error': {'code': code, 'message': message, 'retryable': retryable}} diff --git a/postgres/tests/test_remote_query.py b/postgres/tests/test_remote_query.py index 910c30285cbec..c6a22081cb3fe 100644 --- a/postgres/tests/test_remote_query.py +++ b/postgres/tests/test_remote_query.py @@ -10,9 +10,7 @@ from datadog_checks.postgres.remote_query import ( StaticPostgresCheckRegistry, - execute_agent_rpc_json, execute_agent_rpc_stream_copy, - execute_remote_query, iter_agent_rpc_stream_copy_events, normalize_target, ) @@ -62,9 +60,6 @@ def execute(self, query, params=None): def fetchone(self): return ('0',) - def fetchmany(self, size): - return self.rows[:size] - def copy(self, query): self.executed.append((query, None)) return FakeCopy(self.copy_blocks, self.pool) @@ -99,11 +94,13 @@ def block_existing_query_helpers(check): return check -def valid_request(host='LOCALHOST.', port=5432, dbname='datadog_test', **extra): +def valid_copy_request(host='LOCALHOST.', port=5432, dbname='datadog_test', **extra): request = { + 'operation': 'copy_stream', 'target': {'host': host, 'port': port, 'dbname': dbname}, 'query': 'SELECT 1 AS value', - 'limits': {'maxRows': 10, 'maxBytes': 1048576, 'timeoutMs': 5000}, + 'format': 'csv', + 'limits': {'chunkBytes': 8, 'maxBytes': 64, 'maxRowBytes': 32, 'timeoutMs': 5000}, } request.update(extra) return request @@ -114,18 +111,23 @@ def iter_postgres_checks(self): pytest.fail('registry must not be iterated') -def assert_failed(response, code, message_contains=None): - assert response['status'] == 'FAILED' - assert response['error']['code'] == code - if message_contains is not None: - assert message_contains in response['error']['message'] +def collect_copy_events(request, check): + return list(iter_agent_rpc_stream_copy_events(request, StaticPostgresCheckRegistry([check]))) + + +def event_metadata(event): + return event.metadata -def execute_agent_rpc_response(request_json, check): - response_json = execute_agent_rpc_json(request_json, check) +def event_payload(event): + return event.payload + - assert isinstance(response_json, str) - return json.loads(response_json) +def assert_failed_event(events, code, message_contains=None): + assert event_metadata(events[-1])['status'] == 'FAILED' + assert event_metadata(events[-1])['error']['code'] == code + if message_contains is not None: + assert message_contains in event_metadata(events[-1])['error']['message'] def test_normalize_target_trims_lowercases_host_and_removes_one_trailing_dot(): @@ -163,199 +165,127 @@ def test_normalize_target_rejects_empty_host_or_dbname(target): @pytest.mark.parametrize('field', ['extra', 'password']) -def test_rejects_unknown_request_fields_before_resolution(caplog, field): - request = valid_request(**{field: 'SECRET_DO_NOT_LOG'}) +def test_copy_stream_rejects_unknown_request_fields_before_resolution(caplog, field): + request = valid_copy_request(**{field: 'SECRET_DO_NOT_LOG'}) - response = execute_remote_query(request, ExplodingRegistry()) + events = list(iter_agent_rpc_stream_copy_events(request, ExplodingRegistry())) - assert_failed(response, 'invalid_request', field) - assert 'SECRET_DO_NOT_LOG' not in str(response) + assert_failed_event(events, 'invalid_request', field) + assert 'SECRET_DO_NOT_LOG' not in str(events) assert 'SECRET_DO_NOT_LOG' not in caplog.text -def test_rejects_unknown_target_fields_before_resolution(): - request = valid_request() +def test_copy_stream_rejects_unknown_target_fields_before_resolution(): + request = valid_copy_request() request['target']['password'] = 'SECRET_DO_NOT_LOG' - response = execute_remote_query(request, ExplodingRegistry()) + events = list(iter_agent_rpc_stream_copy_events(request, ExplodingRegistry())) - assert_failed(response, 'invalid_request', 'password') - assert 'SECRET_DO_NOT_LOG' not in str(response) + assert_failed_event(events, 'invalid_request', 'password') + assert 'SECRET_DO_NOT_LOG' not in str(events) -def test_rejects_unknown_limits_fields_before_resolution(): - request = valid_request() +def test_copy_stream_rejects_unknown_limits_fields_before_resolution(): + request = valid_copy_request() request['limits']['password'] = 'SECRET_DO_NOT_LOG' - response = execute_remote_query(request, ExplodingRegistry()) + events = list(iter_agent_rpc_stream_copy_events(request, ExplodingRegistry())) - assert_failed(response, 'invalid_request', 'password') - assert 'SECRET_DO_NOT_LOG' not in str(response) + assert_failed_event(events, 'invalid_request', 'password') + assert 'SECRET_DO_NOT_LOG' not in str(events) -@pytest.mark.parametrize('field', ['maxRows', 'maxBytes', 'timeoutMs']) -def test_rejects_string_limit_values_before_resolution(field): - request = valid_request() +@pytest.mark.parametrize('field', ['chunkBytes', 'maxBytes', 'maxRowBytes', 'timeoutMs']) +def test_copy_stream_rejects_string_limit_values_before_resolution(field): + request = valid_copy_request() request['limits'][field] = '10' - response = execute_remote_query(request, ExplodingRegistry()) + events = list(iter_agent_rpc_stream_copy_events(request, ExplodingRegistry())) - assert_failed(response, 'invalid_request', field) + assert_failed_event(events, 'invalid_request', field) -@pytest.mark.parametrize( - 'request_json', - [ - json.dumps(valid_request()), - json.dumps(valid_request()).encode(), - bytearray(json.dumps(valid_request()), 'utf-8'), - ], -) -def test_agent_rpc_json_accepts_json_request_text_and_live_check(request_json): - pool = FakePool() - check = make_check(pool=pool) - - response = execute_agent_rpc_response(request_json, check) - - assert response['status'] == 'SUCCEEDED' - assert response['rows'] == [{'value': 1}] - assert pool.requested_dbnames == ['datadog_test'] - - -@pytest.mark.parametrize('request_json', ['{"password": "SECRET_DO_NOT_LOG"', b'\xff']) -def test_agent_rpc_json_rejects_malformed_json_without_echoing_input(caplog, request_json): - pool = FakePool() +def test_copy_stream_requires_explicit_operation_before_pool_access(): + pool = FakePool(copy_blocks=[b'1\n']) + request = valid_copy_request() + request.pop('operation') - response = execute_agent_rpc_response(request_json, make_check(pool=pool)) + events = collect_copy_events(request, make_check(pool=pool)) - assert_failed(response, 'invalid_request', 'request_json') - assert 'SECRET_DO_NOT_LOG' not in str(response) - assert 'SECRET_DO_NOT_LOG' not in caplog.text + assert_failed_event(events, 'invalid_request', 'operation') assert pool.requested_dbnames == [] -@pytest.mark.parametrize('request_json', ['[]', 'null', '"SECRET_DO_NOT_LOG"', '1']) -def test_agent_rpc_json_rejects_non_object_json_without_echoing_input(request_json): - pool = FakePool() +@pytest.mark.parametrize('operation', ['query', 'execute', None]) +def test_copy_stream_rejects_non_copy_operation_before_pool_access(operation): + pool = FakePool(copy_blocks=[b'1\n']) + request = valid_copy_request(operation=operation) - response = execute_agent_rpc_response(request_json, make_check(pool=pool)) + events = collect_copy_events(request, make_check(pool=pool)) - assert_failed(response, 'invalid_request', 'JSON object') - assert 'SECRET_DO_NOT_LOG' not in str(response) + assert_failed_event(events, 'invalid_request', 'operation') assert pool.requested_dbnames == [] -def test_agent_rpc_json_reuses_strict_validation_for_request_shape(): - pool = FakePool() - request = valid_request(password='SECRET_DO_NOT_LOG') +def test_copy_stream_rejects_non_copy_allowlisted_queries_before_pool_access(): + pool = FakePool(copy_blocks=[b'1\n']) + request = valid_copy_request(query='SELECT current_database()') - response = execute_agent_rpc_response(json.dumps(request), make_check(pool=pool)) + events = collect_copy_events(request, make_check(pool=pool)) - assert_failed(response, 'invalid_request', 'password') - assert 'SECRET_DO_NOT_LOG' not in str(response) + assert_failed_event(events, 'invalid_request', 'query') assert pool.requested_dbnames == [] -def test_agent_rpc_json_uses_only_supplied_live_check_for_target_matching(): - matching_pool = FakePool() - non_matching_pool = FakePool() - request_json = json.dumps(valid_request(host='configured.internal')) - - response = execute_agent_rpc_response(request_json, make_check(host='localhost', pool=non_matching_pool)) - - assert_failed(response, 'target_not_found') - assert non_matching_pool.requested_dbnames == [] - - response = execute_agent_rpc_response(request_json, make_check(host='configured.internal', pool=matching_pool)) - - assert response['status'] == 'SUCCEEDED' - assert matching_pool.requested_dbnames == ['datadog_test'] - - -def test_resolve_matches_exact_host_port_dbname_from_check_config(): - pool = FakePool() - check = make_check(host='localhost', port=5432, dbname='datadog_test', pool=pool) +@pytest.mark.parametrize('size', [1048576, 2097152, 4194304, 8388608, 16777216, 33554432]) +def test_copy_stream_accepts_large_payload_proof_queries(size): + pool = FakePool(copy_blocks=[b'x' * 8]) + request = valid_copy_request(query=f"SELECT repeat('x', {size}) AS payload") - response = execute_remote_query(valid_request(host='LOCALHOST.', port=5432), StaticPostgresCheckRegistry([check])) + events = collect_copy_events(request, make_check(pool=pool)) - assert response['status'] == 'SUCCEEDED' - assert response['columns'] == [{'name': 'value', 'type': 'integer'}] - assert response['rows'] == [{'value': 1}] + assert event_metadata(events[-1])['status'] == 'SUCCEEDED' assert pool.requested_dbnames == ['datadog_test'] -def test_execute_accepts_fixture_table_query_and_serializes_result_rows(): - pool = FakePool( - rows=[('Beautiful city of lights', 'France'), ('New York', 'USA')], - description=[SimpleNamespace(name='city'), SimpleNamespace(name='country')], - ) - check = make_check(pool=pool) +def test_copy_stream_resolves_exact_host_port_dbname_from_check_config(): + pool = FakePool(copy_blocks=[b'1\n']) + check = make_check(host='localhost', port=5432, dbname='datadog_test', pool=pool) - response = execute_remote_query( - valid_request(query='SELECT city, country FROM cities ORDER BY city'), StaticPostgresCheckRegistry([check]) - ) + events = collect_copy_events(valid_copy_request(host='LOCALHOST.', port=5432), check) - assert response['status'] == 'SUCCEEDED' - assert response['columns'] == [{'name': 'city', 'type': 'string'}, {'name': 'country', 'type': 'string'}] - assert response['rows'] == [ - {'city': 'Beautiful city of lights', 'country': 'France'}, - {'city': 'New York', 'country': 'USA'}, - ] - assert response['truncated'] is False - assert response['stats']['rowCount'] == 2 + assert event_metadata(events[-1])['status'] == 'SUCCEEDED' assert pool.requested_dbnames == ['datadog_test'] -@pytest.mark.parametrize('size', [1048576, 2097152, 4194304, 8388608, 16777216, 33554432]) -def test_execute_accepts_large_payload_proof_queries_and_serializes_result_rows(size): - pool = FakePool(rows=[('x' * size,)], description=[SimpleNamespace(name='payload')]) - check = make_check(pool=pool) +def test_copy_stream_uses_only_supplied_live_check_for_target_matching(): + matching_pool = FakePool(copy_blocks=[b'1\n']) + non_matching_pool = FakePool(copy_blocks=[b'1\n']) + request = valid_copy_request(host='configured.internal') - response = execute_remote_query( - valid_request(query=f"SELECT repeat('x', {size}) AS payload"), StaticPostgresCheckRegistry([check]) - ) + events = collect_copy_events(request, make_check(host='localhost', pool=non_matching_pool)) - assert response['status'] == 'SUCCEEDED' - assert response['columns'] == [{'name': 'payload', 'type': 'string'}] - assert len(response['rows']) == 1 - assert len(response['rows'][0]['payload']) == size - assert response['truncated'] is False - assert response['stats']['rowCount'] == 1 - assert pool.requested_dbnames == ['datadog_test'] - - -@pytest.mark.parametrize( - 'query', - [ - 'SELECT current_database()', - 'SELECT 1 AS value;', - ' SELECT 1 AS value', - 'SELECT city, country FROM cities ORDER BY city;', - 'SELECT country, city FROM cities ORDER BY city', - ], -) -def test_execute_rejects_non_canonical_query_before_pool_access(query): - pool = FakePool() - request = valid_request(query=query) + assert_failed_event(events, 'target_not_found') + assert non_matching_pool.requested_dbnames == [] - response = execute_remote_query(request, StaticPostgresCheckRegistry([make_check(pool=pool)])) + events = collect_copy_events(request, make_check(host='configured.internal', pool=matching_pool)) - assert_failed(response, 'invalid_request', 'query') - assert pool.requested_dbnames == [] + assert event_metadata(events[-1])['status'] == 'SUCCEEDED' + assert matching_pool.requested_dbnames == ['datadog_test'] -def test_resolve_requires_dbname_match_even_when_host_and_port_match(): - pool = FakePool() +def test_copy_stream_requires_dbname_match_even_when_host_and_port_match(): + pool = FakePool(copy_blocks=[b'1\n']) check = make_check(host='localhost', port=5432, dbname='datadog_test', pool=pool) - response = execute_remote_query(valid_request(dbname='postgres'), StaticPostgresCheckRegistry([check])) + events = collect_copy_events(valid_copy_request(dbname='postgres'), check) - assert_failed(response, 'target_not_found') + assert_failed_event(events, 'target_not_found') assert pool.requested_dbnames == [] -def test_resolve_ignores_metadata_identity_matches(): - pool = FakePool() +def test_copy_stream_ignores_metadata_identity_matches(): + pool = FakePool(copy_blocks=[b'1\n']) check = make_check( host='configured.internal', port=5432, @@ -365,121 +295,24 @@ def test_resolve_ignores_metadata_identity_matches(): database_identifier='reported.internal', ) - response = execute_remote_query(valid_request(host='reported.internal'), StaticPostgresCheckRegistry([check])) + events = collect_copy_events(valid_copy_request(host='reported.internal'), check) - assert_failed(response, 'target_not_found') + assert_failed_event(events, 'target_not_found') assert pool.requested_dbnames == [] -def test_resolve_fails_ambiguous_duplicate_configs(): - first_pool = FakePool() - second_pool = FakePool() +def test_copy_stream_fails_ambiguous_duplicate_configs(): + first_pool = FakePool(copy_blocks=[b'1\n']) + second_pool = FakePool(copy_blocks=[b'1\n']) checks = [make_check(pool=first_pool), make_check(pool=second_pool)] - response = execute_remote_query(valid_request(), StaticPostgresCheckRegistry(checks)) + events = list(iter_agent_rpc_stream_copy_events(valid_copy_request(), StaticPostgresCheckRegistry(checks))) - assert_failed(response, 'target_ambiguous') + assert_failed_event(events, 'target_ambiguous') assert first_pool.requested_dbnames == [] assert second_pool.requested_dbnames == [] -def test_execute_sets_truncated_when_more_than_max_rows_returned(): - pool = FakePool(rows=[(1,), (2,)]) - check = make_check(pool=pool) - request = valid_request() - request['limits']['maxRows'] = 1 - - response = execute_remote_query(request, StaticPostgresCheckRegistry([check])) - - assert response['status'] == 'SUCCEEDED' - assert response['rows'] == [{'value': 1}] - assert response['truncated'] is True - assert response['stats']['rowCount'] == 1 - - -def test_execute_uses_connection_pool_not_existing_query_helpers(): - pool = FakePool() - check = block_existing_query_helpers(make_check(pool=pool)) - - response = execute_remote_query(valid_request(), StaticPostgresCheckRegistry([check])) - - assert response['status'] == 'SUCCEEDED' - assert pool.requested_dbnames == ['datadog_test'] - - -def test_execute_closed_pool_returns_target_unavailable_without_recreating_credentials(): - pool = FakePool(closed=True) - check = make_check(pool=pool) - - response = execute_remote_query(valid_request(), StaticPostgresCheckRegistry([check])) - - assert_failed(response, 'target_unavailable') - assert pool.requested_dbnames == [] - - -def valid_copy_request(host='LOCALHOST.', port=5432, dbname='datadog_test', **extra): - request = { - 'operation': 'copy_stream', - 'target': {'host': host, 'port': port, 'dbname': dbname}, - 'query': 'SELECT 1 AS value', - 'format': 'csv', - 'limits': {'chunkBytes': 8, 'maxBytes': 64, 'maxRowBytes': 32, 'timeoutMs': 5000}, - } - request.update(extra) - return request - - -def collect_copy_events(request, check): - return list(iter_agent_rpc_stream_copy_events(request, StaticPostgresCheckRegistry([check]))) - - -def event_metadata(event): - return event.metadata - - -def event_payload(event): - return event.payload - - -def test_copy_stream_requires_explicit_operation_before_pool_access(): - pool = FakePool(copy_blocks=[b'1\n']) - request = valid_copy_request() - request.pop('operation') - - events = collect_copy_events(request, make_check(pool=pool)) - - assert event_metadata(events[-1])['status'] == 'FAILED' - assert event_metadata(events[-1])['error']['code'] == 'invalid_request' - assert 'operation' in event_metadata(events[-1])['error']['message'] - assert pool.requested_dbnames == [] - - -def test_copy_stream_rejects_unknown_fields_without_echoing_secrets(caplog): - pool = FakePool(copy_blocks=[b'1\n']) - request = valid_copy_request(password='SECRET_DO_NOT_LOG') - - events = collect_copy_events(request, make_check(pool=pool)) - - assert event_metadata(events[-1])['status'] == 'FAILED' - assert event_metadata(events[-1])['error']['code'] == 'invalid_request' - assert 'password' in event_metadata(events[-1])['error']['message'] - assert 'SECRET_DO_NOT_LOG' not in str(events) - assert 'SECRET_DO_NOT_LOG' not in caplog.text - assert pool.requested_dbnames == [] - - -def test_copy_stream_rejects_non_copy_allowlisted_queries_before_pool_access(): - pool = FakePool(copy_blocks=[b'1\n']) - request = valid_copy_request(query='SELECT current_database()') - - events = collect_copy_events(request, make_check(pool=pool)) - - assert event_metadata(events[-1])['status'] == 'FAILED' - assert event_metadata(events[-1])['error']['code'] == 'invalid_request' - assert 'query' in event_metadata(events[-1])['error']['message'] - assert pool.requested_dbnames == [] - - def test_copy_stream_uses_connection_pool_and_emits_chunked_copy_bytes(): pool = FakePool(copy_blocks=[b'abc', b'defgh', b'ijklmnop', b'qr']) check = block_existing_query_helpers(make_check(pool=pool)) @@ -502,16 +335,35 @@ def test_copy_stream_uses_connection_pool_and_emits_chunked_copy_bytes(): assert pool.closed_copies == 1 -def test_copy_stream_data_event_payload_preserves_arbitrary_bytes(): - arbitrary_bytes = b'\x00\xff\x80abc\n' +def test_copy_stream_fixture_table_query_emits_copy_bytes(): + pool = FakePool(copy_blocks=[b'Beautiful city of lights,France\n', b'New York,USA\n']) + request = valid_copy_request(query='SELECT city, country FROM cities ORDER BY city') + + events = collect_copy_events(request, make_check(pool=pool)) + + data = b''.join(event_payload(event) for event in events if event.event_type == 'data') + assert b'Beautiful city of lights,France\n' in data + assert b'New York,USA\n' in data + assert event_metadata(events[-1])['status'] == 'SUCCEEDED' + + +def test_copy_stream_binary_format_preserves_arbitrary_bytes(): + arbitrary_bytes = b'PGCOPY\n\xff\r\n\x00\x00\xff\x80abc\n' pool = FakePool(copy_blocks=[arbitrary_bytes]) + request = valid_copy_request( + query="SELECT decode('00ff80', 'hex') AS payload", + format='binary', + limits={'chunkBytes': 1024, 'maxBytes': 4096, 'maxRowBytes': 4096, 'timeoutMs': 5000}, + ) - events = collect_copy_events(valid_copy_request(), make_check(pool=pool)) + events = collect_copy_events(request, make_check(pool=pool)) data_events = [event for event in events if event.event_type == 'data'] + assert event_metadata(events[0])['format'] == 'binary' assert len(data_events) == 1 assert event_payload(data_events[0]) == arbitrary_bytes assert isinstance(event_payload(data_events[0]), bytes) + assert event_metadata(events[-1])['status'] == 'SUCCEEDED' def test_copy_stream_enforces_max_bytes_without_exceeding_limit(): @@ -523,8 +375,7 @@ def test_copy_stream_enforces_max_bytes_without_exceeding_limit(): data_events = [event for event in events if event.event_type == 'data'] assert [event_payload(event) for event in data_events] == [b'abcdefgh', b'ij'] assert sum(event_metadata(event)['bytes'] for event in data_events) == 10 - assert event_metadata(events[-1])['status'] == 'FAILED' - assert event_metadata(events[-1])['error']['code'] == 'max_bytes_exceeded' + assert_failed_event(events, 'max_bytes_exceeded') assert event_metadata(events[-1])['stats']['bytesEmitted'] == 10 assert pool.closed_copies == 1 @@ -535,12 +386,19 @@ def test_copy_stream_enforces_max_row_bytes_after_copy_block_arrives(): events = collect_copy_events(valid_copy_request(), make_check(pool=pool)) assert [event_payload(event) for event in events if event.event_type == 'data'] == [] - assert event_metadata(events[-1])['status'] == 'FAILED' - assert event_metadata(events[-1])['error']['code'] == 'max_row_bytes_exceeded' - assert 'row granularity' in event_metadata(events[-1])['error']['message'] + assert_failed_event(events, 'max_row_bytes_exceeded', 'row granularity') assert pool.closed_copies == 1 +def test_copy_stream_closed_pool_returns_target_unavailable_without_recreating_credentials(): + pool = FakePool(closed=True) + + events = collect_copy_events(valid_copy_request(), make_check(pool=pool)) + + assert_failed_event(events, 'target_unavailable') + assert pool.requested_dbnames == [] + + def test_agent_rpc_stream_copy_adapts_iterator_to_binary_safe_callback(): arbitrary_bytes = b'\x00\xff\x80abc\n' pool = FakePool(copy_blocks=[arbitrary_bytes]) @@ -557,13 +415,12 @@ def test_agent_rpc_stream_copy_adapts_iterator_to_binary_safe_callback(): assert json.loads(events[-1][1])['status'] == 'SUCCEEDED' -def test_agent_rpc_stream_copy_rejects_malformed_json_without_echoing_input(caplog): +@pytest.mark.parametrize('request_json', ['{"password": "SECRET_DO_NOT_LOG"', b'\xff']) +def test_agent_rpc_stream_copy_rejects_malformed_json_without_echoing_input(caplog, request_json): pool = FakePool(copy_blocks=[b'1\n']) events = [] - execute_agent_rpc_stream_copy( - '{"password": "SECRET_DO_NOT_LOG"', make_check(pool=pool), lambda *event: events.append(event) - ) + execute_agent_rpc_stream_copy(request_json, make_check(pool=pool), lambda *event: events.append(event)) metadata = json.loads(events[-1][1]) assert events[-1][0] == 'error' @@ -574,6 +431,22 @@ def test_agent_rpc_stream_copy_rejects_malformed_json_without_echoing_input(capl assert pool.requested_dbnames == [] +@pytest.mark.parametrize('request_json', ['[]', 'null', '"SECRET_DO_NOT_LOG"', '1']) +def test_agent_rpc_stream_copy_rejects_non_object_json_without_echoing_input(request_json): + pool = FakePool(copy_blocks=[b'1\n']) + events = [] + + execute_agent_rpc_stream_copy(request_json, make_check(pool=pool), lambda *event: events.append(event)) + + metadata = json.loads(events[-1][1]) + assert events[-1][0] == 'error' + assert metadata['status'] == 'FAILED' + assert metadata['error']['code'] == 'invalid_request' + assert 'JSON object' in metadata['error']['message'] + assert 'SECRET_DO_NOT_LOG' not in str(events) + assert pool.requested_dbnames == [] + + def test_agent_rpc_stream_copy_closes_copy_when_callback_raises(): pool = FakePool(copy_blocks=[b'12345678', b'abcdef']) events = [] diff --git a/postgres/tests/test_remote_query_integration.py b/postgres/tests/test_remote_query_integration.py index 78dd12f64179c..1c468e720b665 100644 --- a/postgres/tests/test_remote_query_integration.py +++ b/postgres/tests/test_remote_query_integration.py @@ -6,23 +6,7 @@ import pytest -from datadog_checks.postgres.remote_query import ( - StaticPostgresCheckRegistry, - execute_remote_query, - iter_agent_rpc_stream_copy_events, -) - - -def remote_query_request(pg_instance: dict[str, object], query: str) -> dict[str, object]: - return { - 'target': { - 'host': pg_instance['host'], - 'port': int(pg_instance['port']), - 'dbname': pg_instance['dbname'], - }, - 'query': query, - 'limits': {'maxRows': 10, 'maxBytes': 1048576, 'timeoutMs': 5000}, - } +from datadog_checks.postgres.remote_query import StaticPostgresCheckRegistry, iter_agent_rpc_stream_copy_events def remote_query_copy_request( @@ -51,37 +35,22 @@ def event_payload(event): @pytest.mark.integration @pytest.mark.usefixtures('dd_environment') -def test_agent_local_executor_select_1_reuses_integration_check_credentials(integration_check, pg_instance): +def test_agent_local_copy_stream_select_1_reuses_integration_check_credentials(integration_check, pg_instance): check = integration_check(pg_instance) - request = remote_query_request(pg_instance, 'SELECT 1 AS value') - - response = execute_remote_query(request, StaticPostgresCheckRegistry([check])) - - assert response['status'] == 'SUCCEEDED' - assert response['columns'][0]['name'] == 'value' - assert response['rows'] == [{'value': 1}] - assert response['truncated'] is False - assert response['stats']['rowCount'] == 1 - assert 'password' not in json.dumps(request).lower() + request = remote_query_copy_request( + pg_instance, + 'SELECT 1 AS value', + {'chunkBytes': 16, 'maxBytes': 1024, 'maxRowBytes': 128, 'timeoutMs': 5000}, + ) + events = list(iter_agent_rpc_stream_copy_events(request, StaticPostgresCheckRegistry([check]))) -@pytest.mark.integration -@pytest.mark.usefixtures('dd_environment') -def test_agent_local_executor_fixture_table_query_returns_city_rows(integration_check, pg_instance): - bob_instance = dict(pg_instance, username='bob', password='bob') - check = integration_check(bob_instance) - request = remote_query_request(bob_instance, 'SELECT city, country FROM cities ORDER BY city') - - response = execute_remote_query(request, StaticPostgresCheckRegistry([check])) - - assert response['status'] == 'SUCCEEDED' - assert response['columns'] == [{'name': 'city', 'type': 'string'}, {'name': 'country', 'type': 'string'}] - assert response['rows'] == [ - {'city': 'Beautiful city of lights', 'country': 'France'}, - {'city': 'New York', 'country': 'USA'}, - ] - assert response['truncated'] is False - assert response['stats']['rowCount'] == 2 + data = b''.join(event_payload(event) for event in events if event.event_type == 'data') + assert events[0].event_type == 'metadata' + assert event_metadata(events[0])['operation'] == 'copy_stream' + assert event_metadata(events[-1])['status'] == 'SUCCEEDED' + assert data == b'1\n' + assert event_metadata(events[-1])['stats']['bytesEmitted'] == len(data) assert 'password' not in json.dumps(request).lower() @@ -133,22 +102,26 @@ def test_agent_local_copy_stream_binary_format_preserves_non_text_bytes(integrat @pytest.mark.integration @pytest.mark.usefixtures('dd_environment') -def test_agent_local_copy_stream_enforces_max_row_bytes_and_reuses_connection(integration_check, pg_instance): +def test_agent_local_copy_stream_enforces_max_row_bytes_and_connection_remains_reusable(integration_check, pg_instance): check = integration_check(pg_instance) - request = remote_query_copy_request( + oversized_request = remote_query_copy_request( pg_instance, "SELECT repeat('x', 1048576) AS payload", {'chunkBytes': 1024, 'maxBytes': 2 * 1048576, 'maxRowBytes': 1024, 'timeoutMs': 5000}, ) - events = list(iter_agent_rpc_stream_copy_events(request, StaticPostgresCheckRegistry([check]))) + events = list(iter_agent_rpc_stream_copy_events(oversized_request, StaticPostgresCheckRegistry([check]))) assert [event for event in events if event.event_type == 'data'] == [] assert event_metadata(events[-1])['status'] == 'FAILED' assert event_metadata(events[-1])['error']['code'] == 'max_row_bytes_exceeded' - response = execute_remote_query( - remote_query_request(pg_instance, 'SELECT 1 AS value'), StaticPostgresCheckRegistry([check]) + reusable_request = remote_query_copy_request( + pg_instance, + 'SELECT 1 AS value', + {'chunkBytes': 16, 'maxBytes': 1024, 'maxRowBytes': 128, 'timeoutMs': 5000}, ) - assert response['status'] == 'SUCCEEDED' - assert response['rows'] == [{'value': 1}] + reusable_events = list(iter_agent_rpc_stream_copy_events(reusable_request, StaticPostgresCheckRegistry([check]))) + reusable_data = b''.join(event_payload(event) for event in reusable_events if event.event_type == 'data') + assert event_metadata(reusable_events[-1])['status'] == 'SUCCEEDED' + assert reusable_data == b'1\n'