Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions netbox_diode_plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ class NetBoxDiodePluginConfig(PluginConfig):
# Override the displayed Diode target URL without affecting internal
# communication (e.g. to show the external ingress address).
"diode_target_display": None,

# Max number of retries when the batch apply endpoint hits a
# Postgres deadlock (40P01) or serialization failure (40001).
# 0 disables retries; default 3 means up to three retries after
# the initial attempt (4 attempts total).
"batch_apply_deadlock_retry_max_count": 3,
}


Expand Down
85 changes: 84 additions & 1 deletion netbox_diode_plugin/api/applier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

import logging

from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ObjectDoesNotExist
from django.db import models, transaction
from django.db.utils import IntegrityError
from extras.models import Tag
from rest_framework.exceptions import ValidationError as ValidationError

from .bulk_tags import apply_tags_bulk, supports_tags
from .common import NON_FIELD_ERRORS, Change, ChangeSet, ChangeSetException, ChangeSetResult, ChangeType, error_from_validation_error
from .matcher import find_existing_object, invalidate_find_obj_entry
from .plugin_utils import get_object_type_model, legal_fields
Expand All @@ -23,6 +26,12 @@
def apply_changeset(change_set: ChangeSet, request) -> ChangeSetResult:
"""Apply a change set."""
_validate_change_set(change_set)
_preload_changeset_cache(change_set, request)

# Collect (instance, tag_input) pairs as we apply changes; flushed via
# apply_tags_bulk after the main loop so we issue one DELETE+bulk INSERT
# for the whole changeset instead of per-instance m2m_changed re-fires.
request._diode_tag_pairs = []

created = {}
for change in change_set.changes:
Expand Down Expand Up @@ -52,10 +61,71 @@ def apply_changeset(change_set: ChangeSet, request) -> ChangeSetResult:
logger.error(f"Integrity error {object_type}: {e} {data}")
raise _err(f"created a conflict with an existing {object_type}", object_type, "__all__")

# Flush deferred tag writes in one bulk pass.
apply_tags_bulk(request._diode_tag_pairs, request)
request._diode_tag_pairs = []

return ChangeSetResult(
id=change_set.id,
)

def _preload_changeset_cache(change_set: ChangeSet, request) -> dict:
"""
Warm Django's ContentType cache and prefetch tag IDs once per changeset.

Without this, every per-change `ContentType.objects.get_for_model(...)` and
every `Tag.objects.get(slug=...)` issues its own SQL. Stashing the result on
`request._diode_preload` lets later code paths (PR 4 bulk-tag write) reuse
it without re-querying.
"""
models_to_warm: dict[str, models.Model] = {}
tag_slugs: set[str] = set()

for change in change_set.changes:
if change.change_type == ChangeType.NOOP:
continue
ot = change.object_type
if ot and ot not in models_to_warm:
try:
models_to_warm[ot] = get_object_type_model(ot)
except Exception:
# Unknown model — let the main apply path raise the proper error.
continue
for slug in _iter_tag_slugs(change):
tag_slugs.add(slug)

if models_to_warm:
# Populates Django's per-process ContentType cache in a single query.
ContentType.objects.get_for_models(*models_to_warm.values())

tag_ids_by_slug: dict[str, int] = {}
if tag_slugs:
for tag_id, slug in Tag.objects.filter(slug__in=tag_slugs).values_list("id", "slug"):
tag_ids_by_slug[slug] = tag_id

preload = {
"tag_ids_by_slug": tag_ids_by_slug,
"models_by_object_type": models_to_warm,
}
if request is not None:
request._diode_preload = preload
return preload


def _iter_tag_slugs(change: Change):
"""Yield string tag slugs from a change.data['tags'] list."""
if not change.data:
return
tags = change.data.get("tags")
if not isinstance(tags, list):
return
for t in tags:
if isinstance(t, str):
yield t
elif isinstance(t, dict) and isinstance(t.get("slug"), str):
yield t["slug"]


def _is_auto_created_component(object_type: str) -> bool:
"""Check if the object type is auto-created from templates."""
auto_created_components = [
Expand Down Expand Up @@ -103,14 +173,23 @@ def _create_or_find_instance(data: dict, object_type: str, serializer_class, req


def _apply_change(data: dict, model_class: models.Model, change: Change, created: dict, request):
# Pull tags out of `data` BEFORE serializer.save() so the serializer's
# `tag.set([...])` side effect — which fires m2m_changed and triggers a
# duplicate ObjectChange + a re-run of serialize_for_event — does not run.
# We buffer (instance, tag_input) and flush via apply_tags_bulk after the
# main apply_changeset loop.
deferred_tags = None
if supports_tags(model_class) and isinstance(data.get("tags"), list):
deferred_tags = data.pop("tags")

serializer_class = get_serializer_for_model(model_class)
change_type = change.change_type
instance = None

if change_type == ChangeType.CREATE:
# For component types that may be auto-created from e.g. DeviceType or ModuleType templates,
# try to find existing object first before attempting to create.
# This prevents duplicates when components are instantiated during Device/Module save()
instance = None
if _is_auto_created_component(change.object_type):
instance = _try_find_and_update_existing_instance(data, change.object_type, serializer_class, request)

Expand All @@ -135,6 +214,10 @@ def _apply_change(data: dict, model_class: models.Model, change: Change, created
serializer.save()
invalidate_find_obj_entry(change.object_type, instance.id)

if deferred_tags is not None and instance is not None:
request._diode_tag_pairs.append((instance, deferred_tags))


def _set_path(data, path, value):
path = path.split(".")
key = path.pop(0)
Expand Down
150 changes: 150 additions & 0 deletions netbox_diode_plugin/api/bulk_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#!/usr/bin/env python
# Copyright 2026 NetBox Labs, Inc.
"""
Bulk TaggedItem writes for diode apply paths.

PR 4 of BULK-ORM Sprint 1.

NetBox's `instance.tags.set([...])` triggers `m2m_changed` (post_clear +
post_add) per instance. The signal handler re-fires
`handle_changed_object`, which (a) writes another ObjectChange UPDATE and
(b) calls `enqueue_event` again — that re-runs the expensive
`serialize_for_event(instance)`. Multiplied by every tagged object in a
changeset this dominates the apply path.

This module bypasses the through-table m2m signal by writing
`extras_taggeditem` rows directly via `bulk_create` and re-emitting a
single `enqueue_event` per instance to refresh the queued payload with
the post-tag state. PR 3's `deferred_changelog` already collapses the
ObjectChange side; this module collapses the TaggedItem side.
"""

from __future__ import annotations

import logging
from collections import defaultdict
from typing import Any

from core.events import OBJECT_UPDATED
from django.contrib.contenttypes.models import ContentType
from django.db.models import Q
from extras.events import enqueue_event
from extras.models import Tag, TaggedItem
from netbox.context import events_queue
from netbox.models.features import TagsMixin

logger = logging.getLogger(__name__)


def supports_tags(model_or_instance) -> bool:
"""Return True if the model (or instance's class) inherits from TagsMixin."""
cls = model_or_instance if isinstance(model_or_instance, type) else type(model_or_instance)
return issubclass(cls, TagsMixin)


def apply_tags_bulk(instance_tag_pairs: list[tuple[Any, list]], request) -> None:
"""
Write tags for a list of (instance, tag_input) pairs via bulk_create.

`tag_input` may be a list of slug strings, a list of integer tag IDs, or
a list of dicts with a "slug" key (the same shapes the serializer accepts).
Tags listed for an instance fully replace its existing tag set, mirroring
`instance.tags.set([...])` semantics — but without the m2m_changed re-fire.
"""
if not instance_tag_pairs:
return

preload = getattr(request, "_diode_preload", None) or {}
tag_ids_by_slug: dict[str, int] = preload.get("tag_ids_by_slug", {}) or {}
_backfill_missing_tag_ids(instance_tag_pairs, tag_ids_by_slug)

rows, target_keys, instances_to_event = _build_rows(
instance_tag_pairs, tag_ids_by_slug,
)

_replace_existing(target_keys)
if rows:
TaggedItem.objects.bulk_create(rows, ignore_conflicts=True, batch_size=500)
_re_emit_events(instances_to_event, request)


def _backfill_missing_tag_ids(pairs, tag_ids_by_slug: dict[str, int]) -> None:
missing: set[str] = set()
for _instance, tag_input in pairs:
for raw in tag_input or []:
slug = _slug_from(raw)
if slug and slug not in tag_ids_by_slug:
missing.add(slug)
if missing:
for tag_id, slug in Tag.objects.filter(slug__in=missing).values_list("id", "slug"):
tag_ids_by_slug[slug] = tag_id


def _build_rows(pairs, tag_ids_by_slug: dict[str, int]):
ct_by_model: dict[type, ContentType] = {}
rows: list[TaggedItem] = []
target_keys: dict[int, set[int]] = defaultdict(set)
instances_to_event: list[Any] = []

for instance, tag_input in pairs:
if instance is None or instance.pk is None:
continue
ct = _ct_for(instance, ct_by_model)
target_keys[ct.id].add(instance.pk)
instances_to_event.append(instance)
for raw in tag_input or []:
tag_id = _resolve_tag_id(raw, tag_ids_by_slug)
if tag_id is None:
logger.warning("apply_tags_bulk: could not resolve tag %r for %s", raw, instance)
continue
rows.append(TaggedItem(content_type_id=ct.id, object_id=instance.pk, tag_id=tag_id))

return rows, target_keys, instances_to_event


def _ct_for(instance, ct_by_model: dict[type, ContentType]) -> ContentType:
model = type(instance)
ct = ct_by_model.get(model)
if ct is None:
ct = ContentType.objects.get_for_model(model)
ct_by_model[model] = ct
return ct


def _replace_existing(target_keys: dict[int, set[int]]) -> None:
if not target_keys:
return
q = Q()
for ct_id, obj_ids in target_keys.items():
q |= Q(content_type_id=ct_id, object_id__in=obj_ids)
TaggedItem.objects.filter(q).delete()


def _re_emit_events(instances, request) -> None:
if not instances:
return
queue = events_queue.get()
for instance in instances:
try:
enqueue_event(queue, instance, request, OBJECT_UPDATED)
except Exception as e:
logger.warning("apply_tags_bulk: enqueue_event failed for %s: %s", instance, e)
events_queue.set(queue)


def _slug_from(raw) -> str | None:
if isinstance(raw, str):
return raw
if isinstance(raw, dict):
s = raw.get("slug")
return s if isinstance(s, str) else None
return None


def _resolve_tag_id(raw, tag_ids_by_slug: dict[str, int]) -> int | None:
if isinstance(raw, int):
return raw
slug = _slug_from(raw)
if slug is not None:
return tag_ids_by_slug.get(slug)
return None
Loading
Loading