Skip to content

Commit 28a58bd

Browse files
committed
Fixed colorbar
1 parent ef0351d commit 28a58bd

2 files changed

Lines changed: 15 additions & 15 deletions

File tree

app/nicegui_app.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pandas as pd
23
import plotly.graph_objs as go
34
from nicegui import run, ui
45
from sklearn.cluster import DBSCAN, AgglomerativeClustering, KMeans
@@ -388,14 +389,17 @@ def update_plot(self):
388389
iterations=iterations,
389390
seed=42,
390391
)
392+
colors = pd.concat([self.labels, self.X], axis=1)
393+
colors_arr = colors.to_numpy()
394+
color_names = colors.columns.tolist()
391395
mapper_fig = mapper_plot.plot_plotly(
392-
colors=self.labels,
396+
colors=colors_arr,
393397
cmap=["jet", "viridis", "cividis"],
394398
agg=mode,
395-
title="mode of digits",
399+
title=color_names,
396400
width=800,
397401
height=800,
398-
node_size=0.5,
402+
node_size=list(0.125 * x for x in range(17)),
399403
)
400404
mapper_fig.layout.width = None
401405
mapper_fig.layout.autosize = True

src/tdamapper/plot_backends/plot_plotly.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -284,10 +284,10 @@ def _set_colors(mapper_plot, fig: go.Figure, colors, agg):
284284
)
285285

286286

287-
def _set_title(mapper_plot, fig: go.Figure, color_name: str):
287+
def _set_title(fig: go.Figure, color_name: str):
288288
fig.update_traces(
289289
patch=dict(
290-
marker_colorbar=_colorbar(mapper_plot, color_name),
290+
marker_colorbar=_colorbar(color_name),
291291
),
292292
selector=dict(name=_NODES_TRACE),
293293
)
@@ -371,7 +371,7 @@ def _update(
371371
if height is not None:
372372
_set_height(fig, height)
373373
if titles is not None:
374-
_set_title(mapper_plot, fig, titles[0])
374+
_set_title(fig, titles[0])
375375
if node_sizes is not None:
376376
_set_node_size(mapper_plot, fig, node_sizes[len(node_sizes) // 2])
377377
if (colors is not None) and (agg is not None):
@@ -398,7 +398,7 @@ def _nodes_trace(mapper_plot, node_pos_arr):
398398
line_color=_NODE_OUTER_COLOR,
399399
line_colorscale=DEFAULT_CMAP,
400400
colorscale=DEFAULT_CMAP,
401-
colorbar=_colorbar(mapper_plot, DEFAULT_TITLE),
401+
colorbar=_colorbar(DEFAULT_TITLE),
402402
),
403403
)
404404
if mapper_plot.dim == 3:
@@ -437,9 +437,7 @@ def _edges_trace(mapper_plot, edge_pos_arr):
437437
return go.Scatter(scatter)
438438

439439

440-
def _colorbar(
441-
mapper_plot, title: str
442-
) -> Union[go.scatter3d.marker.ColorBar, go.scatter.marker.ColorBar]:
440+
def _colorbar(title: str) -> dict:
443441
cbar = dict(
444442
showticklabels=True,
445443
outlinewidth=1,
@@ -458,10 +456,7 @@ def _colorbar(
458456
)
459457
if title is not None:
460458
cbar["title"] = title
461-
if mapper_plot.dim == 3:
462-
return go.scatter3d.marker.ColorBar(cbar)
463-
elif mapper_plot.dim == 2:
464-
return go.scatter.marker.ColorBar(cbar)
459+
return cbar
465460

466461

467462
def _text(mapper_plot, colors):
@@ -635,7 +630,7 @@ def _update_colors(i: int) -> dict:
635630
arr_agg = _colors_agg(i)
636631
arr = list(arr_agg.values())
637632
scatter_text = _text(mapper_plot, arr_agg)
638-
cbar = _colorbar(mapper_plot, titles[i])
633+
cbar = _colorbar(titles[i])
639634
if mapper_plot.dim == 2:
640635
return {
641636
"text": [scatter_text],
@@ -656,6 +651,7 @@ def _update_colors(i: int) -> dict:
656651
"line.cmax": [max(arr_edge, default=None), None],
657652
"line.cmin": [min(arr_edge, default=None), None],
658653
}
654+
return {}
659655

660656
target_traces = [1] if mapper_plot.dim == 2 else [0, 1]
661657

0 commit comments

Comments
 (0)