Skip to content

Commit 70fed22

Browse files
committed
allow for sorting/filtering of JSON objects when using PostgreSQL #247
1 parent 33eba35 commit 70fed22

1 file changed

Lines changed: 46 additions & 16 deletions

File tree

dynamic_rest/filters.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from django.core.exceptions import ValidationError as InternalValidationError
44
from django.core.exceptions import ImproperlyConfigured
55
from django.db.models import Q, Prefetch, Manager
6+
from django.db.models.expressions import RawSQL, OrderBy
67
from django.utils import six
78
from rest_framework import serializers
89
from rest_framework.exceptions import ValidationError
9-
from rest_framework.fields import BooleanField, NullBooleanField
10+
from rest_framework.fields import BooleanField, NullBooleanField, JSONField
1011
from rest_framework.filters import BaseFilterBackend, OrderingFilter
1112

1213
from dynamic_rest.utils import is_truthy
@@ -124,9 +125,18 @@ def generate_query_key(self, serializer):
124125

125126
if i == last:
126127
break
127-
128+
128129
# Recurse into nested field
129130
s = getattr(field, 'serializer', None)
131+
if isinstance(field, JSONField):
132+
# If a json field is found, append any terms following
133+
j = i+1
134+
while j < len(self.field):
135+
rewritten.append(self.field[j])
136+
j += 1
137+
if self.operator:
138+
rewritten.append(self.operator)
139+
return ('__'.join(rewritten), self.field)
130140
if isinstance(s, serializers.ListSerializer):
131141
s = s.child
132142
if not s:
@@ -192,14 +202,12 @@ def filter_queryset(self, request, queryset, view):
192202
# after this is called may not behave as expected
193203
extra_filters = self.view.get_extra_filters(request)
194204

195-
disable_prefetches = self.view.is_update()
196-
197205
self.DEBUG = settings.DEBUG
198206

199207
return self._build_queryset(
200208
queryset=queryset,
201209
extra_filters=extra_filters,
202-
disable_prefetches=disable_prefetches,
210+
disable_prefetches=False,
203211
)
204212

205213
"""
@@ -643,7 +651,12 @@ def filter_queryset(self, request, queryset, view):
643651
"""
644652
self.ordering_param = view.SORT
645653

646-
ordering = self.get_ordering(request, queryset, view)
654+
ordering, nested = self.get_ordering(request, queryset, view)
655+
if ordering and nested:
656+
ordering_str = ''.join(ordering)
657+
if ordering_str.startswith('-'):
658+
return queryset.order_by(OrderBy(RawSQL('LOWER(%s)'%(ordering_str[1:]), nested), descending=True))
659+
return queryset.order_by(OrderBy(RawSQL('LOWER(%s)'%(ordering_str), nested), descending=False))
647660
if ordering:
648661
return queryset.order_by(*ordering)
649662

@@ -656,9 +669,10 @@ def get_ordering(self, request, queryset, view):
656669
This method overwrites the DRF default so it can parse the array.
657670
"""
658671
params = view.get_request_feature(view.SORT)
672+
nested = []
659673
if params:
660674
fields = [param.strip() for param in params]
661-
valid_ordering, invalid_ordering = self.remove_invalid_fields(
675+
valid_ordering, invalid_ordering, nested = self.remove_invalid_fields(
662676
queryset, fields, view
663677
)
664678

@@ -669,10 +683,10 @@ def get_ordering(self, request, queryset, view):
669683
"Invalid filter field: %s" % invalid_ordering
670684
)
671685
else:
672-
return valid_ordering
686+
return valid_ordering, nested
673687

674688
# No sorting was included
675-
return self.get_default_ordering(view)
689+
return self.get_default_ordering(view), nested
676690

677691
def remove_invalid_fields(self, queryset, fields, view):
678692
"""Remove invalid fields from an ordering.
@@ -690,14 +704,14 @@ def remove_invalid_fields(self, queryset, fields, view):
690704
stripped_term = term.lstrip('-')
691705
# add back the '-' add the end if necessary
692706
reverse_sort_term = '' if len(stripped_term) is len(term) else '-'
693-
ordering = self.ordering_for(stripped_term, view)
707+
ordering, nested = self.ordering_for(stripped_term, view)
694708

695709
if ordering:
696710
valid_orderings.append(reverse_sort_term + ordering)
697711
else:
698712
invalid_orderings.append(term)
699713

700-
return valid_orderings, invalid_orderings
714+
return valid_orderings, invalid_orderings, nested
701715

702716
def ordering_for(self, term, view):
703717
"""
@@ -707,7 +721,7 @@ def ordering_for(self, term, view):
707721
Raise ImproperlyConfigured if serializer_class not set on view
708722
"""
709723
if not self._is_allowed_term(term, view):
710-
return None
724+
return None, None
711725

712726
serializer = self._get_serializer_class(view)()
713727
serializer_chain = term.split('.')
@@ -716,10 +730,26 @@ def ordering_for(self, term, view):
716730

717731
for segment in serializer_chain[:-1]:
718732
field = serializer.get_all_fields().get(segment)
733+
734+
# If its a JSONField, construct a RawSQL command in the form of 'jsonField->{}'.format('nestedField')' or 'jsonField->{}->>{}'.format('nested','doubleNested')
735+
if field and isinstance(field, JSONField):
736+
json_chain_start = str(segment)
737+
json_chain = ''
738+
nested = []
739+
first = True
740+
for nterm in serializer_chain[1:]:
741+
if first:
742+
json_chain += '->>%s'
743+
first = False
744+
else:
745+
json_chain = '->%s'+json_chain
746+
nested.append(nterm)
747+
json_chain = json_chain_start + json_chain
748+
return json_chain, nested
719749

720750
if not (field and field.source != '*' and
721-
isinstance(field, DynamicRelationField)):
722-
return None
751+
isinstance(field, DynamicRelationField)):
752+
return None, None
723753

724754
model_chain.append(field.source or segment)
725755

@@ -729,11 +759,11 @@ def ordering_for(self, term, view):
729759
last_field = serializer.get_all_fields().get(last_segment)
730760

731761
if not last_field or last_field.source == '*':
732-
return None
762+
return None, None
733763

734764
model_chain.append(last_field.source or last_segment)
735765

736-
return '__'.join(model_chain)
766+
return '__'.join(model_chain), None
737767

738768
def _is_allowed_term(self, term, view):
739769
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)

0 commit comments

Comments
 (0)