Skip to content

Commit 1d9ba6c

Browse files
committed
refactor (and test) field type logic
1 parent 3d785fc commit 1d9ba6c

2 files changed

Lines changed: 206 additions & 78 deletions

File tree

drf_writable_nested/mixins.py

Lines changed: 96 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from rest_framework.exceptions import ValidationError
1212
from rest_framework.fields import empty
1313
from rest_framework.relations import ManyRelatedField
14+
from rest_framework.serializers import BaseSerializer
1415
from rest_framework.validators import UniqueValidator, UniqueTogetherValidator
1516

1617
# permit writable nested serializers
@@ -426,7 +427,60 @@ def update(self, instance, validated_data):
426427
return super(UniqueFieldsMixin, self).update(instance, validated_data)
427428

428429

429-
class RelatedSaveMixin(serializers.Serializer):
430+
class FieldLookupMixin(serializers.Serializer):
431+
432+
def _get_model_field(self, source):
433+
"""Returns the field on the model"""
434+
# for serializers like ModelSerializer, the Meta.model can be used to classify fields
435+
if not hasattr(self, 'Meta') or not hasattr(self.Meta, 'model'):
436+
return None
437+
try:
438+
return self.Meta.model._meta.get_field(source)
439+
except FieldDoesNotExist:
440+
pass
441+
try:
442+
# If `related_name` is not set, field name does not include
443+
# `_set` -> remove it and check again
444+
default_postfix = '_set'
445+
if source.endswith(default_postfix):
446+
return self.Meta.model._meta.get_field(source[:-len(default_postfix)])
447+
except FieldDoesNotExist:
448+
pass
449+
return None
450+
451+
TYPE_READ_ONLY = 'read-only'
452+
TYPE_LOCAL = 'local'
453+
TYPE_DIRECT = 'direct'
454+
TYPE_REVERSE = 'reverse'
455+
456+
_cache_field_types = None
457+
458+
@property
459+
def field_types(self):
460+
if self._cache_field_types is None:
461+
self._populate_field_types()
462+
return self._cache_field_types
463+
464+
def _populate_field_types(self):
465+
self._cache_field_types = {}
466+
for field_name, field in self.fields.items():
467+
if field.read_only:
468+
self._cache_field_types[field_name] = self.TYPE_READ_ONLY
469+
continue
470+
if not isinstance(field, BaseSerializer):
471+
self._cache_field_types[field_name] = self.TYPE_LOCAL
472+
continue
473+
if field.source == '*':
474+
self._cache_field_types[field_name] = self.TYPE_DIRECT
475+
continue
476+
related_field = self._get_model_field(field.source)
477+
if isinstance(related_field, ForeignObjectRel):
478+
self._cache_field_types[field_name] = self.TYPE_REVERSE
479+
continue
480+
self._cache_field_types[field_name] = self.TYPE_DIRECT
481+
482+
483+
class RelatedSaveMixin(FieldLookupMixin):
430484
"""
431485
RelatedSaveMixin handes the saving of nested fields, both direct and reverse relations:
432486
- Direct relations needs to be saved first
@@ -435,16 +489,23 @@ class RelatedSaveMixin(serializers.Serializer):
435489
"""
436490
_is_saved = False
437491

492+
def run_validation(self, data=empty):
493+
self._validated_data = super(RelatedSaveMixin, self).run_validation(data)
494+
self._errors = {}
495+
return self._validated_data
496+
438497
def to_internal_value(self, data):
439498
"""Injects the PK of this field into reverse relations so they validate when created in to_internal_value."""
440-
self._make_reverse_relations_valid(data)
499+
self._make_reverse_relations_valid()
441500
return super(RelatedSaveMixin, self).to_internal_value(data)
442501

443-
def _make_reverse_relations_valid(self, data):
502+
def _make_reverse_relations_valid(self):
444503
"""Make the reverse field optional since we may not have a key for the base object."""
445-
for field_name, (field, related_field) in self._get_reverse_fields().items():
446-
if data.get(field.source) is None:
504+
for field_name, field in self.fields.items():
505+
if self.field_types[field_name] != self.TYPE_REVERSE:
447506
continue
507+
# we know this is a reverse so reverse_field.field is valid
508+
related_field = self._get_model_field(field.source).field
448509
if isinstance(field, serializers.ListSerializer):
449510
field = field.child
450511
if isinstance(field, serializers.ModelSerializer):
@@ -455,11 +516,6 @@ def _make_reverse_relations_valid(self, data):
455516
# found the matching field, move on
456517
break
457518

458-
def run_validation(self, data=empty):
459-
self._validated_data = super(RelatedSaveMixin, self).run_validation(data)
460-
self._errors = {}
461-
return self._validated_data
462-
463519
def save(self, **kwargs):
464520
"""We already converted the inputs into a model so we need to save that model"""
465521
if self._is_saved:
@@ -473,59 +529,13 @@ def save(self, **kwargs):
473529
self._save_reverse_relations(reverse_relations, instance=instance)
474530
return instance
475531

476-
def _get_reverse_fields(self):
477-
reverse_fields = OrderedDict()
478-
if not hasattr(self, 'Meta') or not hasattr(self.Meta, 'model'):
479-
# No model means no reverse fields (without the need to iterate)
480-
return reverse_fields
481-
for field_name, field in self.fields.items():
482-
if field.read_only:
483-
continue
484-
try:
485-
related_field, direct = self._get_related_field(field)
486-
except FieldDoesNotExist:
487-
continue
488-
if direct:
489-
continue
490-
reverse_fields[field_name] = (field, related_field)
491-
return reverse_fields
492-
493-
def _get_related_field(self, field):
494-
model_class = self.Meta.model
495-
try:
496-
related_field = model_class._meta.get_field(field.source)
497-
except FieldDoesNotExist:
498-
# If `related_name` is not set, field name does not include
499-
# `_set` -> remove it and check again
500-
default_postfix = '_set'
501-
if field.source.endswith(default_postfix):
502-
related_field = model_class._meta.get_field(
503-
field.source[:-len(default_postfix)])
504-
else:
505-
raise
506-
if isinstance(related_field, ForeignObjectRel):
507-
return related_field.field, False
508-
return related_field, True
509-
510532
def _save_direct_relations(self, kwargs):
511-
"""Save direct relations so related objects have FKs when committing the base instance"""
533+
"""Save direct relations so FKs exist when committing the base instance"""
512534
for field_name, field in self.fields.items():
513-
if field.read_only:
535+
if self.field_types[field_name] != self.TYPE_DIRECT:
514536
continue
515-
if not isinstance(field, serializers.BaseSerializer):
537+
if not isinstance(self._validated_data, dict) or field_name not in self._validated_data:
516538
continue
517-
# source='*' will never be found in _validated_data
518-
if isinstance(self._validated_data, dict) and self._validated_data.get(field_name) is None:
519-
continue
520-
# this serializer looks like a ModelSerializer
521-
if hasattr(self, 'Meta') and hasattr(self.Meta, 'model'):
522-
try:
523-
_, direct = self._get_related_field(field)
524-
# don't try to process reverse relations
525-
if not direct:
526-
continue
527-
except FieldDoesNotExist:
528-
pass
529539
# reinject validated_data
530540
field._validated_data = self._validated_data[field_name]
531541
self._validated_data[field_name] = field.save(**kwargs.pop(field_name, {}))
@@ -534,25 +544,27 @@ def _extract_reverse_relations(self, kwargs):
534544
"""Removes revere relations from _validated_data to avoid FK integrity issues"""
535545
# Remove related fields from validated data for future manipulations
536546
related_objects = []
537-
for field_name, (field, related_field) in self._get_reverse_fields().items():
538-
if self._validated_data.get(field.source) is None:
547+
for field_name, field in self.fields.items():
548+
if self.field_types[field_name] != self.TYPE_REVERSE:
549+
continue
550+
if not isinstance(self._validated_data, dict) or field_name not in self._validated_data:
539551
continue
540552
serializer = field
541553
if isinstance(serializer, serializers.ListSerializer):
542554
serializer = serializer.child
543555
if isinstance(serializer, serializers.ModelSerializer):
544556
related_objects.append((
545557
field,
546-
related_field,
547558
self._validated_data.pop(field.source),
548-
kwargs.pop(field_name, {}),
559+
kwargs.get(field_name, {}),
549560
))
550561
return related_objects
551562

552563
def _save_reverse_relations(self, related_objects, instance):
553564
"""Inject the current object as the FK in the reverse related objects and save them"""
554-
for field, related_field, data, kwargs in related_objects:
565+
for field, data, kwargs in related_objects:
555566
# inject the PK from the instance
567+
related_field = self._get_model_field(field.source).field
556568
if isinstance(field, serializers.ListSerializer):
557569
for obj in data:
558570
obj[related_field.name] = instance
@@ -600,24 +612,31 @@ def run_validation(self, data=empty):
600612
return self._validated_data
601613

602614

603-
class FocalSaveMixin(serializers.Serializer):
615+
class FocalSaveMixin(FieldLookupMixin):
604616

605617
@transaction.atomic
606618
def save(self, **kwargs):
607619
match_on = {}
608620
m2m_relations = {}
609-
defaults = self.validated_data.copy()
610-
for k, v in kwargs.items():
611-
defaults[k] = v
612-
for field_name, field in self.get_fields().items():
621+
defaults = {}
622+
for field_name, field in self.fields.items():
613623
if self.match_on == '__all__' or field_name in self.match_on:
614-
match_on[field.source or field_name] = defaults.pop(field_name, None)
615-
elif isinstance(field, ManyRelatedField):
616-
m2m_relations[field_name] = defaults.pop(field_name, None)
624+
# add to match_on dict
625+
match_on[field.source or field_name] = kwargs.get(field_name, self._validated_data.get(field_name))
626+
if isinstance(field, ManyRelatedField):
627+
# we can't provide m2m values as kwargs; must use set() instead
628+
m2m_relations[field_name] = kwargs.get(field_name, self._validated_data.get(field_name))
629+
elif self.field_types[field_name] == self.TYPE_LOCAL:
630+
# need to check kwargs since there's no pre-processing
631+
defaults[field_name] = kwargs.get(field_name, self._validated_data.get(field_name))
632+
elif self.field_types[field_name] == self.TYPE_DIRECT:
633+
# kwargs should have been injected when direct relations were saved
634+
defaults[field_name] = self._validated_data.get(field_name)
635+
# make reverse relations aren't sent to a create
617636
# a parent serializer may inject a value that isn't among the fields, but is in `match_on`
618637
for key in self.match_on:
619-
if key not in self.get_fields().keys():
620-
match_on[key] = defaults.pop(key, None)
638+
if key not in self.fields.keys():
639+
match_on[key] = kwargs.get(key, None)
621640
try:
622641
match, updated = self.do_save(match_on, defaults)
623642
if not updated:
@@ -655,7 +674,6 @@ def many_init(cls, *args, **kwargs):
655674
if meta is None:
656675
class Meta:
657676
pass
658-
659677
meta = Meta
660678
setattr(cls, 'Meta', meta)
661679
list_serializer_class = getattr(meta, 'list_serializer_class', None)
@@ -687,7 +705,7 @@ def run_validation(self, data=empty):
687705
self._validated_data = super(NestedSaveSerializer, self).run_validation(data)
688706
# restore Unique or UniqueTogether
689707
self.restore_validation_unique(validators)
690-
return self.validated_data
708+
return self._validated_data
691709

692710
def remove_validation_unique(self):
693711
"""
@@ -763,7 +781,7 @@ def do_save(self, match_on, defaults):
763781
try:
764782
return super(GetOrCreateNestedSerializerMixin, self).do_save(match_on, defaults)
765783
except self.queryset.model.DoesNotExist:
766-
return self.queryset.model(**match_on, **defaults), True
784+
return self.queryset.model(**defaults), True
767785

768786

769787
class UpdateOrCreateNestedSerializerMixin(UpdateDoSaveMixin, GetOrCreateNestedSerializerMixin):
@@ -774,4 +792,4 @@ class CreateOnlyNestedSerializerMixin(NestedSaveSerializer):
774792
"""Creates requested object or fails."""
775793

776794
def do_save(self, match_on, defaults):
777-
return self.queryset.model(**match_on, **defaults), True
795+
return self.queryset.model(**defaults), True

tests/test_field_lookup.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from django.db import models
2+
from django.test import TestCase
3+
from rest_framework import serializers
4+
5+
from drf_writable_nested import mixins
6+
7+
8+
class LookupChild(models.Model):
9+
name = models.TextField()
10+
11+
12+
class LookupParent(models.Model):
13+
child = models.ForeignKey(LookupChild, on_delete=models.CASCADE)
14+
15+
16+
class LookupGrandParent(models.Model):
17+
child = models.ForeignKey(LookupParent, on_delete=models.CASCADE)
18+
19+
20+
class LookupReverseChild(models.Model):
21+
name = models.TextField()
22+
parent = models.ForeignKey(LookupParent, on_delete=models.CASCADE, related_name='children')
23+
24+
25+
class ChildSerializer(mixins.FieldLookupMixin, serializers.ModelSerializer):
26+
class Meta:
27+
model = LookupChild
28+
fields = '__all__'
29+
30+
31+
class ReverseChildSerializer(mixins.FieldLookupMixin, serializers.ModelSerializer):
32+
class Meta:
33+
model = LookupReverseChild
34+
fields = '__all__'
35+
36+
37+
class ParentSerializer(mixins.FieldLookupMixin, serializers.ModelSerializer):
38+
class Meta:
39+
model = LookupParent
40+
fields = '__all__'
41+
# source of a 1:many relationship
42+
child = ChildSerializer()
43+
children = ReverseChildSerializer(many=True)
44+
45+
46+
class GrandParentSerializer(mixins.FieldLookupMixin, serializers.ModelSerializer):
47+
class Meta:
48+
model = LookupGrandParent
49+
fields = '__all__'
50+
# source of a 1:many relationship
51+
child = ParentSerializer()
52+
53+
54+
class FieldTypesTest(TestCase):
55+
56+
def test_field_types_grandparent(self):
57+
serializer = GrandParentSerializer()
58+
self.assertEqual(
59+
{
60+
'id': serializer.TYPE_READ_ONLY,
61+
'child': serializer.TYPE_DIRECT,
62+
},
63+
serializer.field_types
64+
)
65+
66+
def test_field_types_parent(self):
67+
serializer = GrandParentSerializer()
68+
self.assertEqual(
69+
{
70+
'id': serializer.TYPE_READ_ONLY,
71+
'child': serializer.TYPE_DIRECT,
72+
'children': serializer.TYPE_REVERSE,
73+
},
74+
serializer.fields['child'].field_types
75+
)
76+
77+
def test_field_types_child(self):
78+
serializer = GrandParentSerializer()
79+
self.assertEqual(
80+
{
81+
'id': serializer.TYPE_READ_ONLY,
82+
'name': serializer.TYPE_LOCAL,
83+
},
84+
serializer.fields['child'].fields['child'].field_types
85+
)
86+
87+
def test_field_types_reversechild(self):
88+
serializer = GrandParentSerializer()
89+
self.assertEqual(
90+
{
91+
'id': serializer.TYPE_READ_ONLY,
92+
'name': serializer.TYPE_LOCAL,
93+
# must have a nested serializer to be "direct" otherwise it's just a local value
94+
'parent': serializer.TYPE_LOCAL,
95+
},
96+
serializer.fields['child'].fields['children'].child.field_types
97+
)
98+
99+
100+
class GetModelFieldTest(TestCase):
101+
102+
def test_reverse(self):
103+
serializer = ParentSerializer()
104+
model_field = serializer._get_model_field('children')
105+
print(type(model_field))
106+
# opposite side of a ForeignKey is a ManyToOne
107+
self.assertIsInstance(
108+
model_field,
109+
models.ManyToOneRel,
110+
)

0 commit comments

Comments
 (0)