Skip to content

Commit 28da883

Browse files
allow for version setting via client
1 parent 80e73bb commit 28da883

5 files changed

Lines changed: 115 additions & 64 deletions

File tree

mp_api/client/core/client.py

Lines changed: 71 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def __init__(
134134
session: requests.Session | None = None,
135135
headers: dict | None = None,
136136
mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS,
137+
db_version: str | None = None,
137138
local_dataset_cache: (
138139
str | os.PathLike
139140
) = MAPI_CLIENT_SETTINGS.LOCAL_DATASET_CACHE,
@@ -167,6 +168,9 @@ def __init__(
167168
advanced usage only.
168169
headers: Custom headers for localhost connections.
169170
mute_progress_bars: Whether to disable progress bars.
171+
db_version (str) : EXPERIMENTAL, allows for accessing a different version of the database
172+
than what is currently deployed. The Materials Project cannot guarantee that all
173+
features will still work.
170174
local_dataset_cache: Target directory for downloading full datasets. Defaults
171175
to 'mp_datasets' in the user's home directory
172176
force_renew: Option to overwrite existing local dataset
@@ -192,6 +196,7 @@ def __init__(
192196

193197
self.use_document_model = use_document_model
194198
self.mute_progress_bars = mute_progress_bars
199+
self.db_version: str = db_version or ""
195200
self.local_dataset_cache = Path(local_dataset_cache)
196201
self.force_renew = force_renew
197202
self._query_builder = (
@@ -260,6 +265,45 @@ def __exit__(self, exc_type, exc_val, exc_tb): # pragma: no cover
260265
self.session.close()
261266
self._session = None
262267

268+
@staticmethod
269+
@cache
270+
def _get_heartbeat_info(endpoint) -> tuple[str, list[str]]:
271+
"""DB version:
272+
The Materials Project database is periodically updated and has a
273+
database version associated with it. When the database is updated,
274+
consolidated data (information about "a material") may and does
275+
change, while calculation data about a specific calculation task
276+
remains unchanged and available for querying via its task_id.
277+
278+
The database version is set as a date in the format YYYY_MM_DD,
279+
where "_DD" may be optional. An additional numerical or `postN` suffix
280+
might be added if multiple releases happen on the same day.
281+
282+
Access Controlled Datasets:
283+
Certain contributions to the Materials Project have access
284+
control restrictions that require explicit agreement to the
285+
Terms of Use for the respective datasets prior to access being
286+
granted.
287+
288+
A full list of the Terms of Use for all contributions in the
289+
Materials Project are available at:
290+
291+
https://next-gen.materialsproject.org/about/terms
292+
293+
Returns:
294+
tuple with database version as a string and a comma separated
295+
string with all calculation batch identifiers that have access
296+
restrictions
297+
"""
298+
if (get_resp := requests.get(url=endpoint + "heartbeat")).status_code == 403:
299+
_emit_status_warning()
300+
return (
301+
"",
302+
[],
303+
) # Catiously do not allow access to any access controlled `batch_id`s
304+
response = get_resp.json()
305+
return response["db_version"], response["access_controlled_batch_ids"]
306+
263307

264308
class BaseRester(_Rester):
265309
"""Base client class with core stubs."""
@@ -278,6 +322,7 @@ def __init__(
278322
session: requests.Session | None = None,
279323
headers: dict | None = None,
280324
mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS,
325+
db_version: str | None = None,
281326
local_dataset_cache: (
282327
str | os.PathLike
283328
) = MAPI_CLIENT_SETTINGS.LOCAL_DATASET_CACHE,
@@ -316,6 +361,9 @@ def __init__(
316361
and will not give auto-complete for available fields.
317362
headers: Custom headers for localhost connections.
318363
mute_progress_bars: Whether to disable progress bars.
364+
db_version (str) : EXPERIMENTAL, allows for accessing a different version of the database
365+
than what is currently deployed. The Materials Project cannot guarantee that all
366+
features will still work.
319367
local_dataset_cache: Target directory for downloading full datasets. Defaults
320368
to 'mp_datasets' in the user's home directory
321369
force_renew: Option to overwrite existing local dataset
@@ -333,6 +381,7 @@ def __init__(
333381
session=session,
334382
headers=headers,
335383
mute_progress_bars=mute_progress_bars,
384+
db_version=db_version,
336385
local_dataset_cache=local_dataset_cache,
337386
force_renew=force_renew,
338387
query_builder=query_builder,
@@ -343,9 +392,11 @@ def __init__(
343392
self.endpoint = validate_endpoint(endpoint, suffix=self.suffix)
344393

345394
(
346-
self.db_version,
395+
hb_db_version,
347396
self.access_controlled_batch_ids,
348-
) = BaseRester._get_heartbeat_info(self.base_endpoint)
397+
) = self._get_heartbeat_info(self.base_endpoint)
398+
if not self.db_version:
399+
self.db_version = hb_db_version
349400

350401
self.timeout = timeout
351402
self._s3_client = s3_client
@@ -359,45 +410,6 @@ def s3_client(self):
359410
)
360411
return self._s3_client
361412

362-
@staticmethod
363-
@cache
364-
def _get_heartbeat_info(endpoint) -> tuple[str, list[str]]:
365-
"""DB version:
366-
The Materials Project database is periodically updated and has a
367-
database version associated with it. When the database is updated,
368-
consolidated data (information about "a material") may and does
369-
change, while calculation data about a specific calculation task
370-
remains unchanged and available for querying via its task_id.
371-
372-
The database version is set as a date in the format YYYY_MM_DD,
373-
where "_DD" may be optional. An additional numerical or `postN` suffix
374-
might be added if multiple releases happen on the same day.
375-
376-
Access Controlled Datasets:
377-
Certain contributions to the Materials Project have access
378-
control restrictions that require explicit agreement to the
379-
Terms of Use for the respective datasets prior to access being
380-
granted.
381-
382-
A full list of the Terms of Use for all contributions in the
383-
Materials Project are available at:
384-
385-
https://next-gen.materialsproject.org/about/terms
386-
387-
Returns:
388-
tuple with database version as a string and a comma separated
389-
string with all calculation batch identifiers that have access
390-
restrictions
391-
"""
392-
if (get_resp := requests.get(url=endpoint + "heartbeat")).status_code == 403:
393-
_emit_status_warning()
394-
return (
395-
"",
396-
[],
397-
) # Catiously do not allow access to any access controlled `batch_id`s
398-
response = get_resp.json()
399-
return response["db_version"], response["access_controlled_batch_ids"]
400-
401413
def _post_resource(
402414
self,
403415
body: dict | None = None,
@@ -744,17 +756,23 @@ def _query_delta_backed(
744756

745757
predicate = (
746758
f"WHERE batch_id NOT IN ({controlled_batch_str})"
747-
if not has_gnome_access
759+
if not has_gnome_access and controlled_batch_str
748760
else ""
749761
)
762+
# TODO: do we need something like this?
763+
# predicate += f"{' AND ' if predicate else 'WHERE '}version='{self.db_version}'"
750764

751765
# Setup progress bar
752766
num_docs_needed: int = tbl.count()
753767

754768
if not has_gnome_access:
755-
num_docs_needed = self.count(
756-
{"batch_id_neq_any": self.access_controlled_batch_ids}
757-
)
769+
try:
770+
num_docs_needed = self.count(
771+
{"batch_id_neq_any": self.access_controlled_batch_ids}
772+
)
773+
except MPRestError:
774+
# batch_id isn't a valid field
775+
num_docs_needed = self.count()
758776

759777
pbar = (
760778
tqdm(
@@ -918,15 +936,19 @@ def _query_resource(
918936
elif suffix in STATIC_COLLECTIONS:
919937
bucket_suffix = "build"
920938
prefix = f"static-collections/{suffix}"
939+
elif self.delta_backed:
940+
return self._query_delta_backed(
941+
"materialsproject-build",
942+
f"collections/{suffix}/",
943+
timeout=timeout,
944+
)
921945
else:
946+
# TODO: remove once all collections are migrated to delta-backed format
922947
bucket_suffix = "build"
923948
prefix = f"collections/{self.db_version.replace('.', '-')}/{suffix}"
924949

925950
bucket = f"materialsproject-{bucket_suffix}"
926951

927-
if self.delta_backed:
928-
return self._query_delta_backed(bucket, prefix, timeout=timeout)
929-
930952
# Paginate over all entries in the bucket.
931953
# TODO: change when a subset of entries needed from DB
932954
paginator = self.s3_client.get_paginator("list_objects_v2")
@@ -1671,6 +1693,7 @@ def __getattr__(self, v: str):
16711693
use_document_model=self.use_document_model,
16721694
headers=self.headers,
16731695
mute_progress_bars=self.mute_progress_bars,
1696+
db_version=self.db_version,
16741697
local_dataset_cache=self.local_dataset_cache,
16751698
force_renew=self.force_renew,
16761699
query_builder=self._query_builder,

mp_api/client/core/settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
_MAX_LIST_LENGTH = min(PMG_SETTINGS.get("MPRESTER_MAX_LIST_LENGTH", 10000), 10000)
1616

1717
_EMMET_SETTINGS = EmmetSettings() # type: ignore[call-arg]
18-
_DEFAULT_ENDPOINT = "https://api.materialsproject.org/"
18+
_DEFAULT_ENDPOINT = "http://0.0.0.0:8000" # https://api.materialsproject.org/"
1919

2020

2121
class MAPIClientSettings(BaseSettings):

mp_api/client/mprester.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __init__(
9797
session: Session | None = None,
9898
headers: dict | None = None,
9999
mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS,
100+
db_version: str | None = None,
100101
local_dataset_cache: (
101102
str | os.PathLike
102103
) = MAPI_CLIENT_SETTINGS.LOCAL_DATASET_CACHE,
@@ -130,6 +131,9 @@ def __init__(
130131
session: Session object to use. By default (None), the client will create one.
131132
headers: Custom headers for localhost connections.
132133
mute_progress_bars: Whether to mute progress bars.
134+
db_version (str) : EXPERIMENTAL, allows for accessing a different version of the database
135+
than what is currently deployed. The Materials Project cannot guarantee that all
136+
features will still work.
133137
local_dataset_cache: Target directory for downloading full datasets. Defaults
134138
to "mp_datasets" in the user's home directory
135139
force_renew: Option to overwrite existing local dataset
@@ -152,6 +156,7 @@ def __init__(
152156
session=session,
153157
headers=headers,
154158
mute_progress_bars=mute_progress_bars,
159+
db_version=db_version,
155160
local_dataset_cache=local_dataset_cache,
156161
force_renew=force_renew,
157162
query_builder=query_builder,
@@ -210,6 +215,17 @@ def __init__(
210215
stacklevel=2,
211216
)
212217

218+
if self.db_version:
219+
warnings.warn(
220+
"Specifying an explicit database version is an experimental "
221+
"feature. The Materials Project cannot guarantee "
222+
"functionality at this time, use at your own risk!",
223+
stacklevel=2,
224+
category=MPRestWarning,
225+
)
226+
else:
227+
self.db_version = self._get_heartbeat_info(self.endpoint)[0]
228+
213229
if notify_db_version:
214230
self._db_version_check()
215231

@@ -233,6 +249,7 @@ def __init__(
233249
use_document_model=self.use_document_model,
234250
headers=self.headers,
235251
mute_progress_bars=self.mute_progress_bars,
252+
db_version=self.db_version,
236253
local_dataset_cache=self.local_dataset_cache,
237254
force_renew=self.force_renew,
238255
query_builder=self._query_builder,
@@ -300,8 +317,7 @@ def __dir__(self):
300317
)
301318

302319
def __repr__(self) -> str:
303-
db_version = self.get_database_version()
304-
return f"MPRester({'v' + db_version if db_version else 'unknown version'})"
320+
return f"MPRester({'v' + self.self.db_version if self.db_version else 'unknown version'})"
305321

306322
def get_task_ids_associated_with_material_id(
307323
self, material_id: str, calc_types: list[CalcType] | None = None
@@ -364,7 +380,9 @@ def get_structure_by_material_id(
364380
return structure_data
365381

366382
def get_database_version(self) -> str | None:
367-
"""The Materials Project database is periodically updated and has a
383+
"""DEPRECATED: use `self.db_version` instead.
384+
385+
The Materials Project database is periodically updated and has a
368386
database version associated with it. When the database is updated,
369387
consolidated data (information about "a material") may and does
370388
change, while calculation data about a specific calculation task
@@ -376,10 +394,13 @@ def get_database_version(self) -> str | None:
376394
377395
Returns: database version as a string if accessible, None otherwise
378396
"""
379-
if (get_resp := get(url=self.endpoint + "heartbeat")).status_code == 403:
380-
_emit_status_warning()
381-
return None
382-
return get_resp.json()["db_version"]
397+
warnings.warn(
398+
"`get_database_version` has been deprecated in favor of "
399+
"MPRester().db_version.",
400+
stacklevel=2,
401+
category=MPRestWarning,
402+
)
403+
return self.db_version
383404

384405
@staticmethod
385406
@cache
@@ -1676,7 +1697,6 @@ def _db_version_check(self) -> None:
16761697
"""Check if the database version has drifted."""
16771698
import yaml # type: ignore[import-untyped]
16781699

1679-
db_version = self.get_database_version()
16801700
old_db_version = None
16811701
if MAPI_CLIENT_SETTINGS.LOG_FILE.exists():
16821702
old_db_version = (
@@ -1687,15 +1707,15 @@ def _db_version_check(self) -> None:
16871707
if not isinstance(old_db_version, str):
16881708
old_db_version = None
16891709

1690-
if old_db_version != db_version:
1710+
if old_db_version != self.db_version:
16911711
MAPI_CLIENT_SETTINGS.LOG_FILE.write_text(
1692-
yaml.safe_dump({"MAPI_DB_VERSION": db_version})
1712+
yaml.safe_dump({"MAPI_DB_VERSION": self.db_version})
16931713
)
16941714

16951715
if old_db_version:
16961716
warnings.warn(
16971717
"Materials Project database version has changed "
1698-
f"from v{old_db_version} to v{db_version}.",
1718+
f"from v{old_db_version} to v{self.db_version}.",
16991719
category=MPRestWarning,
17001720
stacklevel=2,
17011721
)

tests/client/materials/test_thermo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def test_client(rester):
6464
def test_get_phase_diagram_from_chemsys():
6565
# Test that a phase diagram is returned
6666

67+
pd = ThermoRester().get_phase_diagram_from_chemsys("Hf-Pm", thermo_type="GGA_GGA+U")
6768
assert isinstance(
68-
ThermoRester().get_phase_diagram_from_chemsys("Hf-Pm", thermo_type="GGA_GGA+U"),
69+
pd,
6970
PhaseDiagram,
7071
)

tests/client/test_mprester.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def mpr():
5454
@requires_api_key
5555
class TestMPRester:
5656
fake_mp_api_key = "12345678901234567890123456789012" # 32 chars
57-
default_endpoint = _DEFAULT_ENDPOINT
57+
default_endpoint = _DEFAULT_ENDPOINT + (
58+
"/" if not _DEFAULT_ENDPOINT.endswith("/") else ""
59+
)
5860

5961
def test_get_structure_by_material_id(self, mpr):
6062
s0 = mpr.get_structure_by_material_id("mp-149")
@@ -67,9 +69,14 @@ def test_get_structure_by_material_id(self, mpr):
6769
assert {s.formula for s in s2} == {"Si2"}
6870

6971
def test_get_database_version(self, mpr):
70-
db_version = mpr.get_database_version()
72+
db_version = mpr.db_version
7173
assert db_version is not None
7274

75+
with pytest.warns(
76+
MPRestWarning, match="`get_database_version` has been deprecated"
77+
):
78+
assert db_version == mpr.get_database_version()
79+
7380
def test_get_material_id_from_task_id(self, mpr):
7481
assert mpr.get_material_id_from_task_id("mp-540081") == "mp-19017"
7582

@@ -710,7 +717,7 @@ def test_db_warning(self, monkeypatch: pytest.MonkeyPatch):
710717
monkeypatch.setattr(MAPI_CLIENT_SETTINGS, "LOG_FILE", Path(tmp_log.name))
711718

712719
with MPRester(notify_db_version=True) as mpr:
713-
db_version = mpr.get_database_version()
720+
db_version = mpr.db_version
714721

715722
parsed_db_ver = yaml.safe_load(Path(tmp_log.name).read_text()).get(
716723
"MAPI_DB_VERSION"

0 commit comments

Comments
 (0)