Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions dojo/importers/default_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Test_Import,
)
from dojo.notifications.helper import create_notification
from dojo.tag_utils import bulk_apply_parser_tags
from dojo.utils import get_full_url, perform_product_grading
from dojo.validators import clean_tags

Expand Down Expand Up @@ -179,6 +180,7 @@ def process_findings(
at import time
"""
new_findings = []
findings_with_parser_tags: list[tuple] = []
logger.debug("starting import of %i parsed findings.", len(parsed_findings) if parsed_findings else 0)
group_names_to_findings_dict = {}

Expand Down Expand Up @@ -245,12 +247,13 @@ def process_findings(
# TODO: Delete this after the move to Locations
# Process any endpoints on the finding, or added on the form
self.process_endpoints(finding, self.endpoints_to_add)
# Parsers must use unsaved_tags to store tags, so we can clean them
# Parsers must use unsaved_tags to store tags, so we can clean them.
# Accumulate for bulk application after the loop (O(unique_tags) instead of O(N·T)).
cleaned_tags = clean_tags(finding.unsaved_tags)
if isinstance(cleaned_tags, list):
finding.tags.add(*cleaned_tags)
findings_with_parser_tags.append((finding, cleaned_tags))
elif isinstance(cleaned_tags, str):
finding.tags.add(cleaned_tags)
findings_with_parser_tags.append((finding, [cleaned_tags]))
# Process any files
self.process_files(finding)
# Process vulnerability IDs
Expand All @@ -268,6 +271,12 @@ def process_findings(
if len(batch_finding_ids) >= batch_max_size or is_final_finding:
if not settings.V3_FEATURE_LOCATIONS:
self.endpoint_manager.persist(user=self.user)

# Apply parser-supplied tags for this batch before post-processing starts,
# so rules/deduplication tasks see the tags already on the findings.
bulk_apply_parser_tags(findings_with_parser_tags)
findings_with_parser_tags.clear()

finding_ids_batch = list(batch_finding_ids)
batch_finding_ids.clear()
logger.debug("process_findings: dispatching batch with push_to_jira=%s (batch_size=%d, is_final=%s)",
Expand Down
17 changes: 16 additions & 1 deletion dojo/importers/default_reimporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Test,
Test_Import,
)
from dojo.tag_utils import bulk_apply_parser_tags
from dojo.utils import perform_product_grading
from dojo.validators import clean_tags

Expand Down Expand Up @@ -310,6 +311,7 @@ def process_findings(
cleaned_findings.append(sanitized)

batch_finding_ids: list[int] = []
findings_with_parser_tags: list[tuple] = []
# Batch size for deduplication/post-processing (only new findings)
dedupe_batch_max_size = getattr(settings, "IMPORT_REIMPORT_DEDUPE_BATCH_SIZE", 1000)
# Batch size for candidate matching (all findings, before matching)
Expand Down Expand Up @@ -417,6 +419,7 @@ def process_findings(
finding,
unsaved_finding,
is_matched_finding=bool(matched_findings),
tag_accumulator=findings_with_parser_tags,
)
# all data is already saved on the finding, we only need to trigger post processing in batches
push_to_jira = self.push_to_jira and ((not self.findings_groups_enabled or not self.group_by) or not finding_will_be_grouped)
Expand All @@ -440,6 +443,12 @@ def process_findings(
if len(batch_finding_ids) >= dedupe_batch_max_size or is_final:
if not settings.V3_FEATURE_LOCATIONS:
self.endpoint_manager.persist(user=self.user)

# Apply parser-supplied tags for this batch before post-processing starts,
# so rules/deduplication tasks see the tags already on the findings.
bulk_apply_parser_tags(findings_with_parser_tags)
findings_with_parser_tags.clear()

finding_ids_batch = list(batch_finding_ids)
batch_finding_ids.clear()
dojo_dispatch_task(
Expand Down Expand Up @@ -976,6 +985,7 @@ def finding_post_processing(
finding_from_report: Finding,
*,
is_matched_finding: bool = False,
tag_accumulator: list | None = None,
) -> Finding:
"""
Save all associated objects to the finding after it has been saved
Expand Down Expand Up @@ -1006,7 +1016,12 @@ def finding_post_processing(
finding_from_report.unsaved_tags = merged_tags
if finding_from_report.unsaved_tags:
cleaned_tags = clean_tags(finding_from_report.unsaved_tags)
if isinstance(cleaned_tags, list):
if tag_accumulator is not None:
if isinstance(cleaned_tags, list):
tag_accumulator.append((finding, cleaned_tags))
elif isinstance(cleaned_tags, str):
tag_accumulator.append((finding, [cleaned_tags]))
elif isinstance(cleaned_tags, list):
finding.tags.add(*cleaned_tags)
elif isinstance(cleaned_tags, str):
finding.tags.add(cleaned_tags)
Expand Down
186 changes: 185 additions & 1 deletion dojo/tag_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,190 @@ def bulk_add_tags_to_instances(tag_or_tags, instances, tag_field_name: str = "ta
return total_created


def bulk_add_tag_mapping(
tag_to_instances: dict[str, list],
tag_field_name: str = "tags",
batch_size: int | None = None,
) -> int:
"""
Add different tags to different sets of instances in ~5 queries regardless of tag count.

Unlike calling ``bulk_add_tags_to_instances`` once per unique tag — which issues
O(unique_tags) queries — this function batches all work:

1. Fetch all existing tag objects in one query.
2. Bulk-create any missing tag objects (one INSERT + one re-fetch if needed).
3. Fetch all pre-existing through-model rows for these (instance, tag) pairs in one query.
4. Bulk-create all new relationships in one query (batched by ``batch_size``).
5. Update all tag counts in one ``UPDATE … CASE WHEN …`` query.

Args:
tag_to_instances: mapping of tag_name -> list of instances that should receive
that tag. All instances must be of the same model type.
tag_field_name: name of the TagField on the model (default: ``"tags"``).
batch_size: ``bulk_create`` batch size; defaults to ``TAG_BULK_ADD_BATCH_SIZE``
setting (1000).

Returns:
Total number of new tag relationships created.

"""
from collections import defaultdict # noqa: PLC0415

from django.db.models import Case, IntegerField, When # noqa: PLC0415
from django.db.models.functions import Lower # noqa: PLC0415

if not tag_to_instances:
return 0

if batch_size is None:
batch_size = getattr(settings, "TAG_BULK_ADD_BATCH_SIZE", 1000)

all_instances = [inst for insts in tag_to_instances.values() for inst in insts]
if not all_instances:
return 0

model_class = all_instances[0].__class__

if model_class is Product:
msg = "bulk_add_tag_mapping: Product instances are not supported; use Product.tags.add() or a propagation-aware helper"
raise ValueError(msg)

try:
tag_field = model_class._meta.get_field(tag_field_name)
except Exception:
msg = f"Model {model_class.__name__} does not have field '{tag_field_name}'"
raise ValueError(msg)

if not hasattr(tag_field, "tag_options"):
msg = f"Field '{tag_field_name}' is not a TagField"
raise ValueError(msg)

tag_model = tag_field.related_model
through_model = tag_field.remote_field.through
case_sensitive = tag_field.tag_options.case_sensitive

source_field_name = None
target_field_name = None
for field in through_model._meta.fields:
if hasattr(field, "remote_field") and field.remote_field:
if field.remote_field.model == model_class:
source_field_name = field.name
elif field.remote_field.model == tag_model:
target_field_name = field.name

all_tag_names = list(tag_to_instances.keys())

def _key(name: str) -> str:
return name if case_sensitive else name.lower()

# --- Query 1: fetch existing tag objects ---
if case_sensitive:
existing_tags: dict[str, object] = {
t.name: t
for t in tag_model.objects.filter(name__in=all_tag_names)
}
missing_names = [n for n in all_tag_names if n not in existing_tags]
else:
# Annotate with lowercased name for a case-insensitive IN lookup
existing_tags = {
t.name_lower: t
for t in tag_model.objects.annotate(name_lower=Lower("name")).filter(
name_lower__in=[n.lower() for n in all_tag_names],
)
}
missing_names = [n for n in all_tag_names if n.lower() not in existing_tags]

# --- Query 2: create missing tag objects ---
# Use get_or_create to call model.save(), which lets tagulous generate the slug field.
# bulk_create bypasses save() so slug is never set, causing unique constraint failures.
if missing_names:
for n in missing_names:
if case_sensitive:
tag, _ = tag_model.objects.get_or_create(name=n, defaults={"protected": False})
else:
tag, _ = tag_model.objects.get_or_create(name__iexact=n, defaults={"name": n, "protected": False})
existing_tags[_key(n)] = tag

# --- Query 3: fetch all pre-existing (instance, tag) through-model rows ---
all_instance_ids = {inst.pk for inst in all_instances}
all_tag_pks = {tag.pk for tag in existing_tags.values()}

existing_pairs: set[tuple] = set(
through_model.objects.filter(
**{f"{source_field_name}__in": all_instance_ids},
**{f"{target_field_name}__in": all_tag_pks},
).values_list(source_field_name, target_field_name),
)

new_relationships = []
created_per_tag: dict[int, int] = defaultdict(int)

for tag_name, instances in tag_to_instances.items():
tag = existing_tags.get(_key(tag_name))
if tag is None:
continue
for instance in instances:
if (instance.pk, tag.pk) not in existing_pairs:
new_relationships.append(
through_model(**{source_field_name: instance, target_field_name: tag}),
)
created_per_tag[tag.pk] += 1

if not new_relationships:
return 0

# --- Query 4: bulk-create all new relationships (batched for memory) ---
# Use len(new_relationships) for the count: existing pairs were already filtered out above,
# so every entry here is new. bulk_create return value is unreliable with ignore_conflicts.
total_created = len(new_relationships)
with transaction.atomic():
for i in range(0, len(new_relationships), batch_size):
batch = new_relationships[i : i + batch_size]
through_model.objects.bulk_create(batch, ignore_conflicts=True)

# --- Query 5: update all tag counts in one UPDATE … CASE WHEN … ---
tag_model.objects.filter(pk__in=list(created_per_tag.keys())).update(
count=Case(
*[
When(pk=pk, then=models.F("count") + delta)
for pk, delta in created_per_tag.items()
],
output_field=IntegerField(),
),
)

for instance in all_instances:
prefetch_cache = getattr(instance, "_prefetched_objects_cache", None)
if prefetch_cache is not None:
prefetch_cache.pop(tag_field_name, None)

return total_created


def bulk_apply_parser_tags(findings_with_tags: list) -> None:
"""
Bulk-apply per-finding parser tags collected during an import loop.

Delegates to ``bulk_add_tag_mapping`` to process all tags in ~5 queries total,
regardless of how many unique tag values the parser produced.

Args:
findings_with_tags: list of ``(finding, [tag_str, ...])`` pairs accumulated
during the import loop (only for findings whose parser supplied tags).

"""
from collections import defaultdict # noqa: PLC0415

tag_to_findings: dict = defaultdict(list)
for finding, tag_list in findings_with_tags:
for tag in tag_list:
if tag:
tag_to_findings[tag].append(finding)

bulk_add_tag_mapping(tag_to_findings)


def bulk_remove_all_tags(model_class, instance_ids_qs):
"""
Remove all tags from instances identified by the given ID subquery.
Expand Down Expand Up @@ -226,4 +410,4 @@ def bulk_remove_all_tags(model_class, instance_ids_qs):
)


__all__ = ["bulk_add_tags_to_instances", "bulk_remove_all_tags"]
__all__ = ["bulk_add_tag_mapping", "bulk_add_tags_to_instances", "bulk_apply_parser_tags", "bulk_remove_all_tags"]
Loading
Loading