44"""
55
66import math
7+ from typing import Any , Callable
78
9+ import numpy as np
810import plotly .colors as pc
911import plotly .graph_objects as go
1012import plotly .io as pio
13+ from numpy .typing import NDArray
1114from pyvis .network import Network
1215
1316from tdamapper .core import aggregate_graph
1922_TICKS_NUM = 10
2023
2124
22- def _fmt (x , max_len = 3 ) :
25+ def _fmt (x : Any , max_len : int = 3 ) -> str :
2326 fmt = f".{ max_len } g"
2427 return f"{ x :{fmt }} "
2528
2629
27- def _colorbar (height , cmap , cmin , cmax , title ):
30+ def _colorbar (
31+ height : int ,
32+ cmap : str ,
33+ cmin : float ,
34+ cmax : float ,
35+ title : str ,
36+ ) -> go .Figure :
2837 colorbar_fig = go .Figure ()
2938 colorbar_fig .add_trace (
3039 go .Scatter (
@@ -68,7 +77,7 @@ def _colorbar(height, cmap, cmin, cmax, title):
6877 return colorbar_fig
6978
7079
71- def _combine (network , colorbar ) :
80+ def _combine (network : Network , colorbar : go . Figure ) -> str :
7281 network_html = network .generate_html ()
7382 colorbar_html = pio .to_html (
7483 colorbar ,
@@ -130,15 +139,28 @@ def _combine(network, colorbar):
130139
131140def plot_pyvis (
132141 mapper_plot ,
133- output_file ,
134- colors ,
135- node_size ,
136- agg ,
137- title ,
138- width ,
139- height ,
140- cmap ,
141- ):
142+ output_file : str ,
143+ colors : NDArray [np .float64 ],
144+ node_size : float ,
145+ agg : Callable ,
146+ title : str ,
147+ width : int ,
148+ height : int ,
149+ cmap : str ,
150+ ) -> None :
151+ """
152+ Generates a pyvis network visualization of the Mapper graph and saves it to an HTML file.
153+
154+ :param mapper_plot: The Mapper plot object containing the graph and positions.
155+ :param output_file: The path to the output HTML file.
156+ :param colors: A 2D array of colors for the nodes.
157+ :param node_size: The size of the nodes in the graph.
158+ :param agg: A callable function to aggregate the graph.
159+ :param title: The title for the colorbar.
160+ :param width: The width of the network visualization.
161+ :param height: The height of the network visualization.
162+ :param cmap: The colormap to use for the nodes.
163+ """
142164 net , cmin , cmax = _compute_net (
143165 mapper_plot = mapper_plot ,
144166 width = width ,
@@ -150,19 +172,19 @@ def plot_pyvis(
150172 )
151173 colorbar = _colorbar (height = height , cmap = cmap , cmin = cmin , cmax = cmax , title = title )
152174 combined_html = _combine (net , colorbar )
153- with open (output_file , "w" ) as file :
175+ with open (output_file , "w" , encoding = "utf-8" ) as file :
154176 file .write (combined_html )
155177
156178
157179def _compute_net (
158180 mapper_plot ,
159- colors ,
160- node_size ,
161- agg ,
162- width ,
163- height ,
164- cmap ,
165- ):
181+ colors : NDArray [ np . float64 ] ,
182+ node_size : float ,
183+ agg : Callable ,
184+ width : int ,
185+ height : int ,
186+ cmap : str ,
187+ ) -> tuple [ Network , float , float ] :
166188 net = Network (
167189 height = f"{ height } px" ,
168190 width = f"{ width } px" ,
0 commit comments