|
1 | 1 | """ pyplots.ai |
2 | 2 | dendrogram-basic: Basic Dendrogram |
3 | | -Library: altair 6.0.0 | Python 3.13.11 |
4 | | -Quality: 93/100 | Created: 2025-12-23 |
| 3 | +Library: altair 6.0.0 | Python 3.14.3 |
| 4 | +Quality: 88/100 | Updated: 2026-04-05 |
5 | 5 | """ |
6 | 6 |
|
7 | 7 | import altair as alt |
8 | | -import numpy as np |
9 | 8 | import pandas as pd |
10 | | -from scipy.cluster.hierarchy import dendrogram, linkage |
11 | | - |
12 | | - |
13 | | -# Data - Iris flower measurements (4 features for 15 samples) |
14 | | -np.random.seed(42) |
15 | | - |
16 | | -# Simulate iris-like measurements: sepal length, sepal width, petal length, petal width |
17 | | -# Three species with distinct characteristics |
18 | | -samples_per_species = 5 |
19 | | -labels = [] |
20 | | -data = [] |
21 | | - |
22 | | -# Setosa: shorter petals, wider sepals |
23 | | -for i in range(samples_per_species): |
24 | | - labels.append(f"Setosa-{i + 1}") |
25 | | - data.append( |
26 | | - [ |
27 | | - 5.0 + np.random.randn() * 0.3, # sepal length |
28 | | - 3.4 + np.random.randn() * 0.3, # sepal width |
29 | | - 1.5 + np.random.randn() * 0.2, # petal length |
30 | | - 0.3 + np.random.randn() * 0.1, # petal width |
31 | | - ] |
32 | | - ) |
| 9 | +from scipy.cluster.hierarchy import dendrogram, fcluster, linkage |
| 10 | +from sklearn.datasets import load_iris |
33 | 11 |
|
34 | | -# Versicolor: medium measurements |
35 | | -for i in range(samples_per_species): |
36 | | - labels.append(f"Versicolor-{i + 1}") |
37 | | - data.append( |
38 | | - [ |
39 | | - 5.9 + np.random.randn() * 0.4, |
40 | | - 2.8 + np.random.randn() * 0.3, |
41 | | - 4.3 + np.random.randn() * 0.4, |
42 | | - 1.3 + np.random.randn() * 0.2, |
43 | | - ] |
44 | | - ) |
45 | 12 |
|
46 | | -# Virginica: longer petals and sepals |
47 | | -for i in range(samples_per_species): |
48 | | - labels.append(f"Virginica-{i + 1}") |
49 | | - data.append( |
50 | | - [ |
51 | | - 6.6 + np.random.randn() * 0.5, |
52 | | - 3.0 + np.random.randn() * 0.3, |
53 | | - 5.5 + np.random.randn() * 0.5, |
54 | | - 2.0 + np.random.randn() * 0.3, |
55 | | - ] |
| 13 | +# Data - Iris flower measurements (15 samples, 3 species) |
| 14 | +iris = load_iris() |
| 15 | +indices = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140] |
| 16 | +features = iris.data[indices] |
| 17 | +species_names = ["Setosa", "Versicolor", "Virginica"] |
| 18 | +labels = [f"{species_names[iris.target[i]]}-{i}" for i in indices] |
| 19 | + |
| 20 | +# Compute hierarchical clustering using Ward's method |
| 21 | +Z = linkage(features, method="ward") |
| 22 | +dendro = dendrogram(Z, labels=labels, no_plot=True) |
| 23 | + |
| 24 | +# Assign cluster colors based on distance threshold |
| 25 | +distance_threshold = 5.0 |
| 26 | +cluster_ids = fcluster(Z, t=distance_threshold, criterion="distance") |
| 27 | +cluster_colors = {1: "#306998", 2: "#D4A017", 3: "#7B68AE"} |
| 28 | + |
| 29 | +# Build a mapping from leaf index to cluster color, then propagate to merges |
| 30 | +n_leaves = len(labels) |
| 31 | +node_colors = {} |
| 32 | +for idx in dendro["leaves"]: |
| 33 | + node_colors[idx] = cluster_colors.get(cluster_ids[idx], "#888888") |
| 34 | + |
| 35 | +# Track merged node colors through linkage |
| 36 | +for i, row in enumerate(Z): |
| 37 | + left, right = int(row[0]), int(row[1]) |
| 38 | + left_c = node_colors.get(left, "#888888") |
| 39 | + right_c = node_colors.get(right, "#888888") |
| 40 | + node_colors[n_leaves + i] = left_c if left_c == right_c else "#888888" |
| 41 | + |
| 42 | +# Extract line segments with cluster-based coloring |
| 43 | +segments = [] |
| 44 | +for merge_idx, (xpts, ypts) in enumerate(zip(dendro["icoord"], dendro["dcoord"], strict=True)): |
| 45 | + merge_height = max(ypts) |
| 46 | + left_node = int(Z[merge_idx, 0]) |
| 47 | + right_node = int(Z[merge_idx, 1]) |
| 48 | + left_c = node_colors.get(left_node, "#888888") |
| 49 | + right_c = node_colors.get(right_node, "#888888") |
| 50 | + merge_c = left_c if left_c == right_c else "#888888" |
| 51 | + |
| 52 | + # Left vertical |
| 53 | + segments.append( |
| 54 | + {"x": xpts[0], "y": ypts[0], "x2": xpts[1], "y2": ypts[1], "color": left_c, "distance": round(merge_height, 2)} |
| 55 | + ) |
| 56 | + # Horizontal bar |
| 57 | + segments.append( |
| 58 | + {"x": xpts[1], "y": ypts[1], "x2": xpts[2], "y2": ypts[2], "color": merge_c, "distance": round(merge_height, 2)} |
| 59 | + ) |
| 60 | + # Right vertical |
| 61 | + segments.append( |
| 62 | + {"x": xpts[2], "y": ypts[2], "x2": xpts[3], "y2": ypts[3], "color": right_c, "distance": round(merge_height, 2)} |
56 | 63 | ) |
57 | 64 |
|
58 | | -data = np.array(data) |
59 | | -n_samples = len(labels) |
| 65 | +segments_df = pd.DataFrame(segments) |
| 66 | + |
| 67 | +# Leaf label positions from scipy (positioned at 5, 15, 25, ...) |
| 68 | +leaf_labels = dendro["ivl"] |
| 69 | +leaf_df = pd.DataFrame( |
| 70 | + { |
| 71 | + "x": [5 + 10 * i for i in range(len(leaf_labels))], |
| 72 | + "y_base": [0.0] * len(leaf_labels), |
| 73 | + "label": leaf_labels, |
| 74 | + "species": [lbl.rsplit("-", 1)[0] for lbl in leaf_labels], |
| 75 | + } |
| 76 | +) |
60 | 77 |
|
61 | | -# Compute hierarchical clustering using Ward's method |
62 | | -Z = linkage(data, method="ward") |
| 78 | +# Species color palette starting with Python Blue |
| 79 | +species_palette = {"Setosa": "#306998", "Versicolor": "#D4A017", "Virginica": "#7B68AE"} |
63 | 80 |
|
64 | | -# Use scipy's dendrogram function to get proper leaf ordering and coordinates |
65 | | -dendro = dendrogram(Z, labels=labels, no_plot=True) |
| 81 | +# Axis domain |
| 82 | +x_min = min(s["x"] for s in segments) - 8 |
| 83 | +x_max = max(s["x2"] for s in segments) + 8 |
| 84 | +y_max = Z[:, 2].max() * 1.15 |
66 | 85 |
|
67 | | -# Extract coordinates from scipy's dendrogram output (icoord, dcoord) |
68 | | -# Each cluster merge has 4 x-coords and 4 y-coords forming a U-shape |
69 | | -lines_data = [] |
70 | | -color_threshold = 0.7 * Z[:, 2].max() |
71 | | - |
72 | | -for xpts, ypts in zip(dendro["icoord"], dendro["dcoord"], strict=True): |
73 | | - # Each U-shape has 3 segments: left vertical, horizontal, right vertical |
74 | | - # Points: (x0,y0) - (x1,y1) - (x2,y2) - (x3,y3) |
75 | | - max_height = max(ypts) |
76 | | - color = "#306998" if max_height > color_threshold else "#FFD43B" |
77 | | - |
78 | | - # Left vertical segment |
79 | | - lines_data.append({"x": xpts[0], "y": ypts[0], "x2": xpts[1], "y2": ypts[1], "color": color}) |
80 | | - # Horizontal segment |
81 | | - lines_data.append({"x": xpts[1], "y": ypts[1], "x2": xpts[2], "y2": ypts[2], "color": color}) |
82 | | - # Right vertical segment |
83 | | - lines_data.append({"x": xpts[2], "y": ypts[2], "x2": xpts[3], "y2": ypts[3], "color": color}) |
84 | | - |
85 | | -lines_df = pd.DataFrame(lines_data) |
86 | | - |
87 | | -# Create label data for x-axis using scipy's leaf positions |
88 | | -ivl = dendro["ivl"] # Ordered labels from dendrogram |
89 | | -# Labels are positioned at 5, 15, 25, ... (5 + 10*i) |
90 | | -label_positions = [5 + 10 * i for i in range(len(ivl))] |
91 | | -label_data = pd.DataFrame({"x": label_positions, "label": ivl}) |
92 | | - |
93 | | -# Get x-axis domain from the dendrogram coordinates |
94 | | -x_min = min(min(xpts) for xpts in dendro["icoord"]) - 5 |
95 | | -x_max = max(max(xpts) for xpts in dendro["icoord"]) + 5 |
96 | | - |
97 | | -# Create the dendrogram lines chart |
98 | | -dendrogram_lines = ( |
99 | | - alt.Chart(lines_df) |
| 86 | +# Annotation for the final merge (top of tree) — key storytelling element |
| 87 | +top_merge_y = Z[-1, 2] |
| 88 | +top_merge_x = (dendro["icoord"][-1][1] + dendro["icoord"][-1][2]) / 2 |
| 89 | +annotation_df = pd.DataFrame( |
| 90 | + {"x": [top_merge_x], "y": [top_merge_y], "text": ["Setosa diverges\nfrom Versicolor + Virginica"]} |
| 91 | +) |
| 92 | + |
| 93 | +# Interactive selection: click legend to highlight a species |
| 94 | +species_selection = alt.selection_point(fields=["species"], bind="legend") |
| 95 | + |
| 96 | +# Dendrogram branches with cluster-based coloring and tooltips |
| 97 | +branches = ( |
| 98 | + alt.Chart(segments_df) |
100 | 99 | .mark_rule(strokeWidth=3) |
101 | 100 | .encode( |
102 | | - x=alt.X("x:Q", axis=None, scale=alt.Scale(domain=[x_min, x_max])), |
| 101 | + x=alt.X("x:Q", scale=alt.Scale(domain=[x_min, x_max]), axis=None), |
103 | 102 | x2="x2:Q", |
104 | | - y=alt.Y("y:Q", title="Distance (Ward)", scale=alt.Scale(domain=[0, Z[:, 2].max() * 1.1])), |
| 103 | + y=alt.Y("y:Q", title="Distance (Ward's method)", scale=alt.Scale(domain=[0, y_max])), |
105 | 104 | y2="y2:Q", |
106 | 105 | color=alt.Color("color:N", scale=None), |
| 106 | + tooltip=[alt.Tooltip("distance:Q", title="Merge Distance", format=".2f")], |
| 107 | + ) |
| 108 | +) |
| 109 | + |
| 110 | +# Leaf markers at base of dendrogram colored by species |
| 111 | +leaf_dots = ( |
| 112 | + alt.Chart(leaf_df) |
| 113 | + .mark_point(size=180, filled=True, strokeWidth=1.5, stroke="white") |
| 114 | + .encode( |
| 115 | + x=alt.X("x:Q", scale=alt.Scale(domain=[x_min, x_max]), axis=None), |
| 116 | + y=alt.Y("y_base:Q", scale=alt.Scale(domain=[0, y_max])), |
| 117 | + color=alt.Color( |
| 118 | + "species:N", |
| 119 | + scale=alt.Scale(domain=list(species_palette.keys()), range=list(species_palette.values())), |
| 120 | + legend=alt.Legend( |
| 121 | + title="Species", |
| 122 | + titleFontSize=18, |
| 123 | + titleFontWeight="bold", |
| 124 | + labelFontSize=16, |
| 125 | + symbolSize=220, |
| 126 | + orient="right", |
| 127 | + offset=10, |
| 128 | + titleColor="#333333", |
| 129 | + labelColor="#444444", |
| 130 | + ), |
| 131 | + ), |
| 132 | + tooltip=[alt.Tooltip("label:N", title="Sample"), alt.Tooltip("species:N", title="Species")], |
| 133 | + opacity=alt.condition(species_selection, alt.value(1.0), alt.value(0.15)), |
| 134 | + ) |
| 135 | + .add_params(species_selection) |
| 136 | +) |
| 137 | + |
| 138 | +# Leaf labels colored by species |
| 139 | +leaf_text = ( |
| 140 | + alt.Chart(leaf_df) |
| 141 | + .mark_text(angle=315, align="right", baseline="top", fontSize=16, fontWeight="bold", dx=-4, dy=4) |
| 142 | + .encode( |
| 143 | + x=alt.X("x:Q", scale=alt.Scale(domain=[x_min, x_max]), axis=None), |
| 144 | + y=alt.value(870), |
| 145 | + text="label:N", |
| 146 | + color=alt.Color( |
| 147 | + "species:N", |
| 148 | + scale=alt.Scale(domain=list(species_palette.keys()), range=list(species_palette.values())), |
| 149 | + legend=None, |
| 150 | + ), |
| 151 | + opacity=alt.condition(species_selection, alt.value(1.0), alt.value(0.15)), |
107 | 152 | ) |
108 | 153 | ) |
109 | 154 |
|
110 | | -# Create x-axis labels at bottom |
111 | | -x_labels = ( |
112 | | - alt.Chart(label_data) |
113 | | - .mark_text(angle=315, align="right", baseline="top", fontSize=14) |
114 | | - .encode(x=alt.X("x:Q", axis=None, scale=alt.Scale(domain=[x_min, x_max])), y=alt.value(850), text="label:N") |
| 155 | +# Distance threshold reference line |
| 156 | +threshold_df = pd.DataFrame({"y": [distance_threshold]}) |
| 157 | +threshold_line = ( |
| 158 | + alt.Chart(threshold_df).mark_rule(strokeDash=[8, 6], strokeWidth=1.8, color="#CC4444", opacity=0.7).encode(y="y:Q") |
115 | 159 | ) |
116 | 160 |
|
117 | | -# Combine charts |
| 161 | +threshold_label = ( |
| 162 | + alt.Chart(threshold_df) |
| 163 | + .mark_text(align="left", baseline="bottom", fontSize=14, color="#CC4444", fontStyle="italic", dx=5, dy=-5) |
| 164 | + .encode(x=alt.value(10), y="y:Q", text=alt.value("cluster threshold (d = 5.0)")) |
| 165 | +) |
| 166 | + |
| 167 | +# Annotation at top merge point |
| 168 | +top_annotation = ( |
| 169 | + alt.Chart(annotation_df) |
| 170 | + .mark_text(align="left", baseline="middle", fontSize=14, fontWeight="bold", color="#555555", lineBreak="\n", dx=12) |
| 171 | + .encode(x="x:Q", y="y:Q", text="text:N") |
| 172 | +) |
| 173 | + |
| 174 | +top_arrow = ( |
| 175 | + alt.Chart(annotation_df) |
| 176 | + .mark_point(shape="triangle-left", size=80, filled=True, color="#888888") |
| 177 | + .encode(x="x:Q", y="y:Q") |
| 178 | +) |
| 179 | + |
| 180 | +# Combine layers |
118 | 181 | chart = ( |
119 | | - alt.layer(dendrogram_lines, x_labels) |
120 | | - .properties(width=1600, height=900, title=alt.Title("dendrogram-basic · altair · pyplots.ai", fontSize=28)) |
121 | | - .configure_axis(labelFontSize=18, titleFontSize=22, gridOpacity=0.3, gridDash=[4, 4]) |
122 | | - .configure_view(strokeWidth=0) |
| 182 | + alt.layer(branches, threshold_line, threshold_label, leaf_dots, leaf_text, top_arrow, top_annotation) |
| 183 | + .properties( |
| 184 | + width=1600, |
| 185 | + height=900, |
| 186 | + title=alt.Title( |
| 187 | + "dendrogram-basic · altair · pyplots.ai", |
| 188 | + subtitle="Ward's linkage on Iris measurements — Setosa separates clearly from Versicolor / Virginica", |
| 189 | + fontSize=28, |
| 190 | + subtitleFontSize=18, |
| 191 | + subtitleColor="#666666", |
| 192 | + anchor="start", |
| 193 | + offset=20, |
| 194 | + ), |
| 195 | + ) |
| 196 | + .configure_axis( |
| 197 | + labelFontSize=18, |
| 198 | + titleFontSize=22, |
| 199 | + titleColor="#333333", |
| 200 | + labelColor="#555555", |
| 201 | + gridOpacity=0.12, |
| 202 | + gridDash=[3, 5], |
| 203 | + gridColor="#cccccc", |
| 204 | + domainColor="#aaaaaa", |
| 205 | + domainWidth=1.5, |
| 206 | + tickColor="#bbbbbb", |
| 207 | + tickSize=6, |
| 208 | + ) |
| 209 | + .configure_view(strokeWidth=0, fill="#FAFBFC") |
| 210 | + .configure_legend(padding=20, cornerRadius=6, strokeColor="#dddddd", fillColor="#FAFBFC") |
| 211 | + .configure_title(subtitlePadding=8) |
123 | 212 | ) |
124 | 213 |
|
125 | 214 | # Save |
|
0 commit comments