1111from rest_framework .exceptions import ValidationError
1212from rest_framework .fields import empty
1313from rest_framework .relations import ManyRelatedField
14+ from rest_framework .serializers import BaseSerializer
1415from rest_framework .validators import UniqueValidator , UniqueTogetherValidator
1516
1617# permit writable nested serializers
@@ -426,7 +427,60 @@ def update(self, instance, validated_data):
426427 return super (UniqueFieldsMixin , self ).update (instance , validated_data )
427428
428429
429- class RelatedSaveMixin (serializers .Serializer ):
430+ class FieldLookupMixin (serializers .Serializer ):
431+
432+ def _get_model_field (self , source ):
433+ """Returns the field on the model"""
434+ # for serializers like ModelSerializer, the Meta.model can be used to classify fields
435+ if not hasattr (self , 'Meta' ) or not hasattr (self .Meta , 'model' ):
436+ return None
437+ try :
438+ return self .Meta .model ._meta .get_field (source )
439+ except FieldDoesNotExist :
440+ pass
441+ try :
442+ # If `related_name` is not set, field name does not include
443+ # `_set` -> remove it and check again
444+ default_postfix = '_set'
445+ if source .endswith (default_postfix ):
446+ return self .Meta .model ._meta .get_field (source [:- len (default_postfix )])
447+ except FieldDoesNotExist :
448+ pass
449+ return None
450+
451+ TYPE_READ_ONLY = 'read-only'
452+ TYPE_LOCAL = 'local'
453+ TYPE_DIRECT = 'direct'
454+ TYPE_REVERSE = 'reverse'
455+
456+ _cache_field_types = None
457+
458+ @property
459+ def field_types (self ):
460+ if self ._cache_field_types is None :
461+ self ._populate_field_types ()
462+ return self ._cache_field_types
463+
464+ def _populate_field_types (self ):
465+ self ._cache_field_types = {}
466+ for field_name , field in self .fields .items ():
467+ if field .read_only :
468+ self ._cache_field_types [field_name ] = self .TYPE_READ_ONLY
469+ continue
470+ if not isinstance (field , BaseSerializer ):
471+ self ._cache_field_types [field_name ] = self .TYPE_LOCAL
472+ continue
473+ if field .source == '*' :
474+ self ._cache_field_types [field_name ] = self .TYPE_DIRECT
475+ continue
476+ related_field = self ._get_model_field (field .source )
477+ if isinstance (related_field , ForeignObjectRel ):
478+ self ._cache_field_types [field_name ] = self .TYPE_REVERSE
479+ continue
480+ self ._cache_field_types [field_name ] = self .TYPE_DIRECT
481+
482+
483+ class RelatedSaveMixin (FieldLookupMixin ):
430484 """
431485 RelatedSaveMixin handes the saving of nested fields, both direct and reverse relations:
432486 - Direct relations needs to be saved first
@@ -435,16 +489,23 @@ class RelatedSaveMixin(serializers.Serializer):
435489 """
436490 _is_saved = False
437491
492+ def run_validation (self , data = empty ):
493+ self ._validated_data = super (RelatedSaveMixin , self ).run_validation (data )
494+ self ._errors = {}
495+ return self ._validated_data
496+
438497 def to_internal_value (self , data ):
439498 """Injects the PK of this field into reverse relations so they validate when created in to_internal_value."""
440- self ._make_reverse_relations_valid (data )
499+ self ._make_reverse_relations_valid ()
441500 return super (RelatedSaveMixin , self ).to_internal_value (data )
442501
443- def _make_reverse_relations_valid (self , data ):
502+ def _make_reverse_relations_valid (self ):
444503 """Make the reverse field optional since we may not have a key for the base object."""
445- for field_name , ( field , related_field ) in self ._get_reverse_fields () .items ():
446- if data . get ( field . source ) is None :
504+ for field_name , field in self .fields .items ():
505+ if self . field_types [ field_name ] != self . TYPE_REVERSE :
447506 continue
507+ # we know this is a reverse so reverse_field.field is valid
508+ related_field = self ._get_model_field (field .source ).field
448509 if isinstance (field , serializers .ListSerializer ):
449510 field = field .child
450511 if isinstance (field , serializers .ModelSerializer ):
@@ -455,11 +516,6 @@ def _make_reverse_relations_valid(self, data):
455516 # found the matching field, move on
456517 break
457518
458- def run_validation (self , data = empty ):
459- self ._validated_data = super (RelatedSaveMixin , self ).run_validation (data )
460- self ._errors = {}
461- return self ._validated_data
462-
463519 def save (self , ** kwargs ):
464520 """We already converted the inputs into a model so we need to save that model"""
465521 if self ._is_saved :
@@ -473,59 +529,13 @@ def save(self, **kwargs):
473529 self ._save_reverse_relations (reverse_relations , instance = instance )
474530 return instance
475531
476- def _get_reverse_fields (self ):
477- reverse_fields = OrderedDict ()
478- if not hasattr (self , 'Meta' ) or not hasattr (self .Meta , 'model' ):
479- # No model means no reverse fields (without the need to iterate)
480- return reverse_fields
481- for field_name , field in self .fields .items ():
482- if field .read_only :
483- continue
484- try :
485- related_field , direct = self ._get_related_field (field )
486- except FieldDoesNotExist :
487- continue
488- if direct :
489- continue
490- reverse_fields [field_name ] = (field , related_field )
491- return reverse_fields
492-
493- def _get_related_field (self , field ):
494- model_class = self .Meta .model
495- try :
496- related_field = model_class ._meta .get_field (field .source )
497- except FieldDoesNotExist :
498- # If `related_name` is not set, field name does not include
499- # `_set` -> remove it and check again
500- default_postfix = '_set'
501- if field .source .endswith (default_postfix ):
502- related_field = model_class ._meta .get_field (
503- field .source [:- len (default_postfix )])
504- else :
505- raise
506- if isinstance (related_field , ForeignObjectRel ):
507- return related_field .field , False
508- return related_field , True
509-
510532 def _save_direct_relations (self , kwargs ):
511- """Save direct relations so related objects have FKs when committing the base instance"""
533+ """Save direct relations so FKs exist when committing the base instance"""
512534 for field_name , field in self .fields .items ():
513- if field . read_only :
535+ if self . field_types [ field_name ] != self . TYPE_DIRECT :
514536 continue
515- if not isinstance (field , serializers . BaseSerializer ) :
537+ if not isinstance (self . _validated_data , dict ) or field_name not in self . _validated_data :
516538 continue
517- # source='*' will never be found in _validated_data
518- if isinstance (self ._validated_data , dict ) and self ._validated_data .get (field_name ) is None :
519- continue
520- # this serializer looks like a ModelSerializer
521- if hasattr (self , 'Meta' ) and hasattr (self .Meta , 'model' ):
522- try :
523- _ , direct = self ._get_related_field (field )
524- # don't try to process reverse relations
525- if not direct :
526- continue
527- except FieldDoesNotExist :
528- pass
529539 # reinject validated_data
530540 field ._validated_data = self ._validated_data [field_name ]
531541 self ._validated_data [field_name ] = field .save (** kwargs .pop (field_name , {}))
@@ -534,25 +544,27 @@ def _extract_reverse_relations(self, kwargs):
534544 """Removes revere relations from _validated_data to avoid FK integrity issues"""
535545 # Remove related fields from validated data for future manipulations
536546 related_objects = []
537- for field_name , (field , related_field ) in self ._get_reverse_fields ().items ():
538- if self ._validated_data .get (field .source ) is None :
547+ for field_name , field in self .fields .items ():
548+ if self .field_types [field_name ] != self .TYPE_REVERSE :
549+ continue
550+ if not isinstance (self ._validated_data , dict ) or field_name not in self ._validated_data :
539551 continue
540552 serializer = field
541553 if isinstance (serializer , serializers .ListSerializer ):
542554 serializer = serializer .child
543555 if isinstance (serializer , serializers .ModelSerializer ):
544556 related_objects .append ((
545557 field ,
546- related_field ,
547558 self ._validated_data .pop (field .source ),
548- kwargs .pop (field_name , {}),
559+ kwargs .get (field_name , {}),
549560 ))
550561 return related_objects
551562
552563 def _save_reverse_relations (self , related_objects , instance ):
553564 """Inject the current object as the FK in the reverse related objects and save them"""
554- for field , related_field , data , kwargs in related_objects :
565+ for field , data , kwargs in related_objects :
555566 # inject the PK from the instance
567+ related_field = self ._get_model_field (field .source ).field
556568 if isinstance (field , serializers .ListSerializer ):
557569 for obj in data :
558570 obj [related_field .name ] = instance
@@ -600,24 +612,31 @@ def run_validation(self, data=empty):
600612 return self ._validated_data
601613
602614
603- class FocalSaveMixin (serializers . Serializer ):
615+ class FocalSaveMixin (FieldLookupMixin ):
604616
605617 @transaction .atomic
606618 def save (self , ** kwargs ):
607619 match_on = {}
608620 m2m_relations = {}
609- defaults = self .validated_data .copy ()
610- for k , v in kwargs .items ():
611- defaults [k ] = v
612- for field_name , field in self .get_fields ().items ():
621+ defaults = {}
622+ for field_name , field in self .fields .items ():
613623 if self .match_on == '__all__' or field_name in self .match_on :
614- match_on [field .source or field_name ] = defaults .pop (field_name , None )
615- elif isinstance (field , ManyRelatedField ):
616- m2m_relations [field_name ] = defaults .pop (field_name , None )
624+ # add to match_on dict
625+ match_on [field .source or field_name ] = kwargs .get (field_name , self ._validated_data .get (field_name ))
626+ if isinstance (field , ManyRelatedField ):
627+ # we can't provide m2m values as kwargs; must use set() instead
628+ m2m_relations [field_name ] = kwargs .get (field_name , self ._validated_data .get (field_name ))
629+ elif self .field_types [field_name ] == self .TYPE_LOCAL :
630+ # need to check kwargs since there's no pre-processing
631+ defaults [field_name ] = kwargs .get (field_name , self ._validated_data .get (field_name ))
632+ elif self .field_types [field_name ] == self .TYPE_DIRECT :
633+ # kwargs should have been injected when direct relations were saved
634+ defaults [field_name ] = self ._validated_data .get (field_name )
635+ # make reverse relations aren't sent to a create
617636 # a parent serializer may inject a value that isn't among the fields, but is in `match_on`
618637 for key in self .match_on :
619- if key not in self .get_fields () .keys ():
620- match_on [key ] = defaults . pop (key , None )
638+ if key not in self .fields .keys ():
639+ match_on [key ] = kwargs . get (key , None )
621640 try :
622641 match , updated = self .do_save (match_on , defaults )
623642 if not updated :
@@ -655,7 +674,6 @@ def many_init(cls, *args, **kwargs):
655674 if meta is None :
656675 class Meta :
657676 pass
658-
659677 meta = Meta
660678 setattr (cls , 'Meta' , meta )
661679 list_serializer_class = getattr (meta , 'list_serializer_class' , None )
@@ -687,7 +705,7 @@ def run_validation(self, data=empty):
687705 self ._validated_data = super (NestedSaveSerializer , self ).run_validation (data )
688706 # restore Unique or UniqueTogether
689707 self .restore_validation_unique (validators )
690- return self .validated_data
708+ return self ._validated_data
691709
692710 def remove_validation_unique (self ):
693711 """
@@ -763,7 +781,7 @@ def do_save(self, match_on, defaults):
763781 try :
764782 return super (GetOrCreateNestedSerializerMixin , self ).do_save (match_on , defaults )
765783 except self .queryset .model .DoesNotExist :
766- return self .queryset .model (** match_on , ** defaults ), True
784+ return self .queryset .model (** defaults ), True
767785
768786
769787class UpdateOrCreateNestedSerializerMixin (UpdateDoSaveMixin , GetOrCreateNestedSerializerMixin ):
@@ -774,4 +792,4 @@ class CreateOnlyNestedSerializerMixin(NestedSaveSerializer):
774792 """Creates requested object or fails."""
775793
776794 def do_save (self , match_on , defaults ):
777- return self .queryset .model (** match_on , ** defaults ), True
795+ return self .queryset .model (** defaults ), True
0 commit comments