1515# crystal animation algo
1616from pymatgen .analysis .graphs import StructureGraph
1717from pymatgen .analysis .local_env import CrystalNN
18+ from pymatgen .core import Species
1819from pymatgen .ext .matproj import MPRester
1920from pymatgen .phonon .bandstructure import PhononBandStructureSymmLine
2021from pymatgen .phonon .dos import CompletePhononDos
2122from pymatgen .phonon .plotter import PhononBSPlotter
2223from pymatgen .transformations .standard_transformations import SupercellTransformation
2324
25+ from crystal_toolkit .core .legend import Legend
2426from crystal_toolkit .core .mpcomponent import MPComponent
2527from crystal_toolkit .core .panelcomponent import PanelComponent
2628from crystal_toolkit .core .scene import Convex , Cylinders , Lines , Scene , Spheres
3840MAX_MAGNITUDE = 500
3941MIN_MAGNITUDE = 0
4042
43+ DEFAULTS : dict [str , str | bool ] = {
44+ "color_scheme" : "VESTA" ,
45+ }
46+
4147
4248# TODOs:
4349# - look for additional projection methods in phonon DOS (currently only atom
@@ -194,6 +200,8 @@ def _sub_layouts(self) -> dict[str, Component]:
194200 hr = html .Hr (
195201 style = {
196202 "backgroundColor" : "#C5C5C6" ,
203+ "border" : "none" ,
204+ "margin" : "8px 0" ,
197205 }
198206 )
199207
@@ -298,9 +306,14 @@ def _get_animation_panel(self):
298306 [
299307 sub_layouts ["crystal-animation" ],
300308 sub_layouts ["crystal-animation-controls" ],
301- ]
309+ ],
310+ style = {
311+ "display" : "flex" ,
312+ "justify-content" : "center" ,
313+ "gap" : "10px" ,
314+ },
302315 ),
303- ]
316+ ],
304317 ),
305318 ],
306319 )
@@ -844,6 +857,44 @@ def get_figure(
844857
845858 return figure
846859
860+ def _make_legend (self , legend ):
861+ # this is copied and customized from crystal_toolkit.components.structure.StructureMoleculeComponent
862+ # in order to get the consistent legend with the structure viewer
863+ if not legend :
864+ return html .Div (id = self .id ("legend" ))
865+
866+ def get_font_color (hex_code ):
867+ # ensures contrasting font color for background color
868+ c = tuple (int (hex_code [1 :][i : i + 2 ], 16 ) for i in (0 , 2 , 4 ))
869+ return (
870+ "black"
871+ if 1 - (c [0 ] * 0.299 + c [1 ] * 0.587 + c [2 ] * 0.114 ) / 255 < 0.5
872+ else "white"
873+ )
874+
875+ legend_colors = {
876+ key : self ._legend .get_color (Species (key ))
877+ for key , val in legend ["composition" ].items ()
878+ }
879+
880+ legend_elements = [
881+ html .Span (
882+ html .Span (
883+ name , className = "icon" , style = {"color" : get_font_color (color )}
884+ ),
885+ className = "button is-static is-rounded" ,
886+ style = {"backgroundColor" : color },
887+ )
888+ for name , color in legend_colors .items ()
889+ ]
890+
891+ return html .Div (
892+ legend_elements ,
893+ id = self .id ("legend" ),
894+ style = {"display" : "flex" },
895+ className = "buttons" ,
896+ )
897+
847898 def generate_callbacks (self , app , cache ) -> None :
848899 @app .callback (
849900 Output (self .id ("ph-bsdos-graph" ), "figure" ),
@@ -914,6 +965,7 @@ def highlight_bz_on_hover_bs(hover_data, click_data, label_select):
914965
915966 @app .callback (
916967 Output (self .id ("crystal-animation" ), "data" ),
968+ Output (self .id ("crystal-animation" ), "children" ),
917969 Input (self .id ("ph-bsdos-graph" ), "clickData" ),
918970 Input (self .id ("ph_bs" ), "data" ),
919971 Input (self .id ("supercell-controls-btn" ), "n_clicks" ),
@@ -955,6 +1007,7 @@ def update_crystal_animation(
9551007
9561008 struct = bs .structure
9571009 total_repeat_cell_cnt = 1
1010+
9581011 # update structure if the controls got triggered
9591012 if sueprcell_update :
9601013 total_repeat_cell_cnt = scale_x * scale_y * scale_z
@@ -964,15 +1017,36 @@ def update_crystal_animation(
9641017 ((scale_x , 0 , 0 ), (0 , scale_y , 0 ), (0 , 0 , scale_z ))
9651018 )
9661019 struct = trans .apply_transformation (struct )
1020+
9671021 struc_graph = StructureGraph .from_local_env_strategy (struct , CrystalNN ())
1022+
1023+ # legend
1024+ legend = Legend (
1025+ struc_graph .structure ,
1026+ color_scheme = DEFAULTS ["color_scheme" ],
1027+ # radius_scheme=radius_strategy,
1028+ cmap_range = None ,
1029+ )
1030+ self ._legend = legend
1031+ legend_layout = html .Div (self ._make_legend (legend .get_legend ()))
1032+
1033+ # scene
9681034 scene = struc_graph .get_scene (
9691035 draw_image_atoms = False ,
9701036 bonded_sites_outside_unit_cell = False ,
9711037 site_get_scene_kwargs = {
9721038 "retain_atom_idx" : True ,
9731039 "total_repeat_cell_cnt" : total_repeat_cell_cnt ,
9741040 },
1041+ legend = legend ,
9751042 )
1043+
1044+ # axis
1045+ axes = struct .lattice ._axes_from_lattice ()
1046+ axes .visible = True
1047+ scene .contents .append (axes )
1048+
1049+ #
9761050 json_data = scene .to_json ()
9771051
9781052 qpoint = 0
@@ -995,7 +1069,7 @@ def update_crystal_animation(
9951069 total_repeat_cell_cnt = total_repeat_cell_cnt ,
9961070 magnitude = magnitude ,
9971071 velocity = velocity ,
998- )
1072+ ), [ None , legend_layout ]
9991073
10001074
10011075class PhononBandstructureAndDosPanelComponent (PanelComponent ):
0 commit comments