diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 318bb70..0e7b478 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -31,18 +31,22 @@ uv run ty check drift/ tests/ # Type check ### Unit Tests ```bash -uv run python -m unittest discover -s tests/unit -v +uv run pytest tests/unit/ -v # Run a specific test file -uv run python -m unittest tests.unit.test_json_schema_helper -v -uv run python -m unittest tests.unit.test_adapters -v +uv run pytest tests/unit/test_json_schema_helper.py -v +uv run pytest tests/unit/test_adapters.py -v + +# Run a specific test class or function +uv run pytest tests/unit/test_metrics.py::TestMetricsCollector -v +uv run pytest tests/unit/test_metrics.py::TestMetricsCollector::test_record_spans_exported -v ``` ### Integration Tests ```bash # Flask/FastAPI integration tests -timeout 30 uv run python -m unittest discover -s tests/integration -v +timeout 30 uv run pytest tests/integration/ -v ``` ### E2E Tests diff --git a/pyproject.toml b/pyproject.toml index fd0d551..e684e03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,8 @@ dev = [ "python-jsonpath>=0.10", "ruff>=0.8.0", "ty>=0.0.1a7", - "pytest>=8.0.0", + "pytest>=8.0.0,<9.0.0", + "pytest-mock>=3.15.0", ] [project.urls] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..8fcb5d3 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,46 @@ +"""Pytest configuration and fixtures for Drift Python SDK tests.""" + +from __future__ import annotations + +import os +import tempfile +from collections.abc import Generator +from pathlib import Path +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from drift.core.metrics import MetricsCollector + from drift.core.tracing.adapters import InMemorySpanAdapter + + +@pytest.fixture +def temp_dir() -> Generator[Path, None, None]: + """Create a temporary directory for test files.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def original_cwd() -> Generator[str, None, None]: + """Save and restore the current working directory.""" + cwd = os.getcwd() + yield cwd + os.chdir(cwd) + + +@pytest.fixture +def in_memory_adapter() -> InMemorySpanAdapter: + """Create a fresh InMemorySpanAdapter for testing.""" + from drift.core.tracing.adapters import InMemorySpanAdapter + + return InMemorySpanAdapter() + + +@pytest.fixture +def metrics_collector() -> MetricsCollector: + """Create a fresh MetricsCollector for testing.""" + from drift.core.metrics import MetricsCollector + + return MetricsCollector() diff --git a/tests/integration/test_fastapi_basic.py b/tests/integration/test_fastapi_basic.py index d42228d..20a170c 100644 --- a/tests/integration/test_fastapi_basic.py +++ b/tests/integration/test_fastapi_basic.py @@ -7,289 +7,297 @@ import os import sys import time -import unittest from pathlib import Path +import pytest +import requests + # Set up environment before importing drift os.environ["TUSK_DRIFT_MODE"] = "RECORD" sys.path.insert(0, str(Path(__file__).parent.parent.parent)) -import requests - from drift import TuskDrift from drift.core.types import SpanKind, StatusCode -class TestFastAPIBasicSpanCapture(unittest.TestCase): - """Test basic FastAPI request/response span capture.""" +@pytest.fixture(scope="module") +def fastapi_app_and_adapter(): + """Set up the SDK, FastAPI app, and adapter once for all tests in module.""" + from drift.core.tracing.adapters import InMemorySpanAdapter, register_in_memory_adapter + + sdk = TuskDrift.get_instance() + if not TuskDrift._initialized: + sdk = TuskDrift.initialize() + + adapter = InMemorySpanAdapter() + register_in_memory_adapter(adapter) + + # FastAPI is auto-instrumented by SDK initialization + from fastapi import FastAPI + from pydantic import BaseModel + + app = FastAPI() + + class EchoRequest(BaseModel): + message: str | None = None + count: int | None = None - @classmethod - def setUpClass(cls): - """Set up the SDK and FastAPI app once for all tests.""" - from drift.core.tracing.adapters import InMemorySpanAdapter, register_in_memory_adapter + @app.get("/health") + def health(): + return {"status": "healthy", "timestamp": time.time()} - cls.sdk = TuskDrift.get_instance() - if not TuskDrift._initialized: - cls.sdk = TuskDrift.initialize() + @app.get("/greet/{name}") + def greet(name: str, greeting: str = "Hello"): + return {"message": f"{greeting}, {name}!", "name": name} - cls.adapter = InMemorySpanAdapter() - register_in_memory_adapter(cls.adapter) + @app.post("/echo") + def echo(data: EchoRequest): + return {"echoed": data.model_dump(), "received_at": time.time()} - # FastAPI is auto-instrumented by SDK initialization - # Import FastAPI after SDK is set up - from fastapi import FastAPI - from pydantic import BaseModel + @app.get("/error") + def error(): + return {"error": "Something went wrong"}, 500 - cls.app = FastAPI() + @app.get("/headers") + def headers(): + return {"info": "Headers endpoint"} - class EchoRequest(BaseModel): - message: str | None = None - count: int | None = None + sdk.mark_app_as_ready() - @cls.app.get("/health") - def health(): - return {"status": "healthy", "timestamp": time.time()} + # Start FastAPI server in background + from tests.utils.fastapi_test_server import FastAPITestServer - @cls.app.get("/greet/{name}") - def greet(name: str, greeting: str = "Hello"): - return {"message": f"{greeting}, {name}!", "name": name} + server = FastAPITestServer(app=app) + server.start() - @cls.app.post("/echo") - def echo(data: EchoRequest): - return {"echoed": data.model_dump(), "received_at": time.time()} + yield {"sdk": sdk, "adapter": adapter, "app": app, "server": server, "base_url": server.base_url} - @cls.app.get("/error") - def error(): - return {"error": "Something went wrong"}, 500 + server.stop() - @cls.app.get("/headers") - def headers(): - return {"info": "Headers endpoint"} - cls.sdk.mark_app_as_ready() +@pytest.fixture +def adapter(fastapi_app_and_adapter): + """Get adapter and clear it before each test.""" + adapter = fastapi_app_and_adapter["adapter"] + adapter.clear() + return adapter - # Start FastAPI server in background - from tests.utils.fastapi_test_server import FastAPITestServer - cls.server = FastAPITestServer(app=cls.app) - cls.server.start() - cls.base_url = cls.server.base_url +@pytest.fixture +def base_url(fastapi_app_and_adapter): + """Get base URL for requests.""" + return fastapi_app_and_adapter["base_url"] - @classmethod - def tearDownClass(cls): - """Clean up after all tests.""" - cls.server.stop() - def setUp(self): - """Clear spans before each test.""" - self.adapter.clear() +def wait_for_spans(timeout: float = 0.5): + """Wait for spans to be processed.""" + time.sleep(timeout) - def wait_for_spans(self, timeout: float = 0.5): - """Wait for spans to be processed.""" - time.sleep(timeout) - def test_captures_get_request_span(self): +class TestFastAPIBasicSpanCapture: + """Test basic FastAPI request/response span capture.""" + + def test_captures_get_request_span(self, adapter, base_url): """Test that GET requests create spans.""" - response = requests.get(f"{self.base_url}/health") + response = requests.get(f"{base_url}/health") - self.assertEqual(response.status_code, 200) - self.wait_for_spans() + assert response.status_code == 200 + wait_for_spans() - spans = self.adapter.get_all_spans() - self.assertGreaterEqual(len(spans), 1) + spans = adapter.get_all_spans() + assert len(spans) >= 1 - # Find the server span (inbound request) server_spans = [s for s in spans if s.kind == SpanKind.SERVER] - self.assertGreaterEqual(len(server_spans), 1) + assert len(server_spans) >= 1 span = server_spans[0] - self.assertIn("/health", span.name) - self.assertEqual(span.status.code, StatusCode.OK) + assert "/health" in span.name + assert span.status.code == StatusCode.OK - def test_captures_get_with_path_params(self): + def test_captures_get_with_path_params(self, adapter, base_url): """Test that path parameters are captured.""" - response = requests.get(f"{self.base_url}/greet/World") + response = requests.get(f"{base_url}/greet/World") - self.assertEqual(response.status_code, 200) - self.wait_for_spans() + assert response.status_code == 200 + wait_for_spans() - spans = self.adapter.get_all_spans() + spans = adapter.get_all_spans() server_spans = [s for s in spans if s.kind == SpanKind.SERVER] - self.assertGreaterEqual(len(server_spans), 1) + assert len(server_spans) >= 1 span = server_spans[0] input_val = span.input_value - self.assertIsInstance(input_val, dict) + assert isinstance(input_val, dict) - def test_captures_get_with_query_params(self): + def test_captures_get_with_query_params(self, adapter, base_url): """Test that query parameters are captured.""" - response = requests.get(f"{self.base_url}/greet/World?greeting=Hi") + response = requests.get(f"{base_url}/greet/World?greeting=Hi") - self.assertEqual(response.status_code, 200) - self.wait_for_spans() + assert response.status_code == 200 + wait_for_spans() - spans = self.adapter.get_all_spans() + spans = adapter.get_all_spans() server_spans = [s for s in spans if s.kind == SpanKind.SERVER] - self.assertGreaterEqual(len(server_spans), 1) + assert len(server_spans) >= 1 - def test_captures_post_request_span(self): + def test_captures_post_request_span(self, adapter, base_url): """Test that POST requests create spans.""" payload = {"message": "Hello", "count": 42} - response = requests.post(f"{self.base_url}/echo", json=payload) + response = requests.post(f"{base_url}/echo", json=payload) - self.assertEqual(response.status_code, 200) - self.wait_for_spans() + assert response.status_code == 200 + wait_for_spans() - spans = self.adapter.get_all_spans() + spans = adapter.get_all_spans() server_spans = [s for s in spans if s.kind == SpanKind.SERVER] - self.assertGreaterEqual(len(server_spans), 1) + assert len(server_spans) >= 1 span = server_spans[0] input_val = span.input_value - self.assertIsInstance(input_val, dict) - self.assertEqual(input_val.get("method"), "POST") + assert isinstance(input_val, dict) + assert input_val.get("method") == "POST" - def test_span_has_trace_id(self): + def test_span_has_trace_id(self, adapter, base_url): """Test that spans have valid trace IDs.""" - response = requests.get(f"{self.base_url}/health") + response = requests.get(f"{base_url}/health") - self.assertEqual(response.status_code, 200) - self.wait_for_spans() + assert response.status_code == 200 + wait_for_spans() - spans = self.adapter.get_all_spans() - self.assertGreaterEqual(len(spans), 1) + spans = adapter.get_all_spans() + assert len(spans) >= 1 span = spans[0] - self.assertIsNotNone(span.trace_id) - self.assertEqual(len(span.trace_id), 32) + assert span.trace_id is not None + assert len(span.trace_id) == 32 - def test_span_has_span_id(self): + def test_span_has_span_id(self, adapter, base_url): """Test that spans have valid span IDs.""" - response = requests.get(f"{self.base_url}/health") + response = requests.get(f"{base_url}/health") - self.assertEqual(response.status_code, 200) - self.wait_for_spans() + assert response.status_code == 200 + wait_for_spans() - spans = self.adapter.get_all_spans() - self.assertGreaterEqual(len(spans), 1) + spans = adapter.get_all_spans() + assert len(spans) >= 1 span = spans[0] - self.assertIsNotNone(span.span_id) - self.assertEqual(len(span.span_id), 16) + assert span.span_id is not None + assert len(span.span_id) == 16 - def test_span_has_timing_info(self): + def test_span_has_timing_info(self, adapter, base_url): """Test that spans have timing information.""" - response = requests.get(f"{self.base_url}/health") + response = requests.get(f"{base_url}/health") - self.assertEqual(response.status_code, 200) - self.wait_for_spans() + assert response.status_code == 200 + wait_for_spans() - spans = self.adapter.get_all_spans() - self.assertGreaterEqual(len(spans), 1) + spans = adapter.get_all_spans() + assert len(spans) >= 1 span = spans[0] - self.assertIsNotNone(span.timestamp) - self.assertIsNotNone(span.duration) + assert span.timestamp is not None + assert span.duration is not None total_nanos = span.duration.seconds * 1_000_000_000 + span.duration.nanos - self.assertGreater(total_nanos, 0) + assert total_nanos > 0 - def test_captures_request_body(self): + def test_captures_request_body(self, adapter, base_url): """Test that request body is captured for POST requests.""" payload = {"test_key": "test_value"} - response = requests.post(f"{self.base_url}/echo", json=payload) + response = requests.post(f"{base_url}/echo", json=payload) - self.assertEqual(response.status_code, 200) - self.wait_for_spans() + assert response.status_code == 200 + wait_for_spans() - spans = self.adapter.get_all_spans() + spans = adapter.get_all_spans() server_spans = [s for s in spans if s.kind == SpanKind.SERVER] - self.assertGreaterEqual(len(server_spans), 1) + assert len(server_spans) >= 1 span = server_spans[0] input_val = span.input_value - self.assertIsInstance(input_val, dict) - # Body should be captured (may be base64 encoded) - self.assertIn("body", input_val) + assert isinstance(input_val, dict) + assert "body" in input_val -class TestFastAPIMultipleRequests(unittest.TestCase): - """Test multiple FastAPI requests create separate spans.""" +@pytest.fixture(scope="module") +def fastapi_multi_app_and_adapter(): + """Set up FastAPI app for multiple request tests.""" + from drift.core.tracing.adapters import InMemorySpanAdapter, register_in_memory_adapter + + sdk = TuskDrift.get_instance() + adapter = InMemorySpanAdapter() + register_in_memory_adapter(adapter) + + from fastapi import FastAPI - @classmethod - def setUpClass(cls): - """Set up the SDK and FastAPI app once for all tests.""" - from drift.core.tracing.adapters import InMemorySpanAdapter, register_in_memory_adapter + app = FastAPI() - cls.sdk = TuskDrift.get_instance() - cls.adapter = InMemorySpanAdapter() - register_in_memory_adapter(cls.adapter) + @app.get("/endpoint1") + def endpoint1(): + return {"endpoint": 1} - from fastapi import FastAPI + @app.get("/endpoint2") + def endpoint2(): + return {"endpoint": 2} - cls.app = FastAPI() + from tests.utils.fastapi_test_server import FastAPITestServer - @cls.app.get("/endpoint1") - def endpoint1(): - return {"endpoint": 1} + server = FastAPITestServer(app=app) + server.start() - @cls.app.get("/endpoint2") - def endpoint2(): - return {"endpoint": 2} + yield {"sdk": sdk, "adapter": adapter, "app": app, "server": server, "base_url": server.base_url} - from tests.utils.fastapi_test_server import FastAPITestServer + server.stop() - cls.server = FastAPITestServer(app=cls.app) - cls.server.start() - cls.base_url = cls.server.base_url - @classmethod - def tearDownClass(cls): - """Clean up after all tests.""" - cls.server.stop() +@pytest.fixture +def multi_adapter(fastapi_multi_app_and_adapter): + """Get adapter and clear it before each test.""" + adapter = fastapi_multi_app_and_adapter["adapter"] + adapter.clear() + return adapter - def setUp(self): - """Clear spans before each test.""" - self.adapter.clear() - def wait_for_spans(self, timeout: float = 0.5): - """Wait for spans to be processed.""" - time.sleep(timeout) +@pytest.fixture +def multi_base_url(fastapi_multi_app_and_adapter): + """Get base URL for requests.""" + return fastapi_multi_app_and_adapter["base_url"] - def test_multiple_requests_create_separate_spans(self): + +class TestFastAPIMultipleRequests: + """Test multiple FastAPI requests create separate spans.""" + + def test_multiple_requests_create_separate_spans(self, multi_adapter, multi_base_url): """Test that multiple requests create separate spans.""" - response1 = requests.get(f"{self.base_url}/endpoint1") - response2 = requests.get(f"{self.base_url}/endpoint2") + response1 = requests.get(f"{multi_base_url}/endpoint1") + response2 = requests.get(f"{multi_base_url}/endpoint2") - self.assertEqual(response1.status_code, 200) - self.assertEqual(response2.status_code, 200) - self.wait_for_spans() + assert response1.status_code == 200 + assert response2.status_code == 200 + wait_for_spans() - spans = self.adapter.get_all_spans() + spans = multi_adapter.get_all_spans() server_spans = [s for s in spans if s.kind == SpanKind.SERVER] - self.assertGreaterEqual(len(server_spans), 2) + assert len(server_spans) >= 2 span_ids = [s.span_id for s in server_spans] - self.assertEqual(len(span_ids), len(set(span_ids))) + assert len(span_ids) == len(set(span_ids)) - def test_multiple_requests_have_different_trace_ids(self): + def test_multiple_requests_have_different_trace_ids(self, multi_adapter, multi_base_url): """Test that independent requests have different trace IDs.""" - response1 = requests.get(f"{self.base_url}/endpoint1") - response2 = requests.get(f"{self.base_url}/endpoint2") + response1 = requests.get(f"{multi_base_url}/endpoint1") + response2 = requests.get(f"{multi_base_url}/endpoint2") - self.assertEqual(response1.status_code, 200) - self.assertEqual(response2.status_code, 200) - self.wait_for_spans() + assert response1.status_code == 200 + assert response2.status_code == 200 + wait_for_spans() - spans = self.adapter.get_all_spans() + spans = multi_adapter.get_all_spans() server_spans = [s for s in spans if s.kind == SpanKind.SERVER] - self.assertGreaterEqual(len(server_spans), 2) + assert len(server_spans) >= 2 trace_ids = [s.trace_id for s in server_spans] - self.assertEqual(len(trace_ids), len(set(trace_ids))) - - -if __name__ == "__main__": - unittest.main() + assert len(trace_ids) == len(set(trace_ids)) diff --git a/tests/integration/test_fastapi_replay.py b/tests/integration/test_fastapi_replay.py index c3265e5..a8aa3ec 100644 --- a/tests/integration/test_fastapi_replay.py +++ b/tests/integration/test_fastapi_replay.py @@ -4,9 +4,11 @@ import socket import sys import time -import unittest from pathlib import Path +import pytest +import requests + # Set replay mode before importing drift os.environ["TUSK_DRIFT_MODE"] = "REPLAY" @@ -25,136 +27,136 @@ sys.path.insert(0, str(Path(__file__).parent.parent.parent)) -import requests - from drift import TuskDrift from drift.core.tracing.adapters import InMemorySpanAdapter, register_in_memory_adapter from drift.core.types import SpanKind -class TestFastAPIReplayMode(unittest.TestCase): - """Test FastAPI instrumentation in REPLAY mode.""" +@pytest.fixture(scope="module") +def fastapi_replay_app(): + """Set up SDK and FastAPI app once for all tests.""" + sdk = TuskDrift.initialize() + adapter = InMemorySpanAdapter() + register_in_memory_adapter(adapter) + + from fastapi import FastAPI + from pydantic import BaseModel + + app = FastAPI() - @classmethod - def setUpClass(cls): - """Set up SDK and FastAPI app once for all tests.""" - cls.sdk = TuskDrift.initialize() - cls.adapter = InMemorySpanAdapter() - register_in_memory_adapter(cls.adapter) + @app.get("/health") + def health(): + return {"status": "healthy"} - # Create FastAPI app - from fastapi import FastAPI - from pydantic import BaseModel + @app.get("/user/{name}") + def get_user(name: str): + return {"user": name, "id": 123} - cls.app = FastAPI() + class EchoRequest(BaseModel): + message: str - @cls.app.get("/health") - def health(): - return {"status": "healthy"} + @app.post("/echo") + def echo(data: EchoRequest): + return {"echoed": data.model_dump()} - @cls.app.get("/user/{name}") - def get_user(name: str): - return {"user": name, "id": 123} + sdk.mark_app_as_ready() - class EchoRequest(BaseModel): - message: str + # Start FastAPI server in background + from tests.utils.fastapi_test_server import FastAPITestServer - @cls.app.post("/echo") - def echo(data: EchoRequest): - return {"echoed": data.model_dump()} + server = FastAPITestServer(app=app) + server.start() - cls.sdk.mark_app_as_ready() + yield {"sdk": sdk, "adapter": adapter, "app": app, "server": server, "base_url": server.base_url} - # Start FastAPI server in background - from tests.utils.fastapi_test_server import FastAPITestServer + server.stop() + try: + test_socket.close() + if Path(socket_path).exists(): + os.unlink(socket_path) + except Exception: + pass - cls.server = FastAPITestServer(app=cls.app) - cls.server.start() - cls.base_url = cls.server.base_url - @classmethod - def tearDownClass(cls): - """Clean up after all tests.""" - cls.server.stop() - try: - test_socket.close() - if Path(socket_path).exists(): - os.unlink(socket_path) - except Exception: - pass +@pytest.fixture +def adapter(fastapi_replay_app): + """Get adapter and clear it before each test.""" + adapter = fastapi_replay_app["adapter"] + adapter.clear() + return adapter - def setUp(self): - """Clear spans before each test.""" - self.adapter.clear() - def wait_for_spans(self, timeout: float = 0.5): - """Wait for spans to be processed.""" - time.sleep(timeout) +@pytest.fixture +def base_url(fastapi_replay_app): + """Get base URL for requests.""" + return fastapi_replay_app["base_url"] - def test_request_without_trace_id_header(self): + +def wait_for_spans(timeout: float = 0.5): + """Wait for spans to be processed.""" + time.sleep(timeout) + + +class TestFastAPIReplayMode: + """Test FastAPI instrumentation in REPLAY mode.""" + + def test_request_without_trace_id_header(self, adapter, base_url): """Test that requests without trace ID don't create SERVER spans.""" - response = requests.get(f"{self.base_url}/health") + response = requests.get(f"{base_url}/health") - self.assertEqual(response.status_code, 200) - self.wait_for_spans() + assert response.status_code == 200 + wait_for_spans() - spans = self.adapter.get_all_spans() - # No SERVER span should be created in replay mode + spans = adapter.get_all_spans() server_spans = [s for s in spans if s.kind == SpanKind.SERVER] - self.assertEqual(len(server_spans), 0) + assert len(server_spans) == 0 - def test_request_with_trace_id_header(self): + def test_request_with_trace_id_header(self, adapter, base_url): """Test that requests with trace ID create SERVER spans.""" - response = requests.get(f"{self.base_url}/user/alice", headers={"x-td-trace-id": "test-trace-123"}) + response = requests.get(f"{base_url}/user/alice", headers={"x-td-trace-id": "test-trace-123"}) - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 data = response.json() - self.assertEqual(data["user"], "alice") - self.assertEqual(data["id"], 123) + assert data["user"] == "alice" + assert data["id"] == 123 - self.wait_for_spans() + wait_for_spans() - spans = self.adapter.get_all_spans() - self.assertGreaterEqual(len(spans), 1) + spans = adapter.get_all_spans() + assert len(spans) >= 1 - # Find SERVER span server_spans = [s for s in spans if s.kind == SpanKind.SERVER] - self.assertGreaterEqual(len(server_spans), 1) + assert len(server_spans) >= 1 span = server_spans[0] - # Verify trace ID matches the one from header - self.assertEqual(span.trace_id, "test-trace-123") - self.assertIn("/user", span.name) + assert span.trace_id == "test-trace-123" + assert "/user" in span.name - def test_post_request_with_trace_id(self): + def test_post_request_with_trace_id(self, adapter, base_url): """Test that POST requests work in replay mode.""" response = requests.post( - f"{self.base_url}/echo", + f"{base_url}/echo", json={"message": "test"}, headers={"x-td-trace-id": "post-trace-456"}, ) - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 data = response.json() - self.assertEqual(data["echoed"]["message"], "test") + assert data["echoed"]["message"] == "test" - self.wait_for_spans() + wait_for_spans() - spans = self.adapter.get_all_spans() + spans = adapter.get_all_spans() server_spans = [s for s in spans if s.kind == SpanKind.SERVER] if len(server_spans) > 0: span = server_spans[0] - self.assertEqual(span.trace_id, "post-trace-456") + assert span.trace_id == "post-trace-456" - def test_case_insensitive_headers(self): + def test_case_insensitive_headers(self, adapter, base_url): """Test that trace ID header is case-insensitive.""" - response = requests.get(f"{self.base_url}/health", headers={"x-td-trace-id": "lowercase-trace"}) - self.assertEqual(response.status_code, 200) - - response = requests.get(f"{self.base_url}/health", headers={"X-TD-TRACE-ID": "uppercase-trace"}) - self.assertEqual(response.status_code, 200) - + response = requests.get(f"{base_url}/health", headers={"x-td-trace-id": "lowercase-trace"}) + assert response.status_code == 200 -if __name__ == "__main__": - unittest.main() + response = requests.get(f"{base_url}/health", headers={"X-TD-TRACE-ID": "uppercase-trace"}) + assert response.status_code == 200 diff --git a/tests/integration/test_flask_basic.py b/tests/integration/test_flask_basic.py index 548a46a..bf9569a 100644 --- a/tests/integration/test_flask_basic.py +++ b/tests/integration/test_flask_basic.py @@ -7,317 +7,315 @@ import os import sys import time -import unittest from pathlib import Path +import pytest +import requests + # Set up environment before importing drift os.environ["TUSK_DRIFT_MODE"] = "RECORD" sys.path.insert(0, str(Path(__file__).parent.parent.parent)) -import requests - from drift import TuskDrift from drift.core.types import SpanKind, StatusCode -class TestFlaskBasicSpanCapture(unittest.TestCase): - """Test basic Flask request/response span capture.""" +@pytest.fixture(scope="module") +def flask_app_and_adapter(): + """Set up the SDK, Flask app, and adapter once for all tests in module.""" + from drift.core.tracing.adapters import InMemorySpanAdapter, register_in_memory_adapter + + sdk = TuskDrift.initialize() + adapter = InMemorySpanAdapter() + register_in_memory_adapter(adapter) + + # Flask is auto-instrumented by SDK initialization + from flask import Flask, jsonify, request + + app = Flask(__name__) + + @app.route("/health") + def health(): + return jsonify({"status": "healthy", "timestamp": time.time()}) + + @app.route("/greet/") + def greet(name: str): + greeting = request.args.get("greeting", "Hello") + return jsonify({"message": f"{greeting}, {name}!", "name": name}) - @classmethod - def setUpClass(cls): - """Set up the SDK and Flask app once for all tests.""" - from drift.core.tracing.adapters import InMemorySpanAdapter, register_in_memory_adapter + @app.route("/echo", methods=["POST"]) + def echo(): + data = request.get_json() + return jsonify({"echoed": data, "received_at": time.time()}) - cls.sdk = TuskDrift.initialize() - cls.adapter = InMemorySpanAdapter() - register_in_memory_adapter(cls.adapter) + @app.route("/error") + def error(): + return jsonify({"error": "Something went wrong"}), 500 - # Flask is auto-instrumented by SDK initialization - # Import Flask after SDK is set up - from flask import Flask, jsonify, request + @app.route("/headers") + def headers(): + return jsonify( + { + "user_agent": request.headers.get("User-Agent"), + "custom_header": request.headers.get("X-Custom-Header"), + } + ) - cls.app = Flask(__name__) + sdk.mark_app_as_ready() - @cls.app.route("/health") - def health(): - return jsonify({"status": "healthy", "timestamp": time.time()}) + # Start Flask server in background + from tests.utils.flask_test_server import FlaskTestServer - @cls.app.route("/greet/") - def greet(name: str): - greeting = request.args.get("greeting", "Hello") - return jsonify({"message": f"{greeting}, {name}!", "name": name}) + server = FlaskTestServer(app=app) + server.start() - @cls.app.route("/echo", methods=["POST"]) - def echo(): - data = request.get_json() - return jsonify({"echoed": data, "received_at": time.time()}) + yield {"sdk": sdk, "adapter": adapter, "app": app, "server": server, "base_url": server.base_url} - @cls.app.route("/error") - def error(): - return jsonify({"error": "Something went wrong"}), 500 + server.stop() - @cls.app.route("/headers") - def headers(): - # Echo back some headers for testing - return jsonify( - { - "user_agent": request.headers.get("User-Agent"), - "custom_header": request.headers.get("X-Custom-Header"), - } - ) - cls.sdk.mark_app_as_ready() +@pytest.fixture +def adapter(flask_app_and_adapter): + """Get adapter and clear it before each test.""" + adapter = flask_app_and_adapter["adapter"] + adapter.clear() + return adapter - # Start Flask server in background - from tests.utils.flask_test_server import FlaskTestServer - cls.server = FlaskTestServer(app=cls.app) - cls.server.start() - cls.base_url = cls.server.base_url +@pytest.fixture +def base_url(flask_app_and_adapter): + """Get base URL for requests.""" + return flask_app_and_adapter["base_url"] - @classmethod - def tearDownClass(cls): - """Clean up after all tests.""" - cls.server.stop() - def setUp(self): - """Clear spans before each test.""" - self.adapter.clear() +def wait_for_spans(timeout: float = 0.5): + """Wait for spans to be processed.""" + time.sleep(timeout) - def wait_for_spans(self, timeout: float = 0.5): - """Wait for spans to be processed.""" - time.sleep(timeout) - def test_captures_get_request_span(self): +class TestFlaskBasicSpanCapture: + """Test basic Flask request/response span capture.""" + + def test_captures_get_request_span(self, adapter, base_url): """Test that GET requests create spans.""" - response = requests.get(f"{self.base_url}/health") + response = requests.get(f"{base_url}/health") - self.assertEqual(response.status_code, 200) - self.wait_for_spans() + assert response.status_code == 200 + wait_for_spans() - spans = self.adapter.get_all_spans() - self.assertGreaterEqual(len(spans), 1) + spans = adapter.get_all_spans() + assert len(spans) >= 1 - # Find the server span (inbound request) server_spans = [s for s in spans if s.kind == SpanKind.SERVER] - self.assertGreaterEqual(len(server_spans), 1) + assert len(server_spans) >= 1 span = server_spans[0] - self.assertIn("/health", span.name) - self.assertEqual(span.status.code, StatusCode.OK) + assert "/health" in span.name + assert span.status.code == StatusCode.OK - def test_captures_get_with_path_params(self): + def test_captures_get_with_path_params(self, adapter, base_url): """Test that path parameters are captured.""" - response = requests.get(f"{self.base_url}/greet/World") + response = requests.get(f"{base_url}/greet/World") - self.assertEqual(response.status_code, 200) - self.wait_for_spans() + assert response.status_code == 200 + wait_for_spans() - spans = self.adapter.get_all_spans() + spans = adapter.get_all_spans() server_spans = [s for s in spans if s.kind == SpanKind.SERVER] - self.assertGreaterEqual(len(server_spans), 1) + assert len(server_spans) >= 1 span = server_spans[0] - # Check input value contains the path input_val = span.input_value - self.assertIsInstance(input_val, dict) - # Path should contain the actual value + assert isinstance(input_val, dict) target = input_val.get("target") or input_val.get("path") or input_val.get("url") - self.assertIn("World", str(target)) + assert "World" in str(target) - def test_captures_get_with_query_params(self): + def test_captures_get_with_query_params(self, adapter, base_url): """Test that query parameters are captured.""" - response = requests.get(f"{self.base_url}/greet/World?greeting=Hi") + response = requests.get(f"{base_url}/greet/World?greeting=Hi") - self.assertEqual(response.status_code, 200) - self.wait_for_spans() + assert response.status_code == 200 + wait_for_spans() - spans = self.adapter.get_all_spans() + spans = adapter.get_all_spans() server_spans = [s for s in spans if s.kind == SpanKind.SERVER] - self.assertGreaterEqual(len(server_spans), 1) + assert len(server_spans) >= 1 - def test_captures_post_request_span(self): + def test_captures_post_request_span(self, adapter, base_url): """Test that POST requests create spans.""" payload = {"message": "Hello", "count": 42} - response = requests.post(f"{self.base_url}/echo", json=payload) + response = requests.post(f"{base_url}/echo", json=payload) - self.assertEqual(response.status_code, 200) - self.wait_for_spans() + assert response.status_code == 200 + wait_for_spans() - spans = self.adapter.get_all_spans() + spans = adapter.get_all_spans() server_spans = [s for s in spans if s.kind == SpanKind.SERVER] - self.assertGreaterEqual(len(server_spans), 1) + assert len(server_spans) >= 1 span = server_spans[0] input_val = span.input_value - self.assertIsInstance(input_val, dict) - self.assertEqual(input_val.get("method"), "POST") + assert isinstance(input_val, dict) + assert input_val.get("method") == "POST" - def test_captures_error_response(self): + def test_captures_error_response(self, adapter, base_url): """Test that error responses are captured correctly.""" - response = requests.get(f"{self.base_url}/error") + response = requests.get(f"{base_url}/error") - self.assertEqual(response.status_code, 500) - self.wait_for_spans() + assert response.status_code == 500 + wait_for_spans() - spans = self.adapter.get_all_spans() + spans = adapter.get_all_spans() server_spans = [s for s in spans if s.kind == SpanKind.SERVER] - self.assertGreaterEqual(len(server_spans), 1) + assert len(server_spans) >= 1 span = server_spans[0] output_val = span.output_value - self.assertIsInstance(output_val, dict) - # Status code should be 500 + assert isinstance(output_val, dict) status_code = output_val.get("statusCode") or output_val.get("status_code") - self.assertEqual(status_code, 500) + assert status_code == 500 - def test_captures_request_headers(self): + def test_captures_request_headers(self, adapter, base_url): """Test that request headers are captured.""" headers = {"X-Custom-Header": "custom-value"} - response = requests.get(f"{self.base_url}/headers", headers=headers) + response = requests.get(f"{base_url}/headers", headers=headers) - self.assertEqual(response.status_code, 200) - self.wait_for_spans() + assert response.status_code == 200 + wait_for_spans() - spans = self.adapter.get_all_spans() + spans = adapter.get_all_spans() server_spans = [s for s in spans if s.kind == SpanKind.SERVER] - self.assertGreaterEqual(len(server_spans), 1) + assert len(server_spans) >= 1 span = server_spans[0] input_val = span.input_value - self.assertIsInstance(input_val, dict) - # Headers should be captured - self.assertIn("headers", input_val) + assert isinstance(input_val, dict) + assert "headers" in input_val - def test_span_has_trace_id(self): + def test_span_has_trace_id(self, adapter, base_url): """Test that spans have valid trace IDs.""" - response = requests.get(f"{self.base_url}/health") + response = requests.get(f"{base_url}/health") - self.assertEqual(response.status_code, 200) - self.wait_for_spans() + assert response.status_code == 200 + wait_for_spans() - spans = self.adapter.get_all_spans() - self.assertGreaterEqual(len(spans), 1) + spans = adapter.get_all_spans() + assert len(spans) >= 1 span = spans[0] - self.assertIsNotNone(span.trace_id) - self.assertEqual(len(span.trace_id), 32) # 32 hex chars + assert span.trace_id is not None + assert len(span.trace_id) == 32 - def test_span_has_span_id(self): + def test_span_has_span_id(self, adapter, base_url): """Test that spans have valid span IDs.""" - response = requests.get(f"{self.base_url}/health") + response = requests.get(f"{base_url}/health") - self.assertEqual(response.status_code, 200) - self.wait_for_spans() + assert response.status_code == 200 + wait_for_spans() - spans = self.adapter.get_all_spans() - self.assertGreaterEqual(len(spans), 1) + spans = adapter.get_all_spans() + assert len(spans) >= 1 span = spans[0] - self.assertIsNotNone(span.span_id) - self.assertEqual(len(span.span_id), 16) # 16 hex chars + assert span.span_id is not None + assert len(span.span_id) == 16 - def test_span_has_timing_info(self): + def test_span_has_timing_info(self, adapter, base_url): """Test that spans have timing information.""" - response = requests.get(f"{self.base_url}/health") + response = requests.get(f"{base_url}/health") - self.assertEqual(response.status_code, 200) - self.wait_for_spans() + assert response.status_code == 200 + wait_for_spans() - spans = self.adapter.get_all_spans() - self.assertGreaterEqual(len(spans), 1) + spans = adapter.get_all_spans() + assert len(spans) >= 1 span = spans[0] - self.assertIsNotNone(span.timestamp) - self.assertIsNotNone(span.duration) - # Duration should be positive + assert span.timestamp is not None + assert span.duration is not None total_nanos = span.duration.seconds * 1_000_000_000 + span.duration.nanos - self.assertGreater(total_nanos, 0) + assert total_nanos > 0 -class TestFlaskMultipleRequests(unittest.TestCase): - """Test multiple Flask requests create separate spans.""" +@pytest.fixture(scope="module") +def flask_multi_app_and_adapter(): + """Set up Flask app for multiple request tests.""" + from drift.core.tracing.adapters import InMemorySpanAdapter, register_in_memory_adapter + + sdk = TuskDrift.get_instance() + adapter = InMemorySpanAdapter() + register_in_memory_adapter(adapter) + + from flask import Flask, jsonify - @classmethod - def setUpClass(cls): - """Set up the SDK and Flask app once for all tests.""" - from drift.core.tracing.adapters import InMemorySpanAdapter, register_in_memory_adapter + app = Flask(__name__) - cls.sdk = TuskDrift.get_instance() - cls.adapter = InMemorySpanAdapter() - register_in_memory_adapter(cls.adapter) + @app.route("/endpoint1") + def endpoint1(): + return jsonify({"endpoint": 1}) - # Import Flask after instrumentation is set up - from flask import Flask, jsonify + @app.route("/endpoint2") + def endpoint2(): + return jsonify({"endpoint": 2}) - cls.app = Flask(__name__) + from tests.utils.flask_test_server import FlaskTestServer - @cls.app.route("/endpoint1") - def endpoint1(): - return jsonify({"endpoint": 1}) + server = FlaskTestServer(app=app) + server.start() - @cls.app.route("/endpoint2") - def endpoint2(): - return jsonify({"endpoint": 2}) + yield {"sdk": sdk, "adapter": adapter, "app": app, "server": server, "base_url": server.base_url} - # Start Flask server in background - from tests.utils.flask_test_server import FlaskTestServer + server.stop() - cls.server = FlaskTestServer(app=cls.app) - cls.server.start() - cls.base_url = cls.server.base_url - @classmethod - def tearDownClass(cls): - """Clean up after all tests.""" - cls.server.stop() +@pytest.fixture +def multi_adapter(flask_multi_app_and_adapter): + """Get adapter and clear it before each test.""" + adapter = flask_multi_app_and_adapter["adapter"] + adapter.clear() + return adapter - def setUp(self): - """Clear spans before each test.""" - self.adapter.clear() - def wait_for_spans(self, timeout: float = 0.5): - """Wait for spans to be processed.""" - time.sleep(timeout) +@pytest.fixture +def multi_base_url(flask_multi_app_and_adapter): + """Get base URL for requests.""" + return flask_multi_app_and_adapter["base_url"] - def test_multiple_requests_create_separate_spans(self): + +class TestFlaskMultipleRequests: + """Test multiple Flask requests create separate spans.""" + + def test_multiple_requests_create_separate_spans(self, multi_adapter, multi_base_url): """Test that multiple requests create separate spans.""" - response1 = requests.get(f"{self.base_url}/endpoint1") - response2 = requests.get(f"{self.base_url}/endpoint2") + response1 = requests.get(f"{multi_base_url}/endpoint1") + response2 = requests.get(f"{multi_base_url}/endpoint2") - self.assertEqual(response1.status_code, 200) - self.assertEqual(response2.status_code, 200) - self.wait_for_spans() + assert response1.status_code == 200 + assert response2.status_code == 200 + wait_for_spans() - spans = self.adapter.get_all_spans() + spans = multi_adapter.get_all_spans() server_spans = [s for s in spans if s.kind == SpanKind.SERVER] - # Should have at least 2 server spans - self.assertGreaterEqual(len(server_spans), 2) + assert len(server_spans) >= 2 - # Spans should have different span IDs span_ids = [s.span_id for s in server_spans] - self.assertEqual(len(span_ids), len(set(span_ids))) + assert len(span_ids) == len(set(span_ids)) - def test_multiple_requests_have_different_trace_ids(self): + def test_multiple_requests_have_different_trace_ids(self, multi_adapter, multi_base_url): """Test that independent requests have different trace IDs.""" - response1 = requests.get(f"{self.base_url}/endpoint1") - response2 = requests.get(f"{self.base_url}/endpoint2") + response1 = requests.get(f"{multi_base_url}/endpoint1") + response2 = requests.get(f"{multi_base_url}/endpoint2") - self.assertEqual(response1.status_code, 200) - self.assertEqual(response2.status_code, 200) - self.wait_for_spans() + assert response1.status_code == 200 + assert response2.status_code == 200 + wait_for_spans() - spans = self.adapter.get_all_spans() + spans = multi_adapter.get_all_spans() server_spans = [s for s in spans if s.kind == SpanKind.SERVER] - self.assertGreaterEqual(len(server_spans), 2) + assert len(server_spans) >= 2 - # Separate requests should have different trace IDs trace_ids = [s.trace_id for s in server_spans] - self.assertEqual(len(trace_ids), len(set(trace_ids))) - - -if __name__ == "__main__": - unittest.main() + assert len(trace_ids) == len(set(trace_ids)) diff --git a/tests/integration/test_flask_replay.py b/tests/integration/test_flask_replay.py index 8de3503..5374e72 100644 --- a/tests/integration/test_flask_replay.py +++ b/tests/integration/test_flask_replay.py @@ -3,9 +3,10 @@ import os import socket import time -import unittest from pathlib import Path +import pytest + # Set replay mode before importing drift os.environ["TUSK_DRIFT_MODE"] = "REPLAY" @@ -29,116 +30,118 @@ from drift.core.types import SpanKind -class TestFlaskReplayMode(unittest.TestCase): +@pytest.fixture(scope="module") +def flask_replay_app(): + """Set up SDK and Flask app once for all tests.""" + sdk = TuskDrift.initialize() + adapter = InMemorySpanAdapter() + register_in_memory_adapter(adapter) + + app = Flask(__name__) + + @app.route("/health") + def health(): + return jsonify({"status": "healthy"}) + + @app.route("/user/") + def get_user(name: str): + return jsonify({"user": name, "id": 123}) + + @app.route("/echo", methods=["POST"]) + def echo(): + from flask import request + + data = request.get_json() + return jsonify({"echoed": data}) + + sdk.mark_app_as_ready() + client = app.test_client() + + yield {"sdk": sdk, "adapter": adapter, "app": app, "client": client} + + # Cleanup socket + try: + test_socket.close() + if Path(socket_path).exists(): + os.unlink(socket_path) + except Exception: + pass + + +@pytest.fixture +def adapter(flask_replay_app): + """Get adapter and clear it before each test.""" + adapter = flask_replay_app["adapter"] + adapter.clear() + return adapter + + +@pytest.fixture +def client(flask_replay_app): + """Get test client.""" + return flask_replay_app["client"] + + +def wait_for_spans(timeout: float = 0.5): + """Wait for spans to be processed.""" + time.sleep(timeout) + + +class TestFlaskReplayMode: """Test Flask instrumentation in REPLAY mode.""" - @classmethod - def setUpClass(cls): - """Set up SDK and Flask app once for all tests.""" - cls.sdk = TuskDrift.initialize() - cls.adapter = InMemorySpanAdapter() - register_in_memory_adapter(cls.adapter) - - # Create Flask app - cls.app = Flask(__name__) - - @cls.app.route("/health") - def health(): - return jsonify({"status": "healthy"}) - - @cls.app.route("/user/") - def get_user(name: str): - return jsonify({"user": name, "id": 123}) - - @cls.app.route("/echo", methods=["POST"]) - def echo(): - from flask import request - - data = request.get_json() - return jsonify({"echoed": data}) - - cls.sdk.mark_app_as_ready() - cls.client = cls.app.test_client() - - @classmethod - def tearDownClass(cls): - """Clean up socket.""" - try: - test_socket.close() - if Path(socket_path).exists(): - os.unlink(socket_path) - except Exception: - pass - - def setUp(self): - """Clear spans before each test.""" - self.adapter.clear() - - def wait_for_spans(self, timeout: float = 0.5): - """Wait for spans to be processed.""" - time.sleep(timeout) - - def test_request_without_trace_id_header(self): + def test_request_without_trace_id_header(self, adapter, client): """Test that requests without trace ID don't create spans.""" - response = self.client.get("/health") + response = client.get("/health") - self.assertEqual(response.status_code, 200) - self.wait_for_spans() + assert response.status_code == 200 + wait_for_spans() - spans = self.adapter.get_all_spans() - # No span should be created without trace ID in replay mode - self.assertEqual(len(spans), 0) + spans = adapter.get_all_spans() + assert len(spans) == 0 - def test_request_with_trace_id_header(self): + def test_request_with_trace_id_header(self, adapter, client): """Test that requests with trace ID create SERVER spans.""" - response = self.client.get("/user/alice", headers={"x-td-trace-id": "test-trace-123"}) + response = client.get("/user/alice", headers={"x-td-trace-id": "test-trace-123"}) - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 data = response.get_json() - self.assertEqual(data["user"], "alice") - self.assertEqual(data["id"], 123) + assert data["user"] == "alice" + assert data["id"] == 123 - self.wait_for_spans() + wait_for_spans() - spans = self.adapter.get_all_spans() - self.assertGreaterEqual(len(spans), 1) + spans = adapter.get_all_spans() + assert len(spans) >= 1 - # Find SERVER span server_spans = [s for s in spans if s.kind == SpanKind.SERVER] - self.assertGreaterEqual(len(server_spans), 1) + assert len(server_spans) >= 1 span = server_spans[0] - # Verify trace ID matches the one from header - self.assertEqual(span.trace_id, "test-trace-123") - self.assertIn("/user", span.name) + assert span.trace_id == "test-trace-123" + assert "/user" in span.name - def test_post_request_with_trace_id(self): + def test_post_request_with_trace_id(self, adapter, client): """Test that POST requests work in replay mode.""" - response = self.client.post("/echo", json={"message": "test"}, headers={"x-td-trace-id": "post-trace-456"}) + response = client.post("/echo", json={"message": "test"}, headers={"x-td-trace-id": "post-trace-456"}) - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 data = response.get_json() - self.assertEqual(data["echoed"]["message"], "test") + assert data["echoed"]["message"] == "test" - self.wait_for_spans() + wait_for_spans() - spans = self.adapter.get_all_spans() + spans = adapter.get_all_spans() server_spans = [s for s in spans if s.kind == SpanKind.SERVER] if len(server_spans) > 0: span = server_spans[0] - self.assertEqual(span.trace_id, "post-trace-456") + assert span.trace_id == "post-trace-456" - def test_case_insensitive_headers(self): + def test_case_insensitive_headers(self, adapter, client): """Test that trace ID header is case-insensitive.""" - # Try lowercase - response = self.client.get("/health", headers={"x-td-trace-id": "lowercase-trace"}) - self.assertEqual(response.status_code, 200) - - # Try uppercase - response = self.client.get("/health", headers={"X-TD-TRACE-ID": "uppercase-trace"}) - self.assertEqual(response.status_code, 200) - + response = client.get("/health", headers={"x-td-trace-id": "lowercase-trace"}) + assert response.status_code == 200 -if __name__ == "__main__": - unittest.main() + response = client.get("/health", headers={"X-TD-TRACE-ID": "uppercase-trace"}) + assert response.status_code == 200 diff --git a/tests/integration/test_flask_transforms.py b/tests/integration/test_flask_transforms.py index 8cfd241..7e9e523 100644 --- a/tests/integration/test_flask_transforms.py +++ b/tests/integration/test_flask_transforms.py @@ -10,9 +10,10 @@ import os import sys -import unittest from pathlib import Path +import pytest + # Set up environment before importing drift os.environ["TUSK_DRIFT_MODE"] = "RECORD" @@ -22,52 +23,35 @@ TRANSFORMS_IMPLEMENTED = False # Set to True once transforms are implemented -@unittest.skipUnless(TRANSFORMS_IMPLEMENTED, "Transforms not yet implemented in Python SDK") -class TestFlaskMaskTransform(unittest.TestCase): +@pytest.mark.skipif(not TRANSFORMS_IMPLEMENTED, reason="Transforms not yet implemented in Python SDK") +class TestFlaskMaskTransform: """Test mask transform for Flask requests. Mask transform replaces sensitive values with mask characters (e.g., '*') while preserving the value length. """ - @classmethod - def setUpClass(cls): + @pytest.fixture(scope="class", autouse=True) + def setup_transforms(self): """Set up the SDK with mask transforms configured.""" # TODO: Configure transforms when implemented - # transforms = { - # "http": [ - # { - # "matcher": { - # "direction": "outbound", - # "headerName": "X-API-Key", - # }, - # "action": {"type": "mask", "maskChar": "*"}, - # }, - # ], - # } pass def test_masks_api_key_in_outbound_request_headers(self): """API key header should be masked with asterisks.""" - # TODO: Implement when transforms are available - # 1. Start Flask server with endpoint that makes outbound request with X-API-Key header - # 2. Call the endpoint - # 3. Wait for spans - # 4. Find outbound span - # 5. Assert X-API-Key header value is all asterisks - self.fail("Not implemented - waiting for transform engine") + pytest.fail("Not implemented - waiting for transform engine") def test_masks_authorization_header(self): """Authorization header should be masked.""" - self.fail("Not implemented - waiting for transform engine") + pytest.fail("Not implemented - waiting for transform engine") def test_mask_preserves_value_length(self): """Masked value should have the same length as original.""" - self.fail("Not implemented - waiting for transform engine") + pytest.fail("Not implemented - waiting for transform engine") -@unittest.skipUnless(TRANSFORMS_IMPLEMENTED, "Transforms not yet implemented in Python SDK") -class TestFlaskRedactTransform(unittest.TestCase): +@pytest.mark.skipif(not TRANSFORMS_IMPLEMENTED, reason="Transforms not yet implemented in Python SDK") +class TestFlaskRedactTransform: """Test redact transform for Flask requests. Redact transform removes sensitive fields entirely from the span data. @@ -75,23 +59,19 @@ class TestFlaskRedactTransform(unittest.TestCase): def test_redacts_password_field_from_request_body(self): """Password field should be removed from request body.""" - # TODO: Implement when transforms are available - # 1. Configure redact transform for "password" field - # 2. Make POST request with password in body - # 3. Assert password field is not present in span input_value - self.fail("Not implemented - waiting for transform engine") + pytest.fail("Not implemented - waiting for transform engine") def test_redacts_nested_field(self): """Nested fields should be redactable.""" - self.fail("Not implemented - waiting for transform engine") + pytest.fail("Not implemented - waiting for transform engine") def test_redacts_multiple_fields(self): """Multiple fields should be redactable in same request.""" - self.fail("Not implemented - waiting for transform engine") + pytest.fail("Not implemented - waiting for transform engine") -@unittest.skipUnless(TRANSFORMS_IMPLEMENTED, "Transforms not yet implemented in Python SDK") -class TestFlaskReplaceTransform(unittest.TestCase): +@pytest.mark.skipif(not TRANSFORMS_IMPLEMENTED, reason="Transforms not yet implemented in Python SDK") +class TestFlaskReplaceTransform: """Test replace transform for Flask requests. Replace transform substitutes sensitive values with placeholder values. @@ -99,15 +79,15 @@ class TestFlaskReplaceTransform(unittest.TestCase): def test_replaces_email_with_placeholder(self): """Email field should be replaced with placeholder.""" - self.fail("Not implemented - waiting for transform engine") + pytest.fail("Not implemented - waiting for transform engine") def test_replaces_ssn_with_placeholder(self): """SSN field should be replaced with placeholder.""" - self.fail("Not implemented - waiting for transform engine") + pytest.fail("Not implemented - waiting for transform engine") -@unittest.skipUnless(TRANSFORMS_IMPLEMENTED, "Transforms not yet implemented in Python SDK") -class TestFlaskDropTransform(unittest.TestCase): +@pytest.mark.skipif(not TRANSFORMS_IMPLEMENTED, reason="Transforms not yet implemented in Python SDK") +class TestFlaskDropTransform: """Test drop transform for Flask requests. Drop transform removes entire spans matching certain criteria. @@ -115,50 +95,41 @@ class TestFlaskDropTransform(unittest.TestCase): def test_drops_admin_endpoint_spans(self): """Spans for admin endpoints should be dropped entirely.""" - # TODO: Implement when transforms are available - # 1. Configure drop transform for "/admin/*" paths - # 2. Make request to /admin/users - # 3. Assert no spans exist for that request - self.fail("Not implemented - waiting for transform engine") + pytest.fail("Not implemented - waiting for transform engine") def test_drops_health_check_spans(self): """Health check endpoint spans should be droppable.""" - self.fail("Not implemented - waiting for transform engine") + pytest.fail("Not implemented - waiting for transform engine") def test_drop_by_header_value(self): """Spans should be droppable by header value match.""" - self.fail("Not implemented - waiting for transform engine") + pytest.fail("Not implemented - waiting for transform engine") def test_dropped_spans_dont_affect_other_spans(self): """Dropping one span should not affect other spans in same trace.""" - self.fail("Not implemented - waiting for transform engine") + pytest.fail("Not implemented - waiting for transform engine") -@unittest.skipUnless(TRANSFORMS_IMPLEMENTED, "Transforms not yet implemented in Python SDK") -class TestFlaskMultipleTransforms(unittest.TestCase): +@pytest.mark.skipif(not TRANSFORMS_IMPLEMENTED, reason="Transforms not yet implemented in Python SDK") +class TestFlaskMultipleTransforms: """Test combining multiple transforms.""" def test_combines_mask_and_redact(self): """Should be able to mask some fields and redact others.""" - self.fail("Not implemented - waiting for transform engine") + pytest.fail("Not implemented - waiting for transform engine") def test_transform_order_matters(self): """Transforms should be applied in configured order.""" - self.fail("Not implemented - waiting for transform engine") + pytest.fail("Not implemented - waiting for transform engine") -# Minimal test to verify the test file loads correctly -class TestTransformTestsLoad(unittest.TestCase): +class TestTransformTestsLoad: """Verify the transform test module loads correctly.""" def test_module_loads(self): """Test module should load without errors.""" - self.assertTrue(True) + assert True def test_transforms_flag_is_false(self): """Transform implementation flag should be False until implemented.""" - self.assertFalse(TRANSFORMS_IMPLEMENTED) - - -if __name__ == "__main__": - unittest.main() + assert not TRANSFORMS_IMPLEMENTED diff --git a/tests/unit/test_adapters.py b/tests/unit/test_adapters.py index eaab92a..30b6951 100644 --- a/tests/unit/test_adapters.py +++ b/tests/unit/test_adapters.py @@ -2,10 +2,13 @@ import asyncio import json +import shutil import tempfile -import unittest +from datetime import datetime, timedelta from pathlib import Path +import pytest + from drift.core.tracing.adapters import ( ApiSpanAdapter, ApiSpanAdapterConfig, @@ -18,114 +21,126 @@ from tests.utils import create_test_span -class TestExportResult(unittest.TestCase): +class TestExportResult: """Tests for ExportResult dataclass.""" def test_success_result(self): + """Test success result creation.""" result = ExportResult.success() - self.assertEqual(result.code, ExportResultCode.SUCCESS) - self.assertIsNone(result.error) + assert result.code == ExportResultCode.SUCCESS + assert result.error is None def test_failed_result_with_exception(self): + """Test failed result with exception.""" error = ValueError("test error") result = ExportResult.failed(error) - self.assertEqual(result.code, ExportResultCode.FAILED) - self.assertEqual(result.error, error) + assert result.code == ExportResultCode.FAILED + assert result.error == error def test_failed_result_with_string(self): + """Test failed result with string error message.""" result = ExportResult.failed("test error message") - self.assertEqual(result.code, ExportResultCode.FAILED) - self.assertIsInstance(result.error, Exception) - self.assertEqual(str(result.error), "test error message") + assert result.code == ExportResultCode.FAILED + assert isinstance(result.error, Exception) + assert str(result.error) == "test error message" -class TestInMemorySpanAdapter(unittest.TestCase): +class TestInMemorySpanAdapter: """Tests for InMemorySpanAdapter.""" - def setUp(self): - self.adapter = InMemorySpanAdapter() + @pytest.fixture + def adapter(self): + """Create adapter for testing.""" + return InMemorySpanAdapter() - def test_name(self): - self.assertEqual(self.adapter.name, "in-memory") + def test_name(self, adapter): + """Test adapter name.""" + assert adapter.name == "in-memory" - def test_repr(self): - self.assertEqual(repr(self.adapter), "InMemorySpanAdapter(spans=0)") - self.adapter.collect_span(create_test_span()) - self.assertEqual(repr(self.adapter), "InMemorySpanAdapter(spans=1)") + def test_repr(self, adapter): + """Test adapter repr.""" + assert repr(adapter) == "InMemorySpanAdapter(spans=0)" + adapter.collect_span(create_test_span()) + assert repr(adapter) == "InMemorySpanAdapter(spans=1)" - def test_collect_and_get_spans(self): + def test_collect_and_get_spans(self, adapter): + """Test collecting and retrieving spans.""" span1 = create_test_span(span_id="1" * 16) span2 = create_test_span(span_id="2" * 16) - self.adapter.collect_span(span1) - self.adapter.collect_span(span2) + adapter.collect_span(span1) + adapter.collect_span(span2) - spans = self.adapter.get_all_spans() - self.assertEqual(len(spans), 2) - self.assertEqual(spans[0].span_id, "1" * 16) - self.assertEqual(spans[1].span_id, "2" * 16) + spans = adapter.get_all_spans() + assert len(spans) == 2 + assert spans[0].span_id == "1" * 16 + assert spans[1].span_id == "2" * 16 - def test_get_spans_by_instrumentation(self): + def test_get_spans_by_instrumentation(self, adapter): + """Test filtering spans by instrumentation.""" span1 = create_test_span() span1.instrumentation_name = "FlaskInstrumentation" span2 = create_test_span() span2.instrumentation_name = "FastAPIInstrumentation" - self.adapter.collect_span(span1) - self.adapter.collect_span(span2) + adapter.collect_span(span1) + adapter.collect_span(span2) - flask_spans = self.adapter.get_spans_by_instrumentation("Flask") - self.assertEqual(len(flask_spans), 1) - self.assertEqual(flask_spans[0].instrumentation_name, "FlaskInstrumentation") + flask_spans = adapter.get_spans_by_instrumentation("Flask") + assert len(flask_spans) == 1 + assert flask_spans[0].instrumentation_name == "FlaskInstrumentation" - def test_get_spans_by_kind(self): + def test_get_spans_by_kind(self, adapter): + """Test filtering spans by kind.""" span1 = create_test_span() span1.kind = SpanKind.SERVER span2 = create_test_span() span2.kind = SpanKind.CLIENT - self.adapter.collect_span(span1) - self.adapter.collect_span(span2) + adapter.collect_span(span1) + adapter.collect_span(span2) - server_spans = self.adapter.get_spans_by_kind(SpanKind.SERVER) - self.assertEqual(len(server_spans), 1) - self.assertEqual(server_spans[0].kind, SpanKind.SERVER) + server_spans = adapter.get_spans_by_kind(SpanKind.SERVER) + assert len(server_spans) == 1 + assert server_spans[0].kind == SpanKind.SERVER - def test_clear(self): - self.adapter.collect_span(create_test_span()) - self.assertEqual(len(self.adapter.get_all_spans()), 1) + def test_clear(self, adapter): + """Test clearing spans.""" + adapter.collect_span(create_test_span()) + assert len(adapter.get_all_spans()) == 1 - self.adapter.clear() - self.assertEqual(len(self.adapter.get_all_spans()), 0) + adapter.clear() + assert len(adapter.get_all_spans()) == 0 - def test_export_spans_async(self): + def test_export_spans_async(self, adapter): + """Test async span export.""" spans = [create_test_span(), create_test_span(span_id="c" * 16)] - result = asyncio.run(self.adapter.export_spans(spans)) + result = asyncio.run(adapter.export_spans(spans)) - self.assertEqual(result.code, ExportResultCode.SUCCESS) - self.assertEqual(len(self.adapter.get_all_spans()), 2) + assert result.code == ExportResultCode.SUCCESS + assert len(adapter.get_all_spans()) == 2 - def test_shutdown_clears_spans(self): - self.adapter.collect_span(create_test_span()) - asyncio.run(self.adapter.shutdown()) - self.assertEqual(len(self.adapter.get_all_spans()), 0) + def test_shutdown_clears_spans(self, adapter): + """Test that shutdown clears spans.""" + adapter.collect_span(create_test_span()) + asyncio.run(adapter.shutdown()) + assert len(adapter.get_all_spans()) == 0 - def test_get_spans_by_name(self): + def test_get_spans_by_name(self, adapter): """Test filtering spans by name.""" span1 = create_test_span(name="GET /api/users") span2 = create_test_span(name="POST /api/users") span3 = create_test_span(name="GET /api/orders") - self.adapter.collect_span(span1) - self.adapter.collect_span(span2) - self.adapter.collect_span(span3) + adapter.collect_span(span1) + adapter.collect_span(span2) + adapter.collect_span(span3) - # Get all spans with GET in name - all_spans = self.adapter.get_all_spans() + all_spans = adapter.get_all_spans() get_spans = [s for s in all_spans if "GET" in s.name] - self.assertEqual(len(get_spans), 2) + assert len(get_spans) == 2 - def test_get_spans_by_trace_id(self): + def test_get_spans_by_trace_id(self, adapter): """Test filtering spans by trace ID.""" trace1 = "trace1" + "0" * 26 trace2 = "trace2" + "0" * 26 @@ -134,81 +149,87 @@ def test_get_spans_by_trace_id(self): span2 = create_test_span(trace_id=trace1, name="span2") span3 = create_test_span(trace_id=trace2, name="span3") - self.adapter.collect_span(span1) - self.adapter.collect_span(span2) - self.adapter.collect_span(span3) + adapter.collect_span(span1) + adapter.collect_span(span2) + adapter.collect_span(span3) - all_spans = self.adapter.get_all_spans() + all_spans = adapter.get_all_spans() trace1_spans = [s for s in all_spans if s.trace_id == trace1] - self.assertEqual(len(trace1_spans), 2) + assert len(trace1_spans) == 2 - def test_get_spans_preserves_order(self): + def test_get_spans_preserves_order(self, adapter): """Test that spans are returned in insertion order.""" for i in range(5): span = create_test_span(span_id=str(i) * 16, name=f"span-{i}") - self.adapter.collect_span(span) + adapter.collect_span(span) - spans = self.adapter.get_all_spans() + spans = adapter.get_all_spans() for i, span in enumerate(spans): - self.assertEqual(span.name, f"span-{i}") + assert span.name == f"span-{i}" - def test_concurrent_exports(self): + def test_concurrent_exports(self, adapter): """Test concurrent exports don't cause issues.""" async def export_multiple(): tasks = [] for i in range(10): span = create_test_span(span_id=str(i) * 16) - tasks.append(self.adapter.export_spans([span])) + tasks.append(adapter.export_spans([span])) await asyncio.gather(*tasks) asyncio.run(export_multiple()) - self.assertEqual(len(self.adapter.get_all_spans()), 10) + assert len(adapter.get_all_spans()) == 10 -class TestFilesystemSpanAdapter(unittest.TestCase): +class TestFilesystemSpanAdapter: """Tests for FilesystemSpanAdapter.""" - def setUp(self): - self.temp_dir = tempfile.mkdtemp() - self.adapter = FilesystemSpanAdapter(self.temp_dir) - - def tearDown(self): - import shutil - - shutil.rmtree(self.temp_dir, ignore_errors=True) - - def test_name(self): - self.assertEqual(self.adapter.name, "filesystem") - - def test_repr(self): - self.assertIn("FilesystemSpanAdapter", repr(self.adapter)) - self.assertIn(self.temp_dir, repr(self.adapter)) - - def test_creates_directory(self): - new_dir = Path(self.temp_dir) / "nested" / "spans" - FilesystemSpanAdapter(new_dir) # Creates directory on init - self.assertTrue(new_dir.exists()) - - def test_exports_span_to_jsonl(self): + @pytest.fixture + def temp_dir(self): + """Create temporary directory for tests.""" + tmpdir = tempfile.mkdtemp() + yield tmpdir + shutil.rmtree(tmpdir, ignore_errors=True) + + @pytest.fixture + def adapter(self, temp_dir): + """Create adapter for testing.""" + return FilesystemSpanAdapter(temp_dir) + + def test_name(self, adapter): + """Test adapter name.""" + assert adapter.name == "filesystem" + + def test_repr(self, adapter, temp_dir): + """Test adapter repr.""" + assert "FilesystemSpanAdapter" in repr(adapter) + assert temp_dir in repr(adapter) + + def test_creates_directory(self, temp_dir): + """Test directory creation.""" + new_dir = Path(temp_dir) / "nested" / "spans" + FilesystemSpanAdapter(new_dir) + assert new_dir.exists() + + def test_exports_span_to_jsonl(self, adapter, temp_dir): + """Test exporting span to JSONL file.""" span = create_test_span() - result = asyncio.run(self.adapter.export_spans([span])) + result = asyncio.run(adapter.export_spans([span])) - self.assertEqual(result.code, ExportResultCode.SUCCESS) + assert result.code == ExportResultCode.SUCCESS - # Find the created file - files = list(Path(self.temp_dir).glob("*.jsonl")) - self.assertEqual(len(files), 1) + files = list(Path(temp_dir).glob("*.jsonl")) + assert len(files) == 1 - # Verify file content with open(files[0]) as f: line = f.readline() data = json.loads(line) - self.assertEqual(data["traceId"], span.trace_id) - self.assertEqual(data["spanId"], span.span_id) - self.assertEqual(data["name"], span.name) + assert data["traceId"] == span.trace_id + assert data["spanId"] == span.span_id + assert data["name"] == span.name - def test_groups_spans_by_trace_id(self): + def test_groups_spans_by_trace_id(self, adapter, temp_dir): + """Test grouping spans by trace ID.""" trace1 = "t1" + "0" * 30 trace2 = "t2" + "0" * 30 @@ -216,34 +237,31 @@ def test_groups_spans_by_trace_id(self): span2 = create_test_span(trace_id=trace1, span_id="2" * 16) span3 = create_test_span(trace_id=trace2, span_id="3" * 16) - # Export all spans - asyncio.run(self.adapter.export_spans([span1, span2, span3])) + asyncio.run(adapter.export_spans([span1, span2, span3])) - files = list(Path(self.temp_dir).glob("*.jsonl")) - self.assertEqual(len(files), 2) + files = list(Path(temp_dir).glob("*.jsonl")) + assert len(files) == 2 - # Check t1 file has 2 lines t1_file = [f for f in files if "t1" in str(f)][0] with open(t1_file) as f: lines = f.readlines() - self.assertEqual(len(lines), 2) + assert len(lines) == 2 - def test_lru_eviction(self): - adapter = FilesystemSpanAdapter(self.temp_dir, max_cached_traces=2) + def test_lru_eviction(self, temp_dir): + """Test LRU eviction of cached traces.""" + adapter = FilesystemSpanAdapter(temp_dir, max_cached_traces=2) - # Add 3 different traces for i in range(3): trace_id = f"trace{i}" + "0" * 26 span = create_test_span(trace_id=trace_id) asyncio.run(adapter.export_spans([span])) - # Only 2 traces should be in cache - self.assertEqual(len(adapter._trace_file_map), 2) - # Oldest (trace0) should have been evicted - self.assertNotIn("trace0" + "0" * 26, adapter._trace_file_map) + assert len(adapter._trace_file_map) == 2 + assert "trace0" + "0" * 26 not in adapter._trace_file_map - def test_lru_updates_on_access(self): - adapter = FilesystemSpanAdapter(self.temp_dir, max_cached_traces=2) + def test_lru_updates_on_access(self, temp_dir): + """Test LRU updates on access.""" + adapter = FilesystemSpanAdapter(temp_dir, max_cached_traces=2) trace0 = "trace0" + "0" * 26 trace1 = "trace1" + "0" * 26 @@ -251,30 +269,29 @@ def test_lru_updates_on_access(self): asyncio.run(adapter.export_spans([create_test_span(trace_id=trace0)])) asyncio.run(adapter.export_spans([create_test_span(trace_id=trace1)])) - # Access trace0 again (moves to end of LRU) asyncio.run(adapter.export_spans([create_test_span(trace_id=trace0, span_id="d" * 16)])) - # Add trace2 - should evict trace1 asyncio.run(adapter.export_spans([create_test_span(trace_id=trace2)])) - self.assertIn(trace0, adapter._trace_file_map) - self.assertIn(trace2, adapter._trace_file_map) - self.assertNotIn(trace1, adapter._trace_file_map) + assert trace0 in adapter._trace_file_map + assert trace2 in adapter._trace_file_map + assert trace1 not in adapter._trace_file_map - def test_shutdown_clears_cache(self): - asyncio.run(self.adapter.export_spans([create_test_span()])) - self.assertEqual(len(self.adapter._trace_file_map), 1) + def test_shutdown_clears_cache(self, adapter): + """Test that shutdown clears cache.""" + asyncio.run(adapter.export_spans([create_test_span()])) + assert len(adapter._trace_file_map) == 1 - asyncio.run(self.adapter.shutdown()) - self.assertEqual(len(self.adapter._trace_file_map), 0) + asyncio.run(adapter.shutdown()) + assert len(adapter._trace_file_map) == 0 -class TestApiSpanAdapter(unittest.TestCase): +class TestApiSpanAdapter: """Tests for ApiSpanAdapter.""" - def setUp(self): - # ApiSpanAdapterConfig is used internally by the SDK - # These parameters are read from config file, not init params - self.config = ApiSpanAdapterConfig( + @pytest.fixture + def config(self): + """Create config for testing.""" + return ApiSpanAdapterConfig( api_key="test-api-key", tusk_backend_base_url="https://api.test.com", observable_service_id="test-service-id", @@ -282,19 +299,21 @@ def setUp(self): sdk_version="1.0.0", sdk_instance_id="test-instance", ) - self.adapter = ApiSpanAdapter(self.config) - def tearDown(self): - # No cleanup needed for betterproto client - pass + @pytest.fixture + def adapter(self, config): + """Create adapter for testing.""" + return ApiSpanAdapter(config) - def test_name(self): - self.assertEqual(self.adapter.name, "api") + def test_name(self, adapter): + """Test adapter name.""" + assert adapter.name == "api" - def test_repr(self): - self.assertIn("ApiSpanAdapter", repr(self.adapter)) - self.assertIn("api.test.com", repr(self.adapter)) - self.assertIn("test", repr(self.adapter)) # environment + def test_repr(self, adapter): + """Test adapter repr.""" + assert "ApiSpanAdapter" in repr(adapter) + assert "api.test.com" in repr(adapter) + assert "test" in repr(adapter) def test_config_defaults(self): """Test that config has sensible defaults.""" @@ -306,51 +325,38 @@ def test_config_defaults(self): sdk_version="1.0.0", sdk_instance_id="inst", ) - # Config is minimal now - no timeout/retries since using betterproto - self.assertEqual(config.api_key, "key") - self.assertEqual(config.environment, "prod") + assert config.api_key == "key" + assert config.environment == "prod" - def test_transform_span_to_protobuf(self): + def test_transform_span_to_protobuf(self, adapter): """Test span transformation to protobuf format.""" from tusk.drift.core.v1 import Span as ProtoSpan span = create_test_span() - result = self.adapter._transform_span_to_protobuf(span) - - # Result should be a protobuf Span object - self.assertIsInstance(result, ProtoSpan) - self.assertEqual(result.trace_id, span.trace_id) - self.assertEqual(result.span_id, span.span_id) - self.assertEqual(result.name, span.name) - self.assertEqual(result.package_name, span.package_name) - # Kind is converted to int value for protobuf - self.assertEqual(result.kind, span.kind.value) - # Check that input/output were converted to Struct - self.assertIsNotNone(result.input_value) - self.assertIsNotNone(result.output_value) - # Check timestamp and duration are datetime/timedelta - from datetime import datetime, timedelta - - self.assertIsInstance(result.timestamp, datetime) - self.assertIsInstance(result.duration, timedelta) - - def test_base_url_construction(self): + result = adapter._transform_span_to_protobuf(span) + + assert isinstance(result, ProtoSpan) + assert result.trace_id == span.trace_id + assert result.span_id == span.span_id + assert result.name == span.name + assert result.package_name == span.package_name + assert result.kind == span.kind.value + assert result.input_value is not None + assert result.output_value is not None + assert isinstance(result.timestamp, datetime) + assert isinstance(result.duration, timedelta) + + def test_base_url_construction(self, adapter): """Test that the API URL is constructed correctly.""" - self.assertEqual( - self.adapter._base_url, "https://api.test.com/api/drift/tusk.drift.backend.v1.SpanExportService/ExportSpans" - ) + assert adapter._base_url == "https://api.test.com/api/drift/tusk.drift.backend.v1.SpanExportService/ExportSpans" - def test_aiohttp_not_installed(self): + def test_aiohttp_not_installed(self, adapter): """Test graceful handling when aiohttp is not installed.""" - # The error is raised in the channel, which is called during export_spans - # We can't easily mock the import since the adapter is already initialized - # So we'll just verify the adapter was created successfully - # The actual ImportError handling is tested in integration tests - self.assertIsNotNone(self.adapter) - self.assertEqual(self.adapter.name, "api") + assert adapter is not None + assert adapter.name == "api" -class TestAdapterIntegration(unittest.TestCase): +class TestAdapterIntegration: """Integration tests for adapters working together.""" def test_multiple_adapters(self): @@ -362,18 +368,12 @@ def test_multiple_adapters(self): span = create_test_span() - # Export to both result1 = asyncio.run(memory_adapter.export_spans([span])) result2 = asyncio.run(fs_adapter.export_spans([span])) - self.assertEqual(result1.code, ExportResultCode.SUCCESS) - self.assertEqual(result2.code, ExportResultCode.SUCCESS) + assert result1.code == ExportResultCode.SUCCESS + assert result2.code == ExportResultCode.SUCCESS - # Verify both have the span - self.assertEqual(len(memory_adapter.get_all_spans()), 1) + assert len(memory_adapter.get_all_spans()) == 1 files = list(Path(temp_dir).glob("*.jsonl")) - self.assertEqual(len(files), 1) - - -if __name__ == "__main__": - unittest.main() + assert len(files) == 1 diff --git a/tests/unit/test_config_loading.py b/tests/unit/test_config_loading.py index 343107f..9f1570f 100644 --- a/tests/unit/test_config_loading.py +++ b/tests/unit/test_config_loading.py @@ -2,7 +2,6 @@ import os import tempfile -import unittest from pathlib import Path from drift.core.config import ( @@ -12,17 +11,15 @@ ) -class TestFindProjectRoot(unittest.TestCase): +class TestFindProjectRoot: """Test the find_project_root function.""" def test_finds_project_root_with_pyproject_toml(self): """Should find project root when pyproject.toml exists.""" with tempfile.TemporaryDirectory() as tmpdir: - # Create a pyproject.toml file project_root = Path(tmpdir).resolve() (project_root / "pyproject.toml").touch() - # Change to a subdirectory subdir = project_root / "src" / "myapp" subdir.mkdir(parents=True) @@ -31,18 +28,16 @@ def test_finds_project_root_with_pyproject_toml(self): os.chdir(subdir) found_root = find_project_root() assert found_root is not None - self.assertEqual(found_root.resolve(), project_root.resolve()) + assert found_root.resolve() == project_root.resolve() finally: os.chdir(original_cwd) def test_finds_project_root_with_setup_py(self): """Should find project root when setup.py exists.""" with tempfile.TemporaryDirectory() as tmpdir: - # Create a setup.py file project_root = Path(tmpdir).resolve() (project_root / "setup.py").touch() - # Change to a subdirectory subdir = project_root / "src" subdir.mkdir(parents=True) @@ -51,14 +46,13 @@ def test_finds_project_root_with_setup_py(self): os.chdir(subdir) found_root = find_project_root() assert found_root is not None - self.assertEqual(found_root.resolve(), project_root.resolve()) + assert found_root.resolve() == project_root.resolve() finally: os.chdir(original_cwd) def test_returns_none_when_no_markers_found(self): """Should return None when no project markers are found.""" with tempfile.TemporaryDirectory() as tmpdir: - # Create a directory with no project markers subdir = Path(tmpdir) / "some" / "deep" / "path" subdir.mkdir(parents=True) @@ -66,14 +60,12 @@ def test_returns_none_when_no_markers_found(self): try: os.chdir(subdir) found_root = find_project_root() - # This might return None or find a marker higher up the tree - # depending on the actual directory structure - self.assertIsInstance(found_root, (Path, type(None))) + assert isinstance(found_root, (Path, type(None))) finally: os.chdir(original_cwd) -class TestLoadTuskConfig(unittest.TestCase): +class TestLoadTuskConfig: """Test the load_tusk_config function.""" def test_loads_valid_config_file(self): @@ -82,7 +74,6 @@ def test_loads_valid_config_file(self): project_root = Path(tmpdir) (project_root / "pyproject.toml").touch() - # Create .tusk directory and config file tusk_dir = project_root / ".tusk" tusk_dir.mkdir() @@ -115,38 +106,32 @@ def test_loads_valid_config_file(self): os.chdir(project_root) config = load_tusk_config() - self.assertIsNotNone(config) assert config is not None - self.assertIsInstance(config, TuskFileConfig) + assert isinstance(config, TuskFileConfig) # Check service config - self.assertIsNotNone(config.service) assert config.service is not None - self.assertEqual(config.service.id, "test-service-123") - self.assertEqual(config.service.name, "test-service") - self.assertEqual(config.service.port, 3000) + assert config.service.id == "test-service-123" + assert config.service.name == "test-service" + assert config.service.port == 3000 # Check traces config - self.assertIsNotNone(config.traces) assert config.traces is not None - self.assertEqual(config.traces.dir, ".tusk/traces") + assert config.traces.dir == ".tusk/traces" # Check recording config - self.assertIsNotNone(config.recording) assert config.recording is not None - self.assertEqual(config.recording.sampling_rate, 0.5) - self.assertEqual(config.recording.export_spans, False) - self.assertEqual(config.recording.enable_env_var_recording, True) + assert config.recording.sampling_rate == 0.5 + assert config.recording.export_spans is False + assert config.recording.enable_env_var_recording is True # Check tusk_api config - self.assertIsNotNone(config.tusk_api) assert config.tusk_api is not None - self.assertEqual(config.tusk_api.url, "https://api.example.com") + assert config.tusk_api.url == "https://api.example.com" # Check transforms - self.assertIsNotNone(config.transforms) assert config.transforms is not None - self.assertIn("http", config.transforms) + assert "http" in config.transforms finally: os.chdir(original_cwd) @@ -161,7 +146,7 @@ def test_returns_none_when_config_file_missing(self): try: os.chdir(project_root) config = load_tusk_config() - self.assertIsNone(config) + assert config is None finally: os.chdir(original_cwd) @@ -171,7 +156,6 @@ def test_handles_empty_config_file(self): project_root = Path(tmpdir) (project_root / "pyproject.toml").touch() - # Create .tusk directory and empty config file tusk_dir = project_root / ".tusk" tusk_dir.mkdir() (tusk_dir / "config.yaml").write_text("") @@ -181,16 +165,15 @@ def test_handles_empty_config_file(self): os.chdir(project_root) config = load_tusk_config() - self.assertIsNotNone(config) assert config is not None - self.assertIsInstance(config, TuskFileConfig) + assert isinstance(config, TuskFileConfig) # All fields should be None - self.assertIsNone(config.service) - self.assertIsNone(config.traces) - self.assertIsNone(config.recording) - self.assertIsNone(config.tusk_api) - self.assertIsNone(config.transforms) + assert config.service is None + assert config.traces is None + assert config.recording is None + assert config.tusk_api is None + assert config.transforms is None finally: os.chdir(original_cwd) @@ -201,7 +184,6 @@ def test_handles_partial_config(self): project_root = Path(tmpdir) (project_root / "pyproject.toml").touch() - # Create .tusk directory and partial config file tusk_dir = project_root / ".tusk" tusk_dir.mkdir() @@ -219,18 +201,15 @@ def test_handles_partial_config(self): os.chdir(project_root) config = load_tusk_config() - self.assertIsNotNone(config) assert config is not None # Only specified sections should be present - self.assertIsNone(config.service) - self.assertIsNotNone(config.traces) + assert config.service is None assert config.traces is not None - self.assertEqual(config.traces.dir, "./my-traces") - self.assertIsNotNone(config.recording) + assert config.traces.dir == "./my-traces" assert config.recording is not None - self.assertEqual(config.recording.sampling_rate, 0.8) - self.assertIsNone(config.tusk_api) + assert config.recording.sampling_rate == 0.8 + assert config.tusk_api is None finally: os.chdir(original_cwd) @@ -241,7 +220,6 @@ def test_handles_invalid_yaml(self): project_root = Path(tmpdir) (project_root / "pyproject.toml").touch() - # Create .tusk directory and invalid YAML file tusk_dir = project_root / ".tusk" tusk_dir.mkdir() (tusk_dir / "config.yaml").write_text("invalid: yaml: content: [") @@ -250,10 +228,6 @@ def test_handles_invalid_yaml(self): try: os.chdir(project_root) config = load_tusk_config() - self.assertIsNone(config) + assert config is None finally: os.chdir(original_cwd) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/unit/test_context_propagation.py b/tests/unit/test_context_propagation.py index d9628fb..9a1ff8d 100644 --- a/tests/unit/test_context_propagation.py +++ b/tests/unit/test_context_propagation.py @@ -1,17 +1,15 @@ """Test context propagation to ThreadPoolExecutor threads.""" -import unittest from concurrent.futures import ThreadPoolExecutor from drift.core.types import current_span_id_context, current_trace_id_context -class TestContextPropagation(unittest.TestCase): +class TestContextPropagation: """Test that context variables propagate correctly to thread pools.""" def test_context_propagates_to_thread_pool(self): """Test that context variables are accessible in ThreadPoolExecutor threads.""" - # Set context in main thread trace_id = "test-trace-123" span_id = "test-span-456" @@ -19,11 +17,9 @@ def test_context_propagates_to_thread_pool(self): span_token = current_span_id_context.set(span_id) try: - # Verify context is set in main thread - self.assertEqual(current_trace_id_context.get(), trace_id) - self.assertEqual(current_span_id_context.get(), span_id) + assert current_trace_id_context.get() == trace_id + assert current_span_id_context.get() == span_id - # Helper to set context and check it def run_with_context(trace_id, span_id): """Set context and return values.""" if trace_id: @@ -35,7 +31,6 @@ def run_with_context(trace_id, span_id): "span_id": current_span_id_context.get(), } - # Get context values to pass to threads ctx_trace_id = current_trace_id_context.get() ctx_span_id = current_span_id_context.get() @@ -46,15 +41,10 @@ def run_with_context(trace_id, span_id): result1 = future1.result() result2 = future2.result() - # Context should be accessible in threads - self.assertEqual(result1["trace_id"], trace_id, "Context trace_id not propagated to thread 1") - self.assertEqual(result1["span_id"], span_id, "Context span_id not propagated to thread 1") - self.assertEqual(result2["trace_id"], trace_id, "Context trace_id not propagated to thread 2") - self.assertEqual(result2["span_id"], span_id, "Context span_id not propagated to thread 2") + assert result1["trace_id"] == trace_id, "Context trace_id not propagated to thread 1" + assert result1["span_id"] == span_id, "Context span_id not propagated to thread 1" + assert result2["trace_id"] == trace_id, "Context trace_id not propagated to thread 2" + assert result2["span_id"] == span_id, "Context span_id not propagated to thread 2" finally: current_trace_id_context.reset(trace_token) current_span_id_context.reset(span_token) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/unit/test_data_normalization.py b/tests/unit/test_data_normalization.py index 8ace948..a7c54da 100644 --- a/tests/unit/test_data_normalization.py +++ b/tests/unit/test_data_normalization.py @@ -5,7 +5,6 @@ import json import sys -import unittest from datetime import datetime from pathlib import Path @@ -18,7 +17,7 @@ ) -class TestRemoveNoneValues(unittest.TestCase): +class TestRemoveNoneValues: """Tests for remove_none_values function (normalizeInputData equivalent).""" def test_should_remove_none_values_from_objects(self): @@ -26,7 +25,7 @@ def test_should_remove_none_values_from_objects(self): input_data = { "a": "value", "b": None, - "c": None, # explicit None (like JS null, but we remove it like JS undefined) + "c": None, "d": 0, "e": False, "f": "", @@ -34,17 +33,14 @@ def test_should_remove_none_values_from_objects(self): result = remove_none_values(input_data) - self.assertEqual( - result, - { - "a": "value", - "d": 0, - "e": False, - "f": "", - }, - ) - self.assertNotIn("b", result) - self.assertNotIn("c", result) + assert result == { + "a": "value", + "d": 0, + "e": False, + "f": "", + } + assert "b" not in result + assert "c" not in result def test_should_handle_nested_objects_with_none_values(self): """Handle nested objects with None values.""" @@ -71,10 +67,10 @@ def test_should_handle_nested_objects_with_none_values(self): }, }, } - self.assertEqual(result, expected) - self.assertNotIn("age", result["user"]) - self.assertNotIn("zip", result["user"]["address"]) - self.assertNotIn("metadata", result) + assert result == expected + assert "age" not in result["user"] + assert "zip" not in result["user"]["address"] + assert "metadata" not in result def test_should_handle_arrays_with_none_values(self): """Arrays should preserve None (like JS null in arrays).""" @@ -85,14 +81,10 @@ def test_should_handle_arrays_with_none_values(self): result = remove_none_values(input_data) - # In arrays, None is preserved (like JS null) - self.assertEqual( - result, - { - "items": ["a", None, "b", None, "c"], - "numbers": [1, None, 2, 0], - }, - ) + assert result == { + "items": ["a", None, "b", None, "c"], + "numbers": [1, None, 2, 0], + } def test_should_handle_circular_references_safely(self): """Circular references should be replaced with '[Circular]'.""" @@ -104,20 +96,17 @@ def test_should_handle_circular_references_safely(self): result = remove_none_values(input_data) - self.assertEqual( - result, - { - "name": "test", - "value": 123, - "self": "[Circular]", - }, - ) + assert result == { + "name": "test", + "value": 123, + "self": "[Circular]", + } def test_should_handle_empty_objects(self): """Empty objects should remain empty.""" input_data = {} result = remove_none_values(input_data) - self.assertEqual(result, {}) + assert result == {} def test_should_handle_primitive_values_wrapped_in_objects(self): """Test various primitive values in objects.""" @@ -131,14 +120,11 @@ def test_should_handle_primitive_values_wrapped_in_objects(self): result = remove_none_values(input_data) - self.assertEqual( - result, - { - "string": "test", - "number": 42, - "boolean": True, - }, - ) + assert result == { + "string": "test", + "number": 42, + "boolean": True, + } def test_should_preserve_date_objects_as_iso_strings(self): """Date objects should be converted to ISO strings.""" @@ -150,12 +136,9 @@ def test_should_preserve_date_objects_as_iso_strings(self): result = remove_none_values(input_data) - self.assertEqual( - result, - { - "timestamp": date.isoformat(), - }, - ) + assert result == { + "timestamp": date.isoformat(), + } def test_should_handle_complex_nested_structures(self): """Test complex nested structures with various types.""" @@ -183,10 +166,10 @@ def test_should_handle_complex_nested_structures(self): }, }, } - self.assertEqual(result, expected) + assert result == expected -class TestCreateSpanInputValue(unittest.TestCase): +class TestCreateSpanInputValue: """Tests for create_span_input_value function.""" def test_should_return_json_string_of_normalized_data(self): @@ -199,14 +182,11 @@ def test_should_return_json_string_of_normalized_data(self): result = create_span_input_value(input_data) - self.assertIsInstance(result, str) - self.assertEqual( - json.loads(result), - { - "user": "john", - "active": True, - }, - ) + assert isinstance(result, str) + assert json.loads(result) == { + "user": "john", + "active": True, + } def test_should_handle_circular_references_in_span_values(self): """Circular references should be handled in JSON output.""" @@ -217,14 +197,11 @@ def test_should_handle_circular_references_in_span_values(self): result = create_span_input_value(input_data) - self.assertIsInstance(result, str) - self.assertEqual( - json.loads(result), - { - "name": "test", - "circular": "[Circular]", - }, - ) + assert isinstance(result, str) + assert json.loads(result) == { + "name": "test", + "circular": "[Circular]", + } def test_should_produce_consistent_output_for_identical_normalized_data(self): """Same normalized data should produce same JSON output.""" @@ -234,10 +211,10 @@ def test_should_produce_consistent_output_for_identical_normalized_data(self): result1 = create_span_input_value(input1) result2 = create_span_input_value(input2) - self.assertEqual(result1, result2) + assert result1 == result2 -class TestCreateMockInputValue(unittest.TestCase): +class TestCreateMockInputValue: """Tests for create_mock_input_value function.""" def test_should_return_normalized_object_data(self): @@ -250,14 +227,11 @@ def test_should_return_normalized_object_data(self): result = create_mock_input_value(input_data) - self.assertEqual( - result, - { - "user": "john", - "active": True, - }, - ) - self.assertNotIn("age", result) + assert result == { + "user": "john", + "active": True, + } + assert "age" not in result def test_should_handle_circular_references_in_mock_values(self): """Circular references should be replaced with marker.""" @@ -268,13 +242,10 @@ def test_should_handle_circular_references_in_mock_values(self): result = create_mock_input_value(input_data) - self.assertEqual( - result, - { - "name": "test", - "circular": "[Circular]", - }, - ) + assert result == { + "name": "test", + "circular": "[Circular]", + } def test_should_produce_consistent_output_for_identical_normalized_data(self): """Same normalized data should produce same output.""" @@ -284,7 +255,7 @@ def test_should_produce_consistent_output_for_identical_normalized_data(self): result1 = create_mock_input_value(input1) result2 = create_mock_input_value(input2) - self.assertEqual(result1, result2) + assert result1 == result2 def test_should_preserve_type_information(self): """Type information should be preserved in output.""" @@ -296,16 +267,13 @@ def test_should_preserve_type_information(self): result = create_mock_input_value(input_data) - self.assertEqual( - result, - { - "id": 1, - "name": "test", - }, - ) + assert result == { + "id": 1, + "name": "test", + } -class TestConsistencyBetweenFunctions(unittest.TestCase): +class TestConsistencyBetweenFunctions: """Tests ensuring consistency between span and mock value functions.""" def test_should_ensure_functions_produce_equivalent_data_structures(self): @@ -324,7 +292,7 @@ def test_should_ensure_functions_produce_equivalent_data_structures(self): span_value = create_span_input_value(input_data) mock_value = create_mock_input_value(input_data) - self.assertEqual(json.loads(span_value), mock_value) + assert json.loads(span_value) == mock_value def test_should_handle_edge_cases_consistently(self): """Edge cases should be handled consistently by both functions.""" @@ -339,14 +307,10 @@ def test_should_handle_edge_cases_consistently(self): span_value = create_span_input_value(test_case) mock_value = create_mock_input_value(test_case) - self.assertEqual( - json.loads(span_value), - mock_value, - f"Mismatch for test case: {test_case}", - ) + assert json.loads(span_value) == mock_value, f"Mismatch for test case: {test_case}" -class TestEdgeCases(unittest.TestCase): +class TestEdgeCases: """Additional edge case tests.""" def test_deeply_nested_circular_reference(self): @@ -356,7 +320,7 @@ def test_deeply_nested_circular_reference(self): result = remove_none_values(input_data) - self.assertEqual(result["level1"]["level2"]["level3"]["back_to_root"], "[Circular]") + assert result["level1"]["level2"]["level3"]["back_to_root"] == "[Circular]" def test_multiple_circular_references(self): """Test multiple circular references to same object.""" @@ -365,10 +329,9 @@ def test_multiple_circular_references(self): result = remove_none_values(input_data) - # All references should be preserved (not circular since we reset seen on exit) - self.assertEqual(result["ref1"]["name"], "shared") - self.assertEqual(result["ref2"]["name"], "shared") - self.assertEqual(result["nested"]["ref3"]["name"], "shared") + assert result["ref1"]["name"] == "shared" + assert result["ref2"]["name"] == "shared" + assert result["nested"]["ref3"]["name"] == "shared" def test_list_containing_dicts_with_none(self): """Test list containing dicts with None values.""" @@ -387,26 +350,22 @@ def test_list_containing_dicts_with_none(self): {"id": 2, "name": "second"}, ] } - self.assertEqual(result, expected) + assert result == expected def test_empty_string_is_preserved(self): """Empty strings should be preserved.""" input_data = {"empty": "", "none": None} result = remove_none_values(input_data) - self.assertEqual(result, {"empty": ""}) + assert result == {"empty": ""} def test_zero_is_preserved(self): """Zero values should be preserved.""" input_data = {"zero": 0, "none": None} result = remove_none_values(input_data) - self.assertEqual(result, {"zero": 0}) + assert result == {"zero": 0} def test_false_is_preserved(self): """False values should be preserved.""" input_data = {"false": False, "none": None} result = remove_none_values(input_data) - self.assertEqual(result, {"false": False}) - - -if __name__ == "__main__": - unittest.main() + assert result == {"false": False} diff --git a/tests/unit/test_error_resilience.py b/tests/unit/test_error_resilience.py index 4a4b2f9..54bba0d 100644 --- a/tests/unit/test_error_resilience.py +++ b/tests/unit/test_error_resilience.py @@ -7,81 +7,73 @@ import asyncio import os -import unittest os.environ["TUSK_DRIFT_MODE"] = "RECORD" + from drift.core.tracing.adapters import ExportResult, ExportResultCode, InMemorySpanAdapter from drift.core.types import CleanSpanData, Duration, PackageType, SpanKind, SpanStatus, StatusCode, Timestamp from tests.utils import create_test_span -class TestAdapterErrorResilience(unittest.TestCase): +class TestAdapterErrorResilience: """Test that adapters handle errors gracefully.""" def test_in_memory_adapter_continues_after_invalid_data(self): """InMemorySpanAdapter should continue working after receiving invalid data.""" adapter = InMemorySpanAdapter() - # Collect a valid span first span1 = create_test_span(name="valid-1") adapter.collect_span(span1) - # Adapter should be functional spans = adapter.get_all_spans() - self.assertEqual(len(spans), 1) - self.assertEqual(spans[0].name, "valid-1") + assert len(spans) == 1 + assert spans[0].name == "valid-1" def test_adapter_recovers_after_error(self): """Adapter should continue working after an error.""" adapter = InMemorySpanAdapter() - # Collect valid span span1 = create_test_span(name="span1") adapter.collect_span(span1) - # Try to collect invalid data try: adapter.collect_span("not a span") # type: ignore except (TypeError, AttributeError): pass - # Collect another valid span span2 = create_test_span(name="span2") adapter.collect_span(span2) - # Both valid spans should be present spans = adapter.get_all_spans() valid_spans = [s for s in spans if isinstance(s, CleanSpanData)] - self.assertGreaterEqual(len(valid_spans), 2) + assert len(valid_spans) >= 2 def test_export_result_captures_errors(self): """ExportResult should properly capture error information.""" error = ValueError("Test error") result = ExportResult.failed(error) - self.assertEqual(result.code, ExportResultCode.FAILED) - self.assertEqual(result.error, error) + assert result.code == ExportResultCode.FAILED + assert result.error == error def test_export_result_from_string_error(self): """ExportResult should handle string error messages.""" result = ExportResult.failed("Something went wrong") - self.assertEqual(result.code, ExportResultCode.FAILED) - self.assertIsInstance(result.error, Exception) - self.assertIn("Something went wrong", str(result.error)) + assert result.code == ExportResultCode.FAILED + assert isinstance(result.error, Exception) + assert "Something went wrong" in str(result.error) -class TestSpanCreationErrorResilience(unittest.TestCase): +class TestSpanCreationErrorResilience: """Test that span creation handles errors gracefully.""" def test_span_with_invalid_input_value(self): """Creating a span with invalid input should be handled.""" - # Circular reference in input value circular_dict: dict = {} circular_dict["self"] = circular_dict - # This should either handle the circular reference or raise a clear error try: _span = CleanSpanData( trace_id="a" * 32, @@ -99,15 +91,13 @@ def test_span_with_invalid_input_value(self): timestamp=Timestamp(seconds=1700000000, nanos=0), duration=Duration(seconds=0, nanos=1000000), ) - # If span creation succeeds, serialization might fail - # which is also acceptable - del _span # Silence unused variable warning + del _span except (ValueError, RecursionError): - pass # Expected - might reject circular references + pass def test_span_with_very_large_input(self): """Creating a span with very large input should be handled.""" - large_input = {"data": "x" * 1_000_000} # 1MB of data + large_input = {"data": "x" * 1_000_000} span = CleanSpanData( trace_id="a" * 32, @@ -126,18 +116,17 @@ def test_span_with_very_large_input(self): duration=Duration(seconds=0, nanos=1000000), ) - # Span should be created (truncation might happen during export) - self.assertIsNotNone(span) + assert span is not None -class TestAsyncErrorResilience(unittest.TestCase): +class TestAsyncErrorResilience: """Test error resilience in async operations.""" def test_async_export_handles_timeout(self): """Async export should handle timeouts gracefully.""" async def slow_export(spans): - await asyncio.sleep(10) # Very slow + await asyncio.sleep(10) return ExportResult.success() adapter = InMemorySpanAdapter() @@ -153,9 +142,8 @@ async def timeout_export(spans): span = create_test_span() result = asyncio.run(adapter.export_spans([span])) - # Should have failed with timeout - self.assertEqual(result.code, ExportResultCode.FAILED) - self.assertIn("timed out", str(result.error)) + assert result.code == ExportResultCode.FAILED + assert "timed out" in str(result.error) def test_async_export_handles_cancellation(self): """Async export should handle cancellation gracefully.""" @@ -170,27 +158,10 @@ async def run_test(): try: await task except asyncio.CancelledError: - pass # Expected + pass - # Adapter should still be functional after cancellation result = await adapter.export_spans([create_test_span()]) return result result = asyncio.run(run_test()) - self.assertEqual(result.code, ExportResultCode.SUCCESS) - - -# NOTE: The following test categories were removed because they tested -# internal APIs that have significantly changed: -# -# - TestBatchProcessorErrorResilience: BatchSpanProcessor now requires -# a TdSpanExporter with complex configuration. The internal API changed -# significantly. Batch processing behavior is tested via E2E tests. -# -# - TestSDKErrorResilience: The SDK initialization and span collection -# flow has changed. Error resilience at the SDK level is better tested -# via integration/E2E tests that exercise the full SDK lifecycle. - - -if __name__ == "__main__": - unittest.main() + assert result.code == ExportResultCode.SUCCESS diff --git a/tests/unit/test_json_schema_helper.py b/tests/unit/test_json_schema_helper.py index f050c5d..36e95a8 100644 --- a/tests/unit/test_json_schema_helper.py +++ b/tests/unit/test_json_schema_helper.py @@ -6,7 +6,6 @@ import base64 import json import sys -import unittest from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent.parent)) @@ -21,89 +20,97 @@ ) -class TestGetDetailedType(unittest.TestCase): +class TestGetDetailedType: """Tests for JsonSchemaHelper._determine_type (getDetailedType equivalent).""" def test_should_correctly_identify_primitive_types(self): - self.assertEqual(JsonSchemaHelper._determine_type(None), JsonSchemaType.NULL) - self.assertEqual(JsonSchemaHelper._determine_type("hello"), JsonSchemaType.STRING) - self.assertEqual(JsonSchemaHelper._determine_type(42), JsonSchemaType.NUMBER) - self.assertEqual(JsonSchemaHelper._determine_type(3.14), JsonSchemaType.NUMBER) - self.assertEqual(JsonSchemaHelper._determine_type(True), JsonSchemaType.BOOLEAN) - self.assertEqual(JsonSchemaHelper._determine_type(False), JsonSchemaType.BOOLEAN) + """Test correct identification of primitive types.""" + assert JsonSchemaHelper._determine_type(None) == JsonSchemaType.NULL + assert JsonSchemaHelper._determine_type("hello") == JsonSchemaType.STRING + assert JsonSchemaHelper._determine_type(42) == JsonSchemaType.NUMBER + assert JsonSchemaHelper._determine_type(3.14) == JsonSchemaType.NUMBER + assert JsonSchemaHelper._determine_type(True) == JsonSchemaType.BOOLEAN + assert JsonSchemaHelper._determine_type(False) == JsonSchemaType.BOOLEAN def test_should_correctly_identify_object_types(self): - self.assertEqual(JsonSchemaHelper._determine_type({}), JsonSchemaType.OBJECT) - self.assertEqual(JsonSchemaHelper._determine_type([]), JsonSchemaType.ORDERED_LIST) - self.assertEqual(JsonSchemaHelper._determine_type(set()), JsonSchemaType.UNORDERED_LIST) + """Test correct identification of object types.""" + assert JsonSchemaHelper._determine_type({}) == JsonSchemaType.OBJECT + assert JsonSchemaHelper._determine_type([]) == JsonSchemaType.ORDERED_LIST + assert JsonSchemaHelper._determine_type(set()) == JsonSchemaType.UNORDERED_LIST def test_should_identify_callable_as_function(self): - self.assertEqual(JsonSchemaHelper._determine_type(lambda: None), JsonSchemaType.FUNCTION) + """Test identification of callable as function.""" + assert JsonSchemaHelper._determine_type(lambda: None) == JsonSchemaType.FUNCTION def test_should_handle_tuples_as_ordered_lists(self): - self.assertEqual(JsonSchemaHelper._determine_type((1, 2, 3)), JsonSchemaType.ORDERED_LIST) + """Test handling of tuples as ordered lists.""" + assert JsonSchemaHelper._determine_type((1, 2, 3)) == JsonSchemaType.ORDERED_LIST -class TestGenerateSchema(unittest.TestCase): +class TestGenerateSchema: """Tests for JsonSchemaHelper.generate_schema.""" def test_should_generate_schema_for_primitive_types(self): + """Test schema generation for primitive types.""" null_schema = JsonSchemaHelper.generate_schema(None) - self.assertEqual(null_schema.type, JsonSchemaType.NULL) - self.assertEqual(null_schema.properties, {}) + assert null_schema.type == JsonSchemaType.NULL + assert null_schema.properties == {} string_schema = JsonSchemaHelper.generate_schema("test") - self.assertEqual(string_schema.type, JsonSchemaType.STRING) - self.assertEqual(string_schema.properties, {}) + assert string_schema.type == JsonSchemaType.STRING + assert string_schema.properties == {} number_schema = JsonSchemaHelper.generate_schema(42) - self.assertEqual(number_schema.type, JsonSchemaType.NUMBER) - self.assertEqual(number_schema.properties, {}) + assert number_schema.type == JsonSchemaType.NUMBER + assert number_schema.properties == {} bool_schema = JsonSchemaHelper.generate_schema(True) - self.assertEqual(bool_schema.type, JsonSchemaType.BOOLEAN) - self.assertEqual(bool_schema.properties, {}) + assert bool_schema.type == JsonSchemaType.BOOLEAN + assert bool_schema.properties == {} def test_should_generate_schema_for_empty_arrays(self): + """Test schema generation for empty arrays.""" schema = JsonSchemaHelper.generate_schema([]) - self.assertEqual(schema.type, JsonSchemaType.ORDERED_LIST) - self.assertEqual(schema.properties, {}) - self.assertIsNone(schema.items) + assert schema.type == JsonSchemaType.ORDERED_LIST + assert schema.properties == {} + assert schema.items is None def test_should_generate_schema_for_number_arrays(self): + """Test schema generation for number arrays.""" schema = JsonSchemaHelper.generate_schema([1, 2, 3]) - self.assertEqual(schema.type, JsonSchemaType.ORDERED_LIST) - self.assertIsNotNone(schema.items) + assert schema.type == JsonSchemaType.ORDERED_LIST assert schema.items is not None - self.assertEqual(schema.items.type, JsonSchemaType.NUMBER) + assert schema.items.type == JsonSchemaType.NUMBER def test_should_generate_schema_for_string_arrays(self): + """Test schema generation for string arrays.""" schema = JsonSchemaHelper.generate_schema(["a", "b"]) - self.assertEqual(schema.type, JsonSchemaType.ORDERED_LIST) - self.assertIsNotNone(schema.items) + assert schema.type == JsonSchemaType.ORDERED_LIST assert schema.items is not None - self.assertEqual(schema.items.type, JsonSchemaType.STRING) + assert schema.items.type == JsonSchemaType.STRING def test_should_generate_schema_for_object_arrays(self): + """Test schema generation for object arrays.""" schema = JsonSchemaHelper.generate_schema([{"id": 1}]) - self.assertEqual(schema.type, JsonSchemaType.ORDERED_LIST) - self.assertIsNotNone(schema.items) + assert schema.type == JsonSchemaType.ORDERED_LIST assert schema.items is not None - self.assertEqual(schema.items.type, JsonSchemaType.OBJECT) - self.assertIn("id", schema.items.properties) - self.assertEqual(schema.items.properties["id"].type, JsonSchemaType.NUMBER) + assert schema.items.type == JsonSchemaType.OBJECT + assert "id" in schema.items.properties + assert schema.items.properties["id"].type == JsonSchemaType.NUMBER def test_should_generate_schema_for_simple_objects(self): + """Test schema generation for simple objects.""" simple_obj = {"name": "John", "age": 30} schema = JsonSchemaHelper.generate_schema(simple_obj) - self.assertEqual(schema.type, JsonSchemaType.OBJECT) - self.assertIn("name", schema.properties) - self.assertIn("age", schema.properties) - self.assertEqual(schema.properties["name"].type, JsonSchemaType.STRING) - self.assertEqual(schema.properties["age"].type, JsonSchemaType.NUMBER) + assert schema.type == JsonSchemaType.OBJECT + assert "name" in schema.properties + assert "age" in schema.properties + assert schema.properties["name"].type == JsonSchemaType.STRING + assert schema.properties["age"].type == JsonSchemaType.NUMBER def test_should_generate_schema_for_nested_objects(self): + """Test schema generation for nested objects.""" nested_obj = { "user": { "profile": { @@ -113,38 +120,40 @@ def test_should_generate_schema_for_nested_objects(self): } schema = JsonSchemaHelper.generate_schema(nested_obj) - self.assertEqual(schema.type, JsonSchemaType.OBJECT) - self.assertIn("user", schema.properties) + assert schema.type == JsonSchemaType.OBJECT + assert "user" in schema.properties user_schema = schema.properties["user"] - self.assertEqual(user_schema.type, JsonSchemaType.OBJECT) - self.assertIn("profile", user_schema.properties) + assert user_schema.type == JsonSchemaType.OBJECT + assert "profile" in user_schema.properties profile_schema = user_schema.properties["profile"] - self.assertEqual(profile_schema.type, JsonSchemaType.OBJECT) - self.assertIn("name", profile_schema.properties) - self.assertEqual(profile_schema.properties["name"].type, JsonSchemaType.STRING) + assert profile_schema.type == JsonSchemaType.OBJECT + assert "name" in profile_schema.properties + assert profile_schema.properties["name"].type == JsonSchemaType.STRING def test_should_generate_schema_for_empty_set(self): + """Test schema generation for empty set.""" schema = JsonSchemaHelper.generate_schema(set()) - self.assertEqual(schema.type, JsonSchemaType.UNORDERED_LIST) - self.assertEqual(schema.properties, {}) + assert schema.type == JsonSchemaType.UNORDERED_LIST + assert schema.properties == {} def test_should_generate_schema_for_number_set(self): + """Test schema generation for number set.""" schema = JsonSchemaHelper.generate_schema({1, 2, 3}) - self.assertEqual(schema.type, JsonSchemaType.UNORDERED_LIST) - self.assertIsNotNone(schema.items) + assert schema.type == JsonSchemaType.UNORDERED_LIST assert schema.items is not None - self.assertEqual(schema.items.type, JsonSchemaType.NUMBER) + assert schema.items.type == JsonSchemaType.NUMBER def test_should_generate_schema_for_string_set(self): + """Test schema generation for string set.""" schema = JsonSchemaHelper.generate_schema({"a", "b"}) - self.assertEqual(schema.type, JsonSchemaType.UNORDERED_LIST) - self.assertIsNotNone(schema.items) + assert schema.type == JsonSchemaType.UNORDERED_LIST assert schema.items is not None - self.assertEqual(schema.items.type, JsonSchemaType.STRING) + assert schema.items.type == JsonSchemaType.STRING def test_should_apply_schema_merges(self): + """Test applying schema merges.""" data = { "body": "eyJuYW1lIjoiSm9obiJ9", # base64 encoded JSON "header": "regular string", @@ -158,21 +167,22 @@ def test_should_apply_schema_merges(self): schema = JsonSchemaHelper.generate_schema(data, merges) - self.assertEqual(schema.type, JsonSchemaType.OBJECT) - self.assertIn("body", schema.properties) - self.assertIn("header", schema.properties) + assert schema.type == JsonSchemaType.OBJECT + assert "body" in schema.properties + assert "header" in schema.properties body_schema = schema.properties["body"] - self.assertEqual(body_schema.type, JsonSchemaType.STRING) - self.assertEqual(body_schema.encoding, EncodingType.BASE64) - self.assertEqual(body_schema.decoded_type, DecodedType.JSON) + assert body_schema.type == JsonSchemaType.STRING + assert body_schema.encoding == EncodingType.BASE64 + assert body_schema.decoded_type == DecodedType.JSON header_schema = schema.properties["header"] - self.assertEqual(header_schema.type, JsonSchemaType.STRING) - self.assertIsNone(header_schema.encoding) - self.assertIsNone(header_schema.decoded_type) + assert header_schema.type == JsonSchemaType.STRING + assert header_schema.encoding is None + assert header_schema.decoded_type is None def test_should_not_apply_schema_merges_to_nested_properties_with_same_name(self): + """Test that schema merges don't apply to nested properties with same name.""" data = { "body": { "title": "Example Post", @@ -188,23 +198,22 @@ def test_should_not_apply_schema_merges_to_nested_properties_with_same_name(self schema = JsonSchemaHelper.generate_schema(data, merges) - # Top-level body should have the merge applied body_schema = schema.properties["body"] - self.assertEqual(body_schema.type, JsonSchemaType.OBJECT) - self.assertEqual(body_schema.encoding, EncodingType.BASE64) - self.assertEqual(body_schema.decoded_type, DecodedType.JSON) + assert body_schema.type == JsonSchemaType.OBJECT + assert body_schema.encoding == EncodingType.BASE64 + assert body_schema.decoded_type == DecodedType.JSON - # Nested body should NOT have merge applied nested_body_schema = body_schema.properties["body"] - self.assertEqual(nested_body_schema.type, JsonSchemaType.STRING) - self.assertIsNone(nested_body_schema.encoding) - self.assertIsNone(nested_body_schema.decoded_type) + assert nested_body_schema.type == JsonSchemaType.STRING + assert nested_body_schema.encoding is None + assert nested_body_schema.decoded_type is None -class TestSortObjectKeysRecursively(unittest.TestCase): +class TestSortObjectKeysRecursively: """Tests for JsonSchemaHelper._sort_object_keys.""" def test_should_sort_object_keys_recursively(self): + """Test recursive sorting of object keys.""" input_data = { "z": 1, "a": { @@ -224,15 +233,17 @@ def test_should_sort_object_keys_recursively(self): } result = JsonSchemaHelper._sort_object_keys(input_data) - self.assertEqual(result, expected) + assert result == expected def test_should_handle_primitive_values(self): - self.assertIsNone(JsonSchemaHelper._sort_object_keys(None)) - self.assertEqual(JsonSchemaHelper._sort_object_keys("string"), "string") - self.assertEqual(JsonSchemaHelper._sort_object_keys(42), 42) - self.assertEqual(JsonSchemaHelper._sort_object_keys(True), True) + """Test handling of primitive values.""" + assert JsonSchemaHelper._sort_object_keys(None) is None + assert JsonSchemaHelper._sort_object_keys("string") == "string" + assert JsonSchemaHelper._sort_object_keys(42) == 42 + assert JsonSchemaHelper._sort_object_keys(True) is True def test_should_handle_arrays_with_objects(self): + """Test handling of arrays with objects.""" input_data = [ {"c": 1, "a": 2}, {"z": 3, "b": 4}, @@ -244,33 +255,36 @@ def test_should_handle_arrays_with_objects(self): ] result = JsonSchemaHelper._sort_object_keys(input_data) - self.assertEqual(result, expected) + assert result == expected -class TestGenerateDeterministicHash(unittest.TestCase): +class TestGenerateDeterministicHash: """Tests for JsonSchemaHelper.generate_deterministic_hash.""" def test_should_generate_consistent_hashes_for_same_data(self): + """Test consistent hash generation for same data.""" data1 = {"b": 2, "a": 1} data2 = {"a": 1, "b": 2} hash1 = JsonSchemaHelper.generate_deterministic_hash(data1) hash2 = JsonSchemaHelper.generate_deterministic_hash(data2) - self.assertEqual(hash1, hash2) - self.assertIsInstance(hash1, str) - self.assertEqual(len(hash1), 64) # SHA256 hex length + assert hash1 == hash2 + assert isinstance(hash1, str) + assert len(hash1) == 64 # SHA256 hex length def test_should_generate_different_hashes_for_different_data(self): + """Test different hash generation for different data.""" data1 = {"a": 1, "b": 2} data2 = {"a": 1, "b": 3} hash1 = JsonSchemaHelper.generate_deterministic_hash(data1) hash2 = JsonSchemaHelper.generate_deterministic_hash(data2) - self.assertNotEqual(hash1, hash2) + assert hash1 != hash2 def test_should_handle_complex_nested_structures(self): + """Test handling of complex nested structures.""" data = { "users": [ {"name": "John", "age": 30}, @@ -283,33 +297,34 @@ def test_should_handle_complex_nested_structures(self): } hash_result = JsonSchemaHelper.generate_deterministic_hash(data) - self.assertIsInstance(hash_result, str) - self.assertEqual(len(hash_result), 64) + assert isinstance(hash_result, str) + assert len(hash_result) == 64 - # Same data should produce same hash hash2 = JsonSchemaHelper.generate_deterministic_hash(data) - self.assertEqual(hash_result, hash2) + assert hash_result == hash2 -class TestGenerateSchemaAndHash(unittest.TestCase): +class TestGenerateSchemaAndHash: """Tests for JsonSchemaHelper.generate_schema_and_hash.""" def test_should_generate_schema_and_hashes_for_simple_data(self): + """Test schema and hash generation for simple data.""" data = {"name": "John", "age": 30} result = JsonSchemaHelper.generate_schema_and_hash(data) - self.assertEqual(result.schema.type, JsonSchemaType.OBJECT) - self.assertIn("name", result.schema.properties) - self.assertIn("age", result.schema.properties) - self.assertEqual(result.schema.properties["name"].type, JsonSchemaType.STRING) - self.assertEqual(result.schema.properties["age"].type, JsonSchemaType.NUMBER) + assert result.schema.type == JsonSchemaType.OBJECT + assert "name" in result.schema.properties + assert "age" in result.schema.properties + assert result.schema.properties["name"].type == JsonSchemaType.STRING + assert result.schema.properties["age"].type == JsonSchemaType.NUMBER - self.assertIsInstance(result.decoded_value_hash, str) - self.assertIsInstance(result.decoded_schema_hash, str) - self.assertEqual(len(result.decoded_value_hash), 64) - self.assertEqual(len(result.decoded_schema_hash), 64) + assert isinstance(result.decoded_value_hash, str) + assert isinstance(result.decoded_schema_hash, str) + assert len(result.decoded_value_hash) == 64 + assert len(result.decoded_schema_hash) == 64 def test_should_handle_schema_merges_with_base64_encoding(self): + """Test schema merges with base64 encoding.""" json_data = {"message": "Hello World"} base64_data = base64.b64encode(json.dumps(json_data).encode()).decode() @@ -327,18 +342,18 @@ def test_should_handle_schema_merges_with_base64_encoding(self): result = JsonSchemaHelper.generate_schema_and_hash(data, schema_merges) - # The decoded body should now be an object schema body_schema = result.schema.properties["body"] - self.assertEqual(body_schema.type, JsonSchemaType.OBJECT) - self.assertIn("message", body_schema.properties) - self.assertEqual(body_schema.properties["message"].type, JsonSchemaType.STRING) - self.assertEqual(body_schema.encoding, EncodingType.BASE64) - self.assertEqual(body_schema.decoded_type, DecodedType.JSON) + assert body_schema.type == JsonSchemaType.OBJECT + assert "message" in body_schema.properties + assert body_schema.properties["message"].type == JsonSchemaType.STRING + assert body_schema.encoding == EncodingType.BASE64 + assert body_schema.decoded_type == DecodedType.JSON - self.assertIsInstance(result.decoded_value_hash, str) - self.assertIsInstance(result.decoded_schema_hash, str) + assert isinstance(result.decoded_value_hash, str) + assert isinstance(result.decoded_schema_hash, str) def test_should_handle_empty_objects_and_arrays(self): + """Test handling of empty objects and arrays.""" data = { "emptyObj": {}, "emptyArr": [], @@ -347,22 +362,22 @@ def test_should_handle_empty_objects_and_arrays(self): result = JsonSchemaHelper.generate_schema_and_hash(data) - self.assertEqual(result.schema.type, JsonSchemaType.OBJECT) + assert result.schema.type == JsonSchemaType.OBJECT empty_obj_schema = result.schema.properties["emptyObj"] - self.assertEqual(empty_obj_schema.type, JsonSchemaType.OBJECT) - self.assertEqual(empty_obj_schema.properties, {}) + assert empty_obj_schema.type == JsonSchemaType.OBJECT + assert empty_obj_schema.properties == {} empty_arr_schema = result.schema.properties["emptyArr"] - self.assertEqual(empty_arr_schema.type, JsonSchemaType.ORDERED_LIST) + assert empty_arr_schema.type == JsonSchemaType.ORDERED_LIST items_schema = result.schema.properties["items"] - self.assertEqual(items_schema.type, JsonSchemaType.ORDERED_LIST) - self.assertIsNotNone(items_schema.items) + assert items_schema.type == JsonSchemaType.ORDERED_LIST assert items_schema.items is not None - self.assertEqual(items_schema.items.type, JsonSchemaType.NUMBER) + assert items_schema.items.type == JsonSchemaType.NUMBER def test_should_handle_decoding_errors_gracefully(self): + """Test graceful handling of decoding errors.""" data = { "body": "invalid-base64!!!", "other": "valid", @@ -375,65 +390,67 @@ def test_should_handle_decoding_errors_gracefully(self): ), } - # Should not raise, should handle gracefully result = JsonSchemaHelper.generate_schema_and_hash(data, schema_merges) - # Body should remain a string since decode failed body_schema = result.schema.properties["body"] - self.assertEqual(body_schema.type, JsonSchemaType.STRING) - self.assertEqual(body_schema.encoding, EncodingType.BASE64) - self.assertEqual(body_schema.decoded_type, DecodedType.JSON) + assert body_schema.type == JsonSchemaType.STRING + assert body_schema.encoding == EncodingType.BASE64 + assert body_schema.decoded_type == DecodedType.JSON -class TestEncodingAndDecodedTypeEnums(unittest.TestCase): +class TestEncodingAndDecodedTypeEnums: """Tests for EncodingType and DecodedType enums.""" def test_should_have_correct_encoding_type_values(self): - self.assertEqual(EncodingType.UNSPECIFIED.value, 0) - self.assertEqual(EncodingType.BASE64.value, 1) + """Test correct EncodingType values.""" + assert EncodingType.UNSPECIFIED.value == 0 + assert EncodingType.BASE64.value == 1 def test_should_have_correct_decoded_type_values(self): - self.assertEqual(DecodedType.UNSPECIFIED.value, 0) - self.assertEqual(DecodedType.JSON.value, 1) - self.assertEqual(DecodedType.HTML.value, 2) - self.assertEqual(DecodedType.CSS.value, 3) - self.assertEqual(DecodedType.JAVASCRIPT.value, 4) - self.assertEqual(DecodedType.XML.value, 5) - self.assertEqual(DecodedType.YAML.value, 6) - self.assertEqual(DecodedType.MARKDOWN.value, 7) - self.assertEqual(DecodedType.CSV.value, 8) - self.assertEqual(DecodedType.SQL.value, 9) - self.assertEqual(DecodedType.GRAPHQL.value, 10) - self.assertEqual(DecodedType.PLAIN_TEXT.value, 11) - self.assertEqual(DecodedType.FORM_DATA.value, 12) - self.assertEqual(DecodedType.MULTIPART_FORM.value, 13) - self.assertEqual(DecodedType.PDF.value, 14) - self.assertEqual(DecodedType.AUDIO.value, 15) - self.assertEqual(DecodedType.VIDEO.value, 16) - self.assertEqual(DecodedType.GZIP.value, 17) - self.assertEqual(DecodedType.BINARY.value, 18) - self.assertEqual(DecodedType.JPEG.value, 19) - self.assertEqual(DecodedType.PNG.value, 20) - self.assertEqual(DecodedType.GIF.value, 21) - self.assertEqual(DecodedType.WEBP.value, 22) - self.assertEqual(DecodedType.SVG.value, 23) - self.assertEqual(DecodedType.ZIP.value, 24) - - -class TestJsonSchemaToPrimitive(unittest.TestCase): + """Test correct DecodedType values.""" + assert DecodedType.UNSPECIFIED.value == 0 + assert DecodedType.JSON.value == 1 + assert DecodedType.HTML.value == 2 + assert DecodedType.CSS.value == 3 + assert DecodedType.JAVASCRIPT.value == 4 + assert DecodedType.XML.value == 5 + assert DecodedType.YAML.value == 6 + assert DecodedType.MARKDOWN.value == 7 + assert DecodedType.CSV.value == 8 + assert DecodedType.SQL.value == 9 + assert DecodedType.GRAPHQL.value == 10 + assert DecodedType.PLAIN_TEXT.value == 11 + assert DecodedType.FORM_DATA.value == 12 + assert DecodedType.MULTIPART_FORM.value == 13 + assert DecodedType.PDF.value == 14 + assert DecodedType.AUDIO.value == 15 + assert DecodedType.VIDEO.value == 16 + assert DecodedType.GZIP.value == 17 + assert DecodedType.BINARY.value == 18 + assert DecodedType.JPEG.value == 19 + assert DecodedType.PNG.value == 20 + assert DecodedType.GIF.value == 21 + assert DecodedType.WEBP.value == 22 + assert DecodedType.SVG.value == 23 + assert DecodedType.ZIP.value == 24 + + +class TestJsonSchemaToPrimitive: """Tests for JsonSchema.to_primitive conversion.""" def test_should_convert_simple_schema_to_primitive(self): + """Test simple schema conversion to primitive.""" schema = JsonSchema(type=JsonSchemaType.STRING) primitive = schema.to_primitive() - self.assertEqual(primitive["type"], JsonSchemaType.STRING.value) - self.assertEqual(primitive["properties"], {}) - self.assertNotIn("items", primitive) - self.assertNotIn("encoding", primitive) - self.assertNotIn("decoded_type", primitive) + assert primitive["type"] == JsonSchemaType.STRING.value + assert primitive["properties"] == {} + assert "items" not in primitive + assert "encoding" not in primitive + assert "decoded_type" not in primitive def test_should_convert_complex_schema_to_primitive(self): + """Test complex schema conversion to primitive.""" schema = JsonSchema( type=JsonSchemaType.OBJECT, properties={ @@ -443,13 +460,14 @@ def test_should_convert_complex_schema_to_primitive(self): ) primitive = schema.to_primitive() - self.assertEqual(primitive["type"], JsonSchemaType.OBJECT.value) - self.assertIn("name", primitive["properties"]) - self.assertIn("age", primitive["properties"]) - self.assertEqual(primitive["properties"]["name"]["type"], JsonSchemaType.STRING.value) - self.assertEqual(primitive["properties"]["age"]["type"], JsonSchemaType.NUMBER.value) + assert primitive["type"] == JsonSchemaType.OBJECT.value + assert "name" in primitive["properties"] + assert "age" in primitive["properties"] + assert primitive["properties"]["name"]["type"] == JsonSchemaType.STRING.value + assert primitive["properties"]["age"]["type"] == JsonSchemaType.NUMBER.value def test_should_include_encoding_and_decoded_type_when_set(self): + """Test inclusion of encoding and decoded_type when set.""" schema = JsonSchema( type=JsonSchemaType.STRING, encoding=EncodingType.BASE64, @@ -457,20 +475,17 @@ def test_should_include_encoding_and_decoded_type_when_set(self): ) primitive = schema.to_primitive() - self.assertEqual(primitive["encoding"], EncodingType.BASE64.value) - self.assertEqual(primitive["decoded_type"], DecodedType.JSON.value) + assert primitive["encoding"] == EncodingType.BASE64.value + assert primitive["decoded_type"] == DecodedType.JSON.value def test_should_include_items_for_arrays(self): + """Test inclusion of items for arrays.""" schema = JsonSchema( type=JsonSchemaType.ORDERED_LIST, items=JsonSchema(type=JsonSchemaType.NUMBER), ) primitive = schema.to_primitive() - self.assertEqual(primitive["type"], JsonSchemaType.ORDERED_LIST.value) - self.assertIn("items", primitive) - self.assertEqual(primitive["items"]["type"], JsonSchemaType.NUMBER.value) - - -if __name__ == "__main__": - unittest.main() + assert primitive["type"] == JsonSchemaType.ORDERED_LIST.value + assert "items" in primitive + assert primitive["items"]["type"] == JsonSchemaType.NUMBER.value diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py index db203d0..3192f54 100644 --- a/tests/unit/test_metrics.py +++ b/tests/unit/test_metrics.py @@ -2,8 +2,8 @@ import threading import time -import unittest -from unittest.mock import patch + +import pytest from drift.core.metrics import ( ExportMetrics, @@ -15,24 +15,24 @@ ) -class TestExportMetrics(unittest.TestCase): +class TestExportMetrics: """Tests for ExportMetrics dataclass.""" def test_default_values(self): """Test default values.""" metrics = ExportMetrics() - self.assertEqual(metrics.spans_exported, 0) - self.assertEqual(metrics.spans_dropped, 0) - self.assertEqual(metrics.spans_failed, 0) - self.assertEqual(metrics.batches_exported, 0) - self.assertEqual(metrics.batches_failed, 0) - self.assertEqual(metrics.bytes_sent, 0) - self.assertEqual(metrics.bytes_compressed_saved, 0) + assert metrics.spans_exported == 0 + assert metrics.spans_dropped == 0 + assert metrics.spans_failed == 0 + assert metrics.batches_exported == 0 + assert metrics.batches_failed == 0 + assert metrics.bytes_sent == 0 + assert metrics.bytes_compressed_saved == 0 def test_average_latency_when_empty(self): """Test average latency is 0 when no exports.""" metrics = ExportMetrics() - self.assertEqual(metrics.average_export_latency_ms, 0.0) + assert metrics.average_export_latency_ms == 0.0 def test_average_latency_calculation(self): """Test average latency calculation.""" @@ -40,152 +40,149 @@ def test_average_latency_calculation(self): export_latency_sum_ms=300.0, export_count=3, ) - self.assertEqual(metrics.average_export_latency_ms, 100.0) + assert metrics.average_export_latency_ms == 100.0 -class TestQueueMetrics(unittest.TestCase): +class TestQueueMetrics: """Tests for QueueMetrics dataclass.""" def test_default_values(self): """Test default values.""" metrics = QueueMetrics() - self.assertEqual(metrics.current_size, 0) - self.assertEqual(metrics.max_size, 0) - self.assertEqual(metrics.peak_size, 0) + assert metrics.current_size == 0 + assert metrics.max_size == 0 + assert metrics.peak_size == 0 -class TestMetricsCollector(unittest.TestCase): +class TestMetricsCollector: """Tests for MetricsCollector.""" - def setUp(self): + @pytest.fixture + def collector(self): """Create fresh collector for each test.""" - self.collector = MetricsCollector() + return MetricsCollector() - def test_record_spans_exported(self): + def test_record_spans_exported(self, collector): """Test recording exported spans.""" - self.collector.record_spans_exported(10) - self.collector.record_spans_exported(5) + collector.record_spans_exported(10) + collector.record_spans_exported(5) - metrics = self.collector.get_metrics() - self.assertEqual(metrics.export.spans_exported, 15) - self.assertEqual(metrics.export.batches_exported, 2) + metrics = collector.get_metrics() + assert metrics.export.spans_exported == 15 + assert metrics.export.batches_exported == 2 - def test_record_spans_dropped(self): + def test_record_spans_dropped(self, collector): """Test recording dropped spans.""" - self.collector.record_spans_dropped() - self.collector.record_spans_dropped(5) + collector.record_spans_dropped() + collector.record_spans_dropped(5) - metrics = self.collector.get_metrics() - self.assertEqual(metrics.export.spans_dropped, 6) + metrics = collector.get_metrics() + assert metrics.export.spans_dropped == 6 - def test_record_spans_failed(self): + def test_record_spans_failed(self, collector): """Test recording failed spans.""" - self.collector.record_spans_failed(10) + collector.record_spans_failed(10) - metrics = self.collector.get_metrics() - self.assertEqual(metrics.export.spans_failed, 10) - self.assertEqual(metrics.export.batches_failed, 1) + metrics = collector.get_metrics() + assert metrics.export.spans_failed == 10 + assert metrics.export.batches_failed == 1 - def test_record_export_latency(self): + def test_record_export_latency(self, collector): """Test recording export latency.""" - self.collector.record_export_latency(100.0) - self.collector.record_export_latency(200.0) + collector.record_export_latency(100.0) + collector.record_export_latency(200.0) - metrics = self.collector.get_metrics() - self.assertEqual(metrics.export.export_latency_sum_ms, 300.0) - self.assertEqual(metrics.export.export_count, 2) - self.assertEqual(metrics.export.average_export_latency_ms, 150.0) + metrics = collector.get_metrics() + assert metrics.export.export_latency_sum_ms == 300.0 + assert metrics.export.export_count == 2 + assert metrics.export.average_export_latency_ms == 150.0 - def test_record_bytes_sent(self): + def test_record_bytes_sent(self, collector): """Test recording bytes sent.""" - self.collector.record_bytes_sent(1000, 200) - self.collector.record_bytes_sent(500, 100) + collector.record_bytes_sent(1000, 200) + collector.record_bytes_sent(500, 100) - metrics = self.collector.get_metrics() - self.assertEqual(metrics.export.bytes_sent, 1500) - self.assertEqual(metrics.export.bytes_compressed_saved, 300) + metrics = collector.get_metrics() + assert metrics.export.bytes_sent == 1500 + assert metrics.export.bytes_compressed_saved == 300 - def test_update_queue_size(self): + def test_update_queue_size(self, collector): """Test updating queue size.""" - self.collector.set_queue_max_size(1000) - self.collector.update_queue_size(50) - self.collector.update_queue_size(100) - self.collector.update_queue_size(75) + collector.set_queue_max_size(1000) + collector.update_queue_size(50) + collector.update_queue_size(100) + collector.update_queue_size(75) - metrics = self.collector.get_metrics() - self.assertEqual(metrics.queue.current_size, 75) - self.assertEqual(metrics.queue.peak_size, 100) - self.assertEqual(metrics.queue.max_size, 1000) + metrics = collector.get_metrics() + assert metrics.queue.current_size == 75 + assert metrics.queue.peak_size == 100 + assert metrics.queue.max_size == 1000 - def test_instrumentation_tracking(self): + def test_instrumentation_tracking(self, collector): """Test instrumentation activation tracking.""" - self.collector.record_instrumentation_activated() - self.collector.record_instrumentation_activated() - self.collector.record_instrumentation_deactivated() + collector.record_instrumentation_activated() + collector.record_instrumentation_activated() + collector.record_instrumentation_deactivated() - metrics = self.collector.get_metrics() - self.assertEqual(metrics.instrumentations_active, 1) + metrics = collector.get_metrics() + assert metrics.instrumentations_active == 1 - def test_instrumentation_deactivate_floor(self): + def test_instrumentation_deactivate_floor(self, collector): """Test that instrumentation count doesn't go below 0.""" - self.collector.record_instrumentation_deactivated() + collector.record_instrumentation_deactivated() - metrics = self.collector.get_metrics() - self.assertEqual(metrics.instrumentations_active, 0) + metrics = collector.get_metrics() + assert metrics.instrumentations_active == 0 - def test_uptime_tracking(self): + def test_uptime_tracking(self, collector): """Test uptime tracking.""" time.sleep(0.1) - metrics = self.collector.get_metrics() - self.assertGreaterEqual(metrics.uptime_seconds, 0.1) + metrics = collector.get_metrics() + assert metrics.uptime_seconds >= 0.1 - def test_reset(self): + def test_reset(self, collector): """Test resetting metrics.""" - self.collector.record_spans_exported(100) - self.collector.record_spans_dropped(10) - self.collector.update_queue_size(50) + collector.record_spans_exported(100) + collector.record_spans_dropped(10) + collector.update_queue_size(50) - self.collector.reset() + collector.reset() - metrics = self.collector.get_metrics() - self.assertEqual(metrics.export.spans_exported, 0) - self.assertEqual(metrics.export.spans_dropped, 0) - self.assertEqual(metrics.queue.peak_size, 0) - # Current queue size isn't reset (it's external state) + metrics = collector.get_metrics() + assert metrics.export.spans_exported == 0 + assert metrics.export.spans_dropped == 0 + assert metrics.queue.peak_size == 0 - def test_reset_clears_warning_flags(self): + def test_reset_clears_warning_flags(self, collector, mocker): """Test that reset clears warning flags so warnings can fire again.""" - self.collector.set_queue_max_size(100) - - with patch("drift.core.metrics.logger") as mock_logger: - # Trigger warning (85% capacity) - self.collector.update_queue_size(85) - self.assertTrue(self.collector._warned_queue_capacity) - mock_logger.warning.assert_called_once() - - # Reset should clear the flag - self.collector.reset() - self.assertFalse(self.collector._warned_queue_capacity) - self.assertFalse(self.collector._warned_high_drop_rate) - self.assertFalse(self.collector._warned_high_failure_rate) - self.assertFalse(self.collector._warned_circuit_open) - - # Warning should fire again after reset - mock_logger.reset_mock() - self.collector.set_queue_max_size(100) - self.collector.update_queue_size(85) - mock_logger.warning.assert_called_once() - - def test_thread_safety(self): + collector.set_queue_max_size(100) + + mock_logger = mocker.patch("drift.core.metrics.logger") + collector.update_queue_size(85) + assert collector._warned_queue_capacity is True + mock_logger.warning.assert_called_once() + + collector.reset() + assert collector._warned_queue_capacity is False + assert collector._warned_high_drop_rate is False + assert collector._warned_high_failure_rate is False + assert collector._warned_circuit_open is False + + mock_logger.reset_mock() + collector.set_queue_max_size(100) + collector.update_queue_size(85) + mock_logger.warning.assert_called_once() + + def test_thread_safety(self, collector): """Test that metrics collection is thread-safe.""" iterations = 1000 threads = [] def record_metrics(): for _ in range(iterations): - self.collector.record_spans_exported(1) - self.collector.record_spans_dropped() - self.collector.update_queue_size(10) + collector.record_spans_exported(1) + collector.record_spans_dropped() + collector.update_queue_size(10) for _ in range(5): t = threading.Thread(target=record_metrics) @@ -196,155 +193,133 @@ def record_metrics(): for t in threads: t.join() - metrics = self.collector.get_metrics() - self.assertEqual(metrics.export.spans_exported, 5 * iterations) - self.assertEqual(metrics.export.spans_dropped, 5 * iterations) + metrics = collector.get_metrics() + assert metrics.export.spans_exported == 5 * iterations + assert metrics.export.spans_dropped == 5 * iterations -class TestGlobalMetricsCollector(unittest.TestCase): +class TestGlobalMetricsCollector: """Tests for global metrics functions.""" def test_get_metrics_collector_returns_same_instance(self): """Test that get_metrics_collector returns the same instance.""" collector1 = get_metrics_collector() collector2 = get_metrics_collector() - self.assertIs(collector1, collector2) + assert collector1 is collector2 def test_get_sdk_metrics_returns_snapshot(self): """Test that get_sdk_metrics returns a metrics snapshot.""" metrics = get_sdk_metrics() - self.assertIsInstance(metrics, SDKMetrics) + assert isinstance(metrics, SDKMetrics) -class TestEventDrivenWarnings(unittest.TestCase): +class TestEventDrivenWarnings: """Tests for event-driven warning logs.""" - def setUp(self): + @pytest.fixture + def collector(self): """Create fresh collector for each test.""" - self.collector = MetricsCollector() + return MetricsCollector() - def test_high_drop_rate_warning(self): + def test_high_drop_rate_warning(self, collector, mocker): """Test that high drop rate triggers a warning.""" - # Need at least 100 samples for warning to trigger - with patch("drift.core.metrics.logger") as mock_logger: - # 90 exported, 10 dropped = 10% drop rate (above 5% threshold) - for _ in range(90): - self.collector.record_spans_exported(1) - for _ in range(10): - self.collector.record_spans_dropped(1) - - # Should have warned about high drop rate - mock_logger.warning.assert_called() - call_args = mock_logger.warning.call_args[0][0] - self.assertIn("drop rate", call_args.lower()) - - def test_drop_rate_warning_not_spam(self): + mock_logger = mocker.patch("drift.core.metrics.logger") + for _ in range(90): + collector.record_spans_exported(1) + for _ in range(10): + collector.record_spans_dropped(1) + + mock_logger.warning.assert_called() + call_args = mock_logger.warning.call_args[0][0] + assert "drop rate" in call_args.lower() + + def test_drop_rate_warning_not_spam(self, collector, mocker): """Test that drop rate warning is not repeated.""" - with patch("drift.core.metrics.logger") as mock_logger: - # First batch: 90 exported, 10 dropped = 10% drop rate - for _ in range(90): - self.collector.record_spans_exported(1) - for _ in range(10): - self.collector.record_spans_dropped(1) - - # More drops should not spam warnings - for _ in range(10): - self.collector.record_spans_dropped(1) - - # Should only have warned once - warning_calls = [c for c in mock_logger.warning.call_args_list if "drop rate" in str(c).lower()] - self.assertEqual(len(warning_calls), 1) - - def test_high_failure_rate_warning(self): + mock_logger = mocker.patch("drift.core.metrics.logger") + for _ in range(90): + collector.record_spans_exported(1) + for _ in range(10): + collector.record_spans_dropped(1) + + for _ in range(10): + collector.record_spans_dropped(1) + + warning_calls = [c for c in mock_logger.warning.call_args_list if "drop rate" in str(c).lower()] + assert len(warning_calls) == 1 + + def test_high_failure_rate_warning(self, collector, mocker): """Test that high failure rate triggers a warning.""" - with patch("drift.core.metrics.logger") as mock_logger: - # 85 exported, 15 failed = ~15% failure rate (above 10% threshold) - for _ in range(85): - self.collector.record_spans_exported(1) - self.collector.record_spans_failed(15) - - # Should have warned about high failure rate - mock_logger.warning.assert_called() - call_args = mock_logger.warning.call_args[0][0] - self.assertIn("failure rate", call_args.lower()) - - def test_queue_capacity_warning(self): + mock_logger = mocker.patch("drift.core.metrics.logger") + for _ in range(85): + collector.record_spans_exported(1) + collector.record_spans_failed(15) + + mock_logger.warning.assert_called() + call_args = mock_logger.warning.call_args[0][0] + assert "failure rate" in call_args.lower() + + def test_queue_capacity_warning(self, collector, mocker): """Test that high queue capacity triggers a warning.""" - self.collector.set_queue_max_size(100) + collector.set_queue_max_size(100) - with patch("drift.core.metrics.logger") as mock_logger: - # 85% capacity (above 80% threshold) - self.collector.update_queue_size(85) + mock_logger = mocker.patch("drift.core.metrics.logger") + collector.update_queue_size(85) - mock_logger.warning.assert_called() - call_args = mock_logger.warning.call_args[0][0] - self.assertIn("queue", call_args.lower()) - self.assertIn("capacity", call_args.lower()) + mock_logger.warning.assert_called() + call_args = mock_logger.warning.call_args[0][0] + assert "queue" in call_args.lower() + assert "capacity" in call_args.lower() - def test_queue_capacity_warning_clears(self): + def test_queue_capacity_warning_clears(self, collector, mocker): """Test that queue capacity warning clears when capacity decreases.""" - self.collector.set_queue_max_size(100) + collector.set_queue_max_size(100) - with patch("drift.core.metrics.logger"): - # Trigger warning at 85% - self.collector.update_queue_size(85) - self.assertTrue(self.collector._warned_queue_capacity) + mocker.patch("drift.core.metrics.logger") + collector.update_queue_size(85) + assert collector._warned_queue_capacity is True - # Drop below threshold - self.collector.update_queue_size(50) - self.assertFalse(self.collector._warned_queue_capacity) + collector.update_queue_size(50) + assert collector._warned_queue_capacity is False - def test_circuit_breaker_warning(self): + def test_circuit_breaker_warning(self, collector, mocker): """Test circuit breaker open warning.""" - with patch("drift.core.metrics.logger") as mock_logger: - self.collector.warn_circuit_breaker_open() + mock_logger = mocker.patch("drift.core.metrics.logger") + collector.warn_circuit_breaker_open() - mock_logger.warning.assert_called() - call_args = mock_logger.warning.call_args[0][0] - self.assertIn("circuit breaker", call_args.lower()) - self.assertIn("open", call_args.lower()) + mock_logger.warning.assert_called() + call_args = mock_logger.warning.call_args[0][0] + assert "circuit breaker" in call_args.lower() + assert "open" in call_args.lower() - def test_circuit_breaker_closed_notification(self): + def test_circuit_breaker_closed_notification(self, collector, mocker): """Test circuit breaker closed notification.""" - with patch("drift.core.metrics.logger") as mock_logger: - # First open - self.collector.warn_circuit_breaker_open() - # Then close - self.collector.notify_circuit_breaker_closed() + mock_logger = mocker.patch("drift.core.metrics.logger") + collector.warn_circuit_breaker_open() + collector.notify_circuit_breaker_closed() - # Should have logged info about closing - info_calls = mock_logger.info.call_args_list - self.assertTrue(any("closed" in str(c).lower() for c in info_calls)) + info_calls = mock_logger.info.call_args_list + assert any("closed" in str(c).lower() for c in info_calls) - def test_no_warning_below_threshold(self): + def test_no_warning_below_threshold(self, collector, mocker): """Test that no warnings are logged below thresholds.""" - self.collector.set_queue_max_size(100) + collector.set_queue_max_size(100) - with patch("drift.core.metrics.logger") as mock_logger: - # 95 exported, 5 dropped = 5% drop rate (at threshold, not above) - for _ in range(95): - self.collector.record_spans_exported(1) - for _ in range(5): - self.collector.record_spans_dropped(1) + mock_logger = mocker.patch("drift.core.metrics.logger") + for _ in range(95): + collector.record_spans_exported(1) + for _ in range(5): + collector.record_spans_dropped(1) - # Queue at 50% (below 80% threshold) - self.collector.update_queue_size(50) + collector.update_queue_size(50) - # No warnings should be logged - mock_logger.warning.assert_not_called() + mock_logger.warning.assert_not_called() - def test_no_warning_without_enough_samples(self): + def test_no_warning_without_enough_samples(self, collector, mocker): """Test that no drop rate warning without enough samples.""" - with patch("drift.core.metrics.logger") as mock_logger: - # Only 50 samples (below 100 minimum) - for _ in range(45): - self.collector.record_spans_exported(1) - for _ in range(5): - self.collector.record_spans_dropped(1) - - # No warning due to insufficient samples - mock_logger.warning.assert_not_called() - + mock_logger = mocker.patch("drift.core.metrics.logger") + for _ in range(45): + collector.record_spans_exported(1) + for _ in range(5): + collector.record_spans_dropped(1) -if __name__ == "__main__": - unittest.main() + mock_logger.warning.assert_not_called() diff --git a/tests/unit/test_requests_instrumentation.py b/tests/unit/test_requests_instrumentation.py index fa62a06..862593c 100644 --- a/tests/unit/test_requests_instrumentation.py +++ b/tests/unit/test_requests_instrumentation.py @@ -2,7 +2,8 @@ import base64 import json -import unittest + +import pytest from drift.core.json_schema_helper import DecodedType from drift.instrumentation.requests.instrumentation import ( @@ -10,113 +11,110 @@ ) -class TestRequestsInstrumentationHelpers(unittest.TestCase): +class TestRequestsInstrumentationHelpers: """Test body encoding helper methods.""" - def setUp(self): - self.instrumentation = RequestsInstrumentation() + @pytest.fixture + def instrumentation(self): + """Create instrumentation instance for testing.""" + return RequestsInstrumentation() - def test_encode_body_to_base64_with_string(self): + def test_encode_body_to_base64_with_string(self, instrumentation): """Test encoding string body to base64.""" body = "test body" - encoded, size = self.instrumentation._encode_body_to_base64(body) + encoded, size = instrumentation._encode_body_to_base64(body) - self.assertIsNotNone(encoded) assert encoded is not None - self.assertEqual(size, len(body.encode("utf-8"))) - # Verify it's valid base64 + assert size == len(body.encode("utf-8")) decoded = base64.b64decode(encoded.encode("ascii")) - self.assertEqual(decoded.decode("utf-8"), body) + assert decoded.decode("utf-8") == body - def test_encode_body_to_base64_with_bytes(self): + def test_encode_body_to_base64_with_bytes(self, instrumentation): """Test encoding bytes body to base64.""" body = b"test bytes" - encoded, size = self.instrumentation._encode_body_to_base64(body) + encoded, size = instrumentation._encode_body_to_base64(body) - self.assertIsNotNone(encoded) assert encoded is not None - self.assertEqual(size, len(body)) + assert size == len(body) decoded = base64.b64decode(encoded.encode("ascii")) - self.assertEqual(decoded, body) + assert decoded == body - def test_encode_body_to_base64_with_json(self): + def test_encode_body_to_base64_with_json(self, instrumentation): """Test encoding JSON dict to base64.""" body = {"key": "value", "number": 123} - encoded, size = self.instrumentation._encode_body_to_base64(body) + encoded, size = instrumentation._encode_body_to_base64(body) - self.assertIsNotNone(encoded) assert encoded is not None json_str = json.dumps(body) - self.assertEqual(size, len(json_str.encode("utf-8"))) + assert size == len(json_str.encode("utf-8")) decoded = base64.b64decode(encoded.encode("ascii")) - self.assertEqual(json.loads(decoded.decode("utf-8")), body) + assert json.loads(decoded.decode("utf-8")) == body - def test_encode_body_to_base64_with_none(self): + def test_encode_body_to_base64_with_none(self, instrumentation): """Test encoding None returns None and size 0.""" - encoded, size = self.instrumentation._encode_body_to_base64(None) + encoded, size = instrumentation._encode_body_to_base64(None) - self.assertIsNone(encoded) - self.assertEqual(size, 0) + assert encoded is None + assert size == 0 - def test_get_decoded_type_from_content_type_json(self): + def test_get_decoded_type_from_content_type_json(self, instrumentation): """Test JSON content type detection.""" - decoded_type = self.instrumentation._get_decoded_type_from_content_type("application/json") - self.assertEqual(decoded_type, DecodedType.JSON) + decoded_type = instrumentation._get_decoded_type_from_content_type("application/json") + assert decoded_type == DecodedType.JSON - def test_get_decoded_type_from_content_type_with_charset(self): + def test_get_decoded_type_from_content_type_with_charset(self, instrumentation): """Test content type with charset parameter.""" - decoded_type = self.instrumentation._get_decoded_type_from_content_type("application/json; charset=utf-8") - self.assertEqual(decoded_type, DecodedType.JSON) + decoded_type = instrumentation._get_decoded_type_from_content_type("application/json; charset=utf-8") + assert decoded_type == DecodedType.JSON - def test_get_decoded_type_from_content_type_plain_text(self): + def test_get_decoded_type_from_content_type_plain_text(self, instrumentation): """Test plain text content type detection.""" - decoded_type = self.instrumentation._get_decoded_type_from_content_type("text/plain") - self.assertEqual(decoded_type, DecodedType.PLAIN_TEXT) + decoded_type = instrumentation._get_decoded_type_from_content_type("text/plain") + assert decoded_type == DecodedType.PLAIN_TEXT - def test_get_decoded_type_from_content_type_html(self): + def test_get_decoded_type_from_content_type_html(self, instrumentation): """Test HTML content type detection.""" - decoded_type = self.instrumentation._get_decoded_type_from_content_type("text/html") - self.assertEqual(decoded_type, DecodedType.HTML) + decoded_type = instrumentation._get_decoded_type_from_content_type("text/html") + assert decoded_type == DecodedType.HTML - def test_get_decoded_type_from_content_type_unknown(self): + def test_get_decoded_type_from_content_type_unknown(self, instrumentation): """Test unknown content type returns None.""" - decoded_type = self.instrumentation._get_decoded_type_from_content_type("application/unknown") - self.assertIsNone(decoded_type) + decoded_type = instrumentation._get_decoded_type_from_content_type("application/unknown") + assert decoded_type is None - def test_get_decoded_type_from_content_type_none(self): + def test_get_decoded_type_from_content_type_none(self, instrumentation): """Test None content type returns None.""" - decoded_type = self.instrumentation._get_decoded_type_from_content_type(None) - self.assertIsNone(decoded_type) + decoded_type = instrumentation._get_decoded_type_from_content_type(None) + assert decoded_type is None - def test_get_content_type_header_case_insensitive(self): + def test_get_content_type_header_case_insensitive(self, instrumentation): """Test case-insensitive content-type header lookup.""" headers = {"Content-Type": "application/json"} - content_type = self.instrumentation._get_content_type_header(headers) - self.assertEqual(content_type, "application/json") + content_type = instrumentation._get_content_type_header(headers) + assert content_type == "application/json" headers = {"content-type": "text/plain"} - content_type = self.instrumentation._get_content_type_header(headers) - self.assertEqual(content_type, "text/plain") + content_type = instrumentation._get_content_type_header(headers) + assert content_type == "text/plain" headers = {"CONTENT-TYPE": "text/html"} - content_type = self.instrumentation._get_content_type_header(headers) - self.assertEqual(content_type, "text/html") + content_type = instrumentation._get_content_type_header(headers) + assert content_type == "text/html" - def test_get_content_type_header_not_found(self): + def test_get_content_type_header_not_found(self, instrumentation): """Test missing content-type header returns None.""" headers = {"Accept": "application/json"} - content_type = self.instrumentation._get_content_type_header(headers) - self.assertIsNone(content_type) + content_type = instrumentation._get_content_type_header(headers) + assert content_type is None -class TestMockResponseDecoding(unittest.TestCase): +class TestMockResponseDecoding: """Test mock response body decoding.""" def test_create_mock_response_decodes_base64(self): """Test that base64-encoded body is properly decoded.""" instrumentation = RequestsInstrumentation() - # Create mock data with base64-encoded body original_body = '{"result": "success"}' encoded_body = base64.b64encode(original_body.encode("utf-8")).decode("ascii") @@ -129,46 +127,17 @@ def test_create_mock_response_decodes_base64(self): response = instrumentation._create_mock_response(mock_data, "http://example.com") - # Verify body is properly decoded - self.assertEqual(response.content, original_body.encode("utf-8")) - self.assertEqual(response.text, original_body) + assert response.content == original_body.encode("utf-8") + assert response.text == original_body def test_create_mock_response_fallback_to_plain_text(self): """Test fallback to plain text when base64 decode fails.""" instrumentation = RequestsInstrumentation() - # Create mock data with non-base64 plain text plain_text = "This is plain text, not base64" mock_data = {"statusCode": 200, "statusMessage": "OK", "headers": {}, "body": plain_text} response = instrumentation._create_mock_response(mock_data, "http://example.com") - # Verify body is treated as plain text - self.assertEqual(response.text, plain_text) - - -# NOTE: The following test categories were removed because they were testing -# internal implementation details with incorrect mocking patterns: -# -# - TestReplayTraceIDUsage: Tests that _try_get_mock uses replay trace ID. -# The implementation now uses find_mock_response_sync from mock_utils which -# handles this internally. E2E tests cover this functionality. -# -# - TestBodySizeInSpans: Tests bodySize in spans. This is done in _try_get_mock -# and is covered by E2E tests. -# -# - TestTransformEngineIntegration: Tests transform engine. Covered by E2E tests. -# -# - TestSchemaMergeHints: Tests schema merges. Implementation exists in -# _try_get_mock, covered by E2E tests. -# -# - TestMockRequestMetadata: Tests metadata in mock requests. Covered by E2E tests. -# -# - TestDropTransforms: Tests drop transforms. Covered by E2E tests. -# -# - TestTraceContextPropagation: Tests context propagation. Covered by E2E tests. - - -if __name__ == "__main__": - unittest.main() + assert response.text == plain_text diff --git a/tests/unit/test_resilience.py b/tests/unit/test_resilience.py index b634fbb..166a97b 100644 --- a/tests/unit/test_resilience.py +++ b/tests/unit/test_resilience.py @@ -2,7 +2,8 @@ import asyncio import time -import unittest + +import pytest from drift.core.resilience import ( CircuitBreaker, @@ -15,17 +16,17 @@ ) -class TestRetryConfig(unittest.TestCase): +class TestRetryConfig: """Tests for RetryConfig.""" def test_default_values(self): """Test default configuration values.""" config = RetryConfig() - self.assertEqual(config.max_attempts, 3) - self.assertEqual(config.initial_delay_seconds, 0.1) - self.assertEqual(config.max_delay_seconds, 10.0) - self.assertEqual(config.exponential_base, 2.0) - self.assertTrue(config.jitter) + assert config.max_attempts == 3 + assert config.initial_delay_seconds == 0.1 + assert config.max_delay_seconds == 10.0 + assert config.exponential_base == 2.0 + assert config.jitter is True def test_custom_values(self): """Test custom configuration values.""" @@ -36,21 +37,21 @@ def test_custom_values(self): exponential_base=3.0, jitter=False, ) - self.assertEqual(config.max_attempts, 5) - self.assertEqual(config.initial_delay_seconds, 0.5) - self.assertEqual(config.max_delay_seconds, 30.0) - self.assertEqual(config.exponential_base, 3.0) - self.assertFalse(config.jitter) + assert config.max_attempts == 5 + assert config.initial_delay_seconds == 0.5 + assert config.max_delay_seconds == 30.0 + assert config.exponential_base == 3.0 + assert config.jitter is False -class TestCalculateBackoffDelay(unittest.TestCase): +class TestCalculateBackoffDelay: """Tests for backoff delay calculation.""" def test_first_attempt_delay(self): """Test delay for first attempt.""" config = RetryConfig(initial_delay_seconds=0.1, jitter=False) delay = calculate_backoff_delay(1, config) - self.assertEqual(delay, 0.1) + assert delay == 0.1 def test_exponential_increase(self): """Test exponential backoff increase.""" @@ -59,12 +60,9 @@ def test_exponential_increase(self): exponential_base=2.0, jitter=False, ) - # Attempt 1: 0.1 - # Attempt 2: 0.1 * 2^1 = 0.2 - # Attempt 3: 0.1 * 2^2 = 0.4 - self.assertEqual(calculate_backoff_delay(1, config), 0.1) - self.assertEqual(calculate_backoff_delay(2, config), 0.2) - self.assertEqual(calculate_backoff_delay(3, config), 0.4) + assert calculate_backoff_delay(1, config) == 0.1 + assert calculate_backoff_delay(2, config) == 0.2 + assert calculate_backoff_delay(3, config) == 0.4 def test_respects_max_delay(self): """Test that delay is capped at max_delay_seconds.""" @@ -74,10 +72,8 @@ def test_respects_max_delay(self): exponential_base=10.0, jitter=False, ) - # Attempt 1: 1.0 - # Attempt 2: 1.0 * 10^1 = 10.0 -> capped to 5.0 - self.assertEqual(calculate_backoff_delay(1, config), 1.0) - self.assertEqual(calculate_backoff_delay(2, config), 5.0) + assert calculate_backoff_delay(1, config) == 1.0 + assert calculate_backoff_delay(2, config) == 5.0 def test_jitter_adds_randomness(self): """Test that jitter adds randomness to delay.""" @@ -86,11 +82,10 @@ def test_jitter_adds_randomness(self): jitter=True, ) delays = [calculate_backoff_delay(1, config) for _ in range(10)] - # With jitter, delays should vary (±25%) - self.assertTrue(0.75 <= min(delays) < max(delays) <= 1.25) + assert 0.75 <= min(delays) < max(delays) <= 1.25 -class TestRetryAsync(unittest.TestCase): +class TestRetryAsync: """Tests for async retry logic.""" def test_successful_operation(self): @@ -103,8 +98,8 @@ async def operation(): return "success" result = asyncio.run(retry_async(operation)) - self.assertEqual(result, "success") - self.assertEqual(call_count, 1) + assert result == "success" + assert call_count == 1 def test_retries_on_failure(self): """Test that operation is retried on failure.""" @@ -122,8 +117,8 @@ async def operation(): initial_delay_seconds=0.01, ) result = asyncio.run(retry_async(operation, config=config, operation_name="test_op")) - self.assertEqual(result, "success") - self.assertEqual(call_count, 3) + assert result == "success" + assert call_count == 3 def test_raises_after_max_attempts(self): """Test that exception is raised after max attempts.""" @@ -135,9 +130,9 @@ async def operation(): raise ValueError("Always fails") config = RetryConfig(max_attempts=3, initial_delay_seconds=0.01) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): asyncio.run(retry_async(operation, config=config)) - self.assertEqual(call_count, 3) + assert call_count == 3 def test_respects_retryable_exceptions(self): """Test that only specified exceptions trigger retry.""" @@ -149,8 +144,7 @@ async def operation(): raise TypeError("Not retryable") config = RetryConfig(max_attempts=3, initial_delay_seconds=0.01) - # Only retry ValueError, not TypeError - with self.assertRaises(TypeError): + with pytest.raises(TypeError): asyncio.run( retry_async( operation, @@ -158,35 +152,34 @@ async def operation(): retryable_exceptions=(ValueError,), ) ) - # Should only be called once since TypeError is not retryable - self.assertEqual(call_count, 1) + assert call_count == 1 -class TestCircuitBreakerConfig(unittest.TestCase): +class TestCircuitBreakerConfig: """Tests for CircuitBreakerConfig.""" def test_default_values(self): """Test default configuration values.""" config = CircuitBreakerConfig() - self.assertEqual(config.failure_threshold, 5) - self.assertEqual(config.success_threshold, 2) - self.assertEqual(config.timeout_seconds, 30.0) - self.assertEqual(config.failure_window_seconds, 60.0) + assert config.failure_threshold == 5 + assert config.success_threshold == 2 + assert config.timeout_seconds == 30.0 + assert config.failure_window_seconds == 60.0 -class TestCircuitBreaker(unittest.TestCase): +class TestCircuitBreaker: """Tests for CircuitBreaker.""" def test_initial_state_is_closed(self): """Test that circuit starts in closed state.""" cb = CircuitBreaker("test") - self.assertEqual(cb.state, CircuitState.CLOSED) - self.assertTrue(cb.is_closed) + assert cb.state == CircuitState.CLOSED + assert cb.is_closed is True def test_allows_requests_when_closed(self): """Test that requests are allowed when circuit is closed.""" cb = CircuitBreaker("test") - self.assertTrue(cb.allow_request()) + assert cb.allow_request() is True def test_opens_after_failure_threshold(self): """Test that circuit opens after reaching failure threshold.""" @@ -197,8 +190,8 @@ def test_opens_after_failure_threshold(self): cb.allow_request() cb.record_failure() - self.assertEqual(cb.state, CircuitState.OPEN) - self.assertFalse(cb.is_closed) + assert cb.state == CircuitState.OPEN + assert cb.is_closed is False def test_rejects_requests_when_open(self): """Test that requests are rejected when circuit is open.""" @@ -208,8 +201,8 @@ def test_rejects_requests_when_open(self): cb.allow_request() cb.record_failure() - self.assertFalse(cb.allow_request()) - self.assertEqual(cb.stats.rejected_requests, 1) + assert cb.allow_request() is False + assert cb.stats.rejected_requests == 1 def test_transitions_to_half_open_after_timeout(self): """Test that circuit transitions to half-open after timeout.""" @@ -221,11 +214,10 @@ def test_transitions_to_half_open_after_timeout(self): cb.allow_request() cb.record_failure() - self.assertEqual(cb.state, CircuitState.OPEN) + assert cb.state == CircuitState.OPEN time.sleep(0.15) - # Accessing state should trigger transition - self.assertEqual(cb.state, CircuitState.HALF_OPEN) + assert cb.state == CircuitState.HALF_OPEN def test_closes_after_success_threshold_in_half_open(self): """Test that circuit closes after success threshold in half-open.""" @@ -236,19 +228,16 @@ def test_closes_after_success_threshold_in_half_open(self): ) cb = CircuitBreaker("test", config) - # Open the circuit cb.allow_request() cb.record_failure() - # Wait for timeout to transition to half-open time.sleep(0.02) - self.assertEqual(cb.state, CircuitState.HALF_OPEN) + assert cb.state == CircuitState.HALF_OPEN - # Record successes cb.record_success() - self.assertEqual(cb.state, CircuitState.HALF_OPEN) + assert cb.state == CircuitState.HALF_OPEN cb.record_success() - self.assertEqual(cb.state, CircuitState.CLOSED) + assert cb.state == CircuitState.CLOSED def test_reopens_on_failure_in_half_open(self): """Test that circuit reopens on failure in half-open state.""" @@ -258,17 +247,14 @@ def test_reopens_on_failure_in_half_open(self): ) cb = CircuitBreaker("test", config) - # Open the circuit cb.allow_request() cb.record_failure() - # Wait for timeout time.sleep(0.02) - self.assertEqual(cb.state, CircuitState.HALF_OPEN) + assert cb.state == CircuitState.HALF_OPEN - # Record failure in half-open cb.record_failure() - self.assertEqual(cb.state, CircuitState.OPEN) + assert cb.state == CircuitState.OPEN def test_success_clears_failures_in_closed_state(self): """Test that success in closed state prunes old failures.""" @@ -278,26 +264,21 @@ def test_success_clears_failures_in_closed_state(self): ) cb = CircuitBreaker("test", config) - # Record some failures cb.allow_request() cb.record_failure() cb.allow_request() cb.record_failure() - # Wait for failures to age out of window time.sleep(0.15) - # Record success (should prune old failures) cb.record_success() - # Now record more failures - shouldn't open immediately cb.allow_request() cb.record_failure() cb.allow_request() cb.record_failure() - # Circuit should still be closed (only 2 recent failures) - self.assertEqual(cb.state, CircuitState.CLOSED) + assert cb.state == CircuitState.CLOSED def test_reset_closes_circuit(self): """Test that reset returns circuit to closed state.""" @@ -306,10 +287,10 @@ def test_reset_closes_circuit(self): cb.allow_request() cb.record_failure() - self.assertEqual(cb.state, CircuitState.OPEN) + assert cb.state == CircuitState.OPEN cb.reset() - self.assertEqual(cb.state, CircuitState.CLOSED) + assert cb.state == CircuitState.CLOSED def test_stats_tracking(self): """Test that statistics are tracked correctly.""" @@ -324,21 +305,17 @@ def test_stats_tracking(self): cb.record_failure() stats = cb.stats - self.assertEqual(stats.total_requests, 3) - self.assertEqual(stats.successful_requests, 1) - self.assertEqual(stats.failed_requests, 2) + assert stats.total_requests == 3 + assert stats.successful_requests == 1 + assert stats.failed_requests == 2 -class TestCircuitOpenError(unittest.TestCase): +class TestCircuitOpenError: """Tests for CircuitOpenError.""" def test_error_message(self): """Test error message format.""" error = CircuitOpenError("api_export") - self.assertEqual(error.circuit_name, "api_export") - self.assertIn("api_export", str(error)) - self.assertIn("open", str(error)) - - -if __name__ == "__main__": - unittest.main() + assert error.circuit_name == "api_export" + assert "api_export" in str(error) + assert "open" in str(error) diff --git a/tests/unit/test_span_serialization.py b/tests/unit/test_span_serialization.py index 5ecbba5..2efa3c2 100644 --- a/tests/unit/test_span_serialization.py +++ b/tests/unit/test_span_serialization.py @@ -1,4 +1,4 @@ -import unittest +"""Tests for span serialization to protobuf.""" from tusk.drift.core.v1 import JsonSchemaType as ProtoJsonSchemaType @@ -14,8 +14,11 @@ ) -class SpanSerializationTests(unittest.TestCase): +class TestSpanSerialization: + """Tests for span serialization to protobuf.""" + def test_basic_span_serializes_to_proto(self): + """Test that a basic span serializes to protobuf correctly.""" input_schema_info = JsonSchemaHelper.generate_schema_and_hash({"method": "GET", "path": "/health"}) output_schema_info = JsonSchemaHelper.generate_schema_and_hash({"status_code": 200}) @@ -45,19 +48,18 @@ def test_basic_span_serializes_to_proto(self): proto = span.to_proto() - self.assertEqual(proto.trace_id, span.trace_id) - # proto.package_type and proto.kind are ints in protobuf + assert proto.trace_id == span.trace_id assert span.package_type is not None - self.assertEqual(proto.package_type, span.package_type.value) - self.assertEqual(proto.kind, span.kind.value) - self.assertEqual(proto.status.code, StatusCode.OK.value) - # input_value and output_value are protobuf Struct objects - self.assertEqual(proto.input_value.fields["method"].string_value, "GET") - self.assertEqual(proto.output_value.fields["status_code"].number_value, 200) - self.assertEqual(proto.timestamp.year, 2023) - self.assertEqual(proto.duration.total_seconds(), 0.000001) + assert proto.package_type == span.package_type.value + assert proto.kind == span.kind.value + assert proto.status.code == StatusCode.OK.value + assert proto.input_value.fields["method"].string_value == "GET" + assert proto.output_value.fields["status_code"].number_value == 200 + assert proto.timestamp.year == 2023 + assert proto.duration.total_seconds() == 0.000001 def test_schema_serialization_matches_proto_enums(self): + """Test that schema serialization matches protobuf enums.""" schema = JsonSchema( type=JsonSchemaType.OBJECT, properties={ @@ -90,10 +92,6 @@ def test_schema_serialization_matches_proto_enums(self): proto = span.to_proto() - self.assertEqual(proto.input_schema.type, ProtoJsonSchemaType.OBJECT) - self.assertEqual(proto.input_schema.properties["name"].type, ProtoJsonSchemaType.STRING) - self.assertEqual(proto.input_schema.properties["count"].type, ProtoJsonSchemaType.NUMBER) - - -if __name__ == "__main__": - unittest.main() + assert proto.input_schema.type == ProtoJsonSchemaType.OBJECT + assert proto.input_schema.properties["name"].type == ProtoJsonSchemaType.STRING + assert proto.input_schema.properties["count"].type == ProtoJsonSchemaType.NUMBER diff --git a/tests/unit/test_transform_engine.py b/tests/unit/test_transform_engine.py index 50423ff..3463c97 100644 --- a/tests/unit/test_transform_engine.py +++ b/tests/unit/test_transform_engine.py @@ -1,7 +1,8 @@ +"""Tests for HTTP transform engine.""" + import base64 import json import sys -import unittest from pathlib import Path from typing import Any @@ -12,8 +13,11 @@ from drift.instrumentation.http import transform_engine as te -class HttpTransformEngineTests(unittest.TestCase): +class TestHttpTransformEngine: + """Tests for HttpTransformEngine.""" + def test_should_drop_inbound_request_and_sanitize_span(self) -> None: + """Test dropping inbound requests and sanitizing the span.""" engine = HttpTransformEngine( [ { @@ -27,7 +31,7 @@ def test_should_drop_inbound_request_and_sanitize_span(self) -> None: ] ) - self.assertTrue(engine.should_drop_inbound_request("GET", "/private/123", {"Host": "example.com"})) + assert engine.should_drop_inbound_request("GET", "/private/123", {"Host": "example.com"}) span = HttpSpanData( kind=SpanKind.SERVER, @@ -42,15 +46,15 @@ def test_should_drop_inbound_request_and_sanitize_span(self) -> None: ) metadata = engine.apply_transforms(span) - self.assertIsNotNone(metadata) assert metadata is not None assert span.input_value is not None assert span.output_value is not None - self.assertEqual(metadata.actions[0].type, "drop") - self.assertEqual(span.input_value["bodySize"], 0) - self.assertEqual(span.output_value["bodySize"], 0) + assert metadata.actions[0].type == "drop" + assert span.input_value["bodySize"] == 0 + assert span.output_value["bodySize"] == 0 def test_jsonpath_mask_transform_updates_body_and_metadata(self) -> None: + """Test JSONPath mask transform updates body and metadata.""" engine = HttpTransformEngine( [ { @@ -77,18 +81,19 @@ def test_jsonpath_mask_transform_updates_body_and_metadata(self) -> None: ) metadata = engine.apply_transforms(span) - self.assertIsNotNone(metadata) assert metadata is not None assert span.input_value is not None - self.assertTrue(metadata.transformed) - self.assertTrue(metadata.actions[0].field.startswith("jsonPath")) + assert metadata.transformed + assert metadata.actions[0].field.startswith("jsonPath") masked_body = json.loads(base64.b64decode(span.input_value["body"].encode("ascii"))) - self.assertEqual(masked_body["password"], "#" * len("hunter2")) + assert masked_body["password"] == "#" * len("hunter2") expected_size = len(json.dumps(masked_body, separators=(",", ":")).encode("utf-8")) - self.assertEqual(span.input_value["bodySize"], expected_size) + assert span.input_value["bodySize"] == expected_size def test_python_jsonpath_stub_is_used_when_available(self) -> None: + """Test that Python JSONPath stub is used when available.""" + class FakeJSONPath: def __init__(self, expression: str) -> None: self.expression = expression @@ -132,11 +137,7 @@ def find(self, data: Any) -> list[dict[str, Any]]: ) metadata = engine.apply_transforms(span) - self.assertIsNotNone(metadata) + assert metadata is not None assert span.input_value is not None masked_body = json.loads(base64.b64decode(span.input_value["body"].encode("ascii"))) - self.assertEqual(masked_body["password"], "redacted") - - -if __name__ == "__main__": - unittest.main() + assert masked_body["password"] == "redacted" diff --git a/tests/unit/test_wsgi_utilities.py b/tests/unit/test_wsgi_utilities.py index 294525e..fc3c872 100644 --- a/tests/unit/test_wsgi_utilities.py +++ b/tests/unit/test_wsgi_utilities.py @@ -1,7 +1,7 @@ """Unit tests for WSGI utilities.""" import base64 -import unittest +from io import BytesIO from drift.instrumentation.wsgi import ( build_input_schema_merges, @@ -15,7 +15,7 @@ ) -class TestBuildUrl(unittest.TestCase): +class TestBuildUrl: """Test build_url function.""" def test_builds_basic_url(self): @@ -27,7 +27,7 @@ def test_builds_basic_url(self): "QUERY_STRING": "", } url = build_url(environ) - self.assertEqual(url, "http://example.com/users") + assert url == "http://example.com/users" def test_builds_url_with_query_string(self): """Test URL building with query string.""" @@ -38,7 +38,7 @@ def test_builds_url_with_query_string(self): "QUERY_STRING": "q=test&limit=10", } url = build_url(environ) - self.assertEqual(url, "https://example.com/search?q=test&limit=10") + assert url == "https://example.com/search?q=test&limit=10" def test_uses_server_name_fallback(self): """Test fallback to SERVER_NAME when HTTP_HOST is missing.""" @@ -49,7 +49,7 @@ def test_uses_server_name_fallback(self): "QUERY_STRING": "", } url = build_url(environ) - self.assertEqual(url, "http://localhost/") + assert url == "http://localhost/" def test_defaults_to_http(self): """Test default scheme is http.""" @@ -58,10 +58,10 @@ def test_defaults_to_http(self): "PATH_INFO": "/test", } url = build_url(environ) - self.assertTrue(url.startswith("http://")) + assert url.startswith("http://") -class TestExtractHeaders(unittest.TestCase): +class TestExtractHeaders: """Test extract_headers function.""" def test_extracts_http_headers(self): @@ -70,27 +70,25 @@ def test_extracts_http_headers(self): "HTTP_CONTENT_TYPE": "application/json", "HTTP_AUTHORIZATION": "Bearer token123", "HTTP_X_CUSTOM_HEADER": "custom-value", - "REQUEST_METHOD": "GET", # Should be ignored + "REQUEST_METHOD": "GET", } headers = extract_headers(environ) - self.assertEqual(headers["Content-Type"], "application/json") - self.assertEqual(headers["Authorization"], "Bearer token123") - self.assertEqual(headers["X-Custom-Header"], "custom-value") - self.assertNotIn("Request-Method", headers) + assert headers["Content-Type"] == "application/json" + assert headers["Authorization"] == "Bearer token123" + assert headers["X-Custom-Header"] == "custom-value" + assert "Request-Method" not in headers def test_handles_empty_environ(self): """Test with empty environ.""" headers = extract_headers({}) - self.assertEqual(headers, {}) + assert headers == {} -class TestCaptureRequestBody(unittest.TestCase): +class TestCaptureRequestBody: """Test capture_request_body function.""" def test_captures_post_body(self): """Test capturing POST request body.""" - from io import BytesIO - body_content = b'{"key": "value"}' environ = { "REQUEST_METHOD": "POST", @@ -98,20 +96,15 @@ def test_captures_post_body(self): "wsgi.input": BytesIO(body_content), } body = capture_request_body(environ) - self.assertEqual(body, body_content) - - # Verify input was reset - from io import BytesIO + assert body == body_content wsgi_input = environ["wsgi.input"] assert isinstance(wsgi_input, BytesIO) new_body = wsgi_input.read() - self.assertEqual(new_body, body_content) + assert new_body == body_content def test_captures_large_body(self): """Test capturing large body (no truncation at capture time).""" - from io import BytesIO - body_content = b"x" * 15000 environ = { "REQUEST_METHOD": "POST", @@ -119,9 +112,8 @@ def test_captures_large_body(self): "wsgi.input": BytesIO(body_content), } body = capture_request_body(environ) - # No truncation - span-level blocking handles oversized spans assert body is not None - self.assertEqual(len(body), 15000) + assert len(body) == 15000 def test_ignores_get_requests(self): """Test that GET requests are ignored.""" @@ -129,44 +121,42 @@ def test_ignores_get_requests(self): "REQUEST_METHOD": "GET", } body = capture_request_body(environ) - self.assertIsNone(body) + assert body is None def test_handles_empty_body(self): """Test handling of empty body.""" - from io import BytesIO - environ = { "REQUEST_METHOD": "POST", "CONTENT_LENGTH": "0", "wsgi.input": BytesIO(b""), } body = capture_request_body(environ) - self.assertIsNone(body) + assert body is None -class TestParseStatusLine(unittest.TestCase): +class TestParseStatusLine: """Test parse_status_line function.""" def test_parses_standard_status(self): """Test parsing standard status line.""" code, message = parse_status_line("200 OK") - self.assertEqual(code, 200) - self.assertEqual(message, "OK") + assert code == 200 + assert message == "OK" def test_parses_status_with_long_message(self): """Test parsing status with multi-word message.""" code, message = parse_status_line("404 Not Found") - self.assertEqual(code, 404) - self.assertEqual(message, "Not Found") + assert code == 404 + assert message == "Not Found" def test_handles_status_without_message(self): """Test parsing status without message.""" code, message = parse_status_line("500") - self.assertEqual(code, 500) - self.assertEqual(message, "") + assert code == 500 + assert message == "" -class TestBuildInputValue(unittest.TestCase): +class TestBuildInputValue: """Test build_input_value function.""" def test_builds_basic_input_value(self): @@ -181,11 +171,11 @@ def test_builds_basic_input_value(self): "REMOTE_ADDR": "192.168.1.1", } input_value = build_input_value(environ) - self.assertEqual(input_value["method"], "GET") - self.assertEqual(input_value["url"], "http://example.com/users") - self.assertEqual(input_value["target"], "/users") - self.assertEqual(input_value["httpVersion"], "1.1") - self.assertEqual(input_value["remoteAddress"], "192.168.1.1") + assert input_value["method"] == "GET" + assert input_value["url"] == "http://example.com/users" + assert input_value["target"] == "/users" + assert input_value["httpVersion"] == "1.1" + assert input_value["remoteAddress"] == "192.168.1.1" def test_includes_body_when_present(self): """Test including body in input value.""" @@ -199,36 +189,36 @@ def test_includes_body_when_present(self): } body = b'{"key": "value"}' input_value = build_input_value(environ, body=body) - self.assertIn("body", input_value) - self.assertEqual(input_value["body"], base64.b64encode(body).decode("ascii")) - self.assertEqual(input_value["bodySize"], len(body)) + assert "body" in input_value + assert input_value["body"] == base64.b64encode(body).decode("ascii") + assert input_value["bodySize"] == len(body) -class TestBuildOutputValue(unittest.TestCase): +class TestBuildOutputValue: """Test build_output_value function.""" def test_builds_basic_output_value(self): """Test building basic output value.""" output_value = build_output_value(200, "OK", {"Content-Type": "application/json"}) - self.assertEqual(output_value["statusCode"], 200) - self.assertEqual(output_value["statusMessage"], "OK") - self.assertEqual(output_value["headers"]["Content-Type"], "application/json") + assert output_value["statusCode"] == 200 + assert output_value["statusMessage"] == "OK" + assert output_value["headers"]["Content-Type"] == "application/json" def test_includes_body_when_present(self): """Test including body in output value.""" body = b'{"result": "success"}' output_value = build_output_value(200, "OK", {}, body=body) - self.assertIn("body", output_value) - self.assertEqual(output_value["body"], base64.b64encode(body).decode("ascii")) - self.assertEqual(output_value["bodySize"], len(body)) + assert "body" in output_value + assert output_value["body"] == base64.b64encode(body).decode("ascii") + assert output_value["bodySize"] == len(body) def test_includes_error_when_present(self): """Test including error in output value.""" output_value = build_output_value(500, "Internal Server Error", {}, error="Database connection failed") - self.assertEqual(output_value["errorMessage"], "Database connection failed") + assert output_value["errorMessage"] == "Database connection failed" -class TestBuildSchemaMerges(unittest.TestCase): +class TestBuildSchemaMerges: """Test schema merge builder functions.""" def test_builds_input_schema_merges(self): @@ -240,12 +230,9 @@ def test_builds_input_schema_merges(self): } schema_merges = build_input_schema_merges(input_value) - # Should have headers merge - self.assertIn("headers", schema_merges) - self.assertEqual(schema_merges["headers"]["match_importance"], 0.0) - - # Should not have body merge (no body present) - self.assertNotIn("body", schema_merges) + assert "headers" in schema_merges + assert schema_merges["headers"]["match_importance"] == 0.0 + assert "body" not in schema_merges def test_builds_input_schema_merges_with_body(self): """Test input schema merge building with body.""" @@ -256,9 +243,8 @@ def test_builds_input_schema_merges_with_body(self): } schema_merges = build_input_schema_merges(input_value) - # Should have body merge with BASE64 encoding - self.assertIn("body", schema_merges) - self.assertEqual(schema_merges["body"]["encoding"], 1) # BASE64 = 1 + assert "body" in schema_merges + assert schema_merges["body"]["encoding"] == 1 # BASE64 = 1 def test_builds_output_schema_merges(self): """Test output schema merge building.""" @@ -269,9 +255,8 @@ def test_builds_output_schema_merges(self): } schema_merges = build_output_schema_merges(output_value) - # Should have headers merge - self.assertIn("headers", schema_merges) - self.assertEqual(schema_merges["headers"]["match_importance"], 0.0) + assert "headers" in schema_merges + assert schema_merges["headers"]["match_importance"] == 0.0 def test_builds_output_schema_merges_with_body(self): """Test output schema merge building with body.""" @@ -282,10 +267,5 @@ def test_builds_output_schema_merges_with_body(self): } schema_merges = build_output_schema_merges(output_value) - # Should have body merge with BASE64 encoding - self.assertIn("body", schema_merges) - self.assertEqual(schema_merges["body"]["encoding"], 1) # BASE64 = 1 - - -if __name__ == "__main__": - unittest.main() + assert "body" in schema_merges + assert schema_merges["body"]["encoding"] == 1 # BASE64 = 1 diff --git a/uv.lock b/uv.lock index 6a38613..1bf391b 100644 --- a/uv.lock +++ b/uv.lock @@ -1552,17 +1552,15 @@ wheels = [ name = "pytest" version = "8.4.2" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.10'", -] dependencies = [ - { name = "colorama", marker = "python_full_version < '3.10' and sys_platform == 'win32'" }, - { name = "exceptiongroup", marker = "python_full_version < '3.10'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "iniconfig", version = "2.1.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, - { name = "packaging", marker = "python_full_version < '3.10'" }, - { name = "pluggy", marker = "python_full_version < '3.10'" }, - { name = "pygments", marker = "python_full_version < '3.10'" }, - { name = "tomli", marker = "python_full_version < '3.10'" }, + { name = "iniconfig", version = "2.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a3/5c/00a0e072241553e1a7496d638deababa67c5058571567b92a7eaa258397c/pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01", size = 1519618 } wheels = [ @@ -1570,25 +1568,15 @@ wheels = [ ] [[package]] -name = "pytest" -version = "9.0.2" +name = "pytest-mock" +version = "3.15.1" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.12'", - "python_full_version >= '3.10' and python_full_version < '3.12'", -] dependencies = [ - { name = "colorama", marker = "python_full_version >= '3.10' and sys_platform == 'win32'" }, - { name = "exceptiongroup", marker = "python_full_version == '3.10.*'" }, - { name = "iniconfig", version = "2.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, - { name = "packaging", marker = "python_full_version >= '3.10'" }, - { name = "pluggy", marker = "python_full_version >= '3.10'" }, - { name = "pygments", marker = "python_full_version >= '3.10'" }, - { name = "tomli", marker = "python_full_version == '3.10.*'" }, + { name = "pytest" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901 } +sdist = { url = "https://files.pythonhosted.org/packages/68/14/eb014d26be205d38ad5ad20d9a80f7d201472e08167f0bb4361e251084a9/pytest_mock-3.15.1.tar.gz", hash = "sha256:1849a238f6f396da19762269de72cb1814ab44416fa73a8686deac10b0d87a0f", size = 34036 } wheels = [ - { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801 }, + { url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095 }, ] [[package]] @@ -2013,7 +2001,7 @@ wheels = [ [[package]] name = "tusk-drift-python-sdk" -version = "0.1.10" +version = "0.1.11" source = { editable = "." } dependencies = [ { name = "aiofiles" }, @@ -2033,8 +2021,8 @@ dependencies = [ dev = [ { name = "fastapi" }, { name = "flask" }, - { name = "pytest", version = "8.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, - { name = "pytest", version = "9.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "pytest" }, + { name = "pytest-mock" }, { name = "python-jsonpath" }, { name = "ruff" }, { name = "ty" }, @@ -2068,7 +2056,8 @@ requires-dist = [ { name = "opentelemetry-api", specifier = ">=1.20.0" }, { name = "opentelemetry-sdk", specifier = ">=1.20.0" }, { name = "protobuf", specifier = ">=6.0" }, - { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" }, + { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0,<9.0.0" }, + { name = "pytest-mock", marker = "extra == 'dev'", specifier = ">=3.15.0" }, { name = "python-jsonpath", marker = "extra == 'dev'", specifier = ">=0.10" }, { name = "pyyaml", specifier = ">=6.0" }, { name = "requests", specifier = ">=2.32.5" },