Skip to content

Commit 2ed1533

Browse files
replace dojo_async_task decorator with class+helper
1 parent dd5a95f commit 2ed1533

24 files changed

Lines changed: 174 additions & 186 deletions

File tree

dojo/api_v2/views.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
)
4747
from dojo.api_v2.prefetch.prefetcher import _Prefetcher
4848
from dojo.authorization.roles_permissions import Permissions
49+
from dojo.celery_dispatch import dojo_dispatch_task
4950
from dojo.cred.queries import get_authorized_cred_mappings
5051
from dojo.endpoint.queries import (
5152
get_authorized_endpoint_status,
@@ -678,13 +679,13 @@ def update_jira_epic(self, request, pk=None):
678679
try:
679680

680681
if engagement.has_jira_issue:
681-
jira_helper.update_epic(engagement.id, **request.data)
682+
dojo_dispatch_task(jira_helper.update_epic, engagement.id, **request.data)
682683
response = Response(
683684
{"info": "Jira Epic update query sent"},
684685
status=status.HTTP_200_OK,
685686
)
686687
else:
687-
jira_helper.add_epic(engagement.id, **request.data)
688+
dojo_dispatch_task(jira_helper.add_epic, engagement.id, **request.data)
688689
response = Response(
689690
{"info": "Jira Epic create query sent"},
690691
status=status.HTTP_200_OK,

dojo/celery.py

Lines changed: 5 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ class DojoAsyncTask(Task):
5252
This class:
5353
- Injects user context into task kwargs
5454
- Tracks task calls for performance testing
55-
- Handles sync/async execution based on user settings
5655
- Supports all Celery features (signatures, chords, groups, chains)
5756
"""
5857

@@ -68,97 +67,18 @@ def apply_async(self, args=None, kwargs=None, **options):
6867
if "async_user" not in kwargs:
6968
kwargs["async_user"] = get_current_user()
7069

71-
# Track task call (only if not already tracked by __call__)
72-
# Check if this is a direct call to apply_async (not from __call__)
73-
# by checking if _dojo_tracked is not set
74-
if not getattr(self, "_dojo_tracked", False):
75-
dojo_async_task_counter.incr(
76-
self.name,
77-
args=args,
78-
kwargs=kwargs,
79-
)
70+
# Control flag used for sync/async decision; never pass into the task itself
71+
kwargs.pop("sync", None)
8072

81-
# Call parent to execute async
82-
return super().apply_async(args=args, kwargs=kwargs, **options)
83-
84-
def s(self, *args, **kwargs):
85-
"""Create a mutable signature with injected user context."""
86-
from dojo.decorators import dojo_async_task_counter # noqa: PLC0415 circular import
87-
from dojo.utils import get_current_user # noqa: PLC0415 circular import
88-
89-
if "async_user" not in kwargs:
90-
kwargs["async_user"] = get_current_user()
91-
92-
# Track task call
93-
dojo_async_task_counter.incr(
94-
self.name,
95-
args=args,
96-
kwargs=kwargs,
97-
)
98-
99-
return super().s(*args, **kwargs)
100-
101-
def si(self, *args, **kwargs):
102-
"""Create an immutable signature with injected user context."""
103-
from dojo.decorators import dojo_async_task_counter # noqa: PLC0415 circular import
104-
from dojo.utils import get_current_user # noqa: PLC0415 circular import
105-
106-
if "async_user" not in kwargs:
107-
kwargs["async_user"] = get_current_user()
108-
109-
# Track task call
110-
dojo_async_task_counter.incr(
111-
self.name,
112-
args=args,
113-
kwargs=kwargs,
114-
)
115-
116-
return super().si(*args, **kwargs)
117-
118-
def __call__(self, *args, **kwargs):
119-
"""
120-
Override __call__ to handle direct task calls with sync/async logic.
121-
122-
This replicates the behavior of the dojo_async_task decorator wrapper.
123-
"""
124-
# In Celery worker execution, __call__ is how tasks actually run.
125-
# We only want the sync/async decision when tasks are called directly
126-
# from application code (task(...)), not when the worker is executing a message.
127-
if not getattr(self.request, "called_directly", True):
128-
return super().__call__(*args, **kwargs)
129-
130-
from dojo.decorators import dojo_async_task_counter, we_want_async # noqa: PLC0415 circular import
131-
from dojo.utils import get_current_user # noqa: PLC0415 circular import
132-
133-
# Inject user context if not already present
134-
if "async_user" not in kwargs:
135-
kwargs["async_user"] = get_current_user()
136-
137-
# Track task call
73+
# Track dispatch
13874
dojo_async_task_counter.incr(
13975
self.name,
14076
args=args,
14177
kwargs=kwargs,
14278
)
14379

144-
# Extract countdown if present (don't pass to sync execution)
145-
countdown = kwargs.pop("countdown", 0)
146-
147-
# Check if we should run async or sync
148-
if we_want_async(*args, func=self, **kwargs):
149-
# Mark as tracked to avoid double tracking in apply_async
150-
self._dojo_tracked = True
151-
try:
152-
# Run asynchronously
153-
return self.apply_async(args=args, kwargs=kwargs, countdown=countdown)
154-
finally:
155-
# Clean up the flag
156-
delattr(self, "_dojo_tracked")
157-
else:
158-
# Run synchronously in-process, matching the original decorator behavior: func(*args, **kwargs)
159-
# Remove sync from kwargs as it's a control flag, not a task argument.
160-
kwargs.pop("sync", None)
161-
return self.run(*args, **kwargs)
80+
# Call parent to execute async
81+
return super().apply_async(args=args, kwargs=kwargs, **options)
16282

16383

16484
@app.task(bind=True)

dojo/celery_dispatch.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any, Protocol, cast
4+
5+
from celery.canvas import Signature
6+
7+
if TYPE_CHECKING:
8+
from collections.abc import Mapping
9+
10+
11+
class _SupportsSi(Protocol):
12+
def si(self, *args: Any, **kwargs: Any) -> Signature: ...
13+
14+
15+
class _SupportsApplyAsync(Protocol):
16+
def apply_async(self, args: Any | None = None, kwargs: Any | None = None, **options: Any) -> Any: ...
17+
18+
19+
def _inject_async_user(kwargs: Mapping[str, Any] | None) -> dict[str, Any]:
20+
result: dict[str, Any] = dict(kwargs or {})
21+
if "async_user" not in result:
22+
from dojo.utils import get_current_user # noqa: PLC0415 circular import
23+
24+
result["async_user"] = get_current_user()
25+
return result
26+
27+
28+
def dojo_create_signature(task_or_sig: _SupportsSi | Signature, *args: Any, **kwargs: Any) -> Signature:
29+
"""
30+
Build a Celery signature with DefectDojo user context injected.
31+
32+
- If passed a task, returns `task_or_sig.si(*args, **kwargs)`.
33+
- If passed an existing signature, returns a cloned signature with merged kwargs.
34+
"""
35+
injected = _inject_async_user(kwargs)
36+
injected.pop("countdown", None)
37+
38+
if isinstance(task_or_sig, Signature):
39+
merged_kwargs = {**(task_or_sig.kwargs or {}), **injected}
40+
return task_or_sig.clone(kwargs=merged_kwargs)
41+
42+
return task_or_sig.si(*args, **injected)
43+
44+
45+
def dojo_dispatch_task(task_or_sig: _SupportsSi | _SupportsApplyAsync | Signature, *args: Any, **kwargs: Any) -> Any:
46+
"""
47+
Dispatch a task/signature using DefectDojo semantics.
48+
49+
- Inject `async_user` if missing.
50+
- Respect `sync=True` (foreground execution) and user `block_execution`.
51+
- Support `countdown=<seconds>` for async dispatch.
52+
53+
Returns:
54+
- async: AsyncResult-like return from Celery
55+
- sync: underlying return value of the task
56+
57+
"""
58+
from dojo.decorators import dojo_async_task_counter, we_want_async # noqa: PLC0415 circular import
59+
60+
countdown = cast("int", kwargs.pop("countdown", 0))
61+
injected = _inject_async_user(kwargs)
62+
63+
sig = dojo_create_signature(task_or_sig if isinstance(task_or_sig, Signature) else cast("_SupportsSi", task_or_sig), *args, **injected)
64+
sig_kwargs = dict(sig.kwargs or {})
65+
66+
if we_want_async(*sig.args, func=getattr(sig, "type", None), **sig_kwargs):
67+
# DojoAsyncTask.apply_async tracks async dispatch. Avoid double-counting here.
68+
return sig.apply_async(countdown=countdown)
69+
70+
# Track foreground execution as a "created task" as well (matches historical dojo_async_task behavior)
71+
dojo_async_task_counter.incr(str(sig.task), args=sig.args, kwargs=sig_kwargs)
72+
73+
sig_kwargs.pop("sync", None)
74+
sig = sig.clone(kwargs=sig_kwargs)
75+
eager = sig.apply()
76+
return eager.get(propagate=True)

dojo/endpoint/views.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from dojo.authorization.authorization import user_has_permission_or_403
1919
from dojo.authorization.authorization_decorators import user_is_authorized
2020
from dojo.authorization.roles_permissions import Permissions
21+
from dojo.celery_dispatch import dojo_dispatch_task
2122
from dojo.endpoint.queries import get_authorized_endpoints
2223
from dojo.endpoint.utils import clean_hosts_run, endpoint_meta_import
2324
from dojo.filters import EndpointFilter, EndpointFilterWithoutObjectLookups
@@ -373,7 +374,7 @@ def endpoint_bulk_update_all(request, pid=None):
373374
product_calc = list(Product.objects.filter(endpoint__id__in=endpoints_to_update).distinct())
374375
endpoints.delete()
375376
for prod in product_calc:
376-
calculate_grade(prod.id)
377+
dojo_dispatch_task(calculate_grade, prod.id)
377378

378379
if skipped_endpoint_count > 0:
379380
add_error_message_to_response(f"Skipped deletion of {skipped_endpoint_count} endpoints because you are not authorized.")

dojo/engagement/services.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from django.dispatch import receiver
66

77
import dojo.jira_link.helper as jira_helper
8+
from dojo.celery_dispatch import dojo_dispatch_task
89
from dojo.models import Engagement
910

1011
logger = logging.getLogger(__name__)
@@ -16,7 +17,7 @@ def close_engagement(eng):
1617
eng.save()
1718

1819
if jira_helper.get_jira_project(eng):
19-
jira_helper.close_epic(eng.id, push_to_jira=True)
20+
dojo_dispatch_task(jira_helper.close_epic, eng.id, push_to_jira=True)
2021

2122

2223
def reopen_engagement(eng):

dojo/engagement/views.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from dojo.authorization.authorization import user_has_permission_or_403
3838
from dojo.authorization.authorization_decorators import user_is_authorized
3939
from dojo.authorization.roles_permissions import Permissions
40+
from dojo.celery_dispatch import dojo_dispatch_task
4041
from dojo.endpoint.utils import save_endpoints_to_add
4142
from dojo.engagement.queries import get_authorized_engagements
4243
from dojo.engagement.services import close_engagement, reopen_engagement
@@ -390,7 +391,7 @@ def copy_engagement(request, eid):
390391
form = DoneForm(request.POST)
391392
if form.is_valid():
392393
engagement_copy = engagement.copy()
393-
calculate_grade(product.id)
394+
dojo_dispatch_task(calculate_grade, product.id)
394395
messages.add_message(
395396
request,
396397
messages.SUCCESS,

dojo/finding/helper.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,9 @@ def post_process_finding_save_internal(finding, dedupe_option=True, rules_option
434434

435435
if product_grading_option:
436436
if system_settings.enable_product_grade:
437-
calculate_grade(finding.test.engagement.product.id)
437+
from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import
438+
439+
dojo_dispatch_task(calculate_grade, finding.test.engagement.product.id)
438440
else:
439441
deduplicationLogger.debug("skipping product grading because it's disabled in system settings")
440442

@@ -493,7 +495,9 @@ def post_process_findings_batch(finding_ids, *args, dedupe_option=True, rules_op
493495
tool_issue_updater.async_tool_issue_update(finding)
494496

495497
if product_grading_option and system_settings.enable_product_grade:
496-
calculate_grade(findings[0].test.engagement.product.id)
498+
from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import
499+
500+
dojo_dispatch_task(calculate_grade, findings[0].test.engagement.product.id)
497501

498502
if push_to_jira:
499503
for finding in findings:

dojo/finding/views.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
user_is_authorized,
3939
)
4040
from dojo.authorization.roles_permissions import Permissions
41+
from dojo.celery_dispatch import dojo_dispatch_task
4142
from dojo.filters import (
4243
AcceptedFindingFilter,
4344
AcceptedFindingFilterWithoutObjectLookups,
@@ -1082,7 +1083,7 @@ def process_form(self, request: HttpRequest, finding: Finding, context: dict):
10821083
product = finding.test.engagement.product
10831084
finding.delete()
10841085
# Update the grade of the product async
1085-
calculate_grade(product.id)
1086+
dojo_dispatch_task(calculate_grade, product.id)
10861087
# Add a message to the request that the finding was successfully deleted
10871088
messages.add_message(
10881089
request,
@@ -1353,7 +1354,7 @@ def copy_finding(request, fid):
13531354
test = form.cleaned_data.get("test")
13541355
product = finding.test.engagement.product
13551356
finding_copy = finding.copy(test=test)
1356-
calculate_grade(product.id)
1357+
dojo_dispatch_task(calculate_grade, product.id)
13571358
messages.add_message(
13581359
request,
13591360
messages.SUCCESS,

dojo/finding_group/views.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from dojo.authorization.authorization import user_has_permission_or_403
1717
from dojo.authorization.authorization_decorators import user_is_authorized
1818
from dojo.authorization.roles_permissions import Permissions
19+
from dojo.celery_dispatch import dojo_dispatch_task
1920
from dojo.filters import (
2021
FindingFilter,
2122
FindingFilterWithoutObjectLookups,
@@ -100,7 +101,7 @@ def view_finding_group(request, fgid):
100101
elif not finding_group.has_jira_issue:
101102
jira_helper.finding_group_link_jira(request, finding_group, jira_issue)
102103
elif push_to_jira:
103-
jira_helper.push_to_jira(finding_group, sync=True)
104+
dojo_dispatch_task(jira_helper.push_to_jira, finding_group, sync=True)
104105

105106
finding_group.save()
106107
return HttpResponseRedirect(reverse("view_test", args=(finding_group.test.id,)))
@@ -200,7 +201,7 @@ def push_to_jira(request, fgid):
200201

201202
# it may look like success here, but the push_to_jira are swallowing exceptions
202203
# but cant't change too much now without having a test suite, so leave as is for now with the addition warning message to check alerts for background errors.
203-
if jira_helper.push_to_jira(group, sync=True):
204+
if dojo_dispatch_task(jira_helper.push_to_jira, group, sync=True):
204205
messages.add_message(
205206
request,
206207
messages.SUCCESS,

0 commit comments

Comments
 (0)