Skip to content

Commit 933aad0

Browse files
test fixes
1 parent 1a1d21a commit 933aad0

14 files changed

Lines changed: 57 additions & 29 deletions

File tree

mp_api/_test_utils.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from __future__ import annotations
66

7+
from enum import Enum
8+
79
try:
810
import pytest
911
except ImportError as exc:
@@ -108,19 +110,40 @@ def client_pagination(
108110
) == set()
109111

110112

111-
def client_sort(search_method: Callable, sort_fields: str | Sequence[str]):
113+
def client_sort(
114+
search_method: Callable,
115+
sort_fields: str | Sequence[str],
116+
aux_query: dict[str, Any] | None = None,
117+
):
112118
"""Test sorting on an endpoint.
113119
114120
Args:
115121
search_method (Callable) : Client search method to use
116122
sort_fields (str or Sequence of str) : fields to sort on
123+
aux_query (dict) : auxiliary query needed to filter documents
117124
118125
Raises:
119126
AssertionError if sorting in ascending or descending order does not work.
120127
"""
128+
129+
def _normalize(doc, field: str):
130+
v = getattr(doc, field)
131+
# serialize enums
132+
return v.value if isinstance(v, Enum) else v
133+
134+
user_query = {
135+
k: v
136+
for k, v in (aux_query or {}).items()
137+
if k not in ("_page", "_sort_fields", "chunk_size", "fields")
138+
}
121139
for sort_field in [sort_fields] if isinstance(sort_fields, str) else sort_fields:
140+
122141
asc = search_method(
123-
_page=1, _sort_fields=sort_field, chunk_size=NUM_DOCS, fields=[sort_field]
142+
_page=1,
143+
_sort_fields=sort_field,
144+
chunk_size=NUM_DOCS,
145+
fields=[sort_field, "deprecated", "material_id"],
146+
**user_query,
124147
)
125148
desc = search_method(
126149
_page=1,
@@ -130,12 +153,12 @@ def client_sort(search_method: Callable, sort_fields: str | Sequence[str]):
130153
)
131154

132155
idxs = list(range(NUM_DOCS))
133-
assert sorted(idxs, key=lambda idx: getattr(asc[idx], sort_field)) == idxs
156+
assert sorted(idxs, key=lambda idx: _normalize(asc[idx], sort_field)) == idxs
134157

135158
assert (
136159
sorted(
137160
idxs,
138-
key=lambda idx: getattr(desc[idx], sort_field),
161+
key=lambda idx: _normalize(desc[idx], sort_field),
139162
reverse=True,
140163
)
141164
== idxs

mp_api/client/core/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import pyarrow.dataset as ds
1414
from deltalake import DeltaTable
1515
from emmet.core import __version__ as _EMMET_CORE_VER
16-
from emmet.core.mpid_ext import validate_identifier
16+
from emmet.core.mpid import validate_identifier
1717
from monty.json import MontyDecoder
1818
from packaging.version import parse as parse_version
1919

mp_api/client/routes/materials/phonon.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def compute_thermo_quantities(
225225

226226
self.use_document_model = True
227227
docs[0]["phonon_dos"] = self.get_dos_from_phonon_id( # type: ignore[index]
228-
phonon_id, phonon_method
228+
phonon_id, phonon_method # type: ignore[arg-type]
229229
)
230230
doc = PhononBSDOSDoc(**docs[0]) # type: ignore[arg-type]
231231
self.use_document_model = use_document_model

mp_api/client/routes/materials/thermo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
class ThermoRester(BaseRester):
1919
suffix = "materials/thermo"
2020
document_model = ThermoDoc # type: ignore
21-
primary_key = "thermo_id"
21+
primary_key = "material_id"
2222

2323
def search(
2424
self,

mp_api/client/routes/materials/xas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
class XASRester(BaseRester):
1818
suffix = "materials/xas"
1919
document_model = XASDoc # type: ignore
20-
primary_key = "spectrum_id"
20+
primary_key = "material_id"
2121
delta_backed = False
2222

2323
def search(

tests/client/materials/test_electrodes.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,8 @@ def test_pagination():
9595
client_pagination(rester.search, "material_ids")
9696

9797

98-
@pytest.mark.xfail(reason="Sort requires API redeployment", strict=False)
9998
@requires_api_key
100-
@pytest.mark.parametrize(
101-
"sort_field", ["working_ion", "stability_charge", "average_voltage"]
102-
)
99+
@pytest.mark.parametrize("sort_field", ["stability_charge", "average_voltage"])
103100
def test_sort(sort_field):
104101
with ElectrodeRester() as rester:
105102
client_sort(rester.search, sort_field)

tests/client/materials/test_provenance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def rester():
2525

2626
alt_name_dict: dict = {"material_ids": "material_id"}
2727

28-
custom_field_tests: dict = {"material_ids": ["mp-149"]}
28+
custom_field_tests: dict = {"material_ids": ["mp-13"]}
2929

3030

3131
@requires_api_key

tests/client/materials/test_summary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,4 +161,4 @@ def test_pagination():
161161
@pytest.mark.parametrize("sort_field", summary_sort_fields)
162162
def test_sort(sort_field: str):
163163
with SummaryRester() as rester:
164-
client_sort(rester.search, sort_field)
164+
client_sort(rester.search, sort_field, aux_query={sort_field: (0, 10)})

tests/client/materials/test_xas.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def test_client(rester):
6464
)
6565

6666

67-
# TODO: how to test pagination now that spectrum_id is computed, not stored?
6867
@requires_api_key
6968
def test_pagination():
7069
with XASRester() as rester:
@@ -73,10 +72,13 @@ def test_pagination():
7372
)
7473

7574

76-
@pytest.mark.xfail(reason="Sort requires API redeployment", strict=False)
7775
@requires_api_key
7876
@pytest.mark.parametrize(
79-
"sort_field", ["spectrum_type", "absorbing_element", "chemsys"]
77+
"sort_field",
78+
[
79+
"spectrum_type",
80+
"absorbing_element",
81+
],
8082
)
8183
def test_sort(sort_field):
8284
with XASRester() as rester:

tests/client/molecules/test_jcesr.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,13 @@ def test_pagination():
6868
client_pagination(rester.search, "task_id")
6969

7070

71-
@pytest.mark.xfail(reason="Sort requires API redeployment", strict=False)
7271
@requires_api_key
73-
@pytest.mark.parametrize("sort_field", ["task_id", "IE", "EA"])
72+
@pytest.mark.parametrize(
73+
"sort_field",
74+
[
75+
"task_id",
76+
],
77+
)
7478
def test_sort(sort_field):
7579
with JcesrMoleculesRester() as rester:
7680
client_sort(rester.search, sort_field)

0 commit comments

Comments
 (0)