Skip to content

Commit 11391f0

Browse files
load projected dos / bs
1 parent 9181df7 commit 11391f0

1 file changed

Lines changed: 48 additions & 3 deletions

File tree

mp_api/client/routes/materials/electronic_structure.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ def get_bandstructure_from_task_id(
264264
task_id: str,
265265
run_type: str | RunType | None = None,
266266
path_type: str | BSPathType | None = None,
267+
load_projections: bool = False,
267268
) -> BandStructure:
268269
"""Get the band structure pymatgen object associated with a given task ID.
269270
@@ -273,6 +274,9 @@ def get_bandstructure_from_task_id(
273274
will speed up query due to delta table partitioning.
274275
path_type (str, BSPathType, or None) : Optional path type to
275276
speed up query
277+
load_projections (bool) : Optionally load atom- and spin-projected
278+
bandstructure, if available.
279+
276280
Returns:
277281
bandstructure (BandStructure): BandStructure or BandStructureSymmLine object
278282
"""
@@ -296,6 +300,20 @@ def get_bandstructure_from_task_id(
296300

297301
table = self._query_delta_single(query)
298302
if len(deser := table.to_pylist(maps_as_pydicts="strict")) > 0:
303+
if load_projections:
304+
proj_bs_label, _ = self._get_delta_table(
305+
"materialsproject-parsed",
306+
"core/electronic-structure/projected-bandstructures/",
307+
label="bandstructure_projections",
308+
)
309+
proj_table = self._query_delta_single(
310+
query.replace(bs_lbl, proj_bs_label)
311+
)
312+
if (
313+
len(deser_proj := proj_table.to_pylist(maps_as_pydicts="strict"))
314+
> 0
315+
):
316+
deser[0]["projections"] = deser_proj[0]
299317
emmet_bs = ElectronicBS(**deser[0])
300318
return emmet_bs.to_pmg(
301319
pmg_cls=BandStructureSymmLine if emmet_bs.labels_dict else BandStructure
@@ -310,13 +328,16 @@ def get_bandstructure_from_material_id(
310328
material_id: str,
311329
path_type: BSPathType = BSPathType.setyawan_curtarolo,
312330
line_mode=True,
331+
load_projections: bool = False,
313332
):
314333
"""Get the band structure pymatgen object associated with a Materials Project ID.
315334
316335
Arguments:
317336
material_id (str): Materials Project ID for a material
318337
path_type (BSPathType): k-point path selection convention
319338
line_mode (bool): Whether to return data for a line-mode calculation
339+
load_projections (bool) : Optionally load atom- and spin-projected
340+
bandstructure, if available.
320341
321342
Returns:
322343
bandstructure (Union[BandStructure, BandStructureSymmLine]): BandStructure or BandStructureSymmLine object
@@ -371,6 +392,7 @@ def get_bandstructure_from_material_id(
371392
bs_obj = self.get_bandstructure_from_task_id(
372393
bs_task_id,
373394
path_type=path_type if line_mode else BSPathType.unknown,
395+
load_projections=load_projections,
374396
)
375397

376398
if bs_obj:
@@ -495,14 +517,19 @@ def search(
495517
)
496518

497519
def get_dos_from_task_id(
498-
self, task_id: str, run_type: str | RunType | None = None
520+
self,
521+
task_id: str,
522+
run_type: str | RunType | None = None,
523+
load_projections: bool = False,
499524
) -> Dos:
500525
"""Get the density of states pymatgen object associated with a given calculation ID.
501526
502527
Arguments:
503528
task_id (str): Task ID for the density of states calculation
504529
run_type (str, RunType, or None): Optional run type to query by.
505530
Will speed up query due to delta table partitioning.
531+
load_projections (bool) : Optionally load atom- and spin-projected
532+
bandstructure, if available.
506533
507534
Returns:
508535
pymatgen Dos
@@ -525,16 +552,34 @@ def get_dos_from_task_id(
525552

526553
table = self._query_delta_single(query)
527554
if len(deser := table.to_pylist(maps_as_pydicts="strict")) > 0:
555+
if load_projections:
556+
proj_dos_label, _ = self._get_delta_table(
557+
"materialsproject-parsed",
558+
"core/electronic-structure/projected-dos/",
559+
label="dos_projections",
560+
)
561+
proj_table = self._query_delta_single(
562+
query.replace(dos_lbl, proj_dos_label)
563+
)
564+
if (
565+
len(deser_proj := proj_table.to_pylist(maps_as_pydicts="strict"))
566+
> 0
567+
):
568+
deser[0]["projected_densities"] = deser_proj[0]
528569
return ElectronicDos(**deser[0]).to_pmg()
529570
raise MPRestError(
530571
f"No DOS data found for {task_id=}" + (f"run_type={rt}" if run_type else "")
531572
)
532573

533-
def get_dos_from_material_id(self, material_id: str) -> Dos:
574+
def get_dos_from_material_id(
575+
self, material_id: str, load_projections: bool = False
576+
) -> Dos:
534577
"""Get the complete density of states pymatgen object associated with a Materials Project ID.
535578
536579
Arguments:
537580
material_id (str): Materials Project ID for a material
581+
load_projections (bool) : Optionally load atom- and spin-projected
582+
bandstructure, if available.
538583
539584
Returns:
540585
pymatgen Dos
@@ -552,4 +597,4 @@ def get_dos_from_material_id(self, material_id: str) -> Dos:
552597
dos_task_id = (dos_data.model_dump() if self.use_document_model else dos_data)[
553598
"task_id"
554599
]
555-
return self.get_dos_from_task_id(dos_task_id)
600+
return self.get_dos_from_task_id(dos_task_id, load_projections=load_projections)

0 commit comments

Comments
 (0)