Skip to content

Commit fa90e0c

Browse files
committed
Added docstrings. Fixed types
1 parent 6059a0a commit fa90e0c

2 files changed

Lines changed: 45 additions & 16 deletions

File tree

src/tdamapper/plot_backends/plot_matplotlib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,6 @@ def _plot_edges(mapper_plot: MapperPlotType, ax: Axes) -> None:
133133

134134

135135
def _node_pos_array(
136-
graph: nx.Graph, dim: int, positions: NDArray[np.float_]
137-
) -> tuple[list[NDArray[np.float_]], ...]:
136+
graph: nx.Graph, dim: int, positions: dict[int, tuple[float, ...]]
137+
) -> tuple[list[float], ...]:
138138
return tuple([positions[n][i] for n in graph.nodes()] for i in range(dim))

src/tdamapper/plot_backends/plot_plotly.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def plot(
315315
316316
:return: A Plotly Figure object containing the Mapper graph.
317317
"""
318-
self.set_figure(
318+
self._set_figure(
319319
node_sizes=node_sizes,
320320
colors=colors,
321321
titles=titles,
@@ -357,7 +357,7 @@ def _marker_size(self) -> list[float]:
357357
marker_size = [factor * attr_size[n] / max_size for n in self.graph.nodes()]
358358
return marker_size
359359

360-
def set_cmap(self, cmap: str) -> None:
360+
def _set_cmap(self, cmap: str) -> None:
361361
if self.fig is None:
362362
return
363363
cmap_rgb = _get_cmap_rgb(cmap)
@@ -426,7 +426,7 @@ def _set_colors(self, colors: NDArray[np.float_], agg: Callable[..., Any]) -> No
426426
selector=dict(name=_EDGES_TRACE),
427427
)
428428

429-
def set_title(self, color_name: str) -> None:
429+
def _set_title(self, color_name: str) -> None:
430430
if self.fig is None:
431431
return
432432
self.fig.update_traces(
@@ -436,7 +436,7 @@ def set_title(self, color_name: str) -> None:
436436
selector=dict(name=_NODES_TRACE),
437437
)
438438

439-
def set_node_size(self, node_size: float) -> None:
439+
def _set_node_size(self, node_size: float) -> None:
440440
if self.fig is None:
441441
return
442442
self.fig.update_traces(
@@ -450,21 +450,21 @@ def set_node_size(self, node_size: float) -> None:
450450
selector=dict(name=_NODES_TRACE),
451451
)
452452

453-
def set_width(self, width: int) -> None:
453+
def _set_width(self, width: int) -> None:
454454
if self.fig is None:
455455
return
456456
self.fig.update_layout(
457457
width=width,
458458
)
459459

460-
def set_height(self, height: int) -> None:
460+
def _set_height(self, height: int) -> None:
461461
if self.fig is None:
462462
return
463463
self.fig.update_layout(
464464
height=height,
465465
)
466466

467-
def set_figure(
467+
def _set_figure(
468468
self,
469469
node_sizes: list[float],
470470
colors: NDArray[np.float_],
@@ -501,18 +501,34 @@ def update_figure(
501501
width: Optional[int] = None,
502502
height: Optional[int] = None,
503503
) -> None:
504+
"""
505+
Update the Plotly figure with new parameters.
506+
507+
:param titles: A list of titles for the colormap.
508+
:param node_sizes: A list of scaling factors for node size.
509+
:param colors: An array of values that determine the color of each
510+
node in the graph, useful for highlighting different features of
511+
the data.
512+
:param agg: A function used to aggregate the `colors` array over the
513+
points within a single node. The final color of each node is
514+
obtained by mapping the aggregated value with the colormap `cmap`.
515+
:param cmaps: A list of colormap names used to map `colors` data values,
516+
aggregated by `agg`, to actual RGBA colors.
517+
:param width: The desired width of the figure in pixels.
518+
:param height: The desired height of the figure in pixels.
519+
"""
504520
if width is not None:
505-
self.set_width(width)
521+
self._set_width(width)
506522
if height is not None:
507-
self.set_height(height)
523+
self._set_height(height)
508524
if titles is not None:
509-
self.set_title(titles[0])
525+
self._set_title(titles[0])
510526
if node_sizes is not None:
511-
self.set_node_size(node_sizes[len(node_sizes) // 2])
527+
self._set_node_size(node_sizes[len(node_sizes) // 2])
512528
if (colors is not None) and (agg is not None):
513529
self._set_colors(colors[:, 0], agg)
514530
if cmaps is not None:
515-
self.set_cmap(cmaps[0])
531+
self._set_cmap(cmaps[0])
516532

517533
def _nodes_trace(
518534
self, node_pos_arr: tuple[list[float], ...]
@@ -655,6 +671,19 @@ def set_ui(
655671
agg: Optional[Callable[..., Any]],
656672
node_sizes: Optional[list[float]],
657673
) -> None:
674+
"""
675+
Set the UI elements for the Plotly figure.
676+
677+
:param cmaps: A list of colormap names to be used in the UI.
678+
:param colors: An array of values that determine the color of each
679+
node in the graph, useful for highlighting different features of
680+
the data.
681+
:param titles: A list of titles for the colormap.
682+
:param agg: A function used to aggregate the `colors` array over the
683+
points within a single node. The final color of each node is
684+
obtained by mapping the aggregated value with the colormap `cmap`.
685+
:param node_sizes: A list of scaling factors for node size.
686+
"""
658687
if self.fig is None:
659688
return
660689

@@ -673,7 +702,7 @@ def set_ui(
673702
if s.name == SLIDER_NODE_SIZE_NAME:
674703
ui_slider_size = s
675704

676-
ui_menu_dark_mode = self.ui_menu_dark_mode()
705+
ui_menu_dark_mode = self._ui_menu_dark_mode()
677706

678707
if cmaps is not None:
679708
cmaps_plotly = [PLOTLY_CMAPS.get(c.lower()) for c in cmaps]
@@ -697,7 +726,7 @@ def set_ui(
697726
sliders=sliders,
698727
)
699728

700-
def ui_menu_dark_mode(self) -> dict[str, Any]:
729+
def _ui_menu_dark_mode(self) -> dict[str, Any]:
701730
buttons = [
702731
dict(
703732
label="Light",

0 commit comments

Comments
 (0)