@@ -60,7 +60,9 @@ def __init__(self):
6060 self .menu_color = None
6161 self .slider_size = None
6262
63- def set_menu_cmap (self , mapper_plot , cmaps ):
63+ def set_menu_cmap (self , mapper_plot , cmaps : Optional [List [str ]]) -> None :
64+ if cmaps is None :
65+ return
6466 cmaps_plotly = [PLOTLY_CMAPS .get (c .lower ()) for c in cmaps ]
6567 self .menu_cmap = _ui_cmap (mapper_plot , cmaps_plotly )
6668
@@ -71,8 +73,10 @@ def set_slider_size(self, mapper_plot, node_sizes):
7173 self .slider_size = _ui_node_size (mapper_plot , node_sizes )
7274
7375
74- def _to_cmaps (cmap : Union [str , List [str ]]) -> List [str ]:
76+ def _to_cmaps (cmap : Optional [ Union [str , List [str ] ]]) -> List [str ]:
7577 """Convert a single cmap or a list of cmaps to a list of cmaps."""
78+ if cmap is None :
79+ return [DEFAULT_CMAP ]
7680 if isinstance (cmap , str ):
7781 return [cmap ]
7882 elif isinstance (cmap , list ):
@@ -94,11 +98,11 @@ def _to_colors(colors: Union[np.ndarray, List[float]]) -> np.ndarray:
9498 )
9599
96100
97- def _to_titles (title , colors_num ) :
101+ def _to_titles (title : Optional [ Union [ str , List [ str ]]], colors_num : int ) -> List [ str ] :
98102 if title is None :
99- return [DEFAULT_TITLE for _ in range (colors_num )]
103+ return [f" { i } " for i in range (colors_num )]
100104 elif isinstance (title , str ):
101- return [title for _ in range (colors_num )]
105+ return [f" { title } { i } " for i in range (colors_num )]
102106 elif isinstance (title , list ) and len (title ) == colors_num :
103107 return title
104108 else :
@@ -107,7 +111,9 @@ def _to_titles(title, colors_num):
107111 )
108112
109113
110- def _to_node_sizes (node_size ):
114+ def _to_node_sizes (
115+ node_size : Optional [Union [int , float , List [Union [int , float ]]]]
116+ ) -> List [float ]:
111117 if isinstance (node_size , (int , float )):
112118 return [node_size ]
113119 elif isinstance (node_size , list ):
@@ -123,11 +129,11 @@ def plot_plotly(
123129 mapper_plot ,
124130 width : int ,
125131 height : int ,
126- node_size : Optional [ Union [int , float , List [Union [ int , float ]]]] = DEFAULT_NODE_SIZE ,
127- colors = None ,
132+ colors : Union [np . ndarray , List [float ]],
133+ node_size : Optional [ Union [ int , float , List [ Union [ int , float ]]]] = None ,
128134 title : Optional [Union [str , List [str ]]] = None ,
129135 agg = np .nanmean ,
130- cmap : Union [str , List [str ]] = DEFAULT_CMAP ,
136+ cmap : Optional [ Union [str , List [str ]]] = None ,
131137) -> go .Figure :
132138 cmaps = _to_cmaps (cmap )
133139 colors = _to_colors (colors )
@@ -187,7 +193,7 @@ def plot_plotly_update(
187193 return fig
188194
189195
190- def _node_pos_array (graph , dim , node_pos ):
196+ def _node_pos_array (graph : nx . Graph , dim : int , node_pos ):
191197 return tuple ([node_pos [n ][i ] for n in graph .nodes ()] for i in range (dim ))
192198
193199
@@ -202,7 +208,7 @@ def _edge_pos_array(graph, dim, node_pos):
202208 return edges_arr
203209
204210
205- def _marker_size (mapper_plot , node_size ) :
211+ def _marker_size (mapper_plot , node_size : float ) -> List [ float ] :
206212 attr_size = nx .get_node_attributes (mapper_plot .graph , ATTR_SIZE )
207213 max_size = max (attr_size .values (), default = 1.0 )
208214 scale = node_size * (25.0 if mapper_plot .dim == 2 else 15.0 )
@@ -212,14 +218,14 @@ def _marker_size(mapper_plot, node_size):
212218 return marker_size
213219
214220
215- def _get_cmap_rgb (cmap ):
221+ def _get_cmap_rgb (cmap : str ):
216222 """Return a colorscale in [[float, 'rgb(r,g,b)']] format."""
217223 base_scale = pc .get_colorscale (cmap )
218224 # If it's already in [float, color] format, we're good
219225 return [[pos , color ] for pos , color in base_scale ]
220226
221227
222- def _set_cmap (mapper_plot , fig , cmap ) :
228+ def _set_cmap (mapper_plot , fig : go . Figure , cmap : str ) -> None :
223229 cmap_rgb = _get_cmap_rgb (cmap )
224230 fig .update_traces (
225231 patch = dict (
@@ -244,7 +250,7 @@ def _set_cmap(mapper_plot, fig, cmap):
244250 )
245251
246252
247- def _set_colors (mapper_plot , fig , colors , agg ):
253+ def _set_colors (mapper_plot , fig : go . Figure , colors , agg ):
248254 node_col = aggregate_graph (colors , mapper_plot .graph , agg )
249255 scatter_text = _text (mapper_plot , node_col )
250256 colors_arr = list (node_col .values ())
@@ -278,7 +284,7 @@ def _set_colors(mapper_plot, fig, colors, agg):
278284 )
279285
280286
281- def _set_title (mapper_plot , fig , color_name ):
287+ def _set_title (mapper_plot , fig : go . Figure , color_name : str ):
282288 fig .update_traces (
283289 patch = dict (
284290 marker_colorbar = _colorbar (mapper_plot , color_name ),
@@ -287,7 +293,7 @@ def _set_title(mapper_plot, fig, color_name):
287293 )
288294
289295
290- def _set_node_size (mapper_plot , fig , node_size ) :
296+ def _set_node_size (mapper_plot , fig : go . Figure , node_size : float ) -> None :
291297 fig .update_traces (
292298 patch = dict (
293299 marker_size = _marker_size (mapper_plot , node_size ),
@@ -296,19 +302,28 @@ def _set_node_size(mapper_plot, fig, node_size):
296302 )
297303
298304
299- def _set_width (fig , width ) :
305+ def _set_width (fig : go . Figure , width : int ) -> None :
300306 fig .update_layout (
301307 width = width ,
302308 )
303309
304310
305- def _set_height (fig , height ) :
311+ def _set_height (fig : go . Figure , height : int ) -> None :
306312 fig .update_layout (
307313 height = height ,
308314 )
309315
310316
311- def _figure (mapper_plot , width , height , node_sizes , colors , titles , agg , cmaps ):
317+ def _figure (
318+ mapper_plot ,
319+ width : int ,
320+ height : int ,
321+ node_sizes : List [float ],
322+ colors : np .ndarray ,
323+ titles : List [str ],
324+ agg ,
325+ cmaps : List [str ],
326+ ) -> go .Figure :
312327 node_pos = mapper_plot .positions
313328 node_pos_arr = _node_pos_array (
314329 mapper_plot .graph ,
@@ -346,7 +361,7 @@ def _update(
346361 width : Optional [int ] = None ,
347362 height : Optional [int ] = None ,
348363 titles : Optional [List [str ]] = None ,
349- node_sizes : Optional [List [int ]] = None ,
364+ node_sizes : Optional [List [float ]] = None ,
350365 colors = None ,
351366 agg = None ,
352367 cmaps : Optional [List [str ]] = None ,
@@ -422,7 +437,9 @@ def _edges_trace(mapper_plot, edge_pos_arr):
422437 return go .Scatter (scatter )
423438
424439
425- def _colorbar (mapper_plot , title ):
440+ def _colorbar (
441+ mapper_plot , title : str
442+ ) -> Union [go .scatter3d .marker .ColorBar , go .scatter .marker .ColorBar ]:
426443 cbar = dict (
427444 showticklabels = True ,
428445 outlinewidth = 1 ,
@@ -463,7 +480,7 @@ def _fmt(x, max_len=3):
463480 return f"{ x :{fmt }} "
464481
465482
466- def _layout ():
483+ def _layout () -> go . Layout :
467484 line_col = "rgba(230, 230, 230, 1.0)"
468485 axis = dict (
469486 showline = False ,
@@ -506,7 +523,7 @@ def _layout():
506523 )
507524
508525
509- def _set_ui (mapper_fig , plotly_ui : PlotlyUI ):
526+ def _set_ui (mapper_fig : go . Figure , plotly_ui : PlotlyUI ) -> None :
510527 menus = []
511528 sliders = []
512529 x = 0.0
@@ -526,10 +543,10 @@ def _set_ui(mapper_fig, plotly_ui: PlotlyUI):
526543 )
527544
528545
529- def _ui_cmap (mapper_plot , cmaps ) :
546+ def _ui_cmap (mapper_plot , cmaps : List [ str ]) -> dict :
530547 target_traces = [1 ] if mapper_plot .dim == 2 else [0 , 1 ]
531548
532- def _update_cmap (cmap ) :
549+ def _update_cmap (cmap : str ) -> dict :
533550 cmap_rgb = _get_cmap_rgb (cmap )
534551 if mapper_plot .dim == 2 :
535552 return {
@@ -542,6 +559,7 @@ def _update_cmap(cmap):
542559 "marker.line.colorscale" : [None , cmap_rgb ],
543560 "line.colorscale" : [cmap_rgb , None ],
544561 }
562+ return {}
545563
546564 buttons = []
547565 if len (cmaps ) > 1 :
@@ -564,7 +582,7 @@ def _update_cmap(cmap):
564582 )
565583
566584
567- def _ui_node_size (mapper_plot , node_sizes ) :
585+ def _ui_node_size (mapper_plot , node_sizes : List [ float ]) -> dict :
568586 steps = [
569587 dict (
570588 method = "restyle" ,
@@ -589,21 +607,21 @@ def _ui_node_size(mapper_plot, node_sizes):
589607 )
590608
591609
592- def _ui_color (mapper_plot , colors , titles , agg ):
610+ def _ui_color (mapper_plot , colors , titles : List [ str ] , agg ) -> dict :
593611 colors_arr = np .array (colors )
594612 colors_num = colors_arr .shape [1 ] if colors_arr .ndim == 2 else 1
595613
596- def _colors_agg (i ) :
614+ def _colors_agg (i : int ) -> dict :
597615 if i is None :
598616 arr = colors_arr
599617 else :
600618 arr = colors_arr [:, i ] if colors_arr .ndim == 2 else colors_arr
601619 return aggregate_graph (arr , mapper_plot .graph , agg )
602620
603- def _colors (i ) :
621+ def _colors (i : int ) -> List [ float ] :
604622 return list (_colors_agg (i ).values ())
605623
606- def _edge_colors (i ) :
624+ def _edge_colors (i : int ) -> List [ float ] :
607625 colors_avg = []
608626 colors_agg = _colors_agg (i )
609627 for edge in mapper_plot .graph .edges ():
@@ -613,7 +631,7 @@ def _edge_colors(i):
613631 colors_avg .append (c1 )
614632 return colors_avg
615633
616- def _update_colors (i ) :
634+ def _update_colors (i : int ) -> dict :
617635 arr_agg = _colors_agg (i )
618636 arr = list (arr_agg .values ())
619637 scatter_text = _text (mapper_plot , arr_agg )
0 commit comments