From b73282c5692c1daff74f86fe87e3b3b04b59fcd2 Mon Sep 17 00:00:00 2001 From: JY Tan Date: Wed, 14 Jan 2026 15:03:05 -0800 Subject: [PATCH 1/2] Commit --- drift/__init__.py | 9 + drift/core/batch_processor.py | 34 +++- drift/core/metrics.py | 305 +++++++++++++++++++++++++++++ drift/core/resilience.py | 275 ++++++++++++++++++++++++++ drift/core/tracing/adapters/api.py | 223 +++++++++++++++++---- drift/instrumentation/__init__.py | 3 +- drift/instrumentation/registry.py | 14 +- 7 files changed, 806 insertions(+), 57 deletions(-) create mode 100644 drift/core/metrics.py create mode 100644 drift/core/resilience.py diff --git a/drift/__init__.py b/drift/__init__.py index 834baae..31c7ca6 100644 --- a/drift/__init__.py +++ b/drift/__init__.py @@ -19,6 +19,8 @@ load_tusk_config, ) from .core.logger import LogLevel, get_log_level, set_log_level +from .core.metrics import SDKMetrics, get_sdk_metrics +from .core.resilience import CircuitBreaker, CircuitBreakerConfig, RetryConfig from .core.tracing.adapters import ( ApiSpanAdapter, ApiSpanAdapterConfig, @@ -59,6 +61,13 @@ "LogLevel", "set_log_level", "get_log_level", + # Metrics + "SDKMetrics", + "get_sdk_metrics", + # Resilience + "RetryConfig", + "CircuitBreaker", + "CircuitBreakerConfig", # Instrumentations "FlaskInstrumentation", "FastAPIInstrumentation", diff --git a/drift/core/batch_processor.py b/drift/core/batch_processor.py index e82be75..3ebd01b 100644 --- a/drift/core/batch_processor.py +++ b/drift/core/batch_processor.py @@ -5,10 +5,13 @@ import asyncio import logging import threading +import time from collections import deque from dataclasses import dataclass from typing import TYPE_CHECKING +from .metrics import get_metrics_collector + if TYPE_CHECKING: from .tracing.span_exporter import TdSpanExporter from .types import CleanSpanData @@ -63,6 +66,10 @@ def __init__( self._started = False self._dropped_spans = 0 + # Set up metrics + self._metrics = get_metrics_collector() + self._metrics.set_queue_max_size(self._config.max_queue_size) + def start(self) -> None: """Start the background export thread.""" if self._started: @@ -110,11 +117,28 @@ def add_span(self, span: CleanSpanData) -> bool: span: The span to add Returns: - True if span was added, False if queue is full and span was dropped + True if span was added, False if queue is full or trace is blocked """ + # Check if span should be blocked (size limit or server error) + # Blocks entire trace + from .trace_blocking_manager import TraceBlockingManager, should_block_span + + if should_block_span(span): + self._dropped_spans += 1 + self._metrics.record_spans_dropped() + return False + + # Check if trace is already blocked + if TraceBlockingManager.get_instance().is_trace_blocked(span.trace_id): + logger.debug(f"Skipping span '{span.name}' - trace {span.trace_id} is blocked") + self._dropped_spans += 1 + self._metrics.record_spans_dropped() + return False + with self._condition: if len(self._queue) >= self._config.max_queue_size: self._dropped_spans += 1 + self._metrics.record_spans_dropped() logger.warning( f"Span queue full ({self._config.max_queue_size}), dropping span. " f"Total dropped: {self._dropped_spans}" @@ -122,6 +146,7 @@ def add_span(self, span: CleanSpanData) -> bool: return False self._queue.append(span) + self._metrics.update_queue_size(len(self._queue)) # Trigger immediate export if batch size reached if len(self._queue) >= self._config.max_export_batch_size: @@ -149,6 +174,7 @@ def _export_batch(self) -> None: with self._condition: while self._queue and len(batch) < self._config.max_export_batch_size: batch.append(self._queue.popleft()) + self._metrics.update_queue_size(len(self._queue)) if not batch: return @@ -158,6 +184,7 @@ def _export_batch(self) -> None: # Export to all adapters for adapter in adapters: + start_time = time.monotonic() try: # Handle async adapters (create new event loop for this thread) if asyncio.iscoroutinefunction(adapter.export_spans): @@ -170,8 +197,13 @@ def _export_batch(self) -> None: else: adapter.export_spans(batch) # type: ignore + latency_ms = (time.monotonic() - start_time) * 1000 + self._metrics.record_spans_exported(len(batch)) + self._metrics.record_export_latency(latency_ms) + except Exception as e: logger.error(f"Failed to export batch via {adapter.name}: {e}") + self._metrics.record_spans_failed(len(batch)) def _force_flush(self) -> None: """Force export all remaining spans in the queue.""" diff --git a/drift/core/metrics.py b/drift/core/metrics.py new file mode 100644 index 0000000..5581a04 --- /dev/null +++ b/drift/core/metrics.py @@ -0,0 +1,305 @@ +"""SDK self-metrics for observability. + +Provides metrics about the SDK's internal state and performance. +Uses event-driven logging at WARN level for anomalies. +""" + +from __future__ import annotations + +import logging +import threading +import time +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .resilience import CircuitBreaker + +logger = logging.getLogger(__name__) + +DROP_RATE_WARN_THRESHOLD = 0.05 +QUEUE_CAPACITY_WARN_THRESHOLD = 0.80 +FAILURE_RATE_WARN_THRESHOLD = 0.10 +MIN_SAMPLES_FOR_RATE_WARNING = 100 + + +@dataclass +class ExportMetrics: + """Metrics for span export operations.""" + + spans_exported: int = 0 + spans_dropped: int = 0 + spans_failed: int = 0 + batches_exported: int = 0 + batches_failed: int = 0 + bytes_sent: int = 0 + bytes_compressed_saved: int = 0 + export_latency_sum_ms: float = 0.0 + export_count: int = 0 + + @property + def average_export_latency_ms(self) -> float: + """Average export latency in milliseconds.""" + if self.export_count == 0: + return 0.0 + return self.export_latency_sum_ms / self.export_count + + +@dataclass +class QueueMetrics: + """Metrics for the span queue.""" + + current_size: int = 0 + max_size: int = 0 + peak_size: int = 0 # Highest size observed + + +@dataclass +class CircuitBreakerMetrics: + """Metrics for circuit breaker state.""" + + state: str = "closed" + total_requests: int = 0 + rejected_requests: int = 0 + state_transitions: int = 0 + + +@dataclass +class SDKMetrics: + """Aggregated SDK metrics.""" + + export: ExportMetrics = field(default_factory=ExportMetrics) + queue: QueueMetrics = field(default_factory=QueueMetrics) + circuit_breaker: CircuitBreakerMetrics = field(default_factory=CircuitBreakerMetrics) + uptime_seconds: float = 0.0 + instrumentations_active: int = 0 + + +class MetricsCollector: + """Collects and aggregates SDK metrics. + + Thread-safe metrics collection for SDK observability. + Uses event-driven WARN logging when anomalies are detected. + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._start_time = time.monotonic() + + # Export metrics + self._spans_exported = 0 + self._spans_dropped = 0 + self._spans_failed = 0 + self._batches_exported = 0 + self._batches_failed = 0 + self._bytes_sent = 0 + self._bytes_compressed_saved = 0 + self._export_latency_sum_ms = 0.0 + self._export_count = 0 + + # Queue metrics + self._queue_size = 0 + self._queue_max_size = 0 + self._queue_peak_size = 0 + + # Circuit breaker reference (set externally) + self._circuit_breaker: CircuitBreaker | None = None + + # Instrumentation count + self._instrumentations_active = 0 + + # Warning state tracking (to avoid log spam) + self._warned_high_drop_rate = False + self._warned_high_failure_rate = False + self._warned_queue_capacity = False + self._warned_circuit_open = False + + def set_circuit_breaker(self, cb: CircuitBreaker) -> None: + """Set the circuit breaker reference for metrics.""" + self._circuit_breaker = cb + + def _check_and_warn_drop_rate(self) -> None: + """Check drop rate and warn if threshold exceeded.""" + total = self._spans_exported + self._spans_dropped + if total < MIN_SAMPLES_FOR_RATE_WARNING: + return + + drop_rate = self._spans_dropped / total + if drop_rate > DROP_RATE_WARN_THRESHOLD and not self._warned_high_drop_rate: + logger.warning(f"Span export high drop rate: {drop_rate:.1%} ({self._spans_dropped}/{total} spans dropped)") + self._warned_high_drop_rate = True + elif drop_rate <= DROP_RATE_WARN_THRESHOLD and self._warned_high_drop_rate: + # Reset warning flag when rate recovers + self._warned_high_drop_rate = False + + def _check_and_warn_failure_rate(self) -> None: + """Check failure rate and warn if threshold exceeded.""" + total = self._spans_exported + self._spans_failed + if total < MIN_SAMPLES_FOR_RATE_WARNING: + return + + failure_rate = self._spans_failed / total + if failure_rate > FAILURE_RATE_WARN_THRESHOLD and not self._warned_high_failure_rate: + logger.warning( + f"Span export high failure rate: {failure_rate:.1%} ({self._spans_failed}/{total} spans failed)" + ) + self._warned_high_failure_rate = True + elif failure_rate <= FAILURE_RATE_WARN_THRESHOLD and self._warned_high_failure_rate: + self._warned_high_failure_rate = False + + def _check_and_warn_queue_capacity(self) -> None: + """Check queue capacity and warn if threshold exceeded.""" + if self._queue_max_size == 0: + return + + capacity = self._queue_size / self._queue_max_size + if capacity > QUEUE_CAPACITY_WARN_THRESHOLD and not self._warned_queue_capacity: + logger.warning( + f"Span export queue nearing capacity: {capacity:.0%} full ({self._queue_size}/{self._queue_max_size})" + ) + self._warned_queue_capacity = True + elif capacity <= QUEUE_CAPACITY_WARN_THRESHOLD and self._warned_queue_capacity: + self._warned_queue_capacity = False + + def warn_circuit_breaker_open(self) -> None: + """Log warning when circuit breaker opens (called externally).""" + if not self._warned_circuit_open: + logger.warning("Span export circuit breaker open: requests temporarily disabled") + self._warned_circuit_open = True + + def notify_circuit_breaker_closed(self) -> None: + """Reset circuit breaker warning state when it closes.""" + if self._warned_circuit_open: + logger.info("Span export circuit breaker closed: requests resumed") + self._warned_circuit_open = False + + def set_queue_max_size(self, max_size: int) -> None: + """Set the maximum queue size.""" + self._queue_max_size = max_size + + def record_spans_exported(self, count: int) -> None: + """Record successfully exported spans.""" + with self._lock: + self._spans_exported += count + self._batches_exported += 1 + # Check if drop/failure rates have recovered + self._check_and_warn_drop_rate() + self._check_and_warn_failure_rate() + + def record_spans_dropped(self, count: int = 1) -> None: + """Record dropped spans (queue full or blocked).""" + with self._lock: + self._spans_dropped += count + self._check_and_warn_drop_rate() + + def record_spans_failed(self, count: int) -> None: + """Record failed span exports.""" + with self._lock: + self._spans_failed += count + self._batches_failed += 1 + self._check_and_warn_failure_rate() + + def record_export_latency(self, latency_ms: float) -> None: + """Record export operation latency.""" + with self._lock: + self._export_latency_sum_ms += latency_ms + self._export_count += 1 + + def record_bytes_sent(self, bytes_sent: int, bytes_saved: int = 0) -> None: + """Record bytes sent and compression savings.""" + with self._lock: + self._bytes_sent += bytes_sent + self._bytes_compressed_saved += bytes_saved + + def update_queue_size(self, size: int) -> None: + """Update current queue size.""" + with self._lock: + self._queue_size = size + if size > self._queue_peak_size: + self._queue_peak_size = size + self._check_and_warn_queue_capacity() + + def record_instrumentation_activated(self) -> None: + """Record an instrumentation being activated.""" + with self._lock: + self._instrumentations_active += 1 + + def record_instrumentation_deactivated(self) -> None: + """Record an instrumentation being deactivated.""" + with self._lock: + self._instrumentations_active = max(0, self._instrumentations_active - 1) + + def get_metrics(self) -> SDKMetrics: + """Get current SDK metrics snapshot.""" + with self._lock: + export_metrics = ExportMetrics( + spans_exported=self._spans_exported, + spans_dropped=self._spans_dropped, + spans_failed=self._spans_failed, + batches_exported=self._batches_exported, + batches_failed=self._batches_failed, + bytes_sent=self._bytes_sent, + bytes_compressed_saved=self._bytes_compressed_saved, + export_latency_sum_ms=self._export_latency_sum_ms, + export_count=self._export_count, + ) + + queue_metrics = QueueMetrics( + current_size=self._queue_size, + max_size=self._queue_max_size, + peak_size=self._queue_peak_size, + ) + + cb_metrics = CircuitBreakerMetrics() + if self._circuit_breaker: + cb_metrics = CircuitBreakerMetrics( + state=self._circuit_breaker.state.value, + total_requests=self._circuit_breaker.stats.total_requests, + rejected_requests=self._circuit_breaker.stats.rejected_requests, + state_transitions=self._circuit_breaker.stats.state_transitions, + ) + + return SDKMetrics( + export=export_metrics, + queue=queue_metrics, + circuit_breaker=cb_metrics, + uptime_seconds=time.monotonic() - self._start_time, + instrumentations_active=self._instrumentations_active, + ) + + def reset(self) -> None: + """Reset all metrics to initial state.""" + with self._lock: + self._spans_exported = 0 + self._spans_dropped = 0 + self._spans_failed = 0 + self._batches_exported = 0 + self._batches_failed = 0 + self._bytes_sent = 0 + self._bytes_compressed_saved = 0 + self._export_latency_sum_ms = 0.0 + self._export_count = 0 + self._queue_peak_size = 0 + self._start_time = time.monotonic() + + +# Global metrics collector instance +_metrics_collector: MetricsCollector | None = None +_metrics_lock = threading.Lock() + + +def get_metrics_collector() -> MetricsCollector: + """Get or create the global metrics collector.""" + global _metrics_collector + with _metrics_lock: + if _metrics_collector is None: + _metrics_collector = MetricsCollector() + return _metrics_collector + + +def get_sdk_metrics() -> SDKMetrics: + """Get current SDK metrics snapshot. + + Convenience function to get metrics without accessing the collector directly. + """ + return get_metrics_collector().get_metrics() diff --git a/drift/core/resilience.py b/drift/core/resilience.py new file mode 100644 index 0000000..428e753 --- /dev/null +++ b/drift/core/resilience.py @@ -0,0 +1,275 @@ +"""Resilience patterns for reliable span export. + +Provides retry with exponential backoff and circuit breaker patterns. +""" + +from __future__ import annotations + +import asyncio +import logging +import random +import threading +import time +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from enum import Enum + +logger = logging.getLogger(__name__) + + +@dataclass +class RetryConfig: + """Configuration for retry behavior.""" + + max_attempts: int = 3 + initial_delay_seconds: float = 0.1 + max_delay_seconds: float = 10.0 + exponential_base: float = 2.0 + jitter: bool = True # Prevent thundering herd + + +def calculate_backoff_delay( + attempt: int, + config: RetryConfig, +) -> float: + """Calculate delay for a given retry attempt using exponential backoff. + + Args: + attempt: Current attempt number (1-based) + config: Retry configuration + + Returns: + Delay in seconds before next retry + """ + delay = config.initial_delay_seconds * (config.exponential_base ** (attempt - 1)) + delay = min(delay, config.max_delay_seconds) + + if config.jitter: + jitter_factor = 0.75 + random.random() * 0.5 + delay *= jitter_factor + + return delay + + +async def retry_async[T]( + operation: Callable[[], Awaitable[T]], + config: RetryConfig | None = None, + retryable_exceptions: tuple[type[Exception], ...] = (Exception,), + operation_name: str = "operation", +) -> T: + """Execute an async operation with retry and exponential backoff. + + Args: + operation: Async callable to execute + config: Retry configuration (uses defaults if None) + retryable_exceptions: Tuple of exception types that trigger retry + operation_name: Name for logging purposes + + Returns: + Result of the operation + + Raises: + The last exception if all retries are exhausted + """ + config = config or RetryConfig() + last_exception: Exception | None = None + + for attempt in range(1, config.max_attempts + 1): + try: + return await operation() + except retryable_exceptions as e: + last_exception = e + + if attempt == config.max_attempts: + logger.warning(f"{operation_name} failed after {config.max_attempts} attempts: {e}") + raise + + delay = calculate_backoff_delay(attempt, config) + logger.debug(f"{operation_name} attempt {attempt} failed: {e}. Retrying in {delay:.2f}s...") + await asyncio.sleep(delay) + + # Should never reach here, but satisfy type checker + if last_exception: + raise last_exception + raise RuntimeError("Unexpected retry loop exit") + + +class CircuitState(Enum): + """Circuit breaker states.""" + + CLOSED = "closed" # Normal operation, requests pass through + OPEN = "open" # Circuit tripped, requests fail fast + HALF_OPEN = "half_open" # Testing if service recovered + + +@dataclass +class CircuitBreakerConfig: + """Configuration for circuit breaker behavior.""" + + failure_threshold: int = 5 # Failures before opening circuit + success_threshold: int = 2 # Successes in half-open to close circuit + timeout_seconds: float = 30.0 # Time before transitioning open -> half-open + # Count failures in this time window (0 = no window, count all) + failure_window_seconds: float = 60.0 + + +@dataclass +class CircuitBreakerStats: + """Statistics for circuit breaker.""" + + total_requests: int = 0 + successful_requests: int = 0 + failed_requests: int = 0 + rejected_requests: int = 0 # Requests rejected due to open circuit + state_transitions: int = 0 + + +class CircuitBreaker: + """Circuit breaker for protecting against cascading failures. + + States: + - CLOSED: Normal operation. Failures are counted. + - OPEN: Circuit tripped after too many failures. Requests fail fast. + - HALF_OPEN: After timeout, allow limited requests to test recovery. + """ + + def __init__( + self, + name: str, + config: CircuitBreakerConfig | None = None, + ) -> None: + self._name = name + self._config = config or CircuitBreakerConfig() + self._state = CircuitState.CLOSED + self._lock = threading.Lock() + + # Failure tracking + self._failure_timestamps: list[float] = [] + self._consecutive_successes = 0 + + # State tracking + self._last_failure_time: float = 0 + self._last_state_change_time: float = time.monotonic() + + # Statistics + self._stats = CircuitBreakerStats() + + @property + def state(self) -> CircuitState: + """Get current circuit state (may trigger state transition).""" + with self._lock: + self._check_state_transition() + return self._state + + @property + def stats(self) -> CircuitBreakerStats: + """Get circuit breaker statistics.""" + return self._stats + + @property + def is_closed(self) -> bool: + """Check if circuit is closed (allowing requests).""" + return self.state == CircuitState.CLOSED + + def allow_request(self) -> bool: + """Check if a request should be allowed through. + + Returns: + True if request should proceed, False if it should fail fast + """ + with self._lock: + self._check_state_transition() + self._stats.total_requests += 1 + + if self._state == CircuitState.CLOSED: + return True + elif self._state == CircuitState.HALF_OPEN: + # Allow limited requests in half-open state + return True + else: # OPEN + self._stats.rejected_requests += 1 + return False + + def record_success(self) -> None: + """Record a successful request.""" + with self._lock: + self._stats.successful_requests += 1 + + if self._state == CircuitState.HALF_OPEN: + self._consecutive_successes += 1 + if self._consecutive_successes >= self._config.success_threshold: + self._transition_to(CircuitState.CLOSED) + elif self._state == CircuitState.CLOSED: + self._prune_old_failures() + + def record_failure(self) -> None: + """Record a failed request.""" + now = time.monotonic() + + with self._lock: + self._stats.failed_requests += 1 + self._last_failure_time = now + self._failure_timestamps.append(now) + self._consecutive_successes = 0 + + if self._state == CircuitState.HALF_OPEN: + # Any failure in half-open trips the circuit again + self._transition_to(CircuitState.OPEN) + elif self._state == CircuitState.CLOSED: + self._prune_old_failures() + if len(self._failure_timestamps) >= self._config.failure_threshold: + self._transition_to(CircuitState.OPEN) + + def _prune_old_failures(self) -> None: + """Remove failures outside the time window.""" + if self._config.failure_window_seconds <= 0: + return + + cutoff = time.monotonic() - self._config.failure_window_seconds + self._failure_timestamps = [ts for ts in self._failure_timestamps if ts > cutoff] + + def _check_state_transition(self) -> None: + """Check if state should transition based on time.""" + if self._state == CircuitState.OPEN: + time_since_open = time.monotonic() - self._last_state_change_time + if time_since_open >= self._config.timeout_seconds: + self._transition_to(CircuitState.HALF_OPEN) + + def _transition_to(self, new_state: CircuitState) -> None: + """Transition to a new state.""" + from .metrics import get_metrics_collector + + old_state = self._state + self._state = new_state + self._last_state_change_time = time.monotonic() + self._stats.state_transitions += 1 + + if new_state == CircuitState.CLOSED: + self._failure_timestamps.clear() + self._consecutive_successes = 0 + # Notify metrics collector that circuit is healthy again + get_metrics_collector().notify_circuit_breaker_closed() + elif new_state == CircuitState.HALF_OPEN: + self._consecutive_successes = 0 + elif new_state == CircuitState.OPEN: + # Warn about circuit breaker opening + get_metrics_collector().warn_circuit_breaker_open() + + logger.info(f"Circuit breaker '{self._name}' transitioned: {old_state.value} -> {new_state.value}") + + def reset(self) -> None: + """Reset circuit breaker to initial closed state.""" + with self._lock: + self._state = CircuitState.CLOSED + self._failure_timestamps.clear() + self._consecutive_successes = 0 + self._last_state_change_time = time.monotonic() + logger.debug(f"Circuit breaker '{self._name}' reset to CLOSED") + + +class CircuitOpenError(Exception): + """Raised when circuit breaker is open and request is rejected.""" + + def __init__(self, circuit_name: str) -> None: + self.circuit_name = circuit_name + super().__init__(f"Circuit breaker '{circuit_name}' is open") diff --git a/drift/core/tracing/adapters/api.py b/drift/core/tracing/adapters/api.py index d8c0329..f08331b 100644 --- a/drift/core/tracing/adapters/api.py +++ b/drift/core/tracing/adapters/api.py @@ -2,15 +2,28 @@ This adapter uses betterproto to serialize protobuf messages to binary format and sends them directly to the Tusk backend over HTTP. + +Features: +- Retry with exponential backoff for transient failures +- Circuit breaker to prevent cascading failures +- Optional gzip compression for large payloads (currently disabled by default) """ from __future__ import annotations +import gzip import logging from dataclasses import dataclass from datetime import UTC, datetime, timedelta from typing import TYPE_CHECKING, Any, override +from ...resilience import ( + CircuitBreaker, + CircuitBreakerConfig, + CircuitOpenError, + RetryConfig, + retry_async, +) from .base import ExportResult, SpanExportAdapter if TYPE_CHECKING: @@ -22,6 +35,9 @@ DRIFT_API_PATH = "/api/drift/tusk.drift.backend.v1.SpanExportService/ExportSpans" +# Compression threshold (bytes) - only compress if payload exceeds this +COMPRESSION_THRESHOLD = 1024 # 1KB + @dataclass class ApiSpanAdapterConfig: @@ -34,6 +50,19 @@ class ApiSpanAdapterConfig: sdk_version: str sdk_instance_id: str + # Retry configuration + max_retry_attempts: int = 3 + initial_retry_delay_seconds: float = 0.1 + max_retry_delay_seconds: float = 10.0 + + # Circuit breaker configuration + circuit_failure_threshold: int = 5 + circuit_timeout_seconds: float = 30.0 + + # Compression (disabled by default, matches Node SDK behavior) + enable_compression: bool = False + compression_threshold: int = COMPRESSION_THRESHOLD + class ApiSpanAdapter(SpanExportAdapter): """ @@ -41,6 +70,11 @@ class ApiSpanAdapter(SpanExportAdapter): Uses betterproto to serialize protobuf messages to binary format and sends them directly to the backend over HTTP. + + Features: + - Automatic retry with exponential backoff for transient failures + - Circuit breaker to fail fast when backend is unavailable + - Optional gzip compression for large payloads """ def __init__(self, config: ApiSpanAdapterConfig) -> None: @@ -53,7 +87,32 @@ def __init__(self, config: ApiSpanAdapterConfig) -> None: self._config = config self._base_url = f"{config.tusk_backend_base_url}{DRIFT_API_PATH}" - logger.debug("ApiSpanAdapter initialized with native protobuf serialization") + # Initialize retry configuration + self._retry_config = RetryConfig( + max_attempts=config.max_retry_attempts, + initial_delay_seconds=config.initial_retry_delay_seconds, + max_delay_seconds=config.max_retry_delay_seconds, + ) + + # Initialize circuit breaker + self._circuit_breaker = CircuitBreaker( + name="api_export", + config=CircuitBreakerConfig( + failure_threshold=config.circuit_failure_threshold, + timeout_seconds=config.circuit_timeout_seconds, + ), + ) + + # Statistics + self._spans_exported = 0 + self._spans_failed = 0 + self._bytes_sent = 0 + self._bytes_compressed = 0 + + logger.debug( + f"ApiSpanAdapter initialized with native protobuf serialization " + f"(retry={config.max_retry_attempts}, compression={config.enable_compression})" + ) def __repr__(self) -> str: return f"ApiSpanAdapter(url={self._base_url}, env={self._config.environment})" @@ -63,61 +122,136 @@ def __repr__(self) -> str: def name(self) -> str: return "api" + @property + def spans_exported(self) -> int: + """Total number of spans successfully exported.""" + return self._spans_exported + + @property + def spans_failed(self) -> int: + """Total number of spans that failed to export.""" + return self._spans_failed + + @property + def bytes_sent(self) -> int: + """Total bytes sent (after compression if enabled).""" + return self._bytes_sent + + @property + def circuit_state(self) -> str: + """Current circuit breaker state.""" + return self._circuit_breaker.state.value + @override async def export_spans(self, spans: list[CleanSpanData]) -> ExportResult: - """Export spans to the Tusk backend API using native binary protobuf.""" + """Export spans to the Tusk backend API using native binary protobuf. + + Features: + - Circuit breaker to fail fast if backend is unavailable + - Retry with exponential backoff for transient failures + - Optional gzip compression for large payloads + """ + # Check circuit breaker first + if not self._circuit_breaker.allow_request(): + logger.warning(f"Circuit breaker is open, dropping {len(spans)} spans") + self._spans_failed += len(spans) + return ExportResult.failed(CircuitOpenError("api_export")) + try: - import aiohttp - from tusk.drift.backend.v1 import ExportSpansRequest, ExportSpansResponse - - proto_spans = [self._transform_span_to_protobuf(span) for span in spans] - - # Build the protobuf request - request = ExportSpansRequest( - observable_service_id=self._config.observable_service_id, - environment=self._config.environment, - sdk_version=self._config.sdk_version, - sdk_instance_id=self._config.sdk_instance_id, - spans=proto_spans, + # Define the export operation for retry + async def do_export() -> ExportResult: + return await self._do_export(spans) + + # Execute with retry + result = await retry_async( + do_export, + config=self._retry_config, + retryable_exceptions=(Exception,), # Retry all exceptions + operation_name="span_export", ) - request_bytes = bytes(request) + # Record success + self._circuit_breaker.record_success() + self._spans_exported += len(spans) + return result - headers = { - "Content-Type": "application/protobuf", - "Accept": "application/protobuf", - "x-api-key": self._config.api_key, - "x-td-skip-instrumentation": "true", - } + except Exception as error: + # Record failure for circuit breaker + self._circuit_breaker.record_failure() + self._spans_failed += len(spans) + logger.error(f"Failed to export spans after retries: {error}") + return ExportResult.failed(error if isinstance(error, Exception) else Exception(str(error))) + + async def _do_export(self, spans: list[CleanSpanData]) -> ExportResult: + """Perform the actual export to the backend.""" + import aiohttp + from tusk.drift.backend.v1 import ExportSpansRequest, ExportSpansResponse + + proto_spans = [self._transform_span_to_protobuf(span) for span in spans] + + # Build the protobuf request + request = ExportSpansRequest( + observable_service_id=self._config.observable_service_id, + environment=self._config.environment, + sdk_version=self._config.sdk_version, + sdk_instance_id=self._config.sdk_instance_id, + spans=proto_spans, + ) - async with ( - aiohttp.ClientSession() as session, - session.post(self._base_url, data=request_bytes, headers=headers) as http_response, - ): - if http_response.status != 200: - error_text = await http_response.text() - raise Exception(f"API request failed (status {http_response.status}): {error_text}") + request_bytes = bytes(request) + original_size = len(request_bytes) + + # Apply compression if enabled and payload is large enough + headers = { + "Accept": "application/protobuf", + "x-api-key": self._config.api_key, + "x-td-skip-instrumentation": "true", + } + + if self._config.enable_compression and original_size >= self._config.compression_threshold: + request_bytes = gzip.compress(request_bytes, compresslevel=6) + headers["Content-Type"] = "application/protobuf" + headers["Content-Encoding"] = "gzip" + self._bytes_compressed += original_size - len(request_bytes) + logger.debug( + f"Compressed {original_size} -> {len(request_bytes)} bytes " + f"({100 * len(request_bytes) / original_size:.1f}%)" + ) + else: + headers["Content-Type"] = "application/protobuf" - response_bytes = await http_response.read() - response = ExportSpansResponse().parse(response_bytes) + self._bytes_sent += len(request_bytes) - if not response.success: - raise Exception(f'API export reported failure: "{response.message}"') + async with ( + aiohttp.ClientSession() as session, + session.post(self._base_url, data=request_bytes, headers=headers) as http_response, + ): + if http_response.status >= 500: + # Server errors are retryable + error_text = await http_response.text() + raise Exception(f"Server error (status {http_response.status}): {error_text}") + elif http_response.status != 200: + # Client errors (4xx) are not retryable + error_text = await http_response.text() + raise Exception(f"API request failed (status {http_response.status}): {error_text}") - logger.debug(f"Successfully exported {len(spans)} spans to remote endpoint") - return ExportResult.success() + response_bytes = await http_response.read() + response = ExportSpansResponse().parse(response_bytes) - except ImportError as error: - logger.error("aiohttp is required for API adapter. Install it with: pip install aiohttp") - return ExportResult.failed(error) - except Exception as error: - logger.error("Failed to export spans to remote:", exc_info=error) - return ExportResult.failed(error if isinstance(error, Exception) else Exception("API export failed")) + if not response.success: + raise Exception(f'API export reported failure: "{response.message}"') + + logger.debug(f"Successfully exported {len(spans)} spans to remote endpoint") + return ExportResult.success() @override async def shutdown(self) -> None: """Shutdown and cleanup.""" - pass + logger.debug( + f"ApiSpanAdapter shutting down. " + f"Exported: {self._spans_exported}, Failed: {self._spans_failed}, " + f"Bytes sent: {self._bytes_sent}, Compressed: {self._bytes_compressed}" + ) def _transform_span_to_protobuf(self, clean_span: CleanSpanData) -> Any: """Transform CleanSpanData to protobuf Span format.""" @@ -266,6 +400,9 @@ def create_api_adapter( sdk_version: str = "0.1.0", sdk_instance_id: str | None = None, tusk_backend_base_url: str = "https://api.usetusk.ai", + *, + enable_compression: bool = False, + max_retry_attempts: int = 3, ) -> ApiSpanAdapter: """ Create an API span adapter with the given configuration. @@ -277,6 +414,8 @@ def create_api_adapter( sdk_version: Version of the SDK sdk_instance_id: Unique ID for this SDK instance (auto-generated if not provided) tusk_backend_base_url: Base URL for the Tusk backend + enable_compression: Whether to enable gzip compression for large payloads (disabled by default) + max_retry_attempts: Maximum number of retry attempts for failed exports Returns: Configured ApiSpanAdapter instance @@ -293,6 +432,8 @@ def create_api_adapter( environment=environment, sdk_version=sdk_version, sdk_instance_id=sdk_instance_id, + enable_compression=enable_compression, + max_retry_attempts=max_retry_attempts, ) return ApiSpanAdapter(config) diff --git a/drift/instrumentation/__init__.py b/drift/instrumentation/__init__.py index 0fbf732..3d6e09e 100644 --- a/drift/instrumentation/__init__.py +++ b/drift/instrumentation/__init__.py @@ -2,7 +2,7 @@ from .base import InstrumentationBase from .django import DjangoInstrumentation -from .registry import install_hooks, patch_instances_via_gc, register_patch +from .registry import install_hooks, register_patch from .wsgi import WsgiInstrumentation __all__ = [ @@ -11,5 +11,4 @@ "WsgiInstrumentation", "register_patch", "install_hooks", - "patch_instances_via_gc", ] diff --git a/drift/instrumentation/registry.py b/drift/instrumentation/registry.py index 8d5561f..ab9f74c 100644 --- a/drift/instrumentation/registry.py +++ b/drift/instrumentation/registry.py @@ -1,13 +1,11 @@ -import gc import importlib.abc import importlib.machinery import sys from collections.abc import Callable, Sequence from types import ModuleType -from typing import TypeVar, override +from typing import override PatchFn = Callable[[ModuleType], None] -T = TypeVar("T") _registry: dict[str, PatchFn] = {} _installed = False @@ -90,13 +88,3 @@ def _apply_patch(module: ModuleType, patch_fn: PatchFn) -> None: patch_fn(module) module.__drift_patched__ = True # type: ignore[attr-defined] - - -def patch_instances_via_gc[T](class_type: type, patch_instance_fn: Callable[[T], None]) -> None: - """Use gc to patch instances created before SDK initialization""" - for obj in gc.get_objects(): # pyright: ignore[reportAny] - if isinstance(obj, class_type): - obj: T = obj - if not getattr(obj, "__drift_instance_patched__", False): - patch_instance_fn(obj) - obj.__drift_instance_patched__ = True From 7cd433778e8f0aa1dc8d641e95709814ff6a2f59 Mon Sep 17 00:00:00 2001 From: JY Tan Date: Wed, 14 Jan 2026 15:26:16 -0800 Subject: [PATCH 2/2] Fix --- drift/core/batch_processor.py | 28 +++++++++++++++------------- drift/core/metrics.py | 5 +++++ drift/core/resilience.py | 7 ++++++- 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/drift/core/batch_processor.py b/drift/core/batch_processor.py index 3ebd01b..22dcc4e 100644 --- a/drift/core/batch_processor.py +++ b/drift/core/batch_processor.py @@ -119,23 +119,25 @@ def add_span(self, span: CleanSpanData) -> bool: Returns: True if span was added, False if queue is full or trace is blocked """ - # Check if span should be blocked (size limit or server error) - # Blocks entire trace from .trace_blocking_manager import TraceBlockingManager, should_block_span - if should_block_span(span): - self._dropped_spans += 1 - self._metrics.record_spans_dropped() - return False - - # Check if trace is already blocked - if TraceBlockingManager.get_instance().is_trace_blocked(span.trace_id): - logger.debug(f"Skipping span '{span.name}' - trace {span.trace_id} is blocked") - self._dropped_spans += 1 - self._metrics.record_spans_dropped() - return False + # Check blocking conditions outside lock (read-only checks) + is_blocked = should_block_span(span) + is_trace_blocked = TraceBlockingManager.get_instance().is_trace_blocked(span.trace_id) with self._condition: + # Handle blocked spans (increment counter under lock) + if is_blocked: + self._dropped_spans += 1 + self._metrics.record_spans_dropped() + return False + + if is_trace_blocked: + logger.debug(f"Skipping span '{span.name}' - trace {span.trace_id} is blocked") + self._dropped_spans += 1 + self._metrics.record_spans_dropped() + return False + if len(self._queue) >= self._config.max_queue_size: self._dropped_spans += 1 self._metrics.record_spans_dropped() diff --git a/drift/core/metrics.py b/drift/core/metrics.py index 5581a04..dd0f3d5 100644 --- a/drift/core/metrics.py +++ b/drift/core/metrics.py @@ -282,6 +282,11 @@ def reset(self) -> None: self._queue_peak_size = 0 self._start_time = time.monotonic() + self._warned_high_drop_rate = False + self._warned_high_failure_rate = False + self._warned_queue_capacity = False + self._warned_circuit_open = False + # Global metrics collector instance _metrics_collector: MetricsCollector | None = None diff --git a/drift/core/resilience.py b/drift/core/resilience.py index 428e753..2613db2 100644 --- a/drift/core/resilience.py +++ b/drift/core/resilience.py @@ -55,6 +55,7 @@ async def retry_async[T]( operation: Callable[[], Awaitable[T]], config: RetryConfig | None = None, retryable_exceptions: tuple[type[Exception], ...] = (Exception,), + non_retryable_exceptions: tuple[type[Exception], ...] = (), operation_name: str = "operation", ) -> T: """Execute an async operation with retry and exponential backoff. @@ -63,13 +64,14 @@ async def retry_async[T]( operation: Async callable to execute config: Retry configuration (uses defaults if None) retryable_exceptions: Tuple of exception types that trigger retry + non_retryable_exceptions: Tuple of exception types that should fail immediately operation_name: Name for logging purposes Returns: Result of the operation Raises: - The last exception if all retries are exhausted + The last exception if all retries are exhausted, or immediately for non-retryable """ config = config or RetryConfig() last_exception: Exception | None = None @@ -77,6 +79,9 @@ async def retry_async[T]( for attempt in range(1, config.max_attempts + 1): try: return await operation() + except non_retryable_exceptions: + # Don't retry these - re-raise immediately + raise except retryable_exceptions as e: last_exception = e