Skip to content

Commit 4c1d8f2

Browse files
committed
more elegant handling of reverse relations (so the data remains available for filtering)
1 parent c411a9a commit 4c1d8f2

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

drf_writable_nested/mixins.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -516,17 +516,22 @@ def _make_reverse_relations_valid(self):
516516
# found the matching field, move on
517517
break
518518

519+
@property
520+
def validated_data(self):
521+
"""If mixed into a standard Serializer, prevents `save` from accessing reverse relations"""
522+
return {k: v for k, v in super(RelatedSaveMixin, self).validated_data.items()
523+
if self.field_types[k] != self.TYPE_REVERSE}
524+
519525
def save(self, **kwargs):
520526
"""We already converted the inputs into a model so we need to save that model"""
527+
# prevent recursion when we save a reverse (which tries to save self as a direct)
521528
if self._is_saved:
522-
# prevent recursion when we save a reverse (which tries to save self as a direct)
523529
return
524530
# Create or update direct relations (foreign key, one-to-one)
525-
reverse_relations = self._extract_reverse_relations(kwargs)
526531
self._save_direct_relations(kwargs)
527532
instance = super(RelatedSaveMixin, self).save(**kwargs)
528533
self._is_saved = True
529-
self._save_reverse_relations(reverse_relations, instance=instance)
534+
self._save_reverse_relations(instance=instance, kwargs=kwargs)
530535
return instance
531536

532537
def _save_direct_relations(self, kwargs):
@@ -560,21 +565,23 @@ def _extract_reverse_relations(self, kwargs):
560565
))
561566
return related_objects
562567

563-
def _save_reverse_relations(self, related_objects, instance):
568+
def _save_reverse_relations(self, instance, kwargs):
564569
"""Inject the current object as the FK in the reverse related objects and save them"""
565-
for field, data, kwargs in related_objects:
566-
# inject the PK from the instance
570+
for field_name, field in self.fields.items():
571+
if self.field_types[field_name] != self.TYPE_REVERSE:
572+
continue
573+
# inject the instance into validated_data so the *_id field is valid when saved
567574
related_field = self._get_model_field(field.source).field
568575
if isinstance(field, serializers.ListSerializer):
569-
for obj in data:
576+
for obj in self._validated_data[field_name]:
570577
obj[related_field.name] = instance
571578
elif isinstance(field, serializers.ModelSerializer):
572-
data[related_field.name] = instance
579+
self._validated_data[field_name][related_field.name] = instance
573580
else:
574581
raise Exception("unexpected serializer type")
575582

576-
# reinject validated_data
577-
field._validated_data = data
583+
# (re)inject validated_data to field
584+
field._validated_data = self._validated_data.get(field_name)
578585
field.save(**kwargs)
579586

580587

@@ -588,18 +595,18 @@ def save(self, **kwargs):
588595
create_values = {}
589596
for field_name, field in self.fields.items():
590597
if self.match_on == '__all__' or field_name in self.match_on:
591-
# add to match_on dict
598+
# build match_on dict
592599
match_on[field.source or field_name] = kwargs.get(field_name, self._validated_data.get(field_name))
593600
if isinstance(field, ManyRelatedField):
594601
# we can't provide m2m values as kwargs; must use set() instead
595602
m2m_relations[field_name] = kwargs.get(field_name, self._validated_data.get(field_name))
596603
elif self.field_types[field_name] == self.TYPE_LOCAL:
597-
# need to check kwargs since there's no pre-processing
604+
# need to check kwargs dict since there's no pre-processing
598605
create_values[field_name] = kwargs.get(field_name, self._validated_data.get(field_name))
599606
elif self.field_types[field_name] == self.TYPE_DIRECT:
600-
# kwargs should have been injected when direct relations were saved
607+
# kwargs should have been injected into _validated_data when direct relations were saved
601608
create_values[field_name] = self._validated_data.get(field_name)
602-
# make reverse relations aren't sent to a create
609+
# reverse relations aren't sent to a create
603610
# a parent serializer may inject a value that isn't among the fields, but is in `match_on`
604611
for key in self.match_on:
605612
if key not in self.fields.keys():

0 commit comments

Comments
 (0)