Skip to content

Commit 92f684e

Browse files
committed
add configurable default return fields for sort test
- some collections don't have material_id, deprecated, etc.
1 parent 933aad0 commit 92f684e

5 files changed

Lines changed: 13 additions & 11 deletions

File tree

mp_api/_test_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,15 @@ def client_sort(
114114
search_method: Callable,
115115
sort_fields: str | Sequence[str],
116116
aux_query: dict[str, Any] | None = None,
117+
default_fields: tuple[str, ...] = ("deprecated", "material_id"),
117118
):
118119
"""Test sorting on an endpoint.
119120
120121
Args:
121122
search_method (Callable) : Client search method to use
122123
sort_fields (str or Sequence of str) : fields to sort on
123124
aux_query (dict) : auxiliary query needed to filter documents
125+
default_fields (list): default fields to return
124126
125127
Raises:
126128
AssertionError if sorting in ascending or descending order does not work.
@@ -142,7 +144,7 @@ def _normalize(doc, field: str):
142144
_page=1,
143145
_sort_fields=sort_field,
144146
chunk_size=NUM_DOCS,
145-
fields=[sort_field, "deprecated", "material_id"],
147+
fields=[sort_field, *default_fields],
146148
**user_query,
147149
)
148150
desc = search_method(

tests/client/materials/test_electrodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,4 @@ def test_pagination():
9999
@pytest.mark.parametrize("sort_field", ["stability_charge", "average_voltage"])
100100
def test_sort(sort_field):
101101
with ElectrodeRester() as rester:
102-
client_sort(rester.search, sort_field)
102+
client_sort(rester.search, sort_field, default_fields=())

tests/client/materials/test_xas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,4 @@ def test_pagination():
8282
)
8383
def test_sort(sort_field):
8484
with XASRester() as rester:
85-
client_sort(rester.search, sort_field)
85+
client_sort(rester.search, sort_field, default_fields=())

tests/client/molecules/test_jcesr.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import os
2+
3+
import pytest
4+
from pymatgen.core.periodic_table import Element
5+
26
from mp_api._test_utils import (
3-
client_search_testing,
47
client_pagination,
8+
client_search_testing,
59
client_sort,
610
requires_api_key,
711
)
8-
9-
import pytest
10-
from pymatgen.core.periodic_table import Element
11-
1212
from mp_api.client.core.exceptions import MPRestWarning
1313
from mp_api.client.routes.molecules.jcesr import JcesrMoleculesRester
1414

@@ -77,4 +77,4 @@ def test_pagination():
7777
)
7878
def test_sort(sort_field):
7979
with JcesrMoleculesRester() as rester:
80-
client_sort(rester.search, sort_field)
80+
client_sort(rester.search, sort_field, default_fields=())

tests/client/molecules/test_summary.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from emmet.core.mpid import MPculeID
66

77
from mp_api._test_utils import (
8-
client_search_testing,
98
client_pagination,
9+
client_search_testing,
1010
client_sort,
1111
requires_api_key,
1212
)
@@ -70,4 +70,4 @@ def test_pagination():
7070
@pytest.mark.parametrize("sort_field", ["charge", "spin_multiplicity"])
7171
def test_sort(sort_field):
7272
with MoleculesSummaryRester() as rester:
73-
client_sort(rester.search, sort_field)
73+
client_sort(rester.search, sort_field, default_fields=())

0 commit comments

Comments
 (0)