|
| 1 | +"""helpers for making animations""" |
| 2 | + |
| 3 | +import matplotlib.pyplot as plt |
| 4 | +import numpy as np |
| 5 | +import logging |
| 6 | + |
| 7 | +from functools import partial |
| 8 | +from matplotlib.animation import FuncAnimation |
| 9 | +from inspect import isclass |
| 10 | + |
| 11 | +from ._helpers import norm_from_channel |
| 12 | +from ._interact import interact2D_fig |
| 13 | +from ._quick import Quick1DIterator, Quick2DIterator |
| 14 | +from ..kit import joint_shape |
| 15 | + |
| 16 | +__all__ = ["animate2D", "animate_interact2D", "animate_quick"] |
| 17 | +logger = logging.getLogger("animation") |
| 18 | + |
| 19 | + |
| 20 | +def animate2D( |
| 21 | + data, |
| 22 | + norm=None, |
| 23 | + channel=0, |
| 24 | + cmap=None, |
| 25 | + back_and_forth: bool = False, |
| 26 | + **ani_kwargs, |
| 27 | +) -> FuncAnimation: |
| 28 | + """ |
| 29 | + animate pcolormesh of a nd dataset (ndim >=2) |
| 30 | + mesh plots last two axes of the dataset (use `Data.transform` if needed) |
| 31 | +
|
| 32 | + Parameters |
| 33 | + ---------- |
| 34 | +
|
| 35 | + data: WrightTools.data |
| 36 | + dataset to animate. take the last two axes as the ones that are plotted; |
| 37 | + other axes compose the frames of the animation |
| 38 | +
|
| 39 | + norm: Normalize instance or callable |
| 40 | + determines the normalization rules to follow. |
| 41 | + If channel is signed, defaults to CenteredNorm with null center. |
| 42 | + If channel is unsigned, defaults to Normalize from null to max. |
| 43 | +
|
| 44 | + channel: string, index, or Channel |
| 45 | + Select which channel to plot |
| 46 | +
|
| 47 | + cmap: str or Colormap (optional) |
| 48 | + colormap used. Defaults to WrightTools default |
| 49 | +
|
| 50 | + back_and_forth: bool = False |
| 51 | + when True, the animation will go in reverse after going forward, |
| 52 | + creating a continuous loop when repeat is on |
| 53 | +
|
| 54 | + **kwargs: dict items |
| 55 | + all extra kwargs are passed to matplotlib.FuncAnimation |
| 56 | +
|
| 57 | + Returns |
| 58 | + ------- |
| 59 | +
|
| 60 | + animation: matplotlib.animation.animation |
| 61 | +
|
| 62 | + Example |
| 63 | + ------- |
| 64 | + General usage (create an animation procedure and then write to file): |
| 65 | + ``` |
| 66 | + norm=CenteredNorm(vcenter=0, halfrange=np.abs(d.channels[0][:]).max()) |
| 67 | + ani = wt.artists.animate2D( |
| 68 | + d, cmap="bwr", norm=norm, interval=100 |
| 69 | + ) |
| 70 | + ``` |
| 71 | + The animation has write to file utilities like `to_html5_video`: |
| 72 | + ``` |
| 73 | + with open(f"{d.natural_name}_animation.html", "w") as f: |
| 74 | + f.write(ani.to_html5_video()) |
| 75 | + ``` |
| 76 | + Alternatively, you can show in the interactive viewer and watch the animation: |
| 77 | + ``` |
| 78 | + plt.show() |
| 79 | + ``` |
| 80 | + For colorbar normalized at each frame, you can use `functools.partial`: |
| 81 | + ``` |
| 82 | + from functools import partial |
| 83 | + norm = partial(CenteredNorm, vcenter=0) # halfrange evaluated for each frame |
| 84 | + ``` |
| 85 | + """ |
| 86 | + |
| 87 | + channel = data.get_channel(channel) |
| 88 | + if norm is None: |
| 89 | + norm = norm_from_channel(channel) |
| 90 | + if cmap is None: |
| 91 | + cmap = "signed" if channel.signed else "default" |
| 92 | + # detect whether to call norm each frame |
| 93 | + # probably not an optimal implementation, but working for now |
| 94 | + call_norm = isclass(norm) or isinstance(norm, partial) |
| 95 | + |
| 96 | + def gen_title(ind): |
| 97 | + parts = [ |
| 98 | + f"{var.natural_name} = {var[:][ind].squeeze():.2f} {var.units}" |
| 99 | + for var in map(lambda a: a.variables[0], data.axes[:-2]) |
| 100 | + ] |
| 101 | + return "\n".join(parts) |
| 102 | + |
| 103 | + frame_shape = joint_shape(*[a[:] for a in data.axes[:-2]]) |
| 104 | + channel_shape = joint_shape(*[a[:] for a in data.axes[-2:]]) |
| 105 | + # mask indices that are spanned by the x and y axes |
| 106 | + mask = [ci > fi for ci, fi in zip(channel_shape, frame_shape)] |
| 107 | + logger.debug(f"{frame_shape=}, {channel_shape=}, {mask=}") |
| 108 | + |
| 109 | + fig, ax = plt.subplots(subplot_kw=dict(projection="wright"), dpi=140, layout="constrained") |
| 110 | + art = ax.pcolormesh( |
| 111 | + data[tuple([0 for i in data.shape[:-2]])], |
| 112 | + cmap=cmap, |
| 113 | + norm=norm() if call_norm else norm, |
| 114 | + ) |
| 115 | + colorbar = fig.colorbar(art, ax=ax) |
| 116 | + colorbar.set_label(channel.label) |
| 117 | + |
| 118 | + ax.set_title(gen_title(tuple([0 for _ in data.shape[:-2]]))) |
| 119 | + # with layout well set, turn off the engine (avoids jittering frames) |
| 120 | + fig.set_layout_engine("none") |
| 121 | + |
| 122 | + def updater(frame): |
| 123 | + frame = tuple(slice(None) if mi else fi for fi, mi in zip(frame, mask)) |
| 124 | + logger.info(f"{frame=}") |
| 125 | + art.set_array(channel[frame]) |
| 126 | + ax.set_title(gen_title(frame)) |
| 127 | + art.set_norm(norm() if call_norm else norm) |
| 128 | + fig.canvas.draw_idle() |
| 129 | + return art |
| 130 | + |
| 131 | + # generate frame sequence |
| 132 | + |
| 133 | + frames = list(np.ndindex(frame_shape)) |
| 134 | + if back_and_forth: |
| 135 | + frames += reversed(frames) |
| 136 | + |
| 137 | + return FuncAnimation( |
| 138 | + fig=fig, |
| 139 | + func=updater, |
| 140 | + frames=frames, |
| 141 | + **ani_kwargs, |
| 142 | + ) |
| 143 | + |
| 144 | + |
| 145 | +def animate_quick(q2d: Quick1DIterator | Quick2DIterator, **kwargs) -> FuncAnimation: |
| 146 | + """ |
| 147 | + animate a quick2Ds series |
| 148 | +
|
| 149 | + unlike other animation functions, this enforces repeat=False |
| 150 | +
|
| 151 | + Parameters |
| 152 | + ---------- |
| 153 | +
|
| 154 | + **kwargs: dict items |
| 155 | + all extra kwargs are passed to matplotlib.FuncAnimation |
| 156 | +
|
| 157 | +
|
| 158 | + Example |
| 159 | + ------- |
| 160 | + ```python |
| 161 | + quick_iter = wt.artists.quick1Ds(data, autosave=False, local=False) |
| 162 | + ani = wt.artists.animate_quick(quick_iter, interval=100) |
| 163 | + ``` |
| 164 | +
|
| 165 | + """ |
| 166 | + |
| 167 | + return FuncAnimation(fig=q2d.fig, func=lambda x: None, frames=q2d, **kwargs) |
| 168 | + |
| 169 | + |
| 170 | +def animate_interact2D( |
| 171 | + interact2D: interact2D_fig, back_and_forth=False, **kwargs |
| 172 | +) -> FuncAnimation: |
| 173 | + """ |
| 174 | + Take an interact2D figure and create an animation by moving the sliders. |
| 175 | +
|
| 176 | + Parameters |
| 177 | + ---------- |
| 178 | + interact2D: interact2D_fig |
| 179 | + the output of an interact2D call |
| 180 | +
|
| 181 | + back_and_forth: bool = False |
| 182 | + when True, the animation will go in reverse after going forward, |
| 183 | + creating a continuous steps of variables when repeat is on |
| 184 | +
|
| 185 | + **kwargs: dict items |
| 186 | + all extra kwargs are passed to matplotlib.FuncAnimation |
| 187 | +
|
| 188 | + Example |
| 189 | + ------- |
| 190 | + ```python |
| 191 | + interactive = wt.artists.interact2D(data, local=True) |
| 192 | + ani = wt.artists.animate_interact2D(interactive, back_and_forth=True, interval=500) |
| 193 | + ``` |
| 194 | + """ |
| 195 | + |
| 196 | + def update(frame): |
| 197 | + logger.info(f"{frame=}") |
| 198 | + for ind, slider in zip(frame, interact2D.sliders.values()): |
| 199 | + slider.set_val(ind) |
| 200 | + |
| 201 | + frames = list(np.ndindex(tuple([s.valmax + 1 for s in interact2D.sliders.values()]))) |
| 202 | + if back_and_forth: |
| 203 | + frames += reversed(frames) |
| 204 | + |
| 205 | + return FuncAnimation(fig=interact2D.fig, func=update, frames=frames, **kwargs) |
0 commit comments