Skip to content

Commit dd5a95f

Browse files
initial base task
1 parent 8b90d52 commit dd5a95f

11 files changed

Lines changed: 166 additions & 102 deletions

File tree

dojo/celery.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,123 @@ def __call__(self, *args, **kwargs):
4444
app.autodiscover_tasks(lambda: settings.INSTALLED_APPS)
4545

4646

47+
class DojoAsyncTask(Task):
48+
49+
"""
50+
Base task class that provides dojo_async_task functionality without using a decorator.
51+
52+
This class:
53+
- Injects user context into task kwargs
54+
- Tracks task calls for performance testing
55+
- Handles sync/async execution based on user settings
56+
- Supports all Celery features (signatures, chords, groups, chains)
57+
"""
58+
59+
def apply_async(self, args=None, kwargs=None, **options):
60+
"""Override apply_async to inject user context and track tasks."""
61+
from dojo.decorators import dojo_async_task_counter # noqa: PLC0415 circular import
62+
from dojo.utils import get_current_user # noqa: PLC0415 circular import
63+
64+
if kwargs is None:
65+
kwargs = {}
66+
67+
# Inject user context if not already present
68+
if "async_user" not in kwargs:
69+
kwargs["async_user"] = get_current_user()
70+
71+
# Track task call (only if not already tracked by __call__)
72+
# Check if this is a direct call to apply_async (not from __call__)
73+
# by checking if _dojo_tracked is not set
74+
if not getattr(self, "_dojo_tracked", False):
75+
dojo_async_task_counter.incr(
76+
self.name,
77+
args=args,
78+
kwargs=kwargs,
79+
)
80+
81+
# Call parent to execute async
82+
return super().apply_async(args=args, kwargs=kwargs, **options)
83+
84+
def s(self, *args, **kwargs):
85+
"""Create a mutable signature with injected user context."""
86+
from dojo.decorators import dojo_async_task_counter # noqa: PLC0415 circular import
87+
from dojo.utils import get_current_user # noqa: PLC0415 circular import
88+
89+
if "async_user" not in kwargs:
90+
kwargs["async_user"] = get_current_user()
91+
92+
# Track task call
93+
dojo_async_task_counter.incr(
94+
self.name,
95+
args=args,
96+
kwargs=kwargs,
97+
)
98+
99+
return super().s(*args, **kwargs)
100+
101+
def si(self, *args, **kwargs):
102+
"""Create an immutable signature with injected user context."""
103+
from dojo.decorators import dojo_async_task_counter # noqa: PLC0415 circular import
104+
from dojo.utils import get_current_user # noqa: PLC0415 circular import
105+
106+
if "async_user" not in kwargs:
107+
kwargs["async_user"] = get_current_user()
108+
109+
# Track task call
110+
dojo_async_task_counter.incr(
111+
self.name,
112+
args=args,
113+
kwargs=kwargs,
114+
)
115+
116+
return super().si(*args, **kwargs)
117+
118+
def __call__(self, *args, **kwargs):
119+
"""
120+
Override __call__ to handle direct task calls with sync/async logic.
121+
122+
This replicates the behavior of the dojo_async_task decorator wrapper.
123+
"""
124+
# In Celery worker execution, __call__ is how tasks actually run.
125+
# We only want the sync/async decision when tasks are called directly
126+
# from application code (task(...)), not when the worker is executing a message.
127+
if not getattr(self.request, "called_directly", True):
128+
return super().__call__(*args, **kwargs)
129+
130+
from dojo.decorators import dojo_async_task_counter, we_want_async # noqa: PLC0415 circular import
131+
from dojo.utils import get_current_user # noqa: PLC0415 circular import
132+
133+
# Inject user context if not already present
134+
if "async_user" not in kwargs:
135+
kwargs["async_user"] = get_current_user()
136+
137+
# Track task call
138+
dojo_async_task_counter.incr(
139+
self.name,
140+
args=args,
141+
kwargs=kwargs,
142+
)
143+
144+
# Extract countdown if present (don't pass to sync execution)
145+
countdown = kwargs.pop("countdown", 0)
146+
147+
# Check if we should run async or sync
148+
if we_want_async(*args, func=self, **kwargs):
149+
# Mark as tracked to avoid double tracking in apply_async
150+
self._dojo_tracked = True
151+
try:
152+
# Run asynchronously
153+
return self.apply_async(args=args, kwargs=kwargs, countdown=countdown)
154+
finally:
155+
# Clean up the flag
156+
delattr(self, "_dojo_tracked")
157+
else:
158+
# Run synchronously in-process, matching the original decorator behavior: func(*args, **kwargs)
159+
# Remove sync from kwargs as it's a control flag, not a task argument.
160+
kwargs.pop("sync", None)
161+
return self.run(*args, **kwargs)
162+
163+
47164
@app.task(bind=True)
48165
def debug_task(self):
49166
logger.info(f"Request: {self.request!r}")

dojo/finding/deduplication.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
from django.db.models import Prefetch
88
from django.db.models.query_utils import Q
99

10-
from dojo.celery import app
11-
from dojo.decorators import dojo_async_task
10+
from dojo.celery import DojoAsyncTask, app
1211
from dojo.models import Finding, System_Settings
1312

1413
logger = logging.getLogger(__name__)
@@ -45,14 +44,12 @@ def get_finding_models_for_deduplication(finding_ids):
4544
)
4645

4746

48-
@dojo_async_task
49-
@app.task
47+
@app.task(base=DojoAsyncTask)
5048
def do_dedupe_finding_task(new_finding_id, *args, **kwargs):
5149
return do_dedupe_finding_task_internal(Finding.objects.get(id=new_finding_id), *args, **kwargs)
5250

5351

54-
@dojo_async_task
55-
@app.task
52+
@app.task(base=DojoAsyncTask)
5653
def do_dedupe_batch_task(finding_ids, *args, **kwargs):
5754
"""
5855
Async task to deduplicate a batch of findings. The findings are assumed to be in the same test.

dojo/finding/helper.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515

1616
import dojo.jira_link.helper as jira_helper
1717
import dojo.risk_acceptance.helper as ra_helper
18-
from dojo.celery import app
19-
from dojo.decorators import dojo_async_task
18+
from dojo.celery import DojoAsyncTask, app
2019
from dojo.endpoint.utils import endpoint_get_or_create, save_endpoints_to_add
2120
from dojo.file_uploads.helper import delete_related_files
2221
from dojo.finding.deduplication import (
@@ -391,8 +390,7 @@ def add_findings_to_auto_group(name, findings, group_by, *, create_finding_group
391390
finding_group.findings.add(*findings)
392391

393392

394-
@dojo_async_task
395-
@app.task
393+
@app.task(base=DojoAsyncTask)
396394
def post_process_finding_save(finding_id, dedupe_option=True, rules_option=True, product_grading_option=True, # noqa: FBT002
397395
issue_updater_option=True, push_to_jira=False, user=None, *args, **kwargs): # noqa: FBT002 - this is bit hard to fix nice have this universally fixed
398396
finding = get_object_or_none(Finding, id=finding_id)
@@ -453,8 +451,7 @@ def post_process_finding_save_internal(finding, dedupe_option=True, rules_option
453451
jira_helper.push_to_jira(finding.finding_group)
454452

455453

456-
@dojo_async_task
457-
@app.task
454+
@app.task(base=DojoAsyncTask)
458455
def post_process_findings_batch(finding_ids, *args, dedupe_option=True, rules_option=True, product_grading_option=True,
459456
issue_updater_option=True, push_to_jira=False, user=None, **kwargs):
460457

dojo/importers/endpoint_manager.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from django.urls import reverse
55
from django.utils import timezone
66

7-
from dojo.celery import app
8-
from dojo.decorators import dojo_async_task
7+
from dojo.celery import DojoAsyncTask, app
98
from dojo.endpoint.utils import endpoint_get_or_create
109
from dojo.models import (
1110
Dojo_User,
@@ -18,17 +17,15 @@
1817

1918

2019
class EndpointManager:
21-
@dojo_async_task
22-
@app.task()
20+
@app.task(base=DojoAsyncTask)
2321
def add_endpoints_to_unsaved_finding(
24-
self,
25-
finding: Finding,
22+
finding: Finding, # noqa: N805
2623
endpoints: list[Endpoint],
2724
**kwargs: dict,
2825
) -> None:
2926
"""Creates Endpoint objects for a single finding and creates the link via the endpoint status"""
3027
logger.debug(f"IMPORT_SCAN: Adding {len(endpoints)} endpoints to finding: {finding}")
31-
self.clean_unsaved_endpoints(endpoints)
28+
EndpointManager.clean_unsaved_endpoints(endpoints)
3229
for endpoint in endpoints:
3330
ep = None
3431
eps = []
@@ -41,7 +38,8 @@ def add_endpoints_to_unsaved_finding(
4138
path=endpoint.path,
4239
query=endpoint.query,
4340
fragment=endpoint.fragment,
44-
product=finding.test.engagement.product)
41+
product=finding.test.engagement.product,
42+
)
4543
eps.append(ep)
4644
except (MultipleObjectsReturned):
4745
msg = (
@@ -58,11 +56,9 @@ def add_endpoints_to_unsaved_finding(
5856

5957
logger.debug(f"IMPORT_SCAN: {len(endpoints)} endpoints imported")
6058

61-
@dojo_async_task
62-
@app.task()
59+
@app.task(base=DojoAsyncTask)
6360
def mitigate_endpoint_status(
64-
self,
65-
endpoint_status_list: list[Endpoint_Status],
61+
endpoint_status_list: list[Endpoint_Status], # noqa: N805
6662
user: Dojo_User,
6763
**kwargs: dict,
6864
) -> None:
@@ -85,11 +81,9 @@ def mitigate_endpoint_status(
8581
batch_size=1000,
8682
)
8783

88-
@dojo_async_task
89-
@app.task()
84+
@app.task(base=DojoAsyncTask)
9085
def reactivate_endpoint_status(
91-
self,
92-
endpoint_status_list: list[Endpoint_Status],
86+
endpoint_status_list: list[Endpoint_Status], # noqa: N805
9387
**kwargs: dict,
9488
) -> None:
9589
"""Reactivate all endpoint status objects that are supplied"""
@@ -120,8 +114,8 @@ def chunk_endpoints_and_disperse(
120114
) -> None:
121115
self.add_endpoints_to_unsaved_finding(finding, endpoints, sync=True)
122116

117+
@staticmethod
123118
def clean_unsaved_endpoints(
124-
self,
125119
endpoints: list[Endpoint],
126120
) -> None:
127121
"""

dojo/jira_link/helper.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
from jira.exceptions import JIRAError
1818
from requests.auth import HTTPBasicAuth
1919

20-
from dojo.celery import app
21-
from dojo.decorators import dojo_async_task
20+
from dojo.celery import DojoAsyncTask, app
2221
from dojo.forms import JIRAEngagementForm, JIRAProjectForm
2322
from dojo.models import (
2423
Engagement,
@@ -773,8 +772,7 @@ def push_to_jira(obj, *args, **kwargs):
773772

774773

775774
# we need thre separate celery tasks due to the decorators we're using to map to/from ids
776-
@dojo_async_task
777-
@app.task
775+
@app.task(base=DojoAsyncTask)
778776
def push_finding_to_jira(finding_id, *args, **kwargs):
779777
finding = get_object_or_none(Finding, id=finding_id)
780778
if not finding:
@@ -786,8 +784,7 @@ def push_finding_to_jira(finding_id, *args, **kwargs):
786784
return add_jira_issue(finding, *args, **kwargs)
787785

788786

789-
@dojo_async_task
790-
@app.task
787+
@app.task(base=DojoAsyncTask)
791788
def push_finding_group_to_jira(finding_group_id, *args, **kwargs):
792789
finding_group = get_object_or_none(Finding_Group, id=finding_group_id)
793790
if not finding_group:
@@ -803,8 +800,7 @@ def push_finding_group_to_jira(finding_group_id, *args, **kwargs):
803800
return add_jira_issue(finding_group, *args, **kwargs)
804801

805802

806-
@dojo_async_task
807-
@app.task
803+
@app.task(base=DojoAsyncTask)
808804
def push_engagement_to_jira(engagement_id, *args, **kwargs):
809805
engagement = get_object_or_none(Engagement, id=engagement_id)
810806
if not engagement:
@@ -1376,8 +1372,7 @@ def jira_check_attachment(issue, source_file_name):
13761372
return file_exists
13771373

13781374

1379-
@dojo_async_task
1380-
@app.task
1375+
@app.task(base=DojoAsyncTask)
13811376
def close_epic(engagement_id, push_to_jira, **kwargs):
13821377
engagement = get_object_or_none(Engagement, id=engagement_id)
13831378
if not engagement:
@@ -1425,8 +1420,7 @@ def close_epic(engagement_id, push_to_jira, **kwargs):
14251420
return False
14261421

14271422

1428-
@dojo_async_task
1429-
@app.task
1423+
@app.task(base=DojoAsyncTask)
14301424
def update_epic(engagement_id, **kwargs):
14311425
engagement = get_object_or_none(Engagement, id=engagement_id)
14321426
if not engagement:
@@ -1472,8 +1466,7 @@ def update_epic(engagement_id, **kwargs):
14721466
return False
14731467

14741468

1475-
@dojo_async_task
1476-
@app.task
1469+
@app.task(base=DojoAsyncTask)
14771470
def add_epic(engagement_id, **kwargs):
14781471
engagement = get_object_or_none(Engagement, id=engagement_id)
14791472
if not engagement:
@@ -1584,8 +1577,7 @@ def add_comment(obj, note, *, force_push=False, **kwargs):
15841577
return add_comment_internal(jira_issue.id, note.id, force_push=force_push, **kwargs)
15851578

15861579

1587-
@dojo_async_task
1588-
@app.task
1580+
@app.task(base=DojoAsyncTask)
15891581
def add_comment_internal(jira_issue_id, note_id, *, force_push=False, **kwargs):
15901582
"""Internal Celery task that adds a comment to a JIRA issue."""
15911583
jira_issue = get_object_or_none(JIRA_Issue, id=jira_issue_id)

dojo/notifications/helper.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
from dojo import __version__ as dd_version
1919
from dojo.authorization.roles_permissions import Permissions
20-
from dojo.celery import app
21-
from dojo.decorators import dojo_async_task, we_want_async
20+
from dojo.celery import DojoAsyncTask, app
21+
from dojo.decorators import we_want_async
2222
from dojo.labels import get_labels
2323
from dojo.models import (
2424
Alerts,
@@ -199,8 +199,7 @@ class SlackNotificationManger(NotificationManagerHelpers):
199199

200200
"""Manger for slack notifications and their helpers."""
201201

202-
@dojo_async_task
203-
@app.task
202+
@app.task(base=DojoAsyncTask)
204203
def send_slack_notification(
205204
self,
206205
event: str,
@@ -317,8 +316,7 @@ class MSTeamsNotificationManger(NotificationManagerHelpers):
317316

318317
"""Manger for Microsoft Teams notifications and their helpers."""
319318

320-
@dojo_async_task
321-
@app.task
319+
@app.task(base=DojoAsyncTask)
322320
def send_msteams_notification(
323321
self,
324322
event: str,
@@ -368,8 +366,7 @@ class EmailNotificationManger(NotificationManagerHelpers):
368366

369367
"""Manger for email notifications and their helpers."""
370368

371-
@dojo_async_task
372-
@app.task
369+
@app.task(base=DojoAsyncTask)
373370
def send_mail_notification(
374371
self,
375372
event: str,
@@ -420,8 +417,7 @@ class WebhookNotificationManger(NotificationManagerHelpers):
420417
ERROR_PERMANENT = "permanent"
421418
ERROR_TEMPORARY = "temporary"
422419

423-
@dojo_async_task
424-
@app.task
420+
@app.task(base=DojoAsyncTask)
425421
def send_webhooks_notification(
426422
self,
427423
event: str,

0 commit comments

Comments
 (0)