Skip to content

Commit 0349f01

Browse files
perf: bulk-apply parser-supplied per-finding tags during import (#14701)
* perf: bulk-apply parser-supplied per-finding tags during import finding.tags.add() per finding calls tagulous's add() which does: - reload() → SELECT current tags (1 query) - _ensure_tags_in_db() → get_or_create per tag (T queries) - super().add() → INSERT through-table rows (1 query) - tag.increment() → UPDATE count per tag (T queries) For N findings with T parser-supplied tags: O(N·T) queries. Replace with bulk_apply_parser_tags() in tag_utils, which groups findings by tag name and calls bulk_add_tags_to_instances() once per unique tag: O(unique_tags) queries regardless of N. Tags are accumulated per batch and applied just before the post_process_findings_batch task is dispatched, so deduplication and rules tasks see the tags already written to the DB. Both default_importer and default_reimporter use the same approach. For the reimporter, finding_post_processing accepts an optional tag_accumulator list; when supplied, tags are accumulated rather than applied inline (backward-compatible for any direct callers). * chore: fix ruff linting errors in bulk-tag import code * improve bulk add tags for parsers * ruff * fix tag creation * fix tests * fix tests
1 parent 1fa86bc commit 0349f01

File tree

5 files changed

+394
-6
lines changed

5 files changed

+394
-6
lines changed

dojo/importers/default_importer.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Test_Import,
2020
)
2121
from dojo.notifications.helper import create_notification
22+
from dojo.tag_utils import bulk_apply_parser_tags
2223
from dojo.utils import get_full_url, perform_product_grading
2324
from dojo.validators import clean_tags
2425

@@ -179,6 +180,7 @@ def process_findings(
179180
at import time
180181
"""
181182
new_findings = []
183+
findings_with_parser_tags: list[tuple] = []
182184
logger.debug("starting import of %i parsed findings.", len(parsed_findings) if parsed_findings else 0)
183185
group_names_to_findings_dict = {}
184186

@@ -245,12 +247,13 @@ def process_findings(
245247
# TODO: Delete this after the move to Locations
246248
# Process any endpoints on the finding, or added on the form
247249
self.process_endpoints(finding, self.endpoints_to_add)
248-
# Parsers must use unsaved_tags to store tags, so we can clean them
250+
# Parsers must use unsaved_tags to store tags, so we can clean them.
251+
# Accumulate for bulk application after the loop (O(unique_tags) instead of O(N·T)).
249252
cleaned_tags = clean_tags(finding.unsaved_tags)
250253
if isinstance(cleaned_tags, list):
251-
finding.tags.add(*cleaned_tags)
254+
findings_with_parser_tags.append((finding, cleaned_tags))
252255
elif isinstance(cleaned_tags, str):
253-
finding.tags.add(cleaned_tags)
256+
findings_with_parser_tags.append((finding, [cleaned_tags]))
254257
# Process any files
255258
self.process_files(finding)
256259
# Process vulnerability IDs
@@ -268,6 +271,12 @@ def process_findings(
268271
if len(batch_finding_ids) >= batch_max_size or is_final_finding:
269272
if not settings.V3_FEATURE_LOCATIONS:
270273
self.endpoint_manager.persist(user=self.user)
274+
275+
# Apply parser-supplied tags for this batch before post-processing starts,
276+
# so rules/deduplication tasks see the tags already on the findings.
277+
bulk_apply_parser_tags(findings_with_parser_tags)
278+
findings_with_parser_tags.clear()
279+
271280
finding_ids_batch = list(batch_finding_ids)
272281
batch_finding_ids.clear()
273282
logger.debug("process_findings: dispatching batch with push_to_jira=%s (batch_size=%d, is_final=%s)",

dojo/importers/default_reimporter.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Test,
2727
Test_Import,
2828
)
29+
from dojo.tag_utils import bulk_apply_parser_tags
2930
from dojo.utils import perform_product_grading
3031
from dojo.validators import clean_tags
3132

@@ -310,6 +311,7 @@ def process_findings(
310311
cleaned_findings.append(sanitized)
311312

312313
batch_finding_ids: list[int] = []
314+
findings_with_parser_tags: list[tuple] = []
313315
# Batch size for deduplication/post-processing (only new findings)
314316
dedupe_batch_max_size = getattr(settings, "IMPORT_REIMPORT_DEDUPE_BATCH_SIZE", 1000)
315317
# Batch size for candidate matching (all findings, before matching)
@@ -417,6 +419,7 @@ def process_findings(
417419
finding,
418420
unsaved_finding,
419421
is_matched_finding=bool(matched_findings),
422+
tag_accumulator=findings_with_parser_tags,
420423
)
421424
# all data is already saved on the finding, we only need to trigger post processing in batches
422425
push_to_jira = self.push_to_jira and ((not self.findings_groups_enabled or not self.group_by) or not finding_will_be_grouped)
@@ -440,6 +443,12 @@ def process_findings(
440443
if len(batch_finding_ids) >= dedupe_batch_max_size or is_final:
441444
if not settings.V3_FEATURE_LOCATIONS:
442445
self.endpoint_manager.persist(user=self.user)
446+
447+
# Apply parser-supplied tags for this batch before post-processing starts,
448+
# so rules/deduplication tasks see the tags already on the findings.
449+
bulk_apply_parser_tags(findings_with_parser_tags)
450+
findings_with_parser_tags.clear()
451+
443452
finding_ids_batch = list(batch_finding_ids)
444453
batch_finding_ids.clear()
445454
dojo_dispatch_task(
@@ -976,6 +985,7 @@ def finding_post_processing(
976985
finding_from_report: Finding,
977986
*,
978987
is_matched_finding: bool = False,
988+
tag_accumulator: list | None = None,
979989
) -> Finding:
980990
"""
981991
Save all associated objects to the finding after it has been saved
@@ -1006,7 +1016,12 @@ def finding_post_processing(
10061016
finding_from_report.unsaved_tags = merged_tags
10071017
if finding_from_report.unsaved_tags:
10081018
cleaned_tags = clean_tags(finding_from_report.unsaved_tags)
1009-
if isinstance(cleaned_tags, list):
1019+
if tag_accumulator is not None:
1020+
if isinstance(cleaned_tags, list):
1021+
tag_accumulator.append((finding, cleaned_tags))
1022+
elif isinstance(cleaned_tags, str):
1023+
tag_accumulator.append((finding, [cleaned_tags]))
1024+
elif isinstance(cleaned_tags, list):
10101025
finding.tags.add(*cleaned_tags)
10111026
elif isinstance(cleaned_tags, str):
10121027
finding.tags.add(cleaned_tags)

dojo/tag_utils.py

Lines changed: 185 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,190 @@ def bulk_add_tags_to_instances(tag_or_tags, instances, tag_field_name: str = "ta
164164
return total_created
165165

166166

167+
def bulk_add_tag_mapping(
168+
tag_to_instances: dict[str, list],
169+
tag_field_name: str = "tags",
170+
batch_size: int | None = None,
171+
) -> int:
172+
"""
173+
Add different tags to different sets of instances in ~5 queries regardless of tag count.
174+
175+
Unlike calling ``bulk_add_tags_to_instances`` once per unique tag — which issues
176+
O(unique_tags) queries — this function batches all work:
177+
178+
1. Fetch all existing tag objects in one query.
179+
2. Bulk-create any missing tag objects (one INSERT + one re-fetch if needed).
180+
3. Fetch all pre-existing through-model rows for these (instance, tag) pairs in one query.
181+
4. Bulk-create all new relationships in one query (batched by ``batch_size``).
182+
5. Update all tag counts in one ``UPDATE … CASE WHEN …`` query.
183+
184+
Args:
185+
tag_to_instances: mapping of tag_name -> list of instances that should receive
186+
that tag. All instances must be of the same model type.
187+
tag_field_name: name of the TagField on the model (default: ``"tags"``).
188+
batch_size: ``bulk_create`` batch size; defaults to ``TAG_BULK_ADD_BATCH_SIZE``
189+
setting (1000).
190+
191+
Returns:
192+
Total number of new tag relationships created.
193+
194+
"""
195+
from collections import defaultdict # noqa: PLC0415
196+
197+
from django.db.models import Case, IntegerField, When # noqa: PLC0415
198+
from django.db.models.functions import Lower # noqa: PLC0415
199+
200+
if not tag_to_instances:
201+
return 0
202+
203+
if batch_size is None:
204+
batch_size = getattr(settings, "TAG_BULK_ADD_BATCH_SIZE", 1000)
205+
206+
all_instances = [inst for insts in tag_to_instances.values() for inst in insts]
207+
if not all_instances:
208+
return 0
209+
210+
model_class = all_instances[0].__class__
211+
212+
if model_class is Product:
213+
msg = "bulk_add_tag_mapping: Product instances are not supported; use Product.tags.add() or a propagation-aware helper"
214+
raise ValueError(msg)
215+
216+
try:
217+
tag_field = model_class._meta.get_field(tag_field_name)
218+
except Exception:
219+
msg = f"Model {model_class.__name__} does not have field '{tag_field_name}'"
220+
raise ValueError(msg)
221+
222+
if not hasattr(tag_field, "tag_options"):
223+
msg = f"Field '{tag_field_name}' is not a TagField"
224+
raise ValueError(msg)
225+
226+
tag_model = tag_field.related_model
227+
through_model = tag_field.remote_field.through
228+
case_sensitive = tag_field.tag_options.case_sensitive
229+
230+
source_field_name = None
231+
target_field_name = None
232+
for field in through_model._meta.fields:
233+
if hasattr(field, "remote_field") and field.remote_field:
234+
if field.remote_field.model == model_class:
235+
source_field_name = field.name
236+
elif field.remote_field.model == tag_model:
237+
target_field_name = field.name
238+
239+
all_tag_names = list(tag_to_instances.keys())
240+
241+
def _key(name: str) -> str:
242+
return name if case_sensitive else name.lower()
243+
244+
# --- Query 1: fetch existing tag objects ---
245+
if case_sensitive:
246+
existing_tags: dict[str, object] = {
247+
t.name: t
248+
for t in tag_model.objects.filter(name__in=all_tag_names)
249+
}
250+
missing_names = [n for n in all_tag_names if n not in existing_tags]
251+
else:
252+
# Annotate with lowercased name for a case-insensitive IN lookup
253+
existing_tags = {
254+
t.name_lower: t
255+
for t in tag_model.objects.annotate(name_lower=Lower("name")).filter(
256+
name_lower__in=[n.lower() for n in all_tag_names],
257+
)
258+
}
259+
missing_names = [n for n in all_tag_names if n.lower() not in existing_tags]
260+
261+
# --- Query 2: create missing tag objects ---
262+
# Use get_or_create to call model.save(), which lets tagulous generate the slug field.
263+
# bulk_create bypasses save() so slug is never set, causing unique constraint failures.
264+
if missing_names:
265+
for n in missing_names:
266+
if case_sensitive:
267+
tag, _ = tag_model.objects.get_or_create(name=n, defaults={"protected": False})
268+
else:
269+
tag, _ = tag_model.objects.get_or_create(name__iexact=n, defaults={"name": n, "protected": False})
270+
existing_tags[_key(n)] = tag
271+
272+
# --- Query 3: fetch all pre-existing (instance, tag) through-model rows ---
273+
all_instance_ids = {inst.pk for inst in all_instances}
274+
all_tag_pks = {tag.pk for tag in existing_tags.values()}
275+
276+
existing_pairs: set[tuple] = set(
277+
through_model.objects.filter(
278+
**{f"{source_field_name}__in": all_instance_ids},
279+
**{f"{target_field_name}__in": all_tag_pks},
280+
).values_list(source_field_name, target_field_name),
281+
)
282+
283+
new_relationships = []
284+
created_per_tag: dict[int, int] = defaultdict(int)
285+
286+
for tag_name, instances in tag_to_instances.items():
287+
tag = existing_tags.get(_key(tag_name))
288+
if tag is None:
289+
continue
290+
for instance in instances:
291+
if (instance.pk, tag.pk) not in existing_pairs:
292+
new_relationships.append(
293+
through_model(**{source_field_name: instance, target_field_name: tag}),
294+
)
295+
created_per_tag[tag.pk] += 1
296+
297+
if not new_relationships:
298+
return 0
299+
300+
# --- Query 4: bulk-create all new relationships (batched for memory) ---
301+
# Use len(new_relationships) for the count: existing pairs were already filtered out above,
302+
# so every entry here is new. bulk_create return value is unreliable with ignore_conflicts.
303+
total_created = len(new_relationships)
304+
with transaction.atomic():
305+
for i in range(0, len(new_relationships), batch_size):
306+
batch = new_relationships[i : i + batch_size]
307+
through_model.objects.bulk_create(batch, ignore_conflicts=True)
308+
309+
# --- Query 5: update all tag counts in one UPDATE … CASE WHEN … ---
310+
tag_model.objects.filter(pk__in=list(created_per_tag.keys())).update(
311+
count=Case(
312+
*[
313+
When(pk=pk, then=models.F("count") + delta)
314+
for pk, delta in created_per_tag.items()
315+
],
316+
output_field=IntegerField(),
317+
),
318+
)
319+
320+
for instance in all_instances:
321+
prefetch_cache = getattr(instance, "_prefetched_objects_cache", None)
322+
if prefetch_cache is not None:
323+
prefetch_cache.pop(tag_field_name, None)
324+
325+
return total_created
326+
327+
328+
def bulk_apply_parser_tags(findings_with_tags: list) -> None:
329+
"""
330+
Bulk-apply per-finding parser tags collected during an import loop.
331+
332+
Delegates to ``bulk_add_tag_mapping`` to process all tags in ~5 queries total,
333+
regardless of how many unique tag values the parser produced.
334+
335+
Args:
336+
findings_with_tags: list of ``(finding, [tag_str, ...])`` pairs accumulated
337+
during the import loop (only for findings whose parser supplied tags).
338+
339+
"""
340+
from collections import defaultdict # noqa: PLC0415
341+
342+
tag_to_findings: dict = defaultdict(list)
343+
for finding, tag_list in findings_with_tags:
344+
for tag in tag_list:
345+
if tag:
346+
tag_to_findings[tag].append(finding)
347+
348+
bulk_add_tag_mapping(tag_to_findings)
349+
350+
167351
def bulk_remove_all_tags(model_class, instance_ids_qs):
168352
"""
169353
Remove all tags from instances identified by the given ID subquery.
@@ -226,4 +410,4 @@ def bulk_remove_all_tags(model_class, instance_ids_qs):
226410
)
227411

228412

229-
__all__ = ["bulk_add_tags_to_instances", "bulk_remove_all_tags"]
413+
__all__ = ["bulk_add_tag_mapping", "bulk_add_tags_to_instances", "bulk_apply_parser_tags", "bulk_remove_all_tags"]

0 commit comments

Comments
 (0)