diff --git a/.github/workflows/config/labeler.yml b/.github/workflows/config/labeler.yml index e440a28f2f028..63f7249b4c2e1 100644 --- a/.github/workflows/config/labeler.yml +++ b/.github/workflows/config/labeler.yml @@ -890,10 +890,6 @@ integration/langchain: - changed-files: - any-glob-to-any-file: - langchain/**/* -integration/lparstats: -- changed-files: - - any-glob-to-any-file: - - lparstats/**/* integration/lastpass: - changed-files: - any-glob-to-any-file: @@ -918,6 +914,10 @@ integration/litellm: - changed-files: - any-glob-to-any-file: - litellm/**/* +integration/lparstats: +- changed-files: + - any-glob-to-any-file: + - lparstats/**/* integration/lustre: - changed-files: - any-glob-to-any-file: @@ -1002,10 +1002,6 @@ integration/nagios: - changed-files: - any-glob-to-any-file: - nagios/**/* -integration/nifi: -- changed-files: - - any-glob-to-any-file: - - nifi/**/* integration/network: - changed-files: - any-glob-to-any-file: @@ -1026,6 +1022,10 @@ integration/nginx_ingress_controller: - changed-files: - any-glob-to-any-file: - nginx_ingress_controller/**/* +integration/nifi: +- changed-files: + - any-glob-to-any-file: + - nifi/**/* integration/ntp: - changed-files: - any-glob-to-any-file: diff --git a/postgres/changelog.d/23499.added b/postgres/changelog.d/23499.added new file mode 100644 index 0000000000000..2c901bd4b2e33 --- /dev/null +++ b/postgres/changelog.d/23499.added @@ -0,0 +1 @@ +Add a remote query POC executor for Postgres. \ No newline at end of file diff --git a/postgres/datadog_checks/postgres/remote_query.py b/postgres/datadog_checks/postgres/remote_query.py new file mode 100644 index 0000000000000..11603a518a1b5 --- /dev/null +++ b/postgres/datadog_checks/postgres/remote_query.py @@ -0,0 +1,425 @@ +# (C) Datadog, Inc. 2026-present +# All rights reserved +# Licensed under a 3-clause BSD style license (see LICENSE) + +from __future__ import annotations + +import json +import logging +import time +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal, Protocol + +from pydantic import BaseModel, ConfigDict, Field, StrictInt, StrictStr, ValidationError, field_validator + +RemoteQueryCopySql = Literal[ + 'SELECT 1 AS value', + 'SELECT city, country FROM cities ORDER BY city', + "SELECT decode('00ff80', 'hex') AS payload", + "SELECT repeat('x', 1048576) AS payload", + "SELECT repeat('x', 2097152) AS payload", + "SELECT repeat('x', 4194304) AS payload", + "SELECT repeat('x', 8388608) AS payload", + "SELECT repeat('x', 16777216) AS payload", + "SELECT repeat('x', 33554432) AS payload", + "SELECT i, repeat('x', 1000) AS payload FROM generate_series(1, 3000) AS i", +] + +CopyStreamFormat = Literal['csv', 'binary'] +CopyStreamEmit = Callable[[str, str, bytes], None] + +if TYPE_CHECKING: + from datadog_checks.postgres import PostgreSql + +LOGGER = logging.getLogger(__name__) + + +class RemoteQueryTarget(BaseModel): + model_config = ConfigDict(extra='forbid', frozen=True) + + host: StrictStr = Field(min_length=1) + port: StrictInt = Field(default=5432, ge=1, le=65535) + dbname: StrictStr = Field(min_length=1) + + @field_validator('host') + @classmethod + def normalize_host(cls, value: str) -> str: + host = value.strip().lower() + if host.endswith('.'): + host = host[:-1] + if not host: + raise ValueError('host must be a non-empty string') + return host + + @field_validator('dbname') + @classmethod + def validate_dbname(cls, value: str) -> str: + if not value: + raise ValueError('dbname must be a non-empty string') + if value != value.strip(): + raise ValueError('dbname must not contain surrounding whitespace') + return value + + +class RemoteQueryCopyLimits(BaseModel): + """Validate byte-streaming limits for COPY export mode.""" + + model_config = ConfigDict(extra='forbid', frozen=True) + + chunk_bytes: StrictInt = Field(default=1_048_576, alias='chunkBytes', ge=1) + max_bytes: StrictInt = Field(default=64 * 1_048_576, alias='maxBytes', ge=1) + max_row_bytes: StrictInt = Field(default=8 * 1_048_576, alias='maxRowBytes', ge=1) + timeout_ms: StrictInt = Field(default=30_000, alias='timeoutMs', ge=1) + + +class RemoteQueryCopyRequest(BaseModel): + """Accept only explicit COPY byte-stream export requests.""" + + model_config = ConfigDict(extra='forbid', frozen=True) + + operation: Literal['copy_stream'] = Field(alias='operation') + target: RemoteQueryTarget + query: RemoteQueryCopySql + format: CopyStreamFormat = 'csv' + limits: RemoteQueryCopyLimits = Field(default_factory=RemoteQueryCopyLimits) + + +@dataclass(frozen=True) +class StaticPostgresCheckRegistry: + checks: Sequence['PostgreSql'] + + def iter_postgres_checks(self) -> Iterable['PostgreSql']: + return iter(self.checks) + + +@dataclass(frozen=True) +class _CopyStreamState: + sequence: int = 0 + chunks_emitted: int = 0 + bytes_emitted: int = 0 + + +@dataclass(frozen=True) +class CopyStreamEvent: + event_type: str + metadata: Mapping[str, Any] + payload: bytes = b'' + + +class _CopyStreamFailure(Exception): + def __init__(self, code: str, message: str, retryable: bool = False): + self.code = code + self.message = message + self.retryable = retryable + super().__init__(message) + + +class PostgresCheckRegistry(Protocol): + def iter_postgres_checks(self) -> Iterable['PostgreSql']: ... + + +def execute_agent_rpc_stream_copy( + request_json: str | bytes | bytearray, check: 'PostgreSql', emit: CopyStreamEmit +) -> None: + """Execute an explicit COPY byte-stream request and emit chunk events.""" + try: + request = json.loads(request_json) + except (TypeError, ValueError): + _emit_copy_event( + emit, + _stream_failed_event( + 'invalid_request', 'Invalid remote query request: request_json must be a valid JSON object.' + ), + ) + return + + if not isinstance(request, Mapping): + _emit_copy_event( + emit, + _stream_failed_event( + 'invalid_request', 'Invalid remote query request: request_json must be a JSON object.' + ), + ) + return + + events = iter_agent_rpc_stream_copy_events(request, StaticPostgresCheckRegistry([check])) + try: + for event in events: + _emit_copy_event(emit, event) + except BaseException: + events.close() + raise + + +def iter_agent_rpc_stream_copy_events(request: Any, registry: PostgresCheckRegistry) -> Iterator[CopyStreamEvent]: + """Yield COPY byte-stream events for unit tests and callback adaptation.""" + started_at = time.monotonic() + try: + parsed_request = RemoteQueryCopyRequest.model_validate(request) + except ValidationError as e: + yield _stream_failed_event('invalid_request', _validation_message(e), elapsed_ms=_elapsed_ms(started_at)) + return + + target = parsed_request.target + matches = _resolve_matches(target, registry.iter_postgres_checks()) + LOGGER.debug('Remote query COPY stream target match count: %d', len(matches)) + if not matches: + yield _stream_failed_event( + 'target_not_found', + 'No loaded Postgres integration instance matched target selector.', + elapsed_ms=_elapsed_ms(started_at), + ) + return + if len(matches) > 1: + yield _stream_failed_event( + 'target_ambiguous', + 'More than one loaded Postgres integration instance matched target selector.', + elapsed_ms=_elapsed_ms(started_at), + ) + return + + yield from _iter_copy_stream_events(matches[0], parsed_request, started_at) + + +def normalize_target(target: Mapping[str, Any]) -> RemoteQueryTarget: + try: + return RemoteQueryTarget.model_validate(target) + except ValidationError as e: + raise ValueError(_validation_message(e)) from e + + +def _resolve_matches(target: RemoteQueryTarget, checks: Iterable['PostgreSql']) -> list['PostgreSql']: + return [check for check in checks if _target_from_check(check) == target] + + +def _target_from_check(check: 'PostgreSql') -> RemoteQueryTarget | None: + config = getattr(check, '_config', None) + if config is None: + return None + + try: + return RemoteQueryTarget(host=config.host, port=config.port, dbname=config.dbname) + except (AttributeError, ValidationError): + return None + + +def _iter_copy_stream_events( + check: 'PostgreSql', request: RemoteQueryCopyRequest, started_at: float +) -> Iterator[CopyStreamEvent]: + db_pool = getattr(check, 'db_pool', None) + if db_pool is None: + yield _stream_failed_event( + 'credentials_unavailable', + 'Matched Postgres check does not expose a connection pool.', + elapsed_ms=_elapsed_ms(started_at), + ) + return + if getattr(db_pool, 'is_closed', lambda: False)(): + yield _stream_failed_event( + 'target_unavailable', + 'Matched Postgres check connection pool is closed.', + retryable=False, + elapsed_ms=_elapsed_ms(started_at), + ) + return + + yield CopyStreamEvent( + 'metadata', + { + 'status': 'STARTED', + 'format': request.format, + 'operation': request.operation, + 'chunkBytes': request.limits.chunk_bytes, + 'maxBytes': request.limits.max_bytes, + 'maxRowBytes': request.limits.max_row_bytes, + }, + ) + + state = _CopyStreamState() + error: _CopyStreamFailure | None = None + try: + for event, next_state in _copy_stream_data_events(check, request, state, started_at): + state = next_state + yield event + except _CopyStreamFailure as e: + error = e + except RuntimeError: + error = _CopyStreamFailure( + 'target_unavailable', 'Matched Postgres check connection pool is unavailable.', retryable=False + ) + except Exception: + LOGGER.exception('Remote query COPY stream execution failed') + error = _CopyStreamFailure('query_failed', 'Remote query COPY stream execution failed.') + + if error is not None: + yield _stream_failed_event( + error.code, + error.message, + retryable=error.retryable, + stats=_copy_stream_stats(state, started_at, request.format), + ) + return + + yield CopyStreamEvent( + 'final', + {'status': 'SUCCEEDED', 'stats': _copy_stream_stats(state, started_at, request.format)}, + ) + + +def _copy_stream_data_events( + check: 'PostgreSql', request: RemoteQueryCopyRequest, state: _CopyStreamState, started_at: float +) -> Iterator[tuple[CopyStreamEvent, _CopyStreamState]]: + limits = request.limits + deadline = started_at + (limits.timeout_ms / 1000) + copy_sql = _copy_stdout_sql(request.query, request.format) + pending = bytearray() + + with check.db_pool.get_connection(request.target.dbname) as conn: + with conn.cursor() as cursor: + previous_statement_timeout = _set_statement_timeout(cursor, limits.timeout_ms) + try: + with cursor.copy(copy_sql) as copy: + for block in copy: + _raise_if_timed_out(deadline) + block_view = memoryview(block) + if len(block_view) > limits.max_row_bytes: + raise _CopyStreamFailure( + 'max_row_bytes_exceeded', + 'COPY stream row exceeded maxRowBytes; psycopg exposes COPY data at row granularity.', + ) + + offset = 0 + while offset < len(block_view): + _raise_if_timed_out(deadline) + remaining_allowed = limits.max_bytes - state.bytes_emitted - len(pending) + if remaining_allowed <= 0: + raise _CopyStreamFailure('max_bytes_exceeded', 'COPY stream exceeded maxBytes.') + + remaining_chunk = limits.chunk_bytes - len(pending) + take = min(remaining_chunk, remaining_allowed, len(block_view) - offset) + pending.extend(block_view[offset : offset + take]) + offset += take + + if len(pending) >= limits.chunk_bytes: + event, state = _copy_data_event(pending, state) + pending.clear() + yield event, state + + if offset < len(block_view) and state.bytes_emitted + len(pending) >= limits.max_bytes: + if pending: + event, state = _copy_data_event(pending, state) + pending.clear() + yield event, state + raise _CopyStreamFailure('max_bytes_exceeded', 'COPY stream exceeded maxBytes.') + + if pending: + event, state = _copy_data_event(pending, state) + pending.clear() + yield event, state + finally: + _restore_statement_timeout(cursor, previous_statement_timeout) + + +def _copy_stdout_sql(query: str, stream_format: CopyStreamFormat) -> str: + if stream_format == 'csv': + return f'COPY ({query}) TO STDOUT WITH (FORMAT CSV)' + if stream_format == 'binary': + return f'COPY ({query}) TO STDOUT WITH (FORMAT BINARY)' + raise _CopyStreamFailure('invalid_request', 'Unsupported COPY stream format.') + + +def _set_statement_timeout(cursor: Any, timeout_ms: int) -> str | None: + previous_statement_timeout = None + try: + cursor.execute('SHOW statement_timeout') + row = cursor.fetchone() + previous_statement_timeout = row[0] if row else None + cursor.execute('SET statement_timeout = %s', (timeout_ms,)) + except Exception: + LOGGER.debug('Unable to scope statement_timeout for remote query COPY stream', exc_info=True) + return previous_statement_timeout + + +def _restore_statement_timeout(cursor: Any, previous_statement_timeout: str | None) -> None: + if previous_statement_timeout is None: + return + try: + cursor.execute('SET statement_timeout = %s', (previous_statement_timeout,)) + except Exception: + LOGGER.debug('Unable to restore statement_timeout after remote query COPY stream', exc_info=True) + + +def _raise_if_timed_out(deadline: float) -> None: + if time.monotonic() > deadline: + raise _CopyStreamFailure('timeout', 'COPY stream exceeded timeoutMs.', retryable=True) + + +def _copy_data_event(data: bytearray, state: _CopyStreamState) -> tuple[CopyStreamEvent, _CopyStreamState]: + payload = bytes(data) + event = CopyStreamEvent( + 'data', + { + 'sequence': state.sequence, + 'offset': state.bytes_emitted, + 'bytes': len(payload), + }, + payload, + ) + next_state = _CopyStreamState( + sequence=state.sequence + 1, + chunks_emitted=state.chunks_emitted + 1, + bytes_emitted=state.bytes_emitted + len(payload), + ) + return event, next_state + + +def _copy_stream_stats(state: _CopyStreamState, started_at: float, stream_format: CopyStreamFormat) -> dict[str, Any]: + return { + 'format': stream_format, + 'bytesEmitted': state.bytes_emitted, + 'chunksEmitted': state.chunks_emitted, + 'elapsedMs': _elapsed_ms(started_at), + } + + +def _elapsed_ms(started_at: float) -> int: + return max(0, int((time.monotonic() - started_at) * 1000)) + + +def _stream_failed_event( + code: str, + message: str, + retryable: bool = False, + stats: Mapping[str, Any] | None = None, + elapsed_ms: int | None = None, +) -> CopyStreamEvent: + metadata = { + 'status': 'FAILED', + 'error': {'code': code, 'message': message, 'retryable': retryable}, + } + if stats is not None: + metadata['stats'] = dict(stats) + elif elapsed_ms is not None: + metadata['stats'] = {'elapsedMs': elapsed_ms} + return CopyStreamEvent('error', metadata) + + +def _emit_copy_event(emit: CopyStreamEmit, event: CopyStreamEvent) -> None: + emit(event.event_type, json.dumps(event.metadata, default=str), event.payload) + + +def _validation_message(error: ValidationError) -> str: + details = [] + for item in error.errors(include_input=False): + location = _validation_location(item.get('loc', ())) + message = item.get('msg', 'Invalid value') + if location: + details.append(f'{location}: {message}') + else: + details.append(message) + return 'Invalid remote query request: {}'.format('; '.join(details)) + + +def _validation_location(location: tuple[Any, ...]) -> str: + return '.'.join(str(part) for part in location) diff --git a/postgres/tests/test_remote_query.py b/postgres/tests/test_remote_query.py new file mode 100644 index 0000000000000..c6a22081cb3fe --- /dev/null +++ b/postgres/tests/test_remote_query.py @@ -0,0 +1,463 @@ +# (C) Datadog, Inc. 2026-present +# All rights reserved +# Licensed under a 3-clause BSD style license (see LICENSE) + +import json +from contextlib import contextmanager +from types import SimpleNamespace + +import pytest + +from datadog_checks.postgres.remote_query import ( + StaticPostgresCheckRegistry, + execute_agent_rpc_stream_copy, + iter_agent_rpc_stream_copy_events, + normalize_target, +) + + +class FakePool: + def __init__(self, rows=None, description=None, closed=False, copy_blocks=None): + self.rows = rows or [(1,)] + self.description = description or [SimpleNamespace(name='value')] + self.closed = closed + self.copy_blocks = copy_blocks or [] + self.requested_dbnames = [] + self.closed_copies = 0 + + def is_closed(self): + return self.closed + + @contextmanager + def get_connection(self, dbname): + self.requested_dbnames.append(dbname) + yield FakeConnection(self.rows, self.description, self.copy_blocks, self) + + +class FakeConnection: + def __init__(self, rows, description, copy_blocks, pool): + self.rows = rows + self.description = description + self.copy_blocks = copy_blocks + self.pool = pool + + @contextmanager + def cursor(self): + yield FakeCursor(self.rows, self.description, self.copy_blocks, self.pool) + + +class FakeCursor: + def __init__(self, rows, description, copy_blocks, pool): + self.rows = rows + self.description = description + self.copy_blocks = copy_blocks + self.pool = pool + self.executed = [] + + def execute(self, query, params=None): + self.executed.append((query, params)) + + def fetchone(self): + return ('0',) + + def copy(self, query): + self.executed.append((query, None)) + return FakeCopy(self.copy_blocks, self.pool) + + +class FakeCopy: + def __init__(self, blocks, pool): + self.blocks = blocks + self.pool = pool + + def __enter__(self): + return self + + def __exit__(self, *args): + self.pool.closed_copies += 1 + + def __iter__(self): + return iter(self.blocks) + + +def make_check(host='localhost', port=5432, dbname='datadog_test', pool=None, **metadata): + return SimpleNamespace( + _config=SimpleNamespace(host=host, port=port, dbname=dbname, **metadata), + db_pool=pool or FakePool(), + ) + + +def block_existing_query_helpers(check): + check.execute_query_raw = pytest.fail + check._run_query_scope = pytest.fail + check.data_observability = SimpleNamespace(run_job=pytest.fail) + return check + + +def valid_copy_request(host='LOCALHOST.', port=5432, dbname='datadog_test', **extra): + request = { + 'operation': 'copy_stream', + 'target': {'host': host, 'port': port, 'dbname': dbname}, + 'query': 'SELECT 1 AS value', + 'format': 'csv', + 'limits': {'chunkBytes': 8, 'maxBytes': 64, 'maxRowBytes': 32, 'timeoutMs': 5000}, + } + request.update(extra) + return request + + +class ExplodingRegistry: + def iter_postgres_checks(self): + pytest.fail('registry must not be iterated') + + +def collect_copy_events(request, check): + return list(iter_agent_rpc_stream_copy_events(request, StaticPostgresCheckRegistry([check]))) + + +def event_metadata(event): + return event.metadata + + +def event_payload(event): + return event.payload + + +def assert_failed_event(events, code, message_contains=None): + assert event_metadata(events[-1])['status'] == 'FAILED' + assert event_metadata(events[-1])['error']['code'] == code + if message_contains is not None: + assert message_contains in event_metadata(events[-1])['error']['message'] + + +def test_normalize_target_trims_lowercases_host_and_removes_one_trailing_dot(): + target = normalize_target({'host': ' Example.INTERNAL. ', 'port': 5432, 'dbname': 'postgres'}) + + assert target.host == 'example.internal' + assert target.port == 5432 + assert target.dbname == 'postgres' + + +def test_normalize_target_defaults_missing_port_to_5432(): + target = normalize_target({'host': 'localhost', 'dbname': 'postgres'}) + + assert target.port == 5432 + + +@pytest.mark.parametrize('port', [True, '5432', 'abc', '0', 0, -1, 65536, None]) +def test_normalize_target_rejects_invalid_port_values(port): + with pytest.raises(ValueError): + normalize_target({'host': 'localhost', 'port': port, 'dbname': 'postgres'}) + + +@pytest.mark.parametrize( + 'target', + [ + {'host': '', 'port': 5432, 'dbname': 'postgres'}, + {'host': ' ', 'port': 5432, 'dbname': 'postgres'}, + {'host': 'localhost', 'port': 5432, 'dbname': ''}, + {'host': 'localhost', 'port': 5432, 'dbname': ' postgres '}, + ], +) +def test_normalize_target_rejects_empty_host_or_dbname(target): + with pytest.raises(ValueError): + normalize_target(target) + + +@pytest.mark.parametrize('field', ['extra', 'password']) +def test_copy_stream_rejects_unknown_request_fields_before_resolution(caplog, field): + request = valid_copy_request(**{field: 'SECRET_DO_NOT_LOG'}) + + events = list(iter_agent_rpc_stream_copy_events(request, ExplodingRegistry())) + + assert_failed_event(events, 'invalid_request', field) + assert 'SECRET_DO_NOT_LOG' not in str(events) + assert 'SECRET_DO_NOT_LOG' not in caplog.text + + +def test_copy_stream_rejects_unknown_target_fields_before_resolution(): + request = valid_copy_request() + request['target']['password'] = 'SECRET_DO_NOT_LOG' + + events = list(iter_agent_rpc_stream_copy_events(request, ExplodingRegistry())) + + assert_failed_event(events, 'invalid_request', 'password') + assert 'SECRET_DO_NOT_LOG' not in str(events) + + +def test_copy_stream_rejects_unknown_limits_fields_before_resolution(): + request = valid_copy_request() + request['limits']['password'] = 'SECRET_DO_NOT_LOG' + + events = list(iter_agent_rpc_stream_copy_events(request, ExplodingRegistry())) + + assert_failed_event(events, 'invalid_request', 'password') + assert 'SECRET_DO_NOT_LOG' not in str(events) + + +@pytest.mark.parametrize('field', ['chunkBytes', 'maxBytes', 'maxRowBytes', 'timeoutMs']) +def test_copy_stream_rejects_string_limit_values_before_resolution(field): + request = valid_copy_request() + request['limits'][field] = '10' + + events = list(iter_agent_rpc_stream_copy_events(request, ExplodingRegistry())) + + assert_failed_event(events, 'invalid_request', field) + + +def test_copy_stream_requires_explicit_operation_before_pool_access(): + pool = FakePool(copy_blocks=[b'1\n']) + request = valid_copy_request() + request.pop('operation') + + events = collect_copy_events(request, make_check(pool=pool)) + + assert_failed_event(events, 'invalid_request', 'operation') + assert pool.requested_dbnames == [] + + +@pytest.mark.parametrize('operation', ['query', 'execute', None]) +def test_copy_stream_rejects_non_copy_operation_before_pool_access(operation): + pool = FakePool(copy_blocks=[b'1\n']) + request = valid_copy_request(operation=operation) + + events = collect_copy_events(request, make_check(pool=pool)) + + assert_failed_event(events, 'invalid_request', 'operation') + assert pool.requested_dbnames == [] + + +def test_copy_stream_rejects_non_copy_allowlisted_queries_before_pool_access(): + pool = FakePool(copy_blocks=[b'1\n']) + request = valid_copy_request(query='SELECT current_database()') + + events = collect_copy_events(request, make_check(pool=pool)) + + assert_failed_event(events, 'invalid_request', 'query') + assert pool.requested_dbnames == [] + + +@pytest.mark.parametrize('size', [1048576, 2097152, 4194304, 8388608, 16777216, 33554432]) +def test_copy_stream_accepts_large_payload_proof_queries(size): + pool = FakePool(copy_blocks=[b'x' * 8]) + request = valid_copy_request(query=f"SELECT repeat('x', {size}) AS payload") + + events = collect_copy_events(request, make_check(pool=pool)) + + assert event_metadata(events[-1])['status'] == 'SUCCEEDED' + assert pool.requested_dbnames == ['datadog_test'] + + +def test_copy_stream_resolves_exact_host_port_dbname_from_check_config(): + pool = FakePool(copy_blocks=[b'1\n']) + check = make_check(host='localhost', port=5432, dbname='datadog_test', pool=pool) + + events = collect_copy_events(valid_copy_request(host='LOCALHOST.', port=5432), check) + + assert event_metadata(events[-1])['status'] == 'SUCCEEDED' + assert pool.requested_dbnames == ['datadog_test'] + + +def test_copy_stream_uses_only_supplied_live_check_for_target_matching(): + matching_pool = FakePool(copy_blocks=[b'1\n']) + non_matching_pool = FakePool(copy_blocks=[b'1\n']) + request = valid_copy_request(host='configured.internal') + + events = collect_copy_events(request, make_check(host='localhost', pool=non_matching_pool)) + + assert_failed_event(events, 'target_not_found') + assert non_matching_pool.requested_dbnames == [] + + events = collect_copy_events(request, make_check(host='configured.internal', pool=matching_pool)) + + assert event_metadata(events[-1])['status'] == 'SUCCEEDED' + assert matching_pool.requested_dbnames == ['datadog_test'] + + +def test_copy_stream_requires_dbname_match_even_when_host_and_port_match(): + pool = FakePool(copy_blocks=[b'1\n']) + check = make_check(host='localhost', port=5432, dbname='datadog_test', pool=pool) + + events = collect_copy_events(valid_copy_request(dbname='postgres'), check) + + assert_failed_event(events, 'target_not_found') + assert pool.requested_dbnames == [] + + +def test_copy_stream_ignores_metadata_identity_matches(): + pool = FakePool(copy_blocks=[b'1\n']) + check = make_check( + host='configured.internal', + port=5432, + dbname='datadog_test', + pool=pool, + reported_hostname='reported.internal', + database_identifier='reported.internal', + ) + + events = collect_copy_events(valid_copy_request(host='reported.internal'), check) + + assert_failed_event(events, 'target_not_found') + assert pool.requested_dbnames == [] + + +def test_copy_stream_fails_ambiguous_duplicate_configs(): + first_pool = FakePool(copy_blocks=[b'1\n']) + second_pool = FakePool(copy_blocks=[b'1\n']) + checks = [make_check(pool=first_pool), make_check(pool=second_pool)] + + events = list(iter_agent_rpc_stream_copy_events(valid_copy_request(), StaticPostgresCheckRegistry(checks))) + + assert_failed_event(events, 'target_ambiguous') + assert first_pool.requested_dbnames == [] + assert second_pool.requested_dbnames == [] + + +def test_copy_stream_uses_connection_pool_and_emits_chunked_copy_bytes(): + pool = FakePool(copy_blocks=[b'abc', b'defgh', b'ijklmnop', b'qr']) + check = block_existing_query_helpers(make_check(pool=pool)) + + events = collect_copy_events(valid_copy_request(), check) + + assert events[0].event_type == 'metadata' + assert event_metadata(events[0])['operation'] == 'copy_stream' + assert event_metadata(events[0])['format'] == 'csv' + data_events = [event for event in events if event.event_type == 'data'] + assert [event_metadata(event)['sequence'] for event in data_events] == [0, 1, 2] + assert [event_metadata(event)['offset'] for event in data_events] == [0, 8, 16] + assert [event_payload(event) for event in data_events] == [b'abcdefgh', b'ijklmnop', b'qr'] + assert [event_metadata(event)['bytes'] for event in data_events] == [8, 8, 2] + assert events[-1].event_type == 'final' + assert event_metadata(events[-1])['status'] == 'SUCCEEDED' + assert event_metadata(events[-1])['stats']['bytesEmitted'] == 18 + assert event_metadata(events[-1])['stats']['chunksEmitted'] == 3 + assert pool.requested_dbnames == ['datadog_test'] + assert pool.closed_copies == 1 + + +def test_copy_stream_fixture_table_query_emits_copy_bytes(): + pool = FakePool(copy_blocks=[b'Beautiful city of lights,France\n', b'New York,USA\n']) + request = valid_copy_request(query='SELECT city, country FROM cities ORDER BY city') + + events = collect_copy_events(request, make_check(pool=pool)) + + data = b''.join(event_payload(event) for event in events if event.event_type == 'data') + assert b'Beautiful city of lights,France\n' in data + assert b'New York,USA\n' in data + assert event_metadata(events[-1])['status'] == 'SUCCEEDED' + + +def test_copy_stream_binary_format_preserves_arbitrary_bytes(): + arbitrary_bytes = b'PGCOPY\n\xff\r\n\x00\x00\xff\x80abc\n' + pool = FakePool(copy_blocks=[arbitrary_bytes]) + request = valid_copy_request( + query="SELECT decode('00ff80', 'hex') AS payload", + format='binary', + limits={'chunkBytes': 1024, 'maxBytes': 4096, 'maxRowBytes': 4096, 'timeoutMs': 5000}, + ) + + events = collect_copy_events(request, make_check(pool=pool)) + + data_events = [event for event in events if event.event_type == 'data'] + assert event_metadata(events[0])['format'] == 'binary' + assert len(data_events) == 1 + assert event_payload(data_events[0]) == arbitrary_bytes + assert isinstance(event_payload(data_events[0]), bytes) + assert event_metadata(events[-1])['status'] == 'SUCCEEDED' + + +def test_copy_stream_enforces_max_bytes_without_exceeding_limit(): + pool = FakePool(copy_blocks=[b'abcdefgh', b'ijklmnop']) + request = valid_copy_request(limits={'chunkBytes': 8, 'maxBytes': 10, 'maxRowBytes': 32, 'timeoutMs': 5000}) + + events = collect_copy_events(request, make_check(pool=pool)) + + data_events = [event for event in events if event.event_type == 'data'] + assert [event_payload(event) for event in data_events] == [b'abcdefgh', b'ij'] + assert sum(event_metadata(event)['bytes'] for event in data_events) == 10 + assert_failed_event(events, 'max_bytes_exceeded') + assert event_metadata(events[-1])['stats']['bytesEmitted'] == 10 + assert pool.closed_copies == 1 + + +def test_copy_stream_enforces_max_row_bytes_after_copy_block_arrives(): + pool = FakePool(copy_blocks=[b'abc', b'x' * 33]) + + events = collect_copy_events(valid_copy_request(), make_check(pool=pool)) + + assert [event_payload(event) for event in events if event.event_type == 'data'] == [] + assert_failed_event(events, 'max_row_bytes_exceeded', 'row granularity') + assert pool.closed_copies == 1 + + +def test_copy_stream_closed_pool_returns_target_unavailable_without_recreating_credentials(): + pool = FakePool(closed=True) + + events = collect_copy_events(valid_copy_request(), make_check(pool=pool)) + + assert_failed_event(events, 'target_unavailable') + assert pool.requested_dbnames == [] + + +def test_agent_rpc_stream_copy_adapts_iterator_to_binary_safe_callback(): + arbitrary_bytes = b'\x00\xff\x80abc\n' + pool = FakePool(copy_blocks=[arbitrary_bytes]) + events = [] + + execute_agent_rpc_stream_copy( + json.dumps(valid_copy_request()), make_check(pool=pool), lambda *event: events.append(event) + ) + + assert [event[0] for event in events] == ['metadata', 'data', 'final'] + assert json.loads(events[1][1])['bytes'] == len(arbitrary_bytes) + assert events[1][2] == arbitrary_bytes + assert isinstance(events[1][2], bytes) + assert json.loads(events[-1][1])['status'] == 'SUCCEEDED' + + +@pytest.mark.parametrize('request_json', ['{"password": "SECRET_DO_NOT_LOG"', b'\xff']) +def test_agent_rpc_stream_copy_rejects_malformed_json_without_echoing_input(caplog, request_json): + pool = FakePool(copy_blocks=[b'1\n']) + events = [] + + execute_agent_rpc_stream_copy(request_json, make_check(pool=pool), lambda *event: events.append(event)) + + metadata = json.loads(events[-1][1]) + assert events[-1][0] == 'error' + assert metadata['status'] == 'FAILED' + assert metadata['error']['code'] == 'invalid_request' + assert 'SECRET_DO_NOT_LOG' not in str(events) + assert 'SECRET_DO_NOT_LOG' not in caplog.text + assert pool.requested_dbnames == [] + + +@pytest.mark.parametrize('request_json', ['[]', 'null', '"SECRET_DO_NOT_LOG"', '1']) +def test_agent_rpc_stream_copy_rejects_non_object_json_without_echoing_input(request_json): + pool = FakePool(copy_blocks=[b'1\n']) + events = [] + + execute_agent_rpc_stream_copy(request_json, make_check(pool=pool), lambda *event: events.append(event)) + + metadata = json.loads(events[-1][1]) + assert events[-1][0] == 'error' + assert metadata['status'] == 'FAILED' + assert metadata['error']['code'] == 'invalid_request' + assert 'JSON object' in metadata['error']['message'] + assert 'SECRET_DO_NOT_LOG' not in str(events) + assert pool.requested_dbnames == [] + + +def test_agent_rpc_stream_copy_closes_copy_when_callback_raises(): + pool = FakePool(copy_blocks=[b'12345678', b'abcdef']) + events = [] + + def emit(event_type, metadata_json, payload): + events.append((event_type, metadata_json, payload)) + if event_type == 'data': + raise RuntimeError('stop streaming') + + with pytest.raises(RuntimeError, match='stop streaming'): + execute_agent_rpc_stream_copy(json.dumps(valid_copy_request()), make_check(pool=pool), emit) + + assert [event[0] for event in events] == ['metadata', 'data'] + assert pool.closed_copies == 1 diff --git a/postgres/tests/test_remote_query_integration.py b/postgres/tests/test_remote_query_integration.py new file mode 100644 index 0000000000000..1c468e720b665 --- /dev/null +++ b/postgres/tests/test_remote_query_integration.py @@ -0,0 +1,127 @@ +# (C) Datadog, Inc. 2026-present +# All rights reserved +# Licensed under a 3-clause BSD style license (see LICENSE) + +import json + +import pytest + +from datadog_checks.postgres.remote_query import StaticPostgresCheckRegistry, iter_agent_rpc_stream_copy_events + + +def remote_query_copy_request( + pg_instance: dict[str, object], query: str, limits: dict[str, int], stream_format: str = 'csv' +) -> dict[str, object]: + return { + 'operation': 'copy_stream', + 'target': { + 'host': pg_instance['host'], + 'port': int(pg_instance['port']), + 'dbname': pg_instance['dbname'], + }, + 'query': query, + 'format': stream_format, + 'limits': limits, + } + + +def event_metadata(event): + return event.metadata + + +def event_payload(event): + return event.payload + + +@pytest.mark.integration +@pytest.mark.usefixtures('dd_environment') +def test_agent_local_copy_stream_select_1_reuses_integration_check_credentials(integration_check, pg_instance): + check = integration_check(pg_instance) + request = remote_query_copy_request( + pg_instance, + 'SELECT 1 AS value', + {'chunkBytes': 16, 'maxBytes': 1024, 'maxRowBytes': 128, 'timeoutMs': 5000}, + ) + + events = list(iter_agent_rpc_stream_copy_events(request, StaticPostgresCheckRegistry([check]))) + + data = b''.join(event_payload(event) for event in events if event.event_type == 'data') + assert events[0].event_type == 'metadata' + assert event_metadata(events[0])['operation'] == 'copy_stream' + assert event_metadata(events[-1])['status'] == 'SUCCEEDED' + assert data == b'1\n' + assert event_metadata(events[-1])['stats']['bytesEmitted'] == len(data) + assert 'password' not in json.dumps(request).lower() + + +@pytest.mark.integration +@pytest.mark.usefixtures('dd_environment') +def test_agent_local_copy_stream_fixture_table_query_emits_csv_chunks(integration_check, pg_instance): + bob_instance = dict(pg_instance, username='bob', password='bob') + check = integration_check(bob_instance) + request = remote_query_copy_request( + bob_instance, + 'SELECT city, country FROM cities ORDER BY city', + {'chunkBytes': 16, 'maxBytes': 1024, 'maxRowBytes': 128, 'timeoutMs': 5000}, + ) + + events = list(iter_agent_rpc_stream_copy_events(request, StaticPostgresCheckRegistry([check]))) + + data_events = [event for event in events if event.event_type == 'data'] + data = b''.join(event_payload(event) for event in data_events) + assert events[0].event_type == 'metadata' + assert event_metadata(events[0])['format'] == 'csv' + assert event_metadata(events[-1])['status'] == 'SUCCEEDED' + assert b'Beautiful city of lights,France\n' in data + assert b'New York,USA\n' in data + assert event_metadata(events[-1])['stats']['bytesEmitted'] == len(data) + assert all(event_metadata(event)['bytes'] <= 16 for event in data_events) + assert 'password' not in json.dumps(request).lower() + + +@pytest.mark.integration +@pytest.mark.usefixtures('dd_environment') +def test_agent_local_copy_stream_binary_format_preserves_non_text_bytes(integration_check, pg_instance): + check = integration_check(pg_instance) + request = remote_query_copy_request( + pg_instance, + "SELECT decode('00ff80', 'hex') AS payload", + {'chunkBytes': 1024, 'maxBytes': 4096, 'maxRowBytes': 4096, 'timeoutMs': 5000}, + stream_format='binary', + ) + + events = list(iter_agent_rpc_stream_copy_events(request, StaticPostgresCheckRegistry([check]))) + + data = b''.join(event_payload(event) for event in events if event.event_type == 'data') + assert events[0].event_type == 'metadata' + assert event_metadata(events[0])['format'] == 'binary' + assert event_metadata(events[-1])['status'] == 'SUCCEEDED' + assert b'PGCOPY\n\xff\r\n\x00' in data + assert b'\x00\xff\x80' in data + + +@pytest.mark.integration +@pytest.mark.usefixtures('dd_environment') +def test_agent_local_copy_stream_enforces_max_row_bytes_and_connection_remains_reusable(integration_check, pg_instance): + check = integration_check(pg_instance) + oversized_request = remote_query_copy_request( + pg_instance, + "SELECT repeat('x', 1048576) AS payload", + {'chunkBytes': 1024, 'maxBytes': 2 * 1048576, 'maxRowBytes': 1024, 'timeoutMs': 5000}, + ) + + events = list(iter_agent_rpc_stream_copy_events(oversized_request, StaticPostgresCheckRegistry([check]))) + + assert [event for event in events if event.event_type == 'data'] == [] + assert event_metadata(events[-1])['status'] == 'FAILED' + assert event_metadata(events[-1])['error']['code'] == 'max_row_bytes_exceeded' + + reusable_request = remote_query_copy_request( + pg_instance, + 'SELECT 1 AS value', + {'chunkBytes': 16, 'maxBytes': 1024, 'maxRowBytes': 128, 'timeoutMs': 5000}, + ) + reusable_events = list(iter_agent_rpc_stream_copy_events(reusable_request, StaticPostgresCheckRegistry([check]))) + reusable_data = b''.join(event_payload(event) for event in reusable_events if event.event_type == 'data') + assert event_metadata(reusable_events[-1])['status'] == 'SUCCEEDED' + assert reusable_data == b'1\n'