44
55from __future__ import annotations
66
7+ from enum import Enum
8+
79try :
810 import pytest
911except ImportError as exc :
@@ -108,19 +110,40 @@ def client_pagination(
108110 ) == set ()
109111
110112
111- def client_sort (search_method : Callable , sort_fields : str | Sequence [str ]):
113+ def client_sort (
114+ search_method : Callable ,
115+ sort_fields : str | Sequence [str ],
116+ aux_query : dict [str , Any ] | None = None ,
117+ ):
112118 """Test sorting on an endpoint.
113119
114120 Args:
115121 search_method (Callable) : Client search method to use
116122 sort_fields (str or Sequence of str) : fields to sort on
123+ aux_query (dict) : auxiliary query needed to filter documents
117124
118125 Raises:
119126 AssertionError if sorting in ascending or descending order does not work.
120127 """
128+
129+ def _normalize (doc , field : str ):
130+ v = getattr (doc , field )
131+ # serialize enums
132+ return v .value if isinstance (v , Enum ) else v
133+
134+ user_query = {
135+ k : v
136+ for k , v in (aux_query or {}).items ()
137+ if k not in ("_page" , "_sort_fields" , "chunk_size" , "fields" )
138+ }
121139 for sort_field in [sort_fields ] if isinstance (sort_fields , str ) else sort_fields :
140+
122141 asc = search_method (
123- _page = 1 , _sort_fields = sort_field , chunk_size = NUM_DOCS , fields = [sort_field ]
142+ _page = 1 ,
143+ _sort_fields = sort_field ,
144+ chunk_size = NUM_DOCS ,
145+ fields = [sort_field , "deprecated" , "material_id" ],
146+ ** user_query ,
124147 )
125148 desc = search_method (
126149 _page = 1 ,
@@ -130,12 +153,12 @@ def client_sort(search_method: Callable, sort_fields: str | Sequence[str]):
130153 )
131154
132155 idxs = list (range (NUM_DOCS ))
133- assert sorted (idxs , key = lambda idx : getattr (asc [idx ], sort_field )) == idxs
156+ assert sorted (idxs , key = lambda idx : _normalize (asc [idx ], sort_field )) == idxs
134157
135158 assert (
136159 sorted (
137160 idxs ,
138- key = lambda idx : getattr (desc [idx ], sort_field ),
161+ key = lambda idx : _normalize (desc [idx ], sort_field ),
139162 reverse = True ,
140163 )
141164 == idxs
0 commit comments