Skip to content

Commit 5afe3c0

Browse files
pghistory: pass on context to celery tasks
1 parent d8c842a commit 5afe3c0

3 files changed

Lines changed: 132 additions & 2 deletions

File tree

dojo/celery.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
from logging.config import dictConfig
44

5-
from celery import Celery
5+
from celery import Celery, Task
66
from celery.signals import setup_logging
77
from django.conf import settings
88

@@ -11,7 +11,31 @@
1111
# set the default Django settings module for the 'celery' program.
1212
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "dojo.settings.settings")
1313

14-
app = Celery("dojo")
14+
15+
class PgHistoryTask(Task):
16+
17+
"""
18+
Custom Celery base task that automatically applies pghistory context.
19+
20+
When a task is dispatched via dojo_async_task, the current pghistory
21+
context is captured and passed in kwargs as "_pgh_context". This base
22+
class extracts that context and applies it before running the task,
23+
ensuring all database events share the same context as the original
24+
request.
25+
"""
26+
27+
def __call__(self, *args, **kwargs):
28+
# Import here to avoid circular imports during Celery startup
29+
from dojo.pghistory_utils import get_pghistory_context_manager # noqa: PLC0415
30+
31+
# Extract context from kwargs (won't be passed to task function)
32+
pgh_context = kwargs.pop("_pgh_context", None)
33+
34+
with get_pghistory_context_manager(pgh_context):
35+
return super().__call__(*args, **kwargs)
36+
37+
38+
app = Celery("dojo", task_cls=PgHistoryTask)
1539

1640
# Using a string here means the worker will not have to
1741
# pickle the object when using Windows.

dojo/decorators.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,17 @@ def dojo_async_task(func=None, *, signature=False):
8383
def decorator(func):
8484
@wraps(func)
8585
def __wrapper__(*args, **kwargs):
86+
from dojo.pghistory_utils import get_serializable_pghistory_context # noqa: PLC0415 circular import
8687
from dojo.utils import get_current_user # noqa: PLC0415 circular import
88+
8789
user = get_current_user()
8890
kwargs["async_user"] = user
8991

92+
# Capture pghistory context to pass to Celery worker
93+
# The PgHistoryTask base class will apply this context in the worker
94+
if pgh_context := get_serializable_pghistory_context():
95+
kwargs["_pgh_context"] = pgh_context
96+
9097
dojo_async_task_counter.incr(
9198
func.__name__,
9299
args=args,

dojo/pghistory_utils.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""
2+
Utilities for passing pghistory context to Celery tasks.
3+
4+
pghistory uses thread-local storage, so context is lost when tasks run
5+
in Celery workers. These utilities allow capturing context in the sender
6+
process and recreating it in the worker.
7+
"""
8+
import uuid
9+
from contextlib import nullcontext
10+
11+
from pghistory import runtime as pghistory_runtime
12+
13+
14+
def get_serializable_pghistory_context():
15+
"""
16+
Capture the current pghistory context for passing to Celery tasks.
17+
18+
Returns a JSON-serializable dict with context id and metadata,
19+
or None if no context is active.
20+
"""
21+
if hasattr(pghistory_runtime._tracker, "value"):
22+
ctx = pghistory_runtime._tracker.value
23+
return {
24+
"id": str(ctx.id),
25+
"metadata": ctx.metadata.copy(),
26+
}
27+
return None
28+
29+
30+
class PgHistoryContextFromTask:
31+
32+
"""
33+
Context manager to apply pghistory context received from a Celery task.
34+
35+
This recreates the exact same context (with the same UUID) that was
36+
active when the task was dispatched, ensuring all events share the
37+
same pgh_context_id.
38+
39+
Usage:
40+
pgh_context = kwargs.pop("_pgh_context", None)
41+
with PgHistoryContextFromTask(pgh_context):
42+
# Task body runs here with context applied
43+
"""
44+
45+
def __init__(self, context_data):
46+
"""
47+
Initialize with context data from Celery kwargs.
48+
49+
Args:
50+
context_data: Dict with "id" (UUID string) and "metadata" (dict),
51+
or None for no-op behavior.
52+
53+
"""
54+
self.context_data = context_data
55+
self._pre_execute_hook = None
56+
self._owns_context = False
57+
58+
def __enter__(self):
59+
if not self.context_data:
60+
return None
61+
62+
from django.db import connection # noqa: PLC0415
63+
64+
context_id = uuid.UUID(self.context_data["id"])
65+
metadata = self.context_data["metadata"]
66+
67+
# Only create a new context if one doesn't already exist
68+
if not hasattr(pghistory_runtime._tracker, "value"):
69+
self._pre_execute_hook = connection.execute_wrapper(
70+
pghistory_runtime._inject_history_context,
71+
)
72+
self._pre_execute_hook.__enter__()
73+
pghistory_runtime._tracker.value = pghistory_runtime.Context(
74+
id=context_id,
75+
metadata=metadata,
76+
)
77+
self._owns_context = True
78+
else:
79+
# Context already exists, just merge metadata
80+
pghistory_runtime._tracker.value.metadata.update(metadata)
81+
82+
return pghistory_runtime._tracker.value
83+
84+
def __exit__(self, *exc):
85+
if self._owns_context and self._pre_execute_hook:
86+
delattr(pghistory_runtime._tracker, "value")
87+
self._pre_execute_hook.__exit__(*exc)
88+
89+
90+
def get_pghistory_context_manager(context_data):
91+
"""
92+
Return appropriate context manager for the given context data.
93+
94+
Returns PgHistoryContextFromTask if context_data is provided,
95+
otherwise returns a no-op nullcontext.
96+
"""
97+
if context_data:
98+
return PgHistoryContextFromTask(context_data)
99+
return nullcontext()

0 commit comments

Comments
 (0)