diff --git a/.gitignore b/.gitignore index c738b02b..f98e7d76 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,6 @@ build/* *.egg-info dist/* .vscode + +.venv/ +_version.py diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py index 296a2947..e7ddc5ce 100644 --- a/crystal_toolkit/components/phonon.py +++ b/crystal_toolkit/components/phonon.py @@ -1,6 +1,7 @@ from __future__ import annotations import itertools +from copy import deepcopy from typing import TYPE_CHECKING, Any import numpy as np @@ -8,7 +9,11 @@ from dash import dcc, html from dash.dependencies import Component, Input, Output from dash.exceptions import PreventUpdate -from dash_mp_components import CrystalToolkitScene +from dash_mp_components import CrystalToolkitAnimationScene, CrystalToolkitScene + +# crystal animation algo +from pymatgen.analysis.graphs import StructureGraph +from pymatgen.analysis.local_env import CrystalNN from pymatgen.ext.matproj import MPRester from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine from pymatgen.phonon.dos import CompletePhononDos @@ -17,23 +22,14 @@ from crystal_toolkit.core.mpcomponent import MPComponent from crystal_toolkit.core.panelcomponent import PanelComponent from crystal_toolkit.core.scene import Convex, Cylinders, Lines, Scene, Spheres -from crystal_toolkit.helpers.layouts import ( - Column, - Columns, - Label, - MessageBody, - MessageContainer, - get_data_list, -) +from crystal_toolkit.helpers.layouts import Column, Columns, Label, get_data_list from crystal_toolkit.helpers.pretty_labels import pretty_labels if TYPE_CHECKING: from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine from pymatgen.electronic_structure.dos import CompleteDos -# Author: Jason Munro, Janosh Riebesell -# Contact: jmunro@lbl.gov, janosh@lbl.gov - +DISPLACE_COEF = [0, 1, 0, -1, 0] # TODOs: # - look for additional projection methods in phonon DOS (currently only atom @@ -64,26 +60,32 @@ def __init__( **kwargs, ) + bs, _ = PhononBandstructureAndDosComponent._get_ph_bs_dos( + self.initial_data["default"] + ) + self.create_store("bs-store", bs) + self.create_store("bs", None) + self.create_store("dos", None) + @property def _sub_layouts(self) -> dict[str, Component]: # defaults state = {"label-select": "sc", "dos-select": "ap"} - bs, dos = PhononBandstructureAndDosComponent._get_ph_bs_dos( - self.initial_data["default"] - ) - fig = PhononBandstructureAndDosComponent.get_figure(bs, dos) + fig = PhononBandstructureAndDosComponent.get_figure(None, None) # Main plot graph = dcc.Graph( figure=fig, config={"displayModeBar": False}, - responsive=True, + responsive=False, id=self.id("ph-bsdos-graph"), ) # Brillouin zone - zone_scene = self.get_brillouin_zone_scene(bs) - zone = CrystalToolkitScene(data=zone_scene.to_json(), sceneSize="500px") + zone_scene = self.get_brillouin_zone_scene(None) + zone = CrystalToolkitScene( + data=zone_scene.to_json(), sceneSize="500px", id=self.id("zone") + ) # Hide by default if not loaded by mpid, switching between k-paths # on-the-fly only supported for bandstructures retrieved from MP @@ -105,9 +107,11 @@ def _sub_layouts(self) -> dict[str, Component]: options=options, ) ], - style={"width": "200px"} - if show_path_options - else {"maxWidth": "200", "display": "none"}, + style=( + {"width": "200px"} + if show_path_options + else {"maxWidth": "200", "display": "none"} + ), id=self.id("path-container"), ) @@ -122,9 +126,11 @@ def _sub_layouts(self) -> dict[str, Component]: options=options, ) ], - style={"width": "200px"} - if show_path_options - else {"width": "200px", "display": "none"}, + style=( + {"width": "200px"} + if show_path_options + else {"width": "200px", "display": "none"} + ), id=self.id("label-container"), ) @@ -138,9 +144,29 @@ def _sub_layouts(self) -> dict[str, Component]: style={"width": "200px"}, ) - summary_dict = self._get_data_list_dict(bs, dos) + summary_dict = self._get_data_list_dict(None, None) summary_table = get_data_list(summary_dict) + # crystal visualization + + tip = html.P( + "Click different q-points and bands in the dispersion diagram to see the crystal vibration.", + id=self.id("crystal-tip"), + style={ + "margin": "0 0 12px", + "fontSize": "16px", + "color": "#555", + "textAlign": "center", + }, + ) + + crystal_animation = CrystalToolkitAnimationScene( + data={}, + sceneSize="200px", + id=self.id("crystal-animation"), + settings={"defaultZoom": 1.5}, + ) + return { "graph": graph, "convention": convention, @@ -148,10 +174,15 @@ def _sub_layouts(self) -> dict[str, Component]: "label-select": label_select, "zone": zone, "table": summary_table, + "crystal-animation": crystal_animation, + "tip": tip, } def layout(self) -> html.Div: sub_layouts = self._sub_layouts + crystal_animation = Columns( + [Column([sub_layouts["tip"], sub_layouts["crystal-animation"]])] + ) graph = Columns([Column([sub_layouts["graph"]])]) controls = Columns( [ @@ -166,11 +197,134 @@ def layout(self) -> html.Div: ) brillouin_zone = Columns( [ - Column([Label("Summary"), sub_layouts["table"]]), + Column([Label("Summary"), sub_layouts["table"]], id=self.id("table")), Column([Label("Brillouin Zone"), sub_layouts["zone"]]), ] ) - return html.Div([graph, controls, brillouin_zone]) + + return html.Div([graph, crystal_animation, controls, brillouin_zone]) + + @staticmethod + def _get_eigendisplacement( + ph_bs: BandStructureSymmLine, + json_data: dict, + band: int = 0, + qpoint: int = 0, + precision: int = 15, + magnitude: int = 225, + ) -> dict: + if not ph_bs or not json_data: + return {} + + assert json_data["contents"][0]["name"] == "atoms" + assert json_data["contents"][1]["name"] == "bonds" + rdata = deepcopy(json_data) + + def calc_max_displacement(idx: int) -> list: + """ + Retrieve the eigendisplacement for a given atom index from `ph_bs` and compute its maximum displacement. + + Parameters: + idx (int): The atom index. + + Returns: + list: The maximum displacement vector in the form [x_max_displacement, y_max_displacement, z_max_displacement] + + This function extracts the real component of the atom's eigendisplacement, + scales it by the specified magnitude, and returns the resulting vector. + """ + return [ + round(complex(vec).real * magnitude, precision) + for vec in ph_bs.eigendisplacements[band][qpoint][idx] + ] + + def calc_animation_step(max_displacement: list, coef: int) -> list: + """ + Calculate the displacement for an animation frame based on the given coefficient. + + Parameters: + max_displacement (list): A list of maximum displacements along each axis, + formatted as [x_max_displacement, y_max_displacement, z_max_displacement]. + coef (int): A coefficient indicating the motion direction. + - 0: no movement + - 1: forward movement + - -1: backward movement + + Returns: + list: The displacement vector [x_displacement, y_displacement, z_displacement]. + + This function generates oscillatory motion by scaling the maximum displacement + with the provided coefficient. + """ + return [round(coef * md, precision) for md in max_displacement] + + # Compute per-frame atomic motion. + # `rcontent["animate"]` stores the displacement (distance difference) from the previous coordinates. + contents0 = json_data["contents"][0]["contents"] + for cidx, content in enumerate(contents0): + max_displacement = calc_max_displacement(content["_meta"][0]) + rcontent = rdata["contents"][0]["contents"][cidx] + # put animation frame to the given atom index + rcontent["animate"] = [ + calc_animation_step(max_displacement, coef) for coef in DISPLACE_COEF + ] + rcontent["keyframes"] = list(range(len(DISPLACE_COEF))) + rcontent["animateType"] = "displacement" + # Compute per-frame bonding motion. + # Explanation: + # Each bond connects two atoms, `u` and `v`, represented as (u)----(v) + # To model the bond motion, it is divided into two segments: + # from `u` to the midpoint and from the midpoint to `v`, i.e., (u)--(mid)--(v) + # Thus, two cylinders are created: one for (u)--(mid) and another for (v)--(mid). + # For each cylinder, displacements are assigned to the endpoints — for example, + # the (u)--(mid) cylinder uses: + # [ + # [u_x_displacement, u_y_displacement, u_z_displacement], + # [mid_x_displacement, mid_y_displacement, mid_z_displacement] + # ]. + contents1 = json_data["contents"][1]["contents"] + + for cidx, content in enumerate(contents1): + bond_animation = [] + assert len(content["_meta"]) == len(content["positionPairs"]) + + for atom_idx_pair in content["_meta"]: + max_displacements = list( + map(calc_max_displacement, atom_idx_pair) + ) # max displacement for u and v + + u_to_middle_bond_animation = [] + + for coef in DISPLACE_COEF: + # Calculate the midpoint displacement between atom u and v for each animation frame. + u_displacement, v_displacement = [ + np.array(calc_animation_step(max_displacement, coef)) + for max_displacement in max_displacements + ] + middle_end_displacement = np.add(u_displacement, v_displacement) / 2 + + u_to_middle_bond_animation.append( + [ + u_displacement, # u atom displacement + [ + round(dis, precision) for dis in middle_end_displacement + ], # middle point displacement + ] + ) + + bond_animation.append(u_to_middle_bond_animation) + + rdata["contents"][1]["contents"][cidx]["animate"] = bond_animation + rdata["contents"][1]["contents"][cidx]["keyframes"] = list( + range(len(DISPLACE_COEF)) + ) + rdata["contents"][1]["contents"][cidx]["animateType"] = "displacement" + + # remove unused sense + for i in range(2, 4): + rdata["contents"][i]["visible"] = False + + return rdata @staticmethod def _get_ph_bs_dos( @@ -303,6 +457,7 @@ def get_ph_bandstructure_traces(bs, freq_range): "line": {"color": "#1f77b4"}, "hoverinfo": "skip", "name": "Total", + "customdata": [[di, band_num] for di in range(len(x_dat))], "hovertemplate": "%{y:.2f} THz", "showlegend": False, "xaxis": "x", @@ -348,6 +503,9 @@ def get_ph_bandstructure_traces(bs, freq_range): def _get_data_list_dict( bs: PhononBandStructureSymmLine, dos: CompletePhononDos ) -> dict[str, str | bool | int]: + if (not bs) and (not dos): + return {} + bs_minpoint, bs_min_freq = bs.min_freq() min_freq_report = ( f"{bs_min_freq:.2f} THz at frac. coords. {bs_minpoint.frac_coords}" @@ -373,7 +531,7 @@ def _get_data_list_dict( target="blank", ), ] - ): "Yes" if bs.has_nac else "No", + ): ("Yes" if bs.has_nac else "No"), "Has imaginary frequencies": "Yes" if bs.has_imaginary_freq() else "No", "Has eigen-displacements": "Yes" if bs.has_eigendisplacements else "No", "Min frequency": min_freq_report, @@ -443,14 +601,9 @@ def get_figure( ph_dos: CompletePhononDos | None = None, freq_range: tuple[float | None, float | None] = (None, None), ) -> go.Figure: - if freq_range[0] is None: - freq_range = (np.min(ph_bs.bands) * 1.05, freq_range[1]) - - if freq_range[1] is None: - freq_range = (freq_range[0], np.max(ph_bs.bands) * 1.05) - if (not ph_dos) and (not ph_bs): empty_plot_style = { + "height": 500, "xaxis": {"visible": False}, "yaxis": {"visible": False}, "paper_bgcolor": "rgba(0,0,0,0)", @@ -459,6 +612,12 @@ def get_figure( return go.Figure(layout=empty_plot_style) + if freq_range[0] is None: + freq_range = (np.min(ph_bs.bands) * 1.05, freq_range[1]) + + if freq_range[1] is None: + freq_range = (freq_range[0], np.max(ph_bs.bands) * 1.05) + if ph_bs: ( bs_traces, @@ -555,7 +714,7 @@ def get_figure( paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(230,230,230,230)", margin=dict(l=60, b=50, t=50, pad=0, r=30), - # clickmode="event+select" + clickmode="event+select", ) figure = {"data": bs_traces + dos_traces, "layout": layout} @@ -580,124 +739,25 @@ def get_figure( def generate_callbacks(self, app, cache) -> None: @app.callback( Output(self.id("ph-bsdos-graph"), "figure"), - Input(self.id("traces"), "data"), + Output(self.id("zone"), "data"), + Output(self.id("table"), "children"), + Input(self.id("ph_bs"), "data"), + Input(self.id("ph_dos"), "data"), ) - def update_graph(traces): - if traces == "error": - msg_body = MessageBody( - dcc.Markdown( - "Band structure and density of states not available for this selection." - ) - ) - return (MessageContainer([msg_body], kind="warning"),) - - if traces is None: - raise PreventUpdate - - bs, dos = self._get_ph_bs_dos(self.initial_data["default"]) + def update_graph(bs, dos): + if isinstance(bs, dict): + bs = PhononBandStructureSymmLine.from_dict(bs) + if isinstance(dos, dict): + dos = CompletePhononDos.from_dict(dos) figure = self.get_figure(bs, dos) - return dcc.Graph( - figure=figure, config={"displayModeBar": False}, responsive=True - ) - @app.callback( - Output(self.id("label-select"), "value"), - Output(self.id("label-container"), "style"), - Input(self.id("mpid"), "data"), - Input(self.id("path-convention"), "value"), - ) - def update_label_select(mpid, path_convention): - if not mpid: - raise PreventUpdate - label_value = path_convention - label_style = {"maxWidth": "200"} + zone_scene = self.get_brillouin_zone_scene(bs) - return label_value, label_style + summary_dict = self._get_data_list_dict(bs, dos) + summary_table = get_data_list(summary_dict) - @app.callback( - Output(self.id("dos-select"), "options"), - Output(self.id("path-convention"), "options"), - Output(self.id("path-container"), "style"), - Input(self.id("elements"), "data"), - Input(self.id("mpid"), "data"), - ) - def update_select(elements, mpid): - if elements is None: - raise PreventUpdate - if not mpid: - dos_options = ( - [{"label": "Element Projected", "value": "ap"}] - + [{"label": "Orbital Projected - Total", "value": "op"}] - + [ - { - "label": "Orbital Projected - " + str(ele_label), - "value": "orb" + str(ele_label), - } - for ele_label in elements - ] - ) - - path_options = [{"label": "N/A", "value": "sc"}] - path_style = {"maxWidth": "200", "display": "none"} - - return dos_options, path_options, path_style - dos_options = ( - [{"label": "Element Projected", "value": "ap"}] - + [{"label": "Orbital Projected - Total", "value": "op"}] - + [ - { - "label": "Orbital Projected - " + str(ele_label), - "value": "orb" + str(ele_label), - } - for ele_label in elements - ] - ) - - path_options = [ - {"label": "Setyawan-Curtarolo", "value": "sc"}, - {"label": "Latimer-Munro", "value": "lm"}, - {"label": "Hinuma et al.", "value": "hin"}, - ] - - path_style = {"maxWidth": "200"} - - return dos_options, path_options, path_style - - @app.callback( - Output(self.id("traces"), "data"), - Output(self.id("elements"), "data"), - Input(self.id(), "data"), - Input(self.id("path-convention"), "value"), - Input(self.id("dos-select"), "value"), - Input(self.id("label-select"), "value"), - ) - def bs_dos_data(data, dos_select, label_select): - # Obtain bands to plot over and generate traces for bs data: - energy_window = (-6.0, 10.0) - - traces = [] - - bsml, density_of_states = self._get_ph_bs_dos(data) - - if self.bandstructure_symm_line: - bs_traces = self.get_ph_bandstructure_traces( - bsml, freq_range=energy_window - ) - traces.append(bs_traces) - - if self.density_of_states: - dos_traces = self.get_ph_dos_traces( - density_of_states, freq_range=energy_window - ) - traces.append(dos_traces) - - # traces = [bs_traces, dos_traces, bs_data] - - # TODO: not tested if this is correct way to get element list - elements = list(map(str, density_of_states.get_element_dos())) - - return traces, elements + return figure, zone_scene.to_json(), summary_table @app.callback( Output(self.id("brillouin-zone"), "data"), @@ -711,8 +771,42 @@ def highlight_bz_on_hover_bs(hover_data, click_data, label_select): # TODO: figure out what to return (CSS?) to highlight BZ edge/point return - # TODO: figure out what to return (CSS?) to highlight BZ edge/point - return + @app.callback( + Output(self.id("crystal-animation"), "data"), + Input(self.id("ph-bsdos-graph"), "clickData"), + Input(self.id("ph_bs"), "data"), + # prevent_initial_call=True + ) + def update_crystal_animation(cd, bs): + if not bs: + raise PreventUpdate + + if isinstance(bs, dict): + bs = PhononBandStructureSymmLine.from_dict(bs) + + struc_graph = StructureGraph.from_local_env_strategy( + bs.structure, CrystalNN() + ) + scene = struc_graph.get_scene( + draw_image_atoms=False, + bonded_sites_outside_unit_cell=False, + site_get_scene_kwargs={"retain_atom_idx": True}, + ) + json_data = scene.to_json() + + qpoint = 0 + band_num = 0 + + if cd and cd.get("points"): + pt = cd["points"][0] + qpoint, band_num = pt.get("customdata", [0, 0]) + + return PhononBandstructureAndDosComponent._get_eigendisplacement( + ph_bs=bs, + json_data=json_data, + band=band_num, + qpoint=qpoint, + ) class PhononBandstructureAndDosPanelComponent(PanelComponent):