Skip to content

Commit 9b98005

Browse files
committed
allow for sorting/filtering of JSON objects when using PostgreSQL #247
1 parent 33eba35 commit 9b98005

1 file changed

Lines changed: 53 additions & 16 deletions

File tree

dynamic_rest/filters.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from django.core.exceptions import ValidationError as InternalValidationError
44
from django.core.exceptions import ImproperlyConfigured
55
from django.db.models import Q, Prefetch, Manager
6+
from django.db.models.expressions import RawSQL, OrderBy
67
from django.utils import six
78
from rest_framework import serializers
89
from rest_framework.exceptions import ValidationError
9-
from rest_framework.fields import BooleanField, NullBooleanField
10+
from rest_framework.fields import BooleanField, NullBooleanField, JSONField
1011
from rest_framework.filters import BaseFilterBackend, OrderingFilter
1112

1213
from dynamic_rest.utils import is_truthy
@@ -127,6 +128,15 @@ def generate_query_key(self, serializer):
127128

128129
# Recurse into nested field
129130
s = getattr(field, 'serializer', None)
131+
if isinstance(field, JSONField):
132+
# If a json field is found, append any terms following
133+
j = i+1
134+
while j < len(self.field):
135+
rewritten.append(self.field[j])
136+
j += 1
137+
if self.operator:
138+
rewritten.append(self.operator)
139+
return ('__'.join(rewritten), self.field)
130140
if isinstance(s, serializers.ListSerializer):
131141
s = s.child
132142
if not s:
@@ -192,14 +202,12 @@ def filter_queryset(self, request, queryset, view):
192202
# after this is called may not behave as expected
193203
extra_filters = self.view.get_extra_filters(request)
194204

195-
disable_prefetches = self.view.is_update()
196-
197205
self.DEBUG = settings.DEBUG
198206

199207
return self._build_queryset(
200208
queryset=queryset,
201209
extra_filters=extra_filters,
202-
disable_prefetches=disable_prefetches,
210+
disable_prefetches=False,
203211
)
204212

205213
"""
@@ -643,7 +651,16 @@ def filter_queryset(self, request, queryset, view):
643651
"""
644652
self.ordering_param = view.SORT
645653

646-
ordering = self.get_ordering(request, queryset, view)
654+
ordering, nested = self.get_ordering(request, queryset, view)
655+
if ordering and nested:
656+
ordering_str = ''.join(ordering)
657+
if ordering_str.startswith('-'):
658+
return queryset.order_by(
659+
OrderBy(RawSQL('LOWER( %s )' % (ordering_str[1:]), nested),
660+
descending=True))
661+
return queryset.order_by(
662+
OrderBy(RawSQL('LOWER(%s)' % (ordering_str), nested),
663+
descending=False))
647664
if ordering:
648665
return queryset.order_by(*ordering)
649666

@@ -656,11 +673,13 @@ def get_ordering(self, request, queryset, view):
656673
This method overwrites the DRF default so it can parse the array.
657674
"""
658675
params = view.get_request_feature(view.SORT)
676+
nested = []
659677
if params:
660678
fields = [param.strip() for param in params]
661-
valid_ordering, invalid_ordering = self.remove_invalid_fields(
662-
queryset, fields, view
663-
)
679+
valid_ordering, invalid_ordering, nested = \
680+
self.remove_invalid_fields(
681+
queryset, fields, view
682+
)
664683

665684
# if any of the sort fields are invalid, throw an error.
666685
# else return the ordering
@@ -669,10 +688,10 @@ def get_ordering(self, request, queryset, view):
669688
"Invalid filter field: %s" % invalid_ordering
670689
)
671690
else:
672-
return valid_ordering
691+
return valid_ordering, nested
673692

674693
# No sorting was included
675-
return self.get_default_ordering(view)
694+
return self.get_default_ordering(view), nested
676695

677696
def remove_invalid_fields(self, queryset, fields, view):
678697
"""Remove invalid fields from an ordering.
@@ -690,14 +709,14 @@ def remove_invalid_fields(self, queryset, fields, view):
690709
stripped_term = term.lstrip('-')
691710
# add back the '-' add the end if necessary
692711
reverse_sort_term = '' if len(stripped_term) is len(term) else '-'
693-
ordering = self.ordering_for(stripped_term, view)
712+
ordering, nested = self.ordering_for(stripped_term, view)
694713

695714
if ordering:
696715
valid_orderings.append(reverse_sort_term + ordering)
697716
else:
698717
invalid_orderings.append(term)
699718

700-
return valid_orderings, invalid_orderings
719+
return valid_orderings, invalid_orderings, nested
701720

702721
def ordering_for(self, term, view):
703722
"""
@@ -707,7 +726,7 @@ def ordering_for(self, term, view):
707726
Raise ImproperlyConfigured if serializer_class not set on view
708727
"""
709728
if not self._is_allowed_term(term, view):
710-
return None
729+
return None, None
711730

712731
serializer = self._get_serializer_class(view)()
713732
serializer_chain = term.split('.')
@@ -717,9 +736,27 @@ def ordering_for(self, term, view):
717736
for segment in serializer_chain[:-1]:
718737
field = serializer.get_all_fields().get(segment)
719738

739+
# If its a JSONField, construct a RawSQL command in the form
740+
# of 'jsonField->{}'.format('nestedField')' or
741+
# 'jsonField->{}->>{}'.format('nested','doubleNested')
742+
if field and isinstance(field, JSONField):
743+
json_chain_start = str(segment)
744+
json_chain = ''
745+
nested = []
746+
first = True
747+
for nterm in serializer_chain[1:]:
748+
if first:
749+
json_chain += '->>%s'
750+
first = False
751+
else:
752+
json_chain = '->%s'+json_chain
753+
nested.append(nterm)
754+
json_chain = json_chain_start + json_chain
755+
return json_chain, nested
756+
720757
if not (field and field.source != '*' and
721758
isinstance(field, DynamicRelationField)):
722-
return None
759+
return None, None
723760

724761
model_chain.append(field.source or segment)
725762

@@ -729,11 +766,11 @@ def ordering_for(self, term, view):
729766
last_field = serializer.get_all_fields().get(last_segment)
730767

731768
if not last_field or last_field.source == '*':
732-
return None
769+
return None, None
733770

734771
model_chain.append(last_field.source or last_segment)
735772

736-
return '__'.join(model_chain)
773+
return '__'.join(model_chain), None
737774

738775
def _is_allowed_term(self, term, view):
739776
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)

0 commit comments

Comments
 (0)