Skip to content

Commit c9ace10

Browse files
authored
avoid errors if custom field exists and reconcile its params (#197)
1 parent d8e9047 commit c9ace10

2 files changed

Lines changed: 63 additions & 6 deletions

File tree

validity/tests/test_utils/test_orm.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import pytest
2-
from dcim.models import Device
2+
from dcim.models import Device, DeviceType
3+
from django.contrib.contenttypes.models import ContentType
34
from django.db import connection
45
from django.db.models import BigIntegerField
56
from django.db.models.fields.json import KeyTextTransform
67
from django.db.models.functions import Cast
8+
from extras.models import CustomField
79
from factories import DeviceFactory, NameSetDBFactory, SerializerDBFactory
810

911
from validity.models import NameSet
1012
from validity.models.serializer import Serializer
11-
from validity.utils.orm import CustomPrefetchMixin, QuerySetMap
13+
from validity.utils.orm import CustomFieldBuilder, CustomPrefetchMixin, QuerySetMap
1214

1315

1416
@pytest.mark.parametrize("attrib", ["pk", "name"])
@@ -55,3 +57,53 @@ def test_custom_postfetch(monkeypatch):
5557
NameSetDBFactory(name=f"ns{i}")
5658
for device in custom_qs:
5759
assert device.name.replace("dev", "ns") == device.nameset.name
60+
61+
62+
@pytest.mark.django_db
63+
def test_custom_field_builder_creates_object_field():
64+
serializer_ct = ContentType.objects.get_for_model(Serializer)
65+
device_ct = ContentType.objects.get_for_model(Device)
66+
cf_builder = CustomFieldBuilder(cf_model=CustomField, content_type_model=ContentType)
67+
68+
custom_field = cf_builder.create(
69+
name="validity_test_serializer",
70+
label="Validity Test Serializer",
71+
type="object",
72+
required=False,
73+
object_type=serializer_ct,
74+
bind_to=[Device],
75+
)
76+
77+
custom_field.refresh_from_db()
78+
assert custom_field.name == "validity_test_serializer"
79+
assert custom_field.related_object_type == serializer_ct
80+
assert list(custom_field.object_types.all()) == [device_ct]
81+
82+
83+
@pytest.mark.django_db
84+
def test_custom_field_builder_reuses_existing_field():
85+
cf_builder = CustomFieldBuilder(cf_model=CustomField, content_type_model=ContentType)
86+
device_type_ct = ContentType.objects.get_for_model(DeviceType)
87+
88+
original = cf_builder.create(
89+
name="validity_test_reused",
90+
label="Original Label",
91+
type="text",
92+
required=False,
93+
bind_to=[Device],
94+
)
95+
reused = cf_builder.create(
96+
name="validity_test_reused",
97+
label="Changed Label",
98+
type="boolean",
99+
required=True,
100+
bind_to=[DeviceType],
101+
)
102+
103+
original.refresh_from_db()
104+
assert reused == original
105+
assert CustomField.objects.filter(name="validity_test_reused").count() == 1
106+
assert original.label == "Changed Label"
107+
assert original.type == "boolean"
108+
assert original.required is True
109+
assert list(original.object_types.all()) == [device_type_ct]

validity/utils/orm.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,15 @@ class CustomFieldBuilder:
206206
content_type_model: type
207207
db_alias: str = ""
208208

209-
def create(self, *, bind_to, object_type=None, **cf_params):
209+
def create(self, *, bind_to, name, object_type=None, **cf_params):
210210
db = self.db_alias or self.cf_model.objects.db
211-
if object_type is not None:
212-
cf_params["related_object_type"] = object_type
213-
custom_field = self.cf_model.objects.using(db).create(**cf_params)
211+
cf_params["related_object_type"] = object_type
212+
213+
# get_or_create handles #182 - compatibility with netbox-branching weird behaviour
214+
custom_field, created = self.cf_model.objects.using(db).get_or_create(name=name, defaults=cf_params)
215+
if not created:
216+
for field, value in cf_params.items():
217+
setattr(custom_field, field, value)
218+
custom_field.save(force_update=True, update_fields=cf_params.keys())
214219
custom_field.object_types.set(self.content_type_model.objects.get_for_model(model).pk for model in bind_to)
215220
return custom_field

0 commit comments

Comments
 (0)