@@ -516,17 +516,22 @@ def _make_reverse_relations_valid(self):
516516 # found the matching field, move on
517517 break
518518
519+ @property
520+ def validated_data (self ):
521+ """If mixed into a standard Serializer, prevents `save` from accessing reverse relations"""
522+ return {k : v for k , v in super (RelatedSaveMixin , self ).validated_data .items ()
523+ if self .field_types [k ] != self .TYPE_REVERSE }
524+
519525 def save (self , ** kwargs ):
520526 """We already converted the inputs into a model so we need to save that model"""
527+ # prevent recursion when we save a reverse (which tries to save self as a direct)
521528 if self ._is_saved :
522- # prevent recursion when we save a reverse (which tries to save self as a direct)
523529 return
524530 # Create or update direct relations (foreign key, one-to-one)
525- reverse_relations = self ._extract_reverse_relations (kwargs )
526531 self ._save_direct_relations (kwargs )
527532 instance = super (RelatedSaveMixin , self ).save (** kwargs )
528533 self ._is_saved = True
529- self ._save_reverse_relations (reverse_relations , instance = instance )
534+ self ._save_reverse_relations (instance = instance , kwargs = kwargs )
530535 return instance
531536
532537 def _save_direct_relations (self , kwargs ):
@@ -560,21 +565,23 @@ def _extract_reverse_relations(self, kwargs):
560565 ))
561566 return related_objects
562567
563- def _save_reverse_relations (self , related_objects , instance ):
568+ def _save_reverse_relations (self , instance , kwargs ):
564569 """Inject the current object as the FK in the reverse related objects and save them"""
565- for field , data , kwargs in related_objects :
566- # inject the PK from the instance
570+ for field_name , field in self .fields .items ():
571+ if self .field_types [field_name ] != self .TYPE_REVERSE :
572+ continue
573+ # inject the instance into validated_data so the *_id field is valid when saved
567574 related_field = self ._get_model_field (field .source ).field
568575 if isinstance (field , serializers .ListSerializer ):
569- for obj in data :
576+ for obj in self . _validated_data [ field_name ] :
570577 obj [related_field .name ] = instance
571578 elif isinstance (field , serializers .ModelSerializer ):
572- data [related_field .name ] = instance
579+ self . _validated_data [ field_name ] [related_field .name ] = instance
573580 else :
574581 raise Exception ("unexpected serializer type" )
575582
576- # reinject validated_data
577- field ._validated_data = data
583+ # (re)inject validated_data to field
584+ field ._validated_data = self . _validated_data . get ( field_name )
578585 field .save (** kwargs )
579586
580587
@@ -588,18 +595,18 @@ def save(self, **kwargs):
588595 create_values = {}
589596 for field_name , field in self .fields .items ():
590597 if self .match_on == '__all__' or field_name in self .match_on :
591- # add to match_on dict
598+ # build match_on dict
592599 match_on [field .source or field_name ] = kwargs .get (field_name , self ._validated_data .get (field_name ))
593600 if isinstance (field , ManyRelatedField ):
594601 # we can't provide m2m values as kwargs; must use set() instead
595602 m2m_relations [field_name ] = kwargs .get (field_name , self ._validated_data .get (field_name ))
596603 elif self .field_types [field_name ] == self .TYPE_LOCAL :
597- # need to check kwargs since there's no pre-processing
604+ # need to check kwargs dict since there's no pre-processing
598605 create_values [field_name ] = kwargs .get (field_name , self ._validated_data .get (field_name ))
599606 elif self .field_types [field_name ] == self .TYPE_DIRECT :
600- # kwargs should have been injected when direct relations were saved
607+ # kwargs should have been injected into _validated_data when direct relations were saved
601608 create_values [field_name ] = self ._validated_data .get (field_name )
602- # make reverse relations aren't sent to a create
609+ # reverse relations aren't sent to a create
603610 # a parent serializer may inject a value that isn't among the fields, but is in `match_on`
604611 for key in self .match_on :
605612 if key not in self .fields .keys ():
0 commit comments