|
1 | 1 | """ pyplots.ai |
2 | 2 | dendrogram-basic: Basic Dendrogram |
3 | | -Library: bokeh 3.8.1 | Python 3.13.11 |
4 | | -Quality: 91/100 | Created: 2025-12-23 |
| 3 | +Library: bokeh 3.8.2 | Python 3.14.3 |
| 4 | +Quality: 90/100 | Updated: 2026-04-05 |
5 | 5 | """ |
6 | 6 |
|
7 | 7 | import numpy as np |
8 | 8 | from bokeh.io import export_png |
9 | | -from bokeh.models import Label |
| 9 | +from bokeh.models import ColumnDataSource, FixedTicker, HoverTool, Label, Span |
10 | 10 | from bokeh.plotting import figure, output_file, save |
11 | 11 | from scipy.cluster.hierarchy import leaves_list, linkage |
12 | 12 |
|
13 | 13 |
|
14 | 14 | # Data - Iris flower measurements (4 features for 15 samples) |
15 | 15 | np.random.seed(42) |
16 | 16 |
|
17 | | -# Simulate iris-like measurements: sepal length, sepal width, petal length, petal width |
18 | | -# Three species with distinct characteristics |
19 | 17 | samples_per_species = 5 |
20 | 18 |
|
21 | 19 | labels = [] |
|
26 | 24 | labels.append(f"Setosa-{i + 1}") |
27 | 25 | data.append( |
28 | 26 | [ |
29 | | - 5.0 + np.random.randn() * 0.3, # sepal length |
30 | | - 3.4 + np.random.randn() * 0.3, # sepal width |
31 | | - 1.5 + np.random.randn() * 0.2, # petal length |
32 | | - 0.3 + np.random.randn() * 0.1, # petal width |
| 27 | + 5.0 + np.random.randn() * 0.3, |
| 28 | + 3.4 + np.random.randn() * 0.3, |
| 29 | + 1.5 + np.random.randn() * 0.2, |
| 30 | + 0.3 + np.random.randn() * 0.1, |
33 | 31 | ] |
34 | 32 | ) |
35 | 33 |
|
|
38 | 36 | labels.append(f"Versicolor-{i + 1}") |
39 | 37 | data.append( |
40 | 38 | [ |
41 | | - 5.9 + np.random.randn() * 0.4, # sepal length |
42 | | - 2.8 + np.random.randn() * 0.3, # sepal width |
43 | | - 4.3 + np.random.randn() * 0.4, # petal length |
44 | | - 1.3 + np.random.randn() * 0.2, # petal width |
| 39 | + 5.9 + np.random.randn() * 0.4, |
| 40 | + 2.8 + np.random.randn() * 0.3, |
| 41 | + 4.3 + np.random.randn() * 0.4, |
| 42 | + 1.3 + np.random.randn() * 0.2, |
45 | 43 | ] |
46 | 44 | ) |
47 | 45 |
|
|
50 | 48 | labels.append(f"Virginica-{i + 1}") |
51 | 49 | data.append( |
52 | 50 | [ |
53 | | - 6.6 + np.random.randn() * 0.5, # sepal length |
54 | | - 3.0 + np.random.randn() * 0.3, # sepal width |
55 | | - 5.5 + np.random.randn() * 0.5, # petal length |
56 | | - 2.0 + np.random.randn() * 0.3, # petal width |
| 51 | + 6.6 + np.random.randn() * 0.5, |
| 52 | + 3.0 + np.random.randn() * 0.3, |
| 53 | + 5.5 + np.random.randn() * 0.5, |
| 54 | + 2.0 + np.random.randn() * 0.3, |
57 | 55 | ] |
58 | 56 | ) |
59 | 57 |
|
|
68 | 66 | ordered_labels = [labels[i] for i in leaf_order] |
69 | 67 |
|
70 | 68 | # Build dendrogram structure manually |
71 | | -# Position of each node (leaf nodes get integer positions) |
72 | 69 | node_positions = {} |
73 | 70 | for idx, leaf_idx in enumerate(leaf_order): |
74 | 71 | node_positions[leaf_idx] = idx |
75 | 72 |
|
| 73 | +# Track cluster members for hover info |
| 74 | +cluster_members = {} |
| 75 | +for i in range(n_samples): |
| 76 | + cluster_members[i] = [labels[i]] |
| 77 | + |
76 | 78 | # Color threshold for distinguishing clusters |
77 | 79 | max_dist = linkage_matrix[:, 2].max() |
78 | 80 | color_threshold = 0.7 * max_dist |
79 | 81 |
|
80 | | -# Collect line segments for drawing |
81 | | -line_xs = [] |
82 | | -line_ys = [] |
83 | | -line_colors = [] |
| 82 | +# Colorblind-safe palette |
| 83 | +colors_within = "#0F7B6C" # teal for within-cluster |
| 84 | +colors_between = "#C0392B" # warm red for between-cluster (cross-species merges) |
| 85 | + |
| 86 | +# Collect line segments with hover metadata |
| 87 | +all_xs, all_ys = [], [] |
| 88 | +all_colors = [] |
| 89 | +all_distances = [] |
| 90 | +all_left_items = [] |
| 91 | +all_right_items = [] |
| 92 | +all_cluster_sizes = [] |
84 | 93 |
|
85 | | -# Process each merge in the linkage matrix |
86 | | -for i, (left, right, dist, _) in enumerate(linkage_matrix): |
| 94 | +for i, (left, right, dist, count) in enumerate(linkage_matrix): |
87 | 95 | left, right = int(left), int(right) |
88 | 96 | new_node = n_samples + i |
89 | 97 |
|
90 | | - # Get x positions of children |
91 | 98 | left_x = node_positions[left] |
92 | 99 | right_x = node_positions[right] |
| 100 | + left_y = 0 if left < n_samples else linkage_matrix[left - n_samples, 2] |
| 101 | + right_y = 0 if right < n_samples else linkage_matrix[right - n_samples, 2] |
93 | 102 |
|
94 | | - # Get y positions (heights) of children |
95 | | - if left < n_samples: |
96 | | - left_y = 0 |
97 | | - else: |
98 | | - left_y = linkage_matrix[left - n_samples, 2] |
99 | | - |
100 | | - if right < n_samples: |
101 | | - right_y = 0 |
102 | | - else: |
103 | | - right_y = linkage_matrix[right - n_samples, 2] |
104 | | - |
105 | | - # New node position is midpoint of children |
106 | 103 | new_x = (left_x + right_x) / 2 |
107 | 104 | node_positions[new_node] = new_x |
108 | 105 |
|
109 | | - # Determine color based on threshold |
110 | | - color = "#306998" if dist > color_threshold else "#FFD43B" |
| 106 | + # Track members |
| 107 | + left_members = cluster_members[left] |
| 108 | + right_members = cluster_members[right] |
| 109 | + cluster_members[new_node] = left_members + right_members |
| 110 | + |
| 111 | + # U-shaped connector: left vertical, horizontal, right vertical |
| 112 | + xs = [left_x, left_x, right_x, right_x] |
| 113 | + ys = [left_y, dist, dist, right_y] |
111 | 114 |
|
112 | | - # Draw left vertical line |
113 | | - line_xs.append([left_x, left_x]) |
114 | | - line_ys.append([left_y, dist]) |
115 | | - line_colors.append(color) |
| 115 | + color = colors_between if dist > color_threshold else colors_within |
116 | 116 |
|
117 | | - # Draw right vertical line |
118 | | - line_xs.append([right_x, right_x]) |
119 | | - line_ys.append([right_y, dist]) |
120 | | - line_colors.append(color) |
| 117 | + all_xs.append(xs) |
| 118 | + all_ys.append(ys) |
| 119 | + all_colors.append(color) |
| 120 | + all_distances.append(f"{dist:.2f}") |
| 121 | + all_left_items.append(", ".join(left_members[:3]) + ("..." if len(left_members) > 3 else "")) |
| 122 | + all_right_items.append(", ".join(right_members[:3]) + ("..." if len(right_members) > 3 else "")) |
| 123 | + all_cluster_sizes.append(str(int(count))) |
121 | 124 |
|
122 | | - # Draw horizontal line connecting the two |
123 | | - line_xs.append([left_x, right_x]) |
124 | | - line_ys.append([dist, dist]) |
125 | | - line_colors.append(color) |
| 125 | +# Apply sqrt scaling to y-axis for better visibility of lower merges |
| 126 | +sqrt_max = np.sqrt(max_dist) |
126 | 127 |
|
127 | | -# Create figure with extra space at bottom for labels |
| 128 | +all_ys_scaled = [] |
| 129 | +for ys in all_ys: |
| 130 | + all_ys_scaled.append([np.sqrt(y) for y in ys]) |
| 131 | + |
| 132 | +# Plot |
128 | 133 | p = figure( |
129 | 134 | width=4800, |
130 | 135 | height=2700, |
131 | | - title="dendrogram-basic · bokeh · pyplots.ai", |
132 | | - x_axis_label="Sample", |
133 | | - y_axis_label="Distance (Ward)", |
134 | | - x_range=(-0.5, n_samples - 0.5), |
135 | | - y_range=(-max_dist * 0.18, max_dist * 1.1), |
| 136 | + title="dendrogram-basic \u00b7 bokeh \u00b7 pyplots.ai", |
| 137 | + x_axis_label="Iris Sample", |
| 138 | + y_axis_label="Distance (Ward\u2019s Method, \u221a scale)", |
| 139 | + x_range=(-0.8, n_samples - 0.2), |
| 140 | + y_range=(-sqrt_max * 0.02, sqrt_max * 1.12), |
136 | 141 | toolbar_location=None, |
| 142 | + min_border_bottom=220, |
137 | 143 | ) |
138 | 144 |
|
139 | | -# Draw dendrogram lines with thicker lines for visibility |
140 | | -for xs, ys, color in zip(line_xs, line_ys, line_colors, strict=True): |
141 | | - p.line(xs, ys, line_width=4, line_color=color) |
142 | | - |
143 | | -# Add leaf labels with larger font |
144 | | -for idx, label in enumerate(ordered_labels): |
145 | | - label_obj = Label( |
146 | | - x=idx, |
147 | | - y=-max_dist * 0.02, |
148 | | - text=label, |
149 | | - text_font_size="20pt", |
150 | | - text_align="right", |
151 | | - angle=0.785, # 45 degrees in radians |
152 | | - angle_units="rad", |
153 | | - y_offset=-15, |
154 | | - ) |
155 | | - p.add_layout(label_obj) |
| 145 | +# Draw dendrogram branches using multi_line with ColumnDataSource and hover data |
| 146 | +source = ColumnDataSource( |
| 147 | + data={ |
| 148 | + "xs": all_xs, |
| 149 | + "ys": all_ys_scaled, |
| 150 | + "color": all_colors, |
| 151 | + "distance": all_distances, |
| 152 | + "left_cluster": all_left_items, |
| 153 | + "right_cluster": all_right_items, |
| 154 | + "cluster_size": all_cluster_sizes, |
| 155 | + } |
| 156 | +) |
| 157 | + |
| 158 | +branch_renderer = p.multi_line( |
| 159 | + xs="xs", |
| 160 | + ys="ys", |
| 161 | + source=source, |
| 162 | + line_width=4, |
| 163 | + line_color="color", |
| 164 | + line_alpha=0.85, |
| 165 | + hover_line_width=7, |
| 166 | + hover_line_alpha=1.0, |
| 167 | + hover_line_color="#E74C3C", |
| 168 | +) |
156 | 169 |
|
157 | | -# Style - larger fonts for 4800x2700 canvas |
158 | | -p.title.text_font_size = "32pt" |
| 170 | +# Add HoverTool for interactive branch inspection |
| 171 | +hover = HoverTool( |
| 172 | + renderers=[branch_renderer], |
| 173 | + tooltips=[ |
| 174 | + ("Merge Distance", "@distance"), |
| 175 | + ("Cluster Size", "@cluster_size items"), |
| 176 | + ("Left", "@left_cluster"), |
| 177 | + ("Right", "@right_cluster"), |
| 178 | + ], |
| 179 | + line_policy="interp", |
| 180 | +) |
| 181 | +p.add_tools(hover) |
| 182 | + |
| 183 | +# Cluster threshold line for visual storytelling |
| 184 | +threshold_y_scaled = np.sqrt(color_threshold) |
| 185 | +threshold_line = Span( |
| 186 | + location=threshold_y_scaled, |
| 187 | + dimension="width", |
| 188 | + line_color="#999999", |
| 189 | + line_dash="dashed", |
| 190 | + line_width=2, |
| 191 | + line_alpha=0.5, |
| 192 | +) |
| 193 | +p.add_layout(threshold_line) |
| 194 | + |
| 195 | +threshold_label = Label( |
| 196 | + x=n_samples - 1.2, |
| 197 | + y=threshold_y_scaled, |
| 198 | + text="cluster threshold", |
| 199 | + text_font_size="16pt", |
| 200 | + text_color="#888888", |
| 201 | + text_font_style="italic", |
| 202 | + y_offset=8, |
| 203 | + text_align="right", |
| 204 | +) |
| 205 | +p.add_layout(threshold_label) |
| 206 | + |
| 207 | +# Legend entries via off-screen line glyphs for colored swatches |
| 208 | +p.line([-99, -98], [-99, -99], line_color=colors_within, line_width=6, legend_label="Within-cluster") |
| 209 | +p.line([-99, -98], [-99, -99], line_color=colors_between, line_width=6, legend_label="Between-cluster") |
| 210 | + |
| 211 | +# Leaf labels as x-axis tick labels (renders outside plot frame, no clipping) |
| 212 | +p.xaxis.ticker = FixedTicker(ticks=list(range(n_samples))) |
| 213 | +p.xaxis.major_label_overrides = {i: ordered_labels[i] for i in range(n_samples)} |
| 214 | +p.xaxis.major_label_orientation = 0.785 # 45 degrees in radians |
| 215 | + |
| 216 | +# Style |
| 217 | +p.title.text_font_size = "30pt" |
| 218 | +p.title.text_font_style = "normal" |
| 219 | +p.title.text_color = "#333333" |
159 | 220 | p.xaxis.axis_label_text_font_size = "24pt" |
160 | 221 | p.yaxis.axis_label_text_font_size = "24pt" |
161 | | -p.xaxis.major_label_text_font_size = "0pt" # Hide default x-axis labels |
| 222 | +p.xaxis.axis_label_text_color = "#555555" |
| 223 | +p.yaxis.axis_label_text_color = "#555555" |
| 224 | +p.xaxis.major_label_text_font_size = "18pt" |
| 225 | +p.xaxis.major_label_text_color = "#444444" |
162 | 226 | p.yaxis.major_label_text_font_size = "20pt" |
| 227 | +p.yaxis.major_label_text_color = "#666666" |
163 | 228 |
|
164 | | -# Grid styling |
| 229 | +p.background_fill_color = "#FAFAFA" |
| 230 | +p.border_fill_color = "white" |
165 | 231 | p.xgrid.visible = False |
166 | | -p.ygrid.grid_line_alpha = 0.3 |
167 | | -p.ygrid.grid_line_dash = "dashed" |
| 232 | +p.ygrid.grid_line_alpha = 0.12 |
| 233 | +p.ygrid.grid_line_dash = [4, 4] |
| 234 | +p.ygrid.grid_line_color = "#AAAAAA" |
168 | 235 |
|
169 | | -# Remove tick marks on x-axis |
| 236 | +p.xaxis.axis_line_color = "#CCCCCC" |
| 237 | +p.yaxis.axis_line_color = "#CCCCCC" |
170 | 238 | p.xaxis.major_tick_line_color = None |
171 | 239 | p.xaxis.minor_tick_line_color = None |
172 | | - |
173 | | -# Clean outline |
| 240 | +p.yaxis.major_tick_line_color = "#CCCCCC" |
| 241 | +p.yaxis.minor_tick_line_color = None |
174 | 242 | p.outline_line_color = None |
175 | 243 |
|
176 | | -# Save outputs |
| 244 | +# Legend |
| 245 | +p.legend.location = "top_left" |
| 246 | +p.legend.label_text_font_size = "22pt" |
| 247 | +p.legend.label_text_color = "#333333" |
| 248 | +p.legend.glyph_width = 50 |
| 249 | +p.legend.glyph_height = 8 |
| 250 | +p.legend.spacing = 12 |
| 251 | +p.legend.padding = 20 |
| 252 | +p.legend.margin = 15 |
| 253 | +p.legend.background_fill_alpha = 0.92 |
| 254 | +p.legend.background_fill_color = "#FAFAFA" |
| 255 | +p.legend.border_line_color = "#CCCCCC" |
| 256 | +p.legend.border_line_alpha = 0.6 |
| 257 | + |
| 258 | +# Save |
177 | 259 | export_png(p, filename="plot.png") |
178 | 260 | output_file("plot.html") |
179 | 261 | save(p) |
0 commit comments