Skip to content

Commit 0dfa50b

Browse files
authored
Merge pull request #514 from materialsproject/phonon-update
fix click point index
2 parents 3824816 + 7ae1fe4 commit 0dfa50b

1 file changed

Lines changed: 124 additions & 33 deletions

File tree

crystal_toolkit/components/phonon.py

Lines changed: 124 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def _sub_layouts(self) -> dict[str, Component]:
240240
label="x",
241241
min=1,
242242
max=10,
243-
style={"height": "15px"},
243+
style={"height": "10px"},
244244
label_style={"textAlign": "center"},
245245
),
246246
self.get_numerical_input(
@@ -251,7 +251,7 @@ def _sub_layouts(self) -> dict[str, Component]:
251251
label="y",
252252
min=1,
253253
max=10,
254-
style={"height": "15px"},
254+
style={"height": "10px"},
255255
label_style={"textAlign": "center"},
256256
),
257257
self.get_numerical_input(
@@ -262,14 +262,13 @@ def _sub_layouts(self) -> dict[str, Component]:
262262
label="z",
263263
min=1,
264264
max=10,
265-
style={"height": "15px"},
265+
style={"height": "10px"},
266266
label_style={"textAlign": "center"},
267267
),
268268
],
269269
style={
270270
"display": "flex",
271271
"justify-content": "center",
272-
"gap": "16px",
273272
},
274273
),
275274
hr,
@@ -281,12 +280,14 @@ def _sub_layouts(self) -> dict[str, Component]:
281280
domain=[0, 1],
282281
label="Vibration magnitude",
283282
styleInput={
284-
"height": "32px",
283+
"height": "28px",
285284
"box-sizing": "border-box",
286285
"borderRadius": "4px",
287286
"width": "5rem",
287+
"margin": "0",
288288
},
289-
label_style={"textAlign": "center"},
289+
styleSlider={"margin": "0"},
290+
label_style={"textAlign": "center", "margin": "0"},
290291
),
291292
),
292293
hr,
@@ -298,15 +299,52 @@ def _sub_layouts(self) -> dict[str, Component]:
298299
domain=[0, 1],
299300
label="Velocity",
300301
styleInput={
301-
"height": "32px",
302+
"height": "28px",
302303
"box-sizing": "border-box",
303304
"borderRadius": "4px",
304305
"width": "5rem",
306+
"margin": "0",
305307
},
306-
label_style={"textAlign": "center"},
308+
styleSlider={"margin": "0"},
309+
label_style={"textAlign": "center", "margin": "0"},
307310
),
308311
),
309312
hr,
313+
html.Div(
314+
[
315+
html.H6(
316+
"Color scheme",
317+
style={
318+
"textAlign": "center",
319+
},
320+
),
321+
dcc.Dropdown(
322+
id=self.id("color-scheme"),
323+
value="VESTA",
324+
options=[
325+
{
326+
"label": "VESTA",
327+
"value": "VESTA",
328+
},
329+
{
330+
"label": "Jmol",
331+
"value": "Jmol",
332+
},
333+
],
334+
style={
335+
"width": "10rem",
336+
"height": "30px",
337+
"fontSize": "12px",
338+
"display": "inline-block",
339+
},
340+
),
341+
],
342+
style={
343+
"textAlign": "center",
344+
"marginBottom": "0",
345+
},
346+
),
347+
hr,
310348
html.Div(
311349
html.Button(
312350
"Update",
@@ -320,7 +358,8 @@ def _sub_layouts(self) -> dict[str, Component]:
320358
"width": "100%",
321359
},
322360
),
323-
]
361+
],
362+
open=True,
324363
)
325364

326365
return {
@@ -345,13 +384,25 @@ def _get_animation_panel(self):
345384
html.Br(),
346385
Columns(
347386
[
348-
sub_layouts["crystal-animation"],
349-
sub_layouts["crystal-animation-controls"],
387+
html.Div(
388+
sub_layouts["crystal-animation"],
389+
style={
390+
"display": "flex",
391+
"justify-content": "center",
392+
},
393+
),
394+
html.Div(
395+
sub_layouts["crystal-animation-controls"],
396+
style={
397+
"display": "flex",
398+
"justify-content": "flex-end",
399+
"paddingRight": "5%",
400+
},
401+
),
350402
],
351403
style={
352404
"display": "flex",
353405
"justify-content": "center",
354-
"gap": "10px",
355406
},
356407
),
357408
],
@@ -444,35 +495,66 @@ def _get_time_function_json(
444495
rdata["app"] = "phonon"
445496

446497
# omega (ω)
447-
rdata["omega"] = ph_bs.frequencies[band][qpoint]
498+
rdata["omega"] = ph_bs.frequencies[band][
499+
qpoint
500+
] # * 2 * np.pi # should include 2pi, but omitted here to achieve a better visualization
448501

449-
# Take mp-149 as an example:
502+
# The spatial dependence has been simplified:
503+
# The real calculation should be:
504+
# 1.
450505
# ph_bs.qpoints is "frac_coords of the given lattice by default (from Pymatgen)"
451506
# transfer from frac_coords to cart_coords
452507
# the size of ph_bs.structure.lattice.matrix: (3, 3) (lattice size)
453508
# the size of ph_bs.qpoints: (149, 3) (wave vector for each qpoint)
454509
# the size of q: (149, 3)
455-
# q:
456-
q = np.einsum(
457-
"ij,kj->ik",
458-
ph_bs.structure.lattice.reciprocal_lattice.matrix,
459-
np.array(ph_bs.qpoints),
460-
).T
461-
462-
# phases (q⋅R): should be a number
463-
# we calculate the phase with all atoms and qpoints here
464-
# the size of q: (149, 3)
510+
# q = np.einsum(
511+
# "ij,kj->ik",
512+
# ph_bs.structure.lattice.reciprocal_lattice.matrix,
513+
# np.array(ph_bs.qpoints),
514+
# ).T
515+
#
516+
# 2.
465517
# the size of ph_bs.structure.cart_coords: (2, 3) (the coordinate of two atoms in the unit cell)
466-
# the size of phase: (149, 2)
467-
phases = np.einsum(
468-
"ij,kj->ik",
469-
q,
470-
ph_bs.structure.cart_coords,
518+
# R = ph_bs.structure.cart_coords,
519+
#
520+
# 3.
521+
# phases = np.einsum(
522+
# "ij,kj->ik",
523+
# q,
524+
# R,
525+
# )
526+
527+
# Simplified:
528+
# q is fractional (reduced) coordinates:
529+
# q = q1*b1 + q2*b2 + q3*b3
530+
#
531+
# R is a lattice translation written in direct lattice coordinates:
532+
# R = n1*a1 + n2*a2 + n3*a3
533+
#
534+
# Reciprocal/direct basis satisfy:
535+
# ai · bj = 2π δij (δij = 1 if i==j else 0)
536+
#
537+
# Therefore the phase is:
538+
# q · R = 2π (q1*n1 + q2*n2 + q3*n3)
539+
#
540+
# compute:
541+
# phases = 2π * dot(q_frac, R_frac)
542+
543+
phases = (
544+
np.einsum(
545+
"ij,kj->ik",
546+
np.array(ph_bs.qpoints),
547+
ph_bs.structure.frac_coords,
548+
)
549+
* 2
550+
* np.pi
471551
)
472552
rdata["phases"] = phases[qpoint].tolist()
473553

474554
# amplitude (A)
475-
rdata["amplitude"] = magnitude
555+
rdata["amplitude"] = 1 / np.linalg.norm(
556+
ph_bs.eigendisplacements[0][0]
557+
) # magnitude
476558

477559
# eigenVectors
478560
rdata["eigenVectors"] = (
@@ -560,6 +642,7 @@ def get_ph_bandstructure_traces(bs, freq_range):
560642

561643
bs_traces = []
562644

645+
last_di = 0
563646
for d, dist_val in enumerate(bs_data["distances"]):
564647
x_dat = dist_val
565648

@@ -575,7 +658,9 @@ def get_ph_bandstructure_traces(bs, freq_range):
575658
"line": {"color": "#1f77b4"},
576659
"hoverinfo": "skip",
577660
"name": "Total",
578-
"customdata": [[di, band_num] for di in range(len(x_dat))],
661+
"customdata": [
662+
[di + last_di, band_num] for di in range(len(x_dat))
663+
],
579664
"hovertemplate": "%{y:.2f} THz",
580665
"showlegend": False,
581666
"xaxis": "x",
@@ -585,6 +670,7 @@ def get_ph_bandstructure_traces(bs, freq_range):
585670
]
586671

587672
bs_traces += traces_for_segment
673+
last_di += len(x_dat)
588674

589675
for entry_num in range(len(bs_data["ticks"]["label"])):
590676
for key in pretty_labels:
@@ -973,6 +1059,7 @@ def highlight_bz_on_hover_bs(hover_data, click_data, label_select):
9731059
State(self.get_kwarg_id("scale-y"), "value"),
9741060
State(self.get_kwarg_id("scale-z"), "value"),
9751061
State(self.get_kwarg_id("velocity"), "value"),
1062+
State(self.id("color-scheme"), "value"),
9761063
# prevent_initial_call=True
9771064
)
9781065
def update_crystal_animation(
@@ -984,6 +1071,7 @@ def update_crystal_animation(
9841071
scale_y,
9851072
scale_z,
9861073
velocity,
1074+
color_scheme,
9871075
):
9881076
# Avoids using `get_all_kwargs_id` for all `Input`; instead, uses `State` to prevent flickering when users modify `scale_x`, `scale_y`, or `scale_z` fields,
9891077
# ensuring updates occur only after the `supercell-controls-btn`` is clicked.
@@ -999,6 +1087,7 @@ def update_crystal_animation(
9991087
scale_y = kwargs.get("scale-y")
10001088
scale_z = kwargs.get("scale-z")
10011089
velocity = kwargs.get("velocity")
1090+
# color_scheme = kwargs.get("color-scheme")
10021091

10031092
if isinstance(bs, dict):
10041093
bs = PhononBS.from_pmg(bs)
@@ -1025,7 +1114,7 @@ def update_crystal_animation(
10251114
# legend
10261115
legend = Legend(
10271116
struc_graph.structure,
1028-
color_scheme=DEFAULTS["color_scheme"],
1117+
color_scheme=color_scheme,
10291118
# radius_scheme=radius_strategy,
10301119
cmap_range=None,
10311120
)
@@ -1056,7 +1145,9 @@ def update_crystal_animation(
10561145

10571146
if cd and cd.get("points"):
10581147
pt = cd["points"][0]
1059-
qpoint, band_num = pt.get("customdata", [0, 0])
1148+
qpoint, band_num = pt.get("customdata", [-1, -1])
1149+
if qpoint == -1 or band_num == -1:
1150+
raise ValueError("qpoint and band_num are invalid")
10601151

10611152
# magnitude
10621153
magnitude = (

0 commit comments

Comments
 (0)