44
55from django .contrib .contenttypes .fields import GenericRelation
66from django .contrib .contenttypes .models import ContentType
7- from django .db import transaction
7+ from django .core .exceptions import ObjectDoesNotExist
8+ from django .db import transaction , router
89from django .db .models import ProtectedError , FieldDoesNotExist , OneToOneRel
910from django .db .models .fields .related import ForeignObjectRel , ManyToManyField
1011from django .utils .translation import ugettext_lazy as _
1112from rest_framework import serializers
1213from rest_framework .exceptions import ValidationError
1314from rest_framework .fields import empty
1415from rest_framework .relations import ManyRelatedField
15- from rest_framework .serializers import BaseSerializer
16+ from rest_framework .serializers import BaseSerializer , ListSerializer
1617from rest_framework .validators import UniqueValidator , UniqueTogetherValidator
1718
1819# permit writable nested serializers
@@ -557,6 +558,8 @@ def save(self, **kwargs):
557558
558559 def _save_direct_relations (self , kwargs ):
559560 """Save direct relations so FKs exist when committing the base instance"""
561+ if self ._validated_data is None and kwargs == {}:
562+ return # delete-only
560563 for field_name , field in self .fields .items ():
561564 if self .field_types [field_name ] != self .TYPE_DIRECT :
562565 continue
@@ -580,24 +583,47 @@ def _save_reverse_relations(self, instance, kwargs):
580583 for field_name , field in self .fields .items ():
581584 if self .field_types [field_name ] != self .TYPE_REVERSE :
582585 continue
586+ if self ._validated_data is None and kwargs == {}:
587+ return # delete-only
583588 if self ._validated_data .get (field .source , empty ) == empty and kwargs .get (field_name , empty ) == empty :
584589 continue # nothing to save
585- # inject the instance into reverse relations so the <parent>_id ForeignKey field is valid when saved
586- related_field = self ._get_model_field (field .source ).field
587- print ("{} populating reverse field {}" .format (self .__class__ .__name__ , related_field .name ))
590+ model_field = self ._get_model_field (field .source )
591+ print ("{} populating reverse field {}" .format (self .__class__ .__name__ , model_field .field .name ))
588592 if isinstance (field , serializers .ListSerializer ):
593+ # reverse FK, inject the instance into reverse relations so the <parent>_id FK field is valid when saved
589594 for obj in field ._validated_data :
590- obj [related_field .name ] = instance
595+ obj [model_field . field .name ] = instance
591596 elif isinstance (field , serializers .ModelSerializer ):
592- if field ._validated_data is None :
593- field ._validated_data = {} # delete situation, but need a place to put FK
594- field ._validated_data [related_field .name ] = instance
597+ # 1:1
598+ if self ._validated_data [field .source ] is None :
599+ # indicates that we should delete 1:1 relation (if it exists)
600+ try :
601+ getattr (instance , field .source ).delete ()
602+ continue
603+ except ObjectDoesNotExist :
604+ pass
605+ else :
606+ field ._validated_data [model_field .field .name ] = instance
595607 else :
596608 raise Exception ("unexpected serializer type" )
597- # no tests fail if we do not cache this value in _validated_data, but it's consistent with forward relations
609+ # create/update (as appropriate) related objects
598610 self ._validated_data [field .source ] = field .save (** kwargs .get (field_name , {}))
599611 print ("{}._validated_data[{}] to reverse {}" .format (self .__class__ .__name__ , field_name , self ._validated_data [field .source ]))
600612
613+ # eliminate related objects that weren't in the request
614+ if isinstance (field , ListSerializer ):
615+ # due to a bug in Django, calling `set` on a non-nullable reverse relation will only `add`
616+ if model_field .field .null :
617+ getattr (instance , field .source ).set (self ._validated_data [field .source ])
618+ else :
619+ # models should be attached when saved so we only need to delete
620+ obj_field = getattr (instance , field .source )
621+ db = router .db_for_write (obj_field .model , instance = instance )
622+ old_objs = set (obj_field .using (db ).all ())
623+ for obj in old_objs :
624+ if obj not in self ._validated_data [field .source ]:
625+ obj .delete ()
626+
601627
602628class FocalSaveMixin (FieldLookupMixin ):
603629 """Provides a framework for extracting the values needed to get or create the focal object."""
0 commit comments