Skip to content

Commit 994d875

Browse files
committed
Added types and docs
1 parent 819c556 commit 994d875

2 files changed

Lines changed: 54 additions & 26 deletions

File tree

src/tdamapper/plot_backends/plot_plotly.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import numpy as np
1212
import plotly.colors as pc
1313
import plotly.graph_objects as go
14+
from numpy.typing import NDArray
1415

1516
from tdamapper.core import ATTR_SIZE, aggregate_graph
1617

@@ -304,13 +305,15 @@ def plot(
304305
)
305306
return self.fig
306307

307-
def _node_pos_array(self):
308+
def _node_pos_array(self) -> tuple[list[float], ...]:
308309
return tuple(
309310
[self.positions[n][i] for n in self.graph.nodes()] for i in range(self.dim)
310311
)
311312

312-
def _edge_pos_array(self):
313-
edges_arr = tuple([] for i in range(self.dim))
313+
def _edge_pos_array(self) -> tuple[list[Optional[float]], ...]:
314+
edges_arr: tuple[list[Optional[float]], ...] = tuple(
315+
[] for i in range(self.dim)
316+
)
314317
for edge in self.graph.edges():
315318
pos0, pos1 = self.positions[edge[0]], self.positions[edge[1]]
316319
for i in range(self.dim):
@@ -530,7 +533,7 @@ def update_figure(
530533
if cmaps is not None:
531534
self.set_cmap(cmaps[0])
532535

533-
def _nodes_trace(self, node_pos_arr: ) -> Union[go.Scatter, go.Scatter3d]:
536+
def _nodes_trace(self, node_pos_arr) -> Union[go.Scatter, go.Scatter3d]:
534537
scatter = dict(
535538
name=_NODES_TRACE,
536539
x=node_pos_arr[0],
@@ -699,7 +702,8 @@ def set_ui(
699702

700703
if cmaps is not None:
701704
cmaps_plotly = [PLOTLY_CMAPS.get(c.lower()) for c in cmaps]
702-
ui_menu_cmap = self._ui_menu_cmap(cmaps_plotly)
705+
cmaps_plotly_ok = [c for c in cmaps_plotly if c is not None]
706+
ui_menu_cmap = self._ui_menu_cmap(cmaps_plotly_ok)
703707

704708
if colors is not None and agg is not None and titles is not None:
705709
ui_menu_color = self._ui_menu_color(colors, titles, agg)
@@ -798,7 +802,9 @@ def _update_cmap(cmap: str) -> dict:
798802
yanchor="top",
799803
)
800804

801-
def _ui_menu_color(self, colors: NDArray[np.float64], titles: list[str], agg: Callable) -> dict:
805+
def _ui_menu_color(
806+
self, colors: NDArray[np.float64], titles: list[str], agg: Callable
807+
) -> dict:
802808
colors_arr = np.array(colors)
803809
colors_num = colors_arr.shape[1] if colors_arr.ndim == 2 else 1
804810

src/tdamapper/plot_backends/plot_pyvis.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
"""
55

66
import math
7+
from typing import Any, Callable
78

9+
import numpy as np
810
import plotly.colors as pc
911
import plotly.graph_objects as go
1012
import plotly.io as pio
13+
from numpy.typing import NDArray
1114
from pyvis.network import Network
1215

1316
from tdamapper.core import aggregate_graph
@@ -19,12 +22,18 @@
1922
_TICKS_NUM = 10
2023

2124

22-
def _fmt(x, max_len=3):
25+
def _fmt(x: Any, max_len: int = 3) -> str:
2326
fmt = f".{max_len}g"
2427
return f"{x:{fmt}}"
2528

2629

27-
def _colorbar(height, cmap, cmin, cmax, title):
30+
def _colorbar(
31+
height: int,
32+
cmap: str,
33+
cmin: float,
34+
cmax: float,
35+
title: str,
36+
) -> go.Figure:
2837
colorbar_fig = go.Figure()
2938
colorbar_fig.add_trace(
3039
go.Scatter(
@@ -68,7 +77,7 @@ def _colorbar(height, cmap, cmin, cmax, title):
6877
return colorbar_fig
6978

7079

71-
def _combine(network, colorbar):
80+
def _combine(network: Network, colorbar: go.Figure) -> str:
7281
network_html = network.generate_html()
7382
colorbar_html = pio.to_html(
7483
colorbar,
@@ -130,15 +139,28 @@ def _combine(network, colorbar):
130139

131140
def plot_pyvis(
132141
mapper_plot,
133-
output_file,
134-
colors,
135-
node_size,
136-
agg,
137-
title,
138-
width,
139-
height,
140-
cmap,
141-
):
142+
output_file: str,
143+
colors: NDArray[np.float64],
144+
node_size: float,
145+
agg: Callable,
146+
title: str,
147+
width: int,
148+
height: int,
149+
cmap: str,
150+
) -> None:
151+
"""
152+
Generates a pyvis network visualization of the Mapper graph and saves it to an HTML file.
153+
154+
:param mapper_plot: The Mapper plot object containing the graph and positions.
155+
:param output_file: The path to the output HTML file.
156+
:param colors: A 2D array of colors for the nodes.
157+
:param node_size: The size of the nodes in the graph.
158+
:param agg: A callable function to aggregate the graph.
159+
:param title: The title for the colorbar.
160+
:param width: The width of the network visualization.
161+
:param height: The height of the network visualization.
162+
:param cmap: The colormap to use for the nodes.
163+
"""
142164
net, cmin, cmax = _compute_net(
143165
mapper_plot=mapper_plot,
144166
width=width,
@@ -150,19 +172,19 @@ def plot_pyvis(
150172
)
151173
colorbar = _colorbar(height=height, cmap=cmap, cmin=cmin, cmax=cmax, title=title)
152174
combined_html = _combine(net, colorbar)
153-
with open(output_file, "w") as file:
175+
with open(output_file, "w", encoding="utf-8") as file:
154176
file.write(combined_html)
155177

156178

157179
def _compute_net(
158180
mapper_plot,
159-
colors,
160-
node_size,
161-
agg,
162-
width,
163-
height,
164-
cmap,
165-
):
181+
colors: NDArray[np.float64],
182+
node_size: float,
183+
agg: Callable,
184+
width: int,
185+
height: int,
186+
cmap: str,
187+
) -> tuple[Network, float, float]:
166188
net = Network(
167189
height=f"{height}px",
168190
width=f"{width}px",

0 commit comments

Comments
 (0)