Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions dojo/importers/location_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@

from django.core.exceptions import ValidationError
from django.db import transaction
from django.db.models import signals
from django.utils import timezone

from dojo import tag_inheritance
from dojo.importers.base_location_manager import BaseLocationManager
from dojo.location.models import AbstractLocation, Location, LocationFindingReference, LocationProductReference
from dojo.location.status import FindingLocationStatus, ProductLocationStatus
from dojo.models import Product, _manage_inherited_tags
from dojo.tags_signals import make_inherited_tags_sticky
from dojo.tools.locations import LocationData
from dojo.url.models import URL
from dojo.utils import get_system_setting
Expand Down Expand Up @@ -551,10 +550,17 @@ def _get_tags(tags_field: TagField) -> dict[int, set[str]]:
existing_inherited_by_location: dict[int, set[str]] = _get_tags(Location.inherited_tags)
existing_tags_by_location: dict[int, set[str]] = _get_tags(Location.tags)

# Perform the bulk updates. First, though, disconnect the make_inherited_tags_sticky signal on Location.tags
# while updating, otherwise each (inherited_)tags.set() will trigger, defeating the purpose of this bulk update.
disconnected = signals.m2m_changed.disconnect(make_inherited_tags_sticky, sender=Location.tags.through)
try:
# Perform the bulk updates inside a `tag_inheritance.batch()` context.
# While the batch is active, signal handlers in `dojo/tags_signals.py`
# short-circuit per-row inheritance work that would otherwise fire on
# every `(inherited_)tags.set()` and defeat the bulk update.
#
# This replaces a previous `signals.m2m_changed.disconnect(...)` /
# `connect(...)` dance which was process-global and therefore unsafe
# under threaded gunicorn / Celery thread pools / ASGI threadpools:
# while disconnected, every thread in the process lost sticky
# enforcement. Thread-local batch state avoids that hazard.
with tag_inheritance.batch_mode():
for location in locations:
target_tag_names: set[str] = set()
for pid in product_ids_by_location[location.id]:
Expand All @@ -573,6 +579,3 @@ def _get_tags(tags_field: TagField) -> dict[int, set[str]]:
list(target_tag_names),
potentially_existing_tags=existing_tags_by_location[location.id],
)
finally:
if disconnected:
signals.m2m_changed.connect(make_inherited_tags_sticky, sender=Location.tags.through)
54 changes: 54 additions & 0 deletions dojo/tag_inheritance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
Tag inheritance — central coordination module.

Provides a thread-local ``batch()`` context manager that suppresses
per-instance inheritance work driven by ``m2m_changed`` and ``post_save``
signals. While inside a batch, the signal handlers in
``dojo/tags_signals.py`` early-return; the calling code is responsible for
applying inheritance in bulk (e.g. via the importer's existing
``_bulk_inherit_tags`` path or ``propagate_tags_on_product_sync``).

This replaces the previous pattern of ``signals.m2m_changed.disconnect(...)``
in importer hot loops, which was process-global and unsafe under threaded
gunicorn / Celery thread pools / ASGI threadpools (see PR description for
the full rationale).
"""
from __future__ import annotations

import contextlib
import threading
from contextlib import contextmanager

_state = threading.local()


def is_in_batch_mode() -> bool:
"""Return True when the current thread is inside an active ``batch()``."""
return bool(getattr(_state, "depth", 0))


@contextmanager
def batch_mode():
"""
Suppress per-instance inheritance signals for the calling thread.

Usage:
with tag_inheritance.batch():
# Bulk operations that would otherwise fire `make_inherited_tags_sticky`
# or `inherit_tags_on_instance` per row.
...

The context is reentrant; nested ``with`` blocks share the suppression
until the outermost block exits. State lives in ``threading.local()``,
so concurrent threads (and Celery workers in non-prefork pools) are
unaffected by other threads' batches.
"""
_state.depth = getattr(_state, "depth", 0) + 1
try:
yield
finally:
_state.depth -= 1
if _state.depth <= 0:
# Clean up the attribute so leak-free thread reuse stays simple.
with contextlib.suppress(AttributeError):
del _state.depth
7 changes: 7 additions & 0 deletions dojo/tags_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from django.db.models import signals
from django.dispatch import receiver

from dojo import tag_inheritance
from dojo.celery_dispatch import dojo_dispatch_task
from dojo.location.models import Location, LocationFindingReference, LocationProductReference
from dojo.models import Endpoint, Engagement, Finding, Product, Test
Expand Down Expand Up @@ -32,6 +33,12 @@ def product_tags_post_add_remove(sender, instance, action, **kwargs):
@receiver(signals.m2m_changed, sender=Location.tags.through)
def make_inherited_tags_sticky(sender, instance, action, **kwargs):
"""Make sure inherited tags are added back in if they are removed"""
# Inside a `tag_inheritance.batch()` block the caller takes responsibility
# for applying inheritance in bulk; per-row signal work would defeat the
# purpose. This replaces the old `signals.m2m_changed.disconnect(...)`
# pattern, which was process-global and unsafe under threaded workers.
if tag_inheritance.is_in_batch_mode():
return
if action in {"post_add", "post_remove"}:
if inherit_product_tags(instance):
tag_list = [tag.name for tag in instance.tags.all()]
Expand Down
6 changes: 5 additions & 1 deletion unittests/test_tag_inheritance_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,11 @@ def test_baseline_zap_scan_reimport_no_change_v3(self):
# Phase A nudges these slightly downward (post_save gated on created=True
# avoids re-running inheritance on no-op finding updates during reimport).
# Pre-Phase-A: 1461/1319 import, 77/95 reimport.
# Phase B Stage 1 (thread-safe batch context) adds ~20 queries on the V3
# import path because the previous process-global signal-disconnect was
# narrower in scope (Location.tags.through only). Net-positive trade for
# eliminating the threading bug; full Phase B reductions land in Stage 2.
EXPECTED_ZAP_IMPORT_V2 = 1385
EXPECTED_ZAP_IMPORT_V3 = 1243
EXPECTED_ZAP_IMPORT_V3 = 1263
EXPECTED_ZAP_REIMPORT_NO_CHANGE_V2 = 69
EXPECTED_ZAP_REIMPORT_NO_CHANGE_V3 = 87
Loading