Skip to content

Commit dba504d

Browse files
committed
handle 1:1 and "manually" remove items from reverse FKs since set is bugged
1 parent 12151c8 commit dba504d

File tree

1 file changed

+36
-10
lines changed

1 file changed

+36
-10
lines changed

drf_writable_nested/mixins.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@
44

55
from django.contrib.contenttypes.fields import GenericRelation
66
from django.contrib.contenttypes.models import ContentType
7-
from django.db import transaction
7+
from django.core.exceptions import ObjectDoesNotExist
8+
from django.db import transaction, router
89
from django.db.models import ProtectedError, FieldDoesNotExist, OneToOneRel
910
from django.db.models.fields.related import ForeignObjectRel, ManyToManyField
1011
from django.utils.translation import ugettext_lazy as _
1112
from rest_framework import serializers
1213
from rest_framework.exceptions import ValidationError
1314
from rest_framework.fields import empty
1415
from rest_framework.relations import ManyRelatedField
15-
from rest_framework.serializers import BaseSerializer
16+
from rest_framework.serializers import BaseSerializer, ListSerializer
1617
from rest_framework.validators import UniqueValidator, UniqueTogetherValidator
1718

1819
# permit writable nested serializers
@@ -557,6 +558,8 @@ def save(self, **kwargs):
557558

558559
def _save_direct_relations(self, kwargs):
559560
"""Save direct relations so FKs exist when committing the base instance"""
561+
if self._validated_data is None and kwargs == {}:
562+
return # delete-only
560563
for field_name, field in self.fields.items():
561564
if self.field_types[field_name] != self.TYPE_DIRECT:
562565
continue
@@ -580,24 +583,47 @@ def _save_reverse_relations(self, instance, kwargs):
580583
for field_name, field in self.fields.items():
581584
if self.field_types[field_name] != self.TYPE_REVERSE:
582585
continue
586+
if self._validated_data is None and kwargs == {}:
587+
return # delete-only
583588
if self._validated_data.get(field.source, empty) == empty and kwargs.get(field_name, empty) == empty:
584589
continue # nothing to save
585-
# inject the instance into reverse relations so the <parent>_id ForeignKey field is valid when saved
586-
related_field = self._get_model_field(field.source).field
587-
print("{} populating reverse field {}".format(self.__class__.__name__, related_field.name))
590+
model_field = self._get_model_field(field.source)
591+
print("{} populating reverse field {}".format(self.__class__.__name__, model_field.field.name))
588592
if isinstance(field, serializers.ListSerializer):
593+
# reverse FK, inject the instance into reverse relations so the <parent>_id FK field is valid when saved
589594
for obj in field._validated_data:
590-
obj[related_field.name] = instance
595+
obj[model_field.field.name] = instance
591596
elif isinstance(field, serializers.ModelSerializer):
592-
if field._validated_data is None:
593-
field._validated_data = {} # delete situation, but need a place to put FK
594-
field._validated_data[related_field.name] = instance
597+
# 1:1
598+
if self._validated_data[field.source] is None:
599+
# indicates that we should delete 1:1 relation (if it exists)
600+
try:
601+
getattr(instance, field.source).delete()
602+
continue
603+
except ObjectDoesNotExist:
604+
pass
605+
else:
606+
field._validated_data[model_field.field.name] = instance
595607
else:
596608
raise Exception("unexpected serializer type")
597-
# no tests fail if we do not cache this value in _validated_data, but it's consistent with forward relations
609+
# create/update (as appropriate) related objects
598610
self._validated_data[field.source] = field.save(**kwargs.get(field_name, {}))
599611
print("{}._validated_data[{}] to reverse {}".format(self.__class__.__name__, field_name, self._validated_data[field.source]))
600612

613+
# eliminate related objects that weren't in the request
614+
if isinstance(field, ListSerializer):
615+
# due to a bug in Django, calling `set` on a non-nullable reverse relation will only `add`
616+
if model_field.field.null:
617+
getattr(instance, field.source).set(self._validated_data[field.source])
618+
else:
619+
# models should be attached when saved so we only need to delete
620+
obj_field = getattr(instance, field.source)
621+
db = router.db_for_write(obj_field.model, instance=instance)
622+
old_objs = set(obj_field.using(db).all())
623+
for obj in old_objs:
624+
if obj not in self._validated_data[field.source]:
625+
obj.delete()
626+
601627

602628
class FocalSaveMixin(FieldLookupMixin):
603629
"""Provides a framework for extracting the values needed to get or create the focal object."""

0 commit comments

Comments
 (0)