Skip to content

Commit 581f5fe

Browse files
committed
add legend, axis, and consistent color scheme
1 parent 0627e57 commit 581f5fe

1 file changed

Lines changed: 77 additions & 3 deletions

File tree

crystal_toolkit/components/phonon.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
# crystal animation algo
1616
from pymatgen.analysis.graphs import StructureGraph
1717
from pymatgen.analysis.local_env import CrystalNN
18+
from pymatgen.core import Species
1819
from pymatgen.ext.matproj import MPRester
1920
from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine
2021
from pymatgen.phonon.dos import CompletePhononDos
2122
from pymatgen.phonon.plotter import PhononBSPlotter
2223
from pymatgen.transformations.standard_transformations import SupercellTransformation
2324

25+
from crystal_toolkit.core.legend import Legend
2426
from crystal_toolkit.core.mpcomponent import MPComponent
2527
from crystal_toolkit.core.panelcomponent import PanelComponent
2628
from crystal_toolkit.core.scene import Convex, Cylinders, Lines, Scene, Spheres
@@ -38,6 +40,10 @@
3840
MAX_MAGNITUDE = 500
3941
MIN_MAGNITUDE = 0
4042

43+
DEFAULTS: dict[str, str | bool] = {
44+
"color_scheme": "VESTA",
45+
}
46+
4147

4248
# TODOs:
4349
# - look for additional projection methods in phonon DOS (currently only atom
@@ -194,6 +200,8 @@ def _sub_layouts(self) -> dict[str, Component]:
194200
hr = html.Hr(
195201
style={
196202
"backgroundColor": "#C5C5C6",
203+
"border": "none",
204+
"margin": "8px 0",
197205
}
198206
)
199207

@@ -298,9 +306,14 @@ def _get_animation_panel(self):
298306
[
299307
sub_layouts["crystal-animation"],
300308
sub_layouts["crystal-animation-controls"],
301-
]
309+
],
310+
style={
311+
"display": "flex",
312+
"justify-content": "center",
313+
"gap": "10px",
314+
},
302315
),
303-
]
316+
],
304317
),
305318
],
306319
)
@@ -844,6 +857,44 @@ def get_figure(
844857

845858
return figure
846859

860+
def _make_legend(self, legend):
861+
# this is copied and customized from crystal_toolkit.components.structure.StructureMoleculeComponent
862+
# in order to get the consistent legend with the structure viewer
863+
if not legend:
864+
return html.Div(id=self.id("legend"))
865+
866+
def get_font_color(hex_code):
867+
# ensures contrasting font color for background color
868+
c = tuple(int(hex_code[1:][i : i + 2], 16) for i in (0, 2, 4))
869+
return (
870+
"black"
871+
if 1 - (c[0] * 0.299 + c[1] * 0.587 + c[2] * 0.114) / 255 < 0.5
872+
else "white"
873+
)
874+
875+
legend_colors = {
876+
key: self._legend.get_color(Species(key))
877+
for key, val in legend["composition"].items()
878+
}
879+
880+
legend_elements = [
881+
html.Span(
882+
html.Span(
883+
name, className="icon", style={"color": get_font_color(color)}
884+
),
885+
className="button is-static is-rounded",
886+
style={"backgroundColor": color},
887+
)
888+
for name, color in legend_colors.items()
889+
]
890+
891+
return html.Div(
892+
legend_elements,
893+
id=self.id("legend"),
894+
style={"display": "flex"},
895+
className="buttons",
896+
)
897+
847898
def generate_callbacks(self, app, cache) -> None:
848899
@app.callback(
849900
Output(self.id("ph-bsdos-graph"), "figure"),
@@ -914,6 +965,7 @@ def highlight_bz_on_hover_bs(hover_data, click_data, label_select):
914965

915966
@app.callback(
916967
Output(self.id("crystal-animation"), "data"),
968+
Output(self.id("crystal-animation"), "children"),
917969
Input(self.id("ph-bsdos-graph"), "clickData"),
918970
Input(self.id("ph_bs"), "data"),
919971
Input(self.id("supercell-controls-btn"), "n_clicks"),
@@ -955,6 +1007,7 @@ def update_crystal_animation(
9551007

9561008
struct = bs.structure
9571009
total_repeat_cell_cnt = 1
1010+
9581011
# update structure if the controls got triggered
9591012
if sueprcell_update:
9601013
total_repeat_cell_cnt = scale_x * scale_y * scale_z
@@ -964,15 +1017,36 @@ def update_crystal_animation(
9641017
((scale_x, 0, 0), (0, scale_y, 0), (0, 0, scale_z))
9651018
)
9661019
struct = trans.apply_transformation(struct)
1020+
9671021
struc_graph = StructureGraph.from_local_env_strategy(struct, CrystalNN())
1022+
1023+
# legend
1024+
legend = Legend(
1025+
struc_graph.structure,
1026+
color_scheme=DEFAULTS["color_scheme"],
1027+
# radius_scheme=radius_strategy,
1028+
cmap_range=None,
1029+
)
1030+
self._legend = legend
1031+
legend_layout = html.Div(self._make_legend(legend.get_legend()))
1032+
1033+
# scene
9681034
scene = struc_graph.get_scene(
9691035
draw_image_atoms=False,
9701036
bonded_sites_outside_unit_cell=False,
9711037
site_get_scene_kwargs={
9721038
"retain_atom_idx": True,
9731039
"total_repeat_cell_cnt": total_repeat_cell_cnt,
9741040
},
1041+
legend=legend,
9751042
)
1043+
1044+
# axis
1045+
axes = struct.lattice._axes_from_lattice()
1046+
axes.visible = True
1047+
scene.contents.append(axes)
1048+
1049+
#
9761050
json_data = scene.to_json()
9771051

9781052
qpoint = 0
@@ -995,7 +1069,7 @@ def update_crystal_animation(
9951069
total_repeat_cell_cnt=total_repeat_cell_cnt,
9961070
magnitude=magnitude,
9971071
velocity=velocity,
998-
)
1072+
), [None, legend_layout]
9991073

10001074

10011075
class PhononBandstructureAndDosPanelComponent(PanelComponent):

0 commit comments

Comments
 (0)