|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +"""MTG visualization helpers. |
| 4 | +
|
| 5 | +The compact MTG view is a timeline diagnostic. Step panels reuse the molecule- |
| 6 | +like ITS renderer so each reconstructed ITS step is inspected with the same |
| 7 | +visual language as normal tuple ITS drawings. |
| 8 | +""" |
| 9 | + |
| 10 | +from typing import Any, Iterable, Optional |
| 11 | + |
| 12 | +import matplotlib.pyplot as plt |
| 13 | +import networkx as nx |
| 14 | + |
| 15 | +from synkit.Vis.its_drawer import draw_its_only |
| 16 | +from synkit.Vis.visual_drawer import draw_graph |
| 17 | + |
| 18 | + |
| 19 | +def draw_mtg_graph( |
| 20 | + mtg: Any, |
| 21 | + *, |
| 22 | + ax: Optional[plt.Axes] = None, |
| 23 | + title: Optional[str] = None, |
| 24 | + mode: str = "timeline", |
| 25 | + layout: str = "kamada_kawai", |
| 26 | + show_atom_map: bool = True, |
| 27 | + show_edge_labels: bool = True, |
| 28 | + show_node_badges: bool = True, |
| 29 | +) -> tuple[plt.Figure, plt.Axes]: |
| 30 | + """Draw a compact MTG timeline graph. |
| 31 | +
|
| 32 | + ``mtg`` may be a :class:`synkit.Graph.MTG.mtg.MTG` instance or a raw |
| 33 | + compact MTG ``networkx.Graph`` from ``MTG.get_mtg()``. |
| 34 | +
|
| 35 | + :param mtg: MTG object or compact MTG graph. |
| 36 | + :type mtg: Any |
| 37 | + :param ax: Optional Matplotlib axes. |
| 38 | + :type ax: Optional[plt.Axes] |
| 39 | + :param title: Optional title. |
| 40 | + :type title: Optional[str] |
| 41 | + :param mode: Visual adapter mode. ``"timeline"`` is the recommended MTG |
| 42 | + view; ``"sigma_pi"`` gives a shorter electron-bond diagnostic. |
| 43 | + :type mode: str |
| 44 | + :param layout: NetworkX layout name passed to ``draw_graph``. |
| 45 | + :type layout: str |
| 46 | + :returns: ``(figure, axes)``. |
| 47 | + :rtype: tuple[plt.Figure, plt.Axes] |
| 48 | + """ |
| 49 | + |
| 50 | + graph = _as_mtg_graph(mtg) |
| 51 | + return draw_graph( |
| 52 | + graph, |
| 53 | + ax=ax, |
| 54 | + mode=mode, |
| 55 | + title=title or "MTG timeline", |
| 56 | + show_atom_map=show_atom_map, |
| 57 | + layout=layout, |
| 58 | + show_edge_labels=show_edge_labels, |
| 59 | + show_node_badges=show_node_badges, |
| 60 | + ) |
| 61 | + |
| 62 | + |
| 63 | +def draw_mtg_steps( |
| 64 | + mtg: Any, |
| 65 | + *, |
| 66 | + steps: Optional[Iterable[int]] = None, |
| 67 | + include_composed: bool = False, |
| 68 | + title: Optional[str] = None, |
| 69 | + max_columns: int = 3, |
| 70 | + show_atom_map: bool = True, |
| 71 | + label_mode: str = "hetero", |
| 72 | + edge_label_mode: str = "kekule", |
| 73 | + show_edge_labels: bool = False, |
| 74 | + show_electron_labels: bool = False, |
| 75 | + electron_label_mode: str = "charge", |
| 76 | +) -> tuple[plt.Figure, list[plt.Axes]]: |
| 77 | + """Draw reconstructed MTG ITS steps as ordered panels. |
| 78 | +
|
| 79 | + :param mtg: MTG object exposing ``get_its_steps``. |
| 80 | + :type mtg: Any |
| 81 | + :param steps: Optional zero-based step indices to draw. |
| 82 | + :type steps: Optional[Iterable[int]] |
| 83 | + :param include_composed: Append the composed outer-state ITS panel. |
| 84 | + :type include_composed: bool |
| 85 | + :param title: Optional figure title. |
| 86 | + :type title: Optional[str] |
| 87 | + :param max_columns: Maximum subplot columns. |
| 88 | + :type max_columns: int |
| 89 | + :returns: ``(figure, axes)``. |
| 90 | + :rtype: tuple[plt.Figure, list[plt.Axes]] |
| 91 | + """ |
| 92 | + |
| 93 | + if not hasattr(mtg, "get_its_steps"): |
| 94 | + raise TypeError("draw_mtg_steps expects an MTG object with get_its_steps().") |
| 95 | + |
| 96 | + all_steps = list(mtg.get_its_steps()) |
| 97 | + selected = list(range(len(all_steps))) if steps is None else list(steps) |
| 98 | + for step in selected: |
| 99 | + if step < 0 or step >= len(all_steps): |
| 100 | + raise IndexError(f"MTG step index out of range: {step}") |
| 101 | + |
| 102 | + panels = [(f"Step {step + 1}", all_steps[step]) for step in selected] |
| 103 | + if include_composed: |
| 104 | + if not hasattr(mtg, "get_compose_its"): |
| 105 | + raise TypeError("include_composed requires an MTG object with get_compose_its().") |
| 106 | + panels.append(("Composed", mtg.get_compose_its())) |
| 107 | + |
| 108 | + if not panels: |
| 109 | + raise ValueError("No MTG steps selected for drawing.") |
| 110 | + |
| 111 | + ncols = min(max(1, max_columns), len(panels)) |
| 112 | + nrows = (len(panels) + ncols - 1) // ncols |
| 113 | + fig, axes_grid = plt.subplots( |
| 114 | + nrows, |
| 115 | + ncols, |
| 116 | + figsize=(4.8 * ncols, 4.2 * nrows), |
| 117 | + squeeze=False, |
| 118 | + facecolor="white", |
| 119 | + ) |
| 120 | + axes = [ax for row in axes_grid for ax in row] |
| 121 | + if title: |
| 122 | + fig.suptitle(title, fontsize=13, fontweight="bold") |
| 123 | + |
| 124 | + for ax, (panel_title, its) in zip(axes, panels): |
| 125 | + draw_its_only( |
| 126 | + its, |
| 127 | + ax=ax, |
| 128 | + title=panel_title, |
| 129 | + show_atom_map=show_atom_map, |
| 130 | + label_mode=label_mode, |
| 131 | + edge_label_mode=edge_label_mode, |
| 132 | + show_edge_labels=show_edge_labels, |
| 133 | + show_electron_labels=show_electron_labels, |
| 134 | + electron_label_mode=electron_label_mode, |
| 135 | + ) |
| 136 | + |
| 137 | + for ax in axes[len(panels):]: |
| 138 | + ax.set_axis_off() |
| 139 | + |
| 140 | + fig.tight_layout() |
| 141 | + return fig, axes[: len(panels)] |
| 142 | + |
| 143 | + |
| 144 | +def _as_mtg_graph(mtg: Any) -> nx.Graph: |
| 145 | + if isinstance(mtg, nx.Graph): |
| 146 | + return mtg |
| 147 | + if hasattr(mtg, "get_mtg"): |
| 148 | + graph = mtg.get_mtg() |
| 149 | + if isinstance(graph, nx.Graph): |
| 150 | + return graph |
| 151 | + raise TypeError("Expected an MTG object or a NetworkX compact MTG graph.") |
0 commit comments