Skip to content

Commit 6887d24

Browse files
committed
Added type hints. Improved handling of missing arguments
1 parent 81ec45d commit 6887d24

1 file changed

Lines changed: 49 additions & 31 deletions

File tree

src/tdamapper/plot_backends/plot_plotly.py

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ def __init__(self):
6060
self.menu_color = None
6161
self.slider_size = None
6262

63-
def set_menu_cmap(self, mapper_plot, cmaps):
63+
def set_menu_cmap(self, mapper_plot, cmaps: Optional[List[str]]) -> None:
64+
if cmaps is None:
65+
return
6466
cmaps_plotly = [PLOTLY_CMAPS.get(c.lower()) for c in cmaps]
6567
self.menu_cmap = _ui_cmap(mapper_plot, cmaps_plotly)
6668

@@ -71,8 +73,10 @@ def set_slider_size(self, mapper_plot, node_sizes):
7173
self.slider_size = _ui_node_size(mapper_plot, node_sizes)
7274

7375

74-
def _to_cmaps(cmap: Union[str, List[str]]) -> List[str]:
76+
def _to_cmaps(cmap: Optional[Union[str, List[str]]]) -> List[str]:
7577
"""Convert a single cmap or a list of cmaps to a list of cmaps."""
78+
if cmap is None:
79+
return [DEFAULT_CMAP]
7680
if isinstance(cmap, str):
7781
return [cmap]
7882
elif isinstance(cmap, list):
@@ -94,11 +98,11 @@ def _to_colors(colors: Union[np.ndarray, List[float]]) -> np.ndarray:
9498
)
9599

96100

97-
def _to_titles(title, colors_num):
101+
def _to_titles(title: Optional[Union[str, List[str]]], colors_num: int) -> List[str]:
98102
if title is None:
99-
return [DEFAULT_TITLE for _ in range(colors_num)]
103+
return [f"{i}" for i in range(colors_num)]
100104
elif isinstance(title, str):
101-
return [title for _ in range(colors_num)]
105+
return [f"{title} {i}" for i in range(colors_num)]
102106
elif isinstance(title, list) and len(title) == colors_num:
103107
return title
104108
else:
@@ -107,7 +111,9 @@ def _to_titles(title, colors_num):
107111
)
108112

109113

110-
def _to_node_sizes(node_size):
114+
def _to_node_sizes(
115+
node_size: Optional[Union[int, float, List[Union[int, float]]]]
116+
) -> List[float]:
111117
if isinstance(node_size, (int, float)):
112118
return [node_size]
113119
elif isinstance(node_size, list):
@@ -123,11 +129,11 @@ def plot_plotly(
123129
mapper_plot,
124130
width: int,
125131
height: int,
126-
node_size: Optional[Union[int, float, List[Union[int, float]]]] = DEFAULT_NODE_SIZE,
127-
colors=None,
132+
colors: Union[np.ndarray, List[float]],
133+
node_size: Optional[Union[int, float, List[Union[int, float]]]] = None,
128134
title: Optional[Union[str, List[str]]] = None,
129135
agg=np.nanmean,
130-
cmap: Union[str, List[str]] = DEFAULT_CMAP,
136+
cmap: Optional[Union[str, List[str]]] = None,
131137
) -> go.Figure:
132138
cmaps = _to_cmaps(cmap)
133139
colors = _to_colors(colors)
@@ -187,7 +193,7 @@ def plot_plotly_update(
187193
return fig
188194

189195

190-
def _node_pos_array(graph, dim, node_pos):
196+
def _node_pos_array(graph: nx.Graph, dim: int, node_pos):
191197
return tuple([node_pos[n][i] for n in graph.nodes()] for i in range(dim))
192198

193199

@@ -202,7 +208,7 @@ def _edge_pos_array(graph, dim, node_pos):
202208
return edges_arr
203209

204210

205-
def _marker_size(mapper_plot, node_size):
211+
def _marker_size(mapper_plot, node_size: float) -> List[float]:
206212
attr_size = nx.get_node_attributes(mapper_plot.graph, ATTR_SIZE)
207213
max_size = max(attr_size.values(), default=1.0)
208214
scale = node_size * (25.0 if mapper_plot.dim == 2 else 15.0)
@@ -212,14 +218,14 @@ def _marker_size(mapper_plot, node_size):
212218
return marker_size
213219

214220

215-
def _get_cmap_rgb(cmap):
221+
def _get_cmap_rgb(cmap: str):
216222
"""Return a colorscale in [[float, 'rgb(r,g,b)']] format."""
217223
base_scale = pc.get_colorscale(cmap)
218224
# If it's already in [float, color] format, we're good
219225
return [[pos, color] for pos, color in base_scale]
220226

221227

222-
def _set_cmap(mapper_plot, fig, cmap):
228+
def _set_cmap(mapper_plot, fig: go.Figure, cmap: str) -> None:
223229
cmap_rgb = _get_cmap_rgb(cmap)
224230
fig.update_traces(
225231
patch=dict(
@@ -244,7 +250,7 @@ def _set_cmap(mapper_plot, fig, cmap):
244250
)
245251

246252

247-
def _set_colors(mapper_plot, fig, colors, agg):
253+
def _set_colors(mapper_plot, fig: go.Figure, colors, agg):
248254
node_col = aggregate_graph(colors, mapper_plot.graph, agg)
249255
scatter_text = _text(mapper_plot, node_col)
250256
colors_arr = list(node_col.values())
@@ -278,7 +284,7 @@ def _set_colors(mapper_plot, fig, colors, agg):
278284
)
279285

280286

281-
def _set_title(mapper_plot, fig, color_name):
287+
def _set_title(mapper_plot, fig: go.Figure, color_name: str):
282288
fig.update_traces(
283289
patch=dict(
284290
marker_colorbar=_colorbar(mapper_plot, color_name),
@@ -287,7 +293,7 @@ def _set_title(mapper_plot, fig, color_name):
287293
)
288294

289295

290-
def _set_node_size(mapper_plot, fig, node_size):
296+
def _set_node_size(mapper_plot, fig: go.Figure, node_size: float) -> None:
291297
fig.update_traces(
292298
patch=dict(
293299
marker_size=_marker_size(mapper_plot, node_size),
@@ -296,19 +302,28 @@ def _set_node_size(mapper_plot, fig, node_size):
296302
)
297303

298304

299-
def _set_width(fig, width):
305+
def _set_width(fig: go.Figure, width: int) -> None:
300306
fig.update_layout(
301307
width=width,
302308
)
303309

304310

305-
def _set_height(fig, height):
311+
def _set_height(fig: go.Figure, height: int) -> None:
306312
fig.update_layout(
307313
height=height,
308314
)
309315

310316

311-
def _figure(mapper_plot, width, height, node_sizes, colors, titles, agg, cmaps):
317+
def _figure(
318+
mapper_plot,
319+
width: int,
320+
height: int,
321+
node_sizes: List[float],
322+
colors: np.ndarray,
323+
titles: List[str],
324+
agg,
325+
cmaps: List[str],
326+
) -> go.Figure:
312327
node_pos = mapper_plot.positions
313328
node_pos_arr = _node_pos_array(
314329
mapper_plot.graph,
@@ -346,7 +361,7 @@ def _update(
346361
width: Optional[int] = None,
347362
height: Optional[int] = None,
348363
titles: Optional[List[str]] = None,
349-
node_sizes: Optional[List[int]] = None,
364+
node_sizes: Optional[List[float]] = None,
350365
colors=None,
351366
agg=None,
352367
cmaps: Optional[List[str]] = None,
@@ -422,7 +437,9 @@ def _edges_trace(mapper_plot, edge_pos_arr):
422437
return go.Scatter(scatter)
423438

424439

425-
def _colorbar(mapper_plot, title):
440+
def _colorbar(
441+
mapper_plot, title: str
442+
) -> Union[go.scatter3d.marker.ColorBar, go.scatter.marker.ColorBar]:
426443
cbar = dict(
427444
showticklabels=True,
428445
outlinewidth=1,
@@ -463,7 +480,7 @@ def _fmt(x, max_len=3):
463480
return f"{x:{fmt}}"
464481

465482

466-
def _layout():
483+
def _layout() -> go.Layout:
467484
line_col = "rgba(230, 230, 230, 1.0)"
468485
axis = dict(
469486
showline=False,
@@ -506,7 +523,7 @@ def _layout():
506523
)
507524

508525

509-
def _set_ui(mapper_fig, plotly_ui: PlotlyUI):
526+
def _set_ui(mapper_fig: go.Figure, plotly_ui: PlotlyUI) -> None:
510527
menus = []
511528
sliders = []
512529
x = 0.0
@@ -526,10 +543,10 @@ def _set_ui(mapper_fig, plotly_ui: PlotlyUI):
526543
)
527544

528545

529-
def _ui_cmap(mapper_plot, cmaps):
546+
def _ui_cmap(mapper_plot, cmaps: List[str]) -> dict:
530547
target_traces = [1] if mapper_plot.dim == 2 else [0, 1]
531548

532-
def _update_cmap(cmap):
549+
def _update_cmap(cmap: str) -> dict:
533550
cmap_rgb = _get_cmap_rgb(cmap)
534551
if mapper_plot.dim == 2:
535552
return {
@@ -542,6 +559,7 @@ def _update_cmap(cmap):
542559
"marker.line.colorscale": [None, cmap_rgb],
543560
"line.colorscale": [cmap_rgb, None],
544561
}
562+
return {}
545563

546564
buttons = []
547565
if len(cmaps) > 1:
@@ -564,7 +582,7 @@ def _update_cmap(cmap):
564582
)
565583

566584

567-
def _ui_node_size(mapper_plot, node_sizes):
585+
def _ui_node_size(mapper_plot, node_sizes: List[float]) -> dict:
568586
steps = [
569587
dict(
570588
method="restyle",
@@ -589,21 +607,21 @@ def _ui_node_size(mapper_plot, node_sizes):
589607
)
590608

591609

592-
def _ui_color(mapper_plot, colors, titles, agg):
610+
def _ui_color(mapper_plot, colors, titles: List[str], agg) -> dict:
593611
colors_arr = np.array(colors)
594612
colors_num = colors_arr.shape[1] if colors_arr.ndim == 2 else 1
595613

596-
def _colors_agg(i):
614+
def _colors_agg(i: int) -> dict:
597615
if i is None:
598616
arr = colors_arr
599617
else:
600618
arr = colors_arr[:, i] if colors_arr.ndim == 2 else colors_arr
601619
return aggregate_graph(arr, mapper_plot.graph, agg)
602620

603-
def _colors(i):
621+
def _colors(i: int) -> List[float]:
604622
return list(_colors_agg(i).values())
605623

606-
def _edge_colors(i):
624+
def _edge_colors(i: int) -> List[float]:
607625
colors_avg = []
608626
colors_agg = _colors_agg(i)
609627
for edge in mapper_plot.graph.edges():
@@ -613,7 +631,7 @@ def _edge_colors(i):
613631
colors_avg.append(c1)
614632
return colors_avg
615633

616-
def _update_colors(i):
634+
def _update_colors(i: int) -> dict:
617635
arr_agg = _colors_agg(i)
618636
arr = list(arr_agg.values())
619637
scatter_text = _text(mapper_plot, arr_agg)

0 commit comments

Comments
 (0)