Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 140 additions & 64 deletions sentry_sdk/integrations/celery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
)
from sentry_sdk.integrations.celery.utils import _now_seconds_since_epoch
from sentry_sdk.integrations.logging import ignore_logger
from sentry_sdk.traces import StreamedSpan
from sentry_sdk.tracing import BAGGAGE_HEADER_NAME, Span, TransactionSource
from sentry_sdk.tracing_utils import Baggage
from sentry_sdk.tracing_utils import Baggage, has_span_streaming_enabled
from sentry_sdk.utils import (
capture_internal_exceptions,
ensure_integration_enabled,
event_from_exception,
reraise,
)
Expand Down Expand Up @@ -162,7 +162,9 @@


def _update_celery_task_headers(
original_headers: "dict[str, Any]", span: "Optional[Span]", monitor_beat_tasks: bool
original_headers: "dict[str, Any]",
span: "Optional[Union[StreamedSpan, Span]]",
monitor_beat_tasks: bool,
) -> "dict[str, Any]":
"""
Updates the headers of the Celery task with the tracing information
Expand Down Expand Up @@ -255,7 +257,8 @@
def apply_async(*args: "Any", **kwargs: "Any") -> "Any":
# Note: kwargs can contain headers=None, so no setdefault!
# Unsure which backend though.
integration = sentry_sdk.get_client().get_integration(CeleryIntegration)
client = sentry_sdk.get_client()
integration = client.get_integration(CeleryIntegration)
if integration is None:
return f(*args, **kwargs)

Expand All @@ -274,17 +277,28 @@
else:
task_name = "<unknown Celery task>"

span_streaming = has_span_streaming_enabled(client.options)

task_started_from_beat = sentry_sdk.get_isolation_scope()._name == "celery-beat"

span_mgr: "Union[Span, NoOpMgr]" = (
sentry_sdk.start_span(
op=OP.QUEUE_SUBMIT_CELERY,
name=task_name,
origin=CeleryIntegration.origin,
)
if not task_started_from_beat
else NoOpMgr()
)
span_mgr: "Union[StreamedSpan, Span, NoOpMgr]" = NoOpMgr()
if span_streaming:
if not task_started_from_beat:
span_mgr = sentry_sdk.traces.start_span(
name=task_name,
attributes={
"sentry.op": OP.QUEUE_SUBMIT_CELERY,
"sentry.origin": CeleryIntegration.origin,
},
)

else:
if not task_started_from_beat:
span_mgr = sentry_sdk.start_span(
op=OP.QUEUE_SUBMIT_CELERY,
name=task_name,
origin=CeleryIntegration.origin,
)

with span_mgr as span:
kwargs["headers"] = _update_celery_task_headers(
Expand All @@ -303,50 +317,73 @@
# Also because in Celery 3, signal dispatch returns early if one handler
# crashes.
@wraps(f)
@ensure_integration_enabled(CeleryIntegration, f)
def _inner(*args: "Any", **kwargs: "Any") -> "Any":
client = sentry_sdk.get_client()
if client.get_integration(CeleryIntegration) is None:
return f(*args, **kwargs)

span_streaming = has_span_streaming_enabled(client.options)

with isolation_scope() as scope:
scope._name = "celery"
scope.clear_breadcrumbs()
scope.add_event_processor(_make_event_processor(task, *args, **kwargs))

transaction = None
transaction: "Optional[Union[Span, StreamedSpan]]" = None
span_ctx: "Optional[Union[Span, StreamedSpan]]" = None

# Celery task objects are not a thing to be trusted. Even
# something such as attribute access can fail.
with capture_internal_exceptions():
headers = args[3].get("headers") or {}
transaction = continue_trace(
headers,
op=OP.QUEUE_TASK_CELERY,
name="unknown celery task",
source=TransactionSource.TASK,
origin=CeleryIntegration.origin,
)
transaction.name = task.name
transaction.set_status(SPANSTATUS.OK)
if span_streaming:
sentry_sdk.traces.continue_trace(headers)
transaction = sentry_sdk.traces.start_span(
name=task.name,
attributes={
"sentry.origin": CeleryIntegration.origin,
"sentry.span.source": TransactionSource.TASK.value,
"sentry.op": OP.QUEUE_TASK_CELERY,
},
)

if transaction is None:
span_ctx = transaction

else:
transaction = continue_trace(
headers,
op=OP.QUEUE_TASK_CELERY,
name=task.name,
source=TransactionSource.TASK,
origin=CeleryIntegration.origin,
)
transaction.set_status(SPANSTATUS.OK)

span_ctx = sentry_sdk.start_transaction(
transaction,
custom_sampling_context={
"celery_job": {
"task": task.name,
# for some reason, args[1] is a list if non-empty but a
# tuple if empty
"args": list(args[1]),
"kwargs": args[2],
}
},
)

if transaction is None or span_ctx is None:
return f(*args, **kwargs)

with sentry_sdk.start_transaction(
transaction,
custom_sampling_context={
"celery_job": {
"task": task.name,
# for some reason, args[1] is a list if non-empty but a
# tuple if empty
"args": list(args[1]),
"kwargs": args[2],
}
},
):
with span_ctx:
return f(*args, **kwargs)

return _inner # type: ignore


def _set_messaging_destination_name(task: "Any", span: "Span") -> None:
def _set_messaging_destination_name(
task: "Any", span: "Union[StreamedSpan, Span]"
) -> None:
"""Set "messaging.destination.name" tag for span"""
with capture_internal_exceptions():
delivery_info = task.request.delivery_info
Expand All @@ -355,26 +392,43 @@
if delivery_info.get("exchange") == "" and routing_key is not None:
# Empty exchange indicates the default exchange, meaning the tasks
# are sent to the queue with the same name as the routing key.
span.set_data(SPANDATA.MESSAGING_DESTINATION_NAME, routing_key)
if isinstance(span, StreamedSpan):
span.set_attribute(SPANDATA.MESSAGING_DESTINATION_NAME, routing_key)
else:
span.set_data(SPANDATA.MESSAGING_DESTINATION_NAME, routing_key)


def _wrap_task_call(task: "Any", f: "F") -> "F":
# Need to wrap task call because the exception is caught before we get to
# see it. Also celery's reported stacktrace is untrustworthy.

# functools.wraps is important here because celery-once looks at this
# method's name. @ensure_integration_enabled internally calls functools.wraps,
# but if we ever remove the @ensure_integration_enabled decorator, we need
# to add @functools.wraps(f) here.
# https://github.com/getsentry/sentry-python/issues/421
@ensure_integration_enabled(CeleryIntegration, f)
@wraps(f)
def _inner(*args: "Any", **kwargs: "Any") -> "Any":
client = sentry_sdk.get_client()
if client.get_integration(CeleryIntegration) is None:
return f(*args, **kwargs)

span_streaming = has_span_streaming_enabled(client.options)

try:
with sentry_sdk.start_span(
op=OP.QUEUE_PROCESS,
name=task.name,
origin=CeleryIntegration.origin,
) as span:
span: "Union[Span, StreamedSpan]"
if span_streaming:
span = sentry_sdk.traces.start_span(name=task.name)
span.set_attribute("sentry.op", OP.QUEUE_PROCESS)
span.set_attribute("sentry.origin", CeleryIntegration.origin)
else:
span = sentry_sdk.start_span(
op=OP.QUEUE_PROCESS,
name=task.name,
origin=CeleryIntegration.origin,
)

with span:
if isinstance(span, StreamedSpan):
set_on_span = span.set_attribute
else:
set_on_span = span.set_data

_set_messaging_destination_name(task, span)

latency = None
Expand All @@ -389,19 +443,19 @@

if latency is not None:
latency *= 1000 # milliseconds
span.set_data(SPANDATA.MESSAGING_MESSAGE_RECEIVE_LATENCY, latency)
set_on_span(SPANDATA.MESSAGING_MESSAGE_RECEIVE_LATENCY, latency)

with capture_internal_exceptions():
span.set_data(SPANDATA.MESSAGING_MESSAGE_ID, task.request.id)
set_on_span(SPANDATA.MESSAGING_MESSAGE_ID, task.request.id)

with capture_internal_exceptions():
span.set_data(
set_on_span(
SPANDATA.MESSAGING_MESSAGE_RETRY_COUNT, task.request.retries
)

with capture_internal_exceptions():
with task.app.connection() as conn:
span.set_data(
set_on_span(
SPANDATA.MESSAGING_SYSTEM,
conn.transport.driver_type,
)
Expand Down Expand Up @@ -476,8 +530,13 @@
def _patch_producer_publish() -> None:
original_publish = Producer.publish

@ensure_integration_enabled(CeleryIntegration, original_publish)
def sentry_publish(self: "Producer", *args: "Any", **kwargs: "Any") -> "Any":
client = sentry_sdk.get_client()
if client.get_integration(CeleryIntegration) is None:
return original_publish(self, *args, **kwargs)

span_streaming = has_span_streaming_enabled(client.options)

kwargs_headers = kwargs.get("headers", {})
if not isinstance(kwargs_headers, Mapping):
# Ensure kwargs_headers is a Mapping, so we can safely call get().
Expand All @@ -487,31 +546,48 @@
# method will still work.
kwargs_headers = {}

if "task" not in kwargs_headers:
# filter out heartbeat and other internal Celery events
return original_publish(self, *args, **kwargs)

task_name = kwargs_headers.get("task")
task_id = kwargs_headers.get("id")
retries = kwargs_headers.get("retries")

routing_key = kwargs.get("routing_key")
exchange = kwargs.get("exchange")

with sentry_sdk.start_span(
op=OP.QUEUE_PUBLISH,
name=task_name,
origin=CeleryIntegration.origin,
) as span:
span: "Union[StreamedSpan, Span]"
if span_streaming:
span = sentry_sdk.traces.start_span(name=task_name)
span.set_attribute("sentry.op", OP.QUEUE_PUBLISH)
span.set_attribute("sentry.origin", CeleryIntegration.origin)

Check warning on line 564 in sentry_sdk/integrations/celery/__init__.py

View check run for this annotation

@sentry/warden / warden: code-review

Missing test coverage for span streaming in producer publish

The new span streaming functionality in `_patch_producer_publish` lacks test coverage. The ASGI integration has comprehensive parameterized tests that validate both streaming and non-streaming modes, but no equivalent tests exist for the Celery producer publish flow. Without tests, regressions in span creation or attribute setting for message publishing could go undetected.
Comment thread
sentry-warden[bot] marked this conversation as resolved.
Outdated
else:
span = sentry_sdk.start_span(
op=OP.QUEUE_PUBLISH,
name=task_name,
origin=CeleryIntegration.origin,
)

with span:
if isinstance(span, StreamedSpan):
set_on_span = span.set_attribute
else:
set_on_span = span.set_data

if task_id is not None:
span.set_data(SPANDATA.MESSAGING_MESSAGE_ID, task_id)
set_on_span(SPANDATA.MESSAGING_MESSAGE_ID, task_id)

if exchange == "" and routing_key is not None:
# Empty exchange indicates the default exchange, meaning messages are
# routed to the queue with the same name as the routing key.
span.set_data(SPANDATA.MESSAGING_DESTINATION_NAME, routing_key)
set_on_span(SPANDATA.MESSAGING_DESTINATION_NAME, routing_key)

if retries is not None:
span.set_data(SPANDATA.MESSAGING_MESSAGE_RETRY_COUNT, retries)
set_on_span(SPANDATA.MESSAGING_MESSAGE_RETRY_COUNT, retries)

with capture_internal_exceptions():
span.set_data(
set_on_span(
SPANDATA.MESSAGING_SYSTEM, self.connection.transport.driver_type
)

Expand Down
Loading