Skip to content

Commit c1cb1f0

Browse files
no copy-paste errors to be had
1 parent 11391f0 commit c1cb1f0

1 file changed

Lines changed: 12 additions & 9 deletions

File tree

mp_api/client/routes/materials/electronic_structure.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -326,22 +326,25 @@ def get_bandstructure_from_task_id(
326326
def get_bandstructure_from_material_id(
327327
self,
328328
material_id: str,
329-
path_type: BSPathType = BSPathType.setyawan_curtarolo,
329+
path_type: str | BSPathType = BSPathType.setyawan_curtarolo,
330330
line_mode=True,
331331
load_projections: bool = False,
332332
):
333333
"""Get the band structure pymatgen object associated with a Materials Project ID.
334334
335335
Arguments:
336336
material_id (str): Materials Project ID for a material
337-
path_type (BSPathType): k-point path selection convention
337+
path_type (BSPathType or its value as a str): k-point path selection convention
338338
line_mode (bool): Whether to return data for a line-mode calculation
339339
load_projections (bool) : Optionally load atom- and spin-projected
340340
bandstructure, if available.
341341
342342
Returns:
343343
bandstructure (Union[BandStructure, BandStructureSymmLine]): BandStructure or BandStructureSymmLine object
344344
"""
345+
pt: BSPathType = (
346+
BSPathType(path_type) if isinstance(path_type, str) else path_type
347+
)
345348
if line_mode:
346349
bs_doc = self.es_rester.search(
347350
material_ids=material_id, fields=["bandstructure"]
@@ -353,18 +356,18 @@ def get_bandstructure_from_material_id(
353356

354357
if (_bs_data := bs_doc[0]["bandstructure"]) is None:
355358
raise MPRestError(
356-
f"No {path_type.value} band structure data found for {material_id}"
359+
f"No {pt.value} band structure data found for {material_id}"
357360
)
358361

359362
bs_data = (
360363
_bs_data.model_dump() if self.use_document_model else _bs_data # type: ignore
361364
)
362365

363-
if bs_data.get(path_type.value, None) is None:
366+
if bs_data.get(pt.value, None) is None:
364367
raise MPRestError(
365-
f"No {path_type.value} band structure data found for {material_id}"
368+
f"No {pt.value} band structure data found for {material_id}"
366369
)
367-
bs_task_id = bs_data[path_type.value]["task_id"]
370+
bs_task_id = bs_data[pt.value]["task_id"]
368371

369372
else:
370373
if not (
@@ -391,7 +394,7 @@ def get_bandstructure_from_material_id(
391394

392395
bs_obj = self.get_bandstructure_from_task_id(
393396
bs_task_id,
394-
path_type=path_type if line_mode else BSPathType.unknown,
397+
path_type=pt if line_mode else BSPathType.unknown,
395398
load_projections=load_projections,
396399
)
397400

@@ -529,7 +532,7 @@ def get_dos_from_task_id(
529532
run_type (str, RunType, or None): Optional run type to query by.
530533
Will speed up query due to delta table partitioning.
531534
load_projections (bool) : Optionally load atom- and spin-projected
532-
bandstructure, if available.
535+
DOS, if available.
533536
534537
Returns:
535538
pymatgen Dos
@@ -579,7 +582,7 @@ def get_dos_from_material_id(
579582
Arguments:
580583
material_id (str): Materials Project ID for a material
581584
load_projections (bool) : Optionally load atom- and spin-projected
582-
bandstructure, if available.
585+
DOS, if available.
583586
584587
Returns:
585588
pymatgen Dos

0 commit comments

Comments
 (0)