@@ -240,7 +240,7 @@ def _sub_layouts(self) -> dict[str, Component]:
240240 label = "x" ,
241241 min = 1 ,
242242 max = 10 ,
243- style = {"height" : "15px " },
243+ style = {"height" : "10px " },
244244 label_style = {"textAlign" : "center" },
245245 ),
246246 self .get_numerical_input (
@@ -251,7 +251,7 @@ def _sub_layouts(self) -> dict[str, Component]:
251251 label = "y" ,
252252 min = 1 ,
253253 max = 10 ,
254- style = {"height" : "15px " },
254+ style = {"height" : "10px " },
255255 label_style = {"textAlign" : "center" },
256256 ),
257257 self .get_numerical_input (
@@ -262,14 +262,13 @@ def _sub_layouts(self) -> dict[str, Component]:
262262 label = "z" ,
263263 min = 1 ,
264264 max = 10 ,
265- style = {"height" : "15px " },
265+ style = {"height" : "10px " },
266266 label_style = {"textAlign" : "center" },
267267 ),
268268 ],
269269 style = {
270270 "display" : "flex" ,
271271 "justify-content" : "center" ,
272- "gap" : "16px" ,
273272 },
274273 ),
275274 hr ,
@@ -281,12 +280,14 @@ def _sub_layouts(self) -> dict[str, Component]:
281280 domain = [0 , 1 ],
282281 label = "Vibration magnitude" ,
283282 styleInput = {
284- "height" : "32px " ,
283+ "height" : "28px " ,
285284 "box-sizing" : "border-box" ,
286285 "borderRadius" : "4px" ,
287286 "width" : "5rem" ,
287+ "margin" : "0" ,
288288 },
289- label_style = {"textAlign" : "center" },
289+ styleSlider = {"margin" : "0" },
290+ label_style = {"textAlign" : "center" , "margin" : "0" },
290291 ),
291292 ),
292293 hr ,
@@ -298,15 +299,52 @@ def _sub_layouts(self) -> dict[str, Component]:
298299 domain = [0 , 1 ],
299300 label = "Velocity" ,
300301 styleInput = {
301- "height" : "32px " ,
302+ "height" : "28px " ,
302303 "box-sizing" : "border-box" ,
303304 "borderRadius" : "4px" ,
304305 "width" : "5rem" ,
306+ "margin" : "0" ,
305307 },
306- label_style = {"textAlign" : "center" },
308+ styleSlider = {"margin" : "0" },
309+ label_style = {"textAlign" : "center" , "margin" : "0" },
307310 ),
308311 ),
309312 hr ,
313+ html .Div (
314+ [
315+ html .H6 (
316+ "Color scheme" ,
317+ style = {
318+ "textAlign" : "center" ,
319+ },
320+ ),
321+ dcc .Dropdown (
322+ id = self .id ("color-scheme" ),
323+ value = "VESTA" ,
324+ options = [
325+ {
326+ "label" : "VESTA" ,
327+ "value" : "VESTA" ,
328+ },
329+ {
330+ "label" : "Jmol" ,
331+ "value" : "Jmol" ,
332+ },
333+ ],
334+ style = {
335+ "width" : "10rem" ,
336+ "height" : "30px" ,
337+ "fontSize" : "12px" ,
338+ "display" : "inline-block" ,
339+ },
340+ ),
341+ ],
342+ style = {
343+ "textAlign" : "center" ,
344+ "marginBottom" : "0" ,
345+ },
346+ ),
347+ hr ,
310348 html .Div (
311349 html .Button (
312350 "Update" ,
@@ -320,7 +358,8 @@ def _sub_layouts(self) -> dict[str, Component]:
320358 "width" : "100%" ,
321359 },
322360 ),
323- ]
361+ ],
362+ open = True ,
324363 )
325364
326365 return {
@@ -345,13 +384,25 @@ def _get_animation_panel(self):
345384 html .Br (),
346385 Columns (
347386 [
348- sub_layouts ["crystal-animation" ],
349- sub_layouts ["crystal-animation-controls" ],
387+ html .Div (
388+ sub_layouts ["crystal-animation" ],
389+ style = {
390+ "display" : "flex" ,
391+ "justify-content" : "center" ,
392+ },
393+ ),
394+ html .Div (
395+ sub_layouts ["crystal-animation-controls" ],
396+ style = {
397+ "display" : "flex" ,
398+ "justify-content" : "flex-end" ,
399+ "paddingRight" : "5%" ,
400+ },
401+ ),
350402 ],
351403 style = {
352404 "display" : "flex" ,
353405 "justify-content" : "center" ,
354- "gap" : "10px" ,
355406 },
356407 ),
357408 ],
@@ -444,35 +495,66 @@ def _get_time_function_json(
444495 rdata ["app" ] = "phonon"
445496
446497 # omega (ω)
447- rdata ["omega" ] = ph_bs .frequencies [band ][qpoint ]
498+ rdata ["omega" ] = ph_bs .frequencies [band ][
499+ qpoint
500+ ] # * 2 * np.pi # should include 2pi, but omitted here to achieve a better visualization
448501
449- # Take mp-149 as an example:
502+ # The spatial dependence has been simplified:
503+ # The real calculation should be:
504+ # 1.
450505 # ph_bs.qpoints is "frac_coords of the given lattice by default (from Pymatgen)"
451506 # transfer from frac_coords to cart_coords
452507 # the size of ph_bs.structure.lattice.matrix: (3, 3) (lattice size)
453508 # the size of ph_bs.qpoints: (149, 3) (wave vector for each qpoint)
454509 # the size of q: (149, 3)
455- # q:
456- q = np .einsum (
457- "ij,kj->ik" ,
458- ph_bs .structure .lattice .reciprocal_lattice .matrix ,
459- np .array (ph_bs .qpoints ),
460- ).T
461-
462- # phases (q⋅R): should be a number
463- # we calculate the phase with all atoms and qpoints here
464- # the size of q: (149, 3)
510+ # q = np.einsum(
511+ # "ij,kj->ik",
512+ # ph_bs.structure.lattice.reciprocal_lattice.matrix,
513+ # np.array(ph_bs.qpoints),
514+ # ).T
515+ #
516+ # 2.
465517 # the size of ph_bs.structure.cart_coords: (2, 3) (the coordinate of two atoms in the unit cell)
466- # the size of phase: (149, 2)
467- phases = np .einsum (
468- "ij,kj->ik" ,
469- q ,
470- ph_bs .structure .cart_coords ,
518+ # R = ph_bs.structure.cart_coords,
519+ #
520+ # 3.
521+ # phases = np.einsum(
522+ # "ij,kj->ik",
523+ # q,
524+ # R,
525+ # )
526+
527+ # Simplified:
528+ # q is fractional (reduced) coordinates:
529+ # q = q1*b1 + q2*b2 + q3*b3
530+ #
531+ # R is a lattice translation written in direct lattice coordinates:
532+ # R = n1*a1 + n2*a2 + n3*a3
533+ #
534+ # Reciprocal/direct basis satisfy:
535+ # ai · bj = 2π δij (δij = 1 if i==j else 0)
536+ #
537+ # Therefore the phase is:
538+ # q · R = 2π (q1*n1 + q2*n2 + q3*n3)
539+ #
540+ # compute:
541+ # phases = 2π * dot(q_frac, R_frac)
542+
543+ phases = (
544+ np .einsum (
545+ "ij,kj->ik" ,
546+ np .array (ph_bs .qpoints ),
547+ ph_bs .structure .frac_coords ,
548+ )
549+ * 2
550+ * np .pi
471551 )
472552 rdata ["phases" ] = phases [qpoint ].tolist ()
473553
474554 # amplitude (A)
475- rdata ["amplitude" ] = magnitude
555+ rdata ["amplitude" ] = 1 / np .linalg .norm (
556+ ph_bs .eigendisplacements [0 ][0 ]
557+ ) # magnitude
476558
477559 # eigenVectors
478560 rdata ["eigenVectors" ] = (
@@ -560,6 +642,7 @@ def get_ph_bandstructure_traces(bs, freq_range):
560642
561643 bs_traces = []
562644
645+ last_di = 0
563646 for d , dist_val in enumerate (bs_data ["distances" ]):
564647 x_dat = dist_val
565648
@@ -575,7 +658,9 @@ def get_ph_bandstructure_traces(bs, freq_range):
575658 "line" : {"color" : "#1f77b4" },
576659 "hoverinfo" : "skip" ,
577660 "name" : "Total" ,
578- "customdata" : [[di , band_num ] for di in range (len (x_dat ))],
661+ "customdata" : [
662+ [di + last_di , band_num ] for di in range (len (x_dat ))
663+ ],
579664 "hovertemplate" : "%{y:.2f} THz" ,
580665 "showlegend" : False ,
581666 "xaxis" : "x" ,
@@ -585,6 +670,7 @@ def get_ph_bandstructure_traces(bs, freq_range):
585670 ]
586671
587672 bs_traces += traces_for_segment
673+ last_di += len (x_dat )
588674
589675 for entry_num in range (len (bs_data ["ticks" ]["label" ])):
590676 for key in pretty_labels :
@@ -973,6 +1059,7 @@ def highlight_bz_on_hover_bs(hover_data, click_data, label_select):
9731059 State (self .get_kwarg_id ("scale-y" ), "value" ),
9741060 State (self .get_kwarg_id ("scale-z" ), "value" ),
9751061 State (self .get_kwarg_id ("velocity" ), "value" ),
1062+ State (self .id ("color-scheme" ), "value" ),
9761063 # prevent_initial_call=True
9771064 )
9781065 def update_crystal_animation (
@@ -984,6 +1071,7 @@ def update_crystal_animation(
9841071 scale_y ,
9851072 scale_z ,
9861073 velocity ,
1074+ color_scheme ,
9871075 ):
9881076 # Avoids using `get_all_kwargs_id` for all `Input`; instead, uses `State` to prevent flickering when users modify `scale_x`, `scale_y`, or `scale_z` fields,
9891077 # ensuring updates occur only after the `supercell-controls-btn`` is clicked.
@@ -999,6 +1087,7 @@ def update_crystal_animation(
9991087 scale_y = kwargs .get ("scale-y" )
10001088 scale_z = kwargs .get ("scale-z" )
10011089 velocity = kwargs .get ("velocity" )
1090+ # color_scheme = kwargs.get("color-scheme")
10021091
10031092 if isinstance (bs , dict ):
10041093 bs = PhononBS .from_pmg (bs )
@@ -1025,7 +1114,7 @@ def update_crystal_animation(
10251114 # legend
10261115 legend = Legend (
10271116 struc_graph .structure ,
1028- color_scheme = DEFAULTS [ " color_scheme" ] ,
1117+ color_scheme = color_scheme ,
10291118 # radius_scheme=radius_strategy,
10301119 cmap_range = None ,
10311120 )
@@ -1056,7 +1145,9 @@ def update_crystal_animation(
10561145
10571146 if cd and cd .get ("points" ):
10581147 pt = cd ["points" ][0 ]
1059- qpoint , band_num = pt .get ("customdata" , [0 , 0 ])
1148+ qpoint , band_num = pt .get ("customdata" , [- 1 , - 1 ])
1149+ if qpoint == - 1 or band_num == - 1 :
1150+ raise ValueError ("qpoint and band_num are invalid" )
10601151
10611152 # magnitude
10621153 magnitude = (
0 commit comments