Skip to content

Commit 50794f2

Browse files
update(dendrogram-basic): plotnine — comprehensive quality review (#5205)
## Summary Updated **plotnine** implementation for **dendrogram-basic**. **Changes:** Comprehensive review improving code quality, data choice, visual design, spec compliance, and library feature usage. ## Test Plan - [x] Preview images uploaded to GCS staging - [x] Implementation file passes ruff format/check - [x] Metadata YAML updated with current versions - [ ] Automated review triggered --- Generated with [Claude Code](https://claude.com/claude-code) `/update` command --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent d591a9c commit 50794f2

File tree

2 files changed

+313
-209
lines changed

2 files changed

+313
-209
lines changed
Lines changed: 167 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,26 @@
11
""" pyplots.ai
22
dendrogram-basic: Basic Dendrogram
3-
Library: plotnine 0.15.2 | Python 3.13.11
4-
Quality: 91/100 | Created: 2025-12-23
3+
Library: plotnine 0.15.3 | Python 3.14.3
4+
Quality: 89/100 | Updated: 2026-04-05
55
"""
66

77
import numpy as np
88
import pandas as pd
99
from plotnine import (
1010
aes,
11+
annotate,
12+
coord_cartesian,
1113
element_blank,
1214
element_line,
15+
element_rect,
1316
element_text,
17+
geom_hline,
18+
geom_point,
1419
geom_segment,
1520
geom_text,
1621
ggplot,
22+
guide_legend,
23+
guides,
1724
labs,
1825
scale_color_manual,
1926
scale_x_continuous,
@@ -22,112 +29,184 @@
2229
theme_minimal,
2330
)
2431
from scipy.cluster.hierarchy import dendrogram, linkage
32+
from sklearn.datasets import load_iris
2533

2634

27-
# Data - Iris flower measurements (4 features for 15 samples)
35+
# Data - Real iris flower measurements (15 samples, 5 per species)
36+
iris = load_iris()
2837
np.random.seed(42)
38+
species_names = ["Setosa", "Versicolor", "Virginica"]
39+
species_counts = dict.fromkeys(species_names, 0)
40+
sample_labels = []
41+
indices = np.concatenate([np.random.choice(np.where(iris.target == i)[0], 5, replace=False) for i in range(3)])
42+
for i in indices:
43+
name = species_names[iris.target[i]]
44+
species_counts[name] += 1
45+
sample_labels.append(f"{name}-{species_counts[name]}")
46+
features = iris.data[indices]
47+
48+
# Hierarchical clustering with Ward's method
49+
linkage_matrix = linkage(features, method="ward")
50+
palette = {"Setosa": "#306998", "Versicolor": "#E8833A", "Virginica": "#55A868"}
51+
52+
# Extract dendrogram coordinates
53+
dend = dendrogram(linkage_matrix, labels=sample_labels, no_plot=True)
54+
55+
# Track species composition of each node for branch coloring
56+
n = len(sample_labels)
57+
leaf_species = {lbl: lbl.rsplit("-", 1)[0] for lbl in sample_labels}
58+
node_species = {}
59+
for i, label in enumerate(sample_labels):
60+
node_species[i] = {leaf_species[label]}
61+
for i, row in enumerate(linkage_matrix):
62+
left, right = int(row[0]), int(row[1])
63+
node_species[n + i] = node_species[left] | node_species[right]
64+
65+
# Branch type label for each merge: species name if pure, "Mixed" if mixed
66+
branch_type_labels = {"Setosa": "Setosa (pure)", "Versicolor": "Versicolor (pure)", "Virginica": "Virginica (pure)"}
67+
merge_branch_types = []
68+
for i in range(len(linkage_matrix)):
69+
sp = node_species[n + i]
70+
if len(sp) == 1:
71+
merge_branch_types.append(branch_type_labels[next(iter(sp))])
72+
else:
73+
merge_branch_types.append("Mixed species")
2974

30-
# Simulate iris-like measurements: sepal length, sepal width, petal length, petal width
31-
# Three species with distinct characteristics
32-
samples_per_species = 5
33-
34-
labels = []
35-
data = []
36-
37-
# Setosa: shorter petals, wider sepals
38-
for i in range(samples_per_species):
39-
labels.append(f"Setosa-{i + 1}")
40-
data.append(
41-
[
42-
5.0 + np.random.randn() * 0.3, # sepal length
43-
3.4 + np.random.randn() * 0.3, # sepal width
44-
1.5 + np.random.randn() * 0.2, # petal length
45-
0.3 + np.random.randn() * 0.1, # petal width
46-
]
47-
)
48-
49-
# Versicolor: medium measurements
50-
for i in range(samples_per_species):
51-
labels.append(f"Versicolor-{i + 1}")
52-
data.append(
53-
[
54-
5.9 + np.random.randn() * 0.4, # sepal length
55-
2.8 + np.random.randn() * 0.3, # sepal width
56-
4.3 + np.random.randn() * 0.4, # petal length
57-
1.3 + np.random.randn() * 0.2, # petal width
58-
]
59-
)
60-
61-
# Virginica: longer petals and sepals
62-
for i in range(samples_per_species):
63-
labels.append(f"Virginica-{i + 1}")
64-
data.append(
65-
[
66-
6.6 + np.random.randn() * 0.5, # sepal length
67-
3.0 + np.random.randn() * 0.3, # sepal width
68-
5.5 + np.random.randn() * 0.5, # petal length
69-
2.0 + np.random.randn() * 0.3, # petal width
70-
]
71-
)
72-
73-
data = np.array(data)
74-
75-
# Compute hierarchical clustering using Ward's method
76-
linkage_matrix = linkage(data, method="ward")
77-
78-
# Extract dendrogram coordinates using scipy (no_plot=True returns coordinates only)
79-
dend = dendrogram(linkage_matrix, labels=labels, no_plot=True)
75+
# Map dendrogram order to linkage order via merge heights
76+
height_to_merge = {}
77+
for i, h in enumerate(linkage_matrix[:, 2]):
78+
height_to_merge.setdefault(round(h, 10), []).append(i)
8079

81-
# Convert dendrogram coordinates to segment data for plotnine
82-
# icoord contains x coords (pairs of 4 for each merge)
83-
# dcoord contains y coords (pairs of 4 for each merge)
80+
# Build segment dataframe
8481
segments = []
85-
color_threshold = 0.7 * max(linkage_matrix[:, 2])
86-
8782
for xs, ys in zip(dend["icoord"], dend["dcoord"], strict=True):
88-
# Each merge has 4 points forming a U-shape: [x1, x2, x3, x4], [y1, y2, y3, y4]
89-
# We need 3 segments: left vertical, horizontal, right vertical
90-
91-
# Determine color based on height (merge distance)
92-
merge_height = max(ys)
93-
if merge_height > color_threshold:
94-
color = "#306998" # Python Blue for high-level merges
83+
h = round(max(ys), 10)
84+
if h in height_to_merge and height_to_merge[h]:
85+
merge_idx = height_to_merge[h].pop(0)
86+
btype = merge_branch_types[merge_idx]
9587
else:
96-
color = "#FFD43B" # Python Yellow for low-level merges
97-
98-
segments.append({"x": xs[0], "xend": xs[1], "y": ys[0], "yend": ys[1], "color": color})
99-
segments.append({"x": xs[1], "xend": xs[2], "y": ys[1], "yend": ys[2], "color": color})
100-
segments.append({"x": xs[2], "xend": xs[3], "y": ys[2], "yend": ys[3], "color": color})
88+
btype = "Mixed species"
89+
segments.append({"x": xs[0], "xend": xs[1], "y": ys[0], "yend": ys[1], "branch_type": btype})
90+
segments.append({"x": xs[1], "xend": xs[2], "y": ys[1], "yend": ys[2], "branch_type": btype})
91+
segments.append({"x": xs[2], "xend": xs[3], "y": ys[2], "yend": ys[3], "branch_type": btype})
10192

10293
segments_df = pd.DataFrame(segments)
10394

104-
# Create label data using the actual leaf positions from dendrogram
105-
# dend['leaves'] gives the order, and x positions are at 5, 15, 25, ... (spacing of 10)
106-
leaf_positions = [(i + 1) * 10 - 5 for i in range(len(dend["ivl"]))]
107-
ivl = dend["ivl"] # Reordered labels from dendrogram
108-
label_df = pd.DataFrame({"x": leaf_positions, "label": ivl, "y": [-0.8] * len(ivl)})
95+
# Leaf labels with species-based coloring
96+
n_leaves = len(dend["ivl"])
97+
leaf_positions = [(i + 1) * 10 - 5 for i in range(n_leaves)]
98+
leaf_labels = dend["ivl"]
99+
leaf_btypes = [branch_type_labels[leaf_species[lbl]] for lbl in leaf_labels]
100+
label_df = pd.DataFrame({"x": leaf_positions, "label": leaf_labels, "y": [0.0] * n_leaves, "branch_type": leaf_btypes})
101+
102+
# Ordered category for consistent legend
103+
category_order = ["Setosa (pure)", "Versicolor (pure)", "Virginica (pure)", "Mixed species"]
104+
color_map = {
105+
"Setosa (pure)": palette["Setosa"],
106+
"Versicolor (pure)": palette["Versicolor"],
107+
"Virginica (pure)": palette["Virginica"],
108+
"Mixed species": "#888888",
109+
}
110+
segments_df["branch_type"] = pd.Categorical(segments_df["branch_type"], categories=category_order, ordered=True)
111+
label_df["branch_type"] = pd.Categorical(label_df["branch_type"], categories=category_order, ordered=True)
112+
113+
# Merge node points - highlight where clusters join (plotnine geom_point layer)
114+
merge_nodes = []
115+
for xs, ys, btype in zip(dend["icoord"], dend["dcoord"], merge_branch_types, strict=True):
116+
cx = (xs[1] + xs[2]) / 2
117+
cy = max(ys)
118+
merge_nodes.append({"x": cx, "y": cy, "branch_type": btype})
119+
merge_df = pd.DataFrame(merge_nodes)
120+
merge_df["branch_type"] = pd.Categorical(merge_df["branch_type"], categories=category_order, ordered=True)
121+
122+
# Key merge threshold: where Setosa separates from the rest
123+
setosa_sep_height = linkage_matrix[-2, 2]
124+
threshold_df = pd.DataFrame({"yintercept": [setosa_sep_height]})
125+
126+
# Plot
127+
y_max = max(linkage_matrix[:, 2]) * 1.08
128+
x_min = min(segments_df["x"].min(), segments_df["xend"].min())
129+
x_max = max(segments_df["x"].max(), segments_df["xend"].max())
130+
x_pad = (x_max - x_min) * 0.06
109131

110-
# Plot using plotnine's native geom_segment
111132
plot = (
112133
ggplot()
113-
+ geom_segment(aes(x="x", xend="xend", y="y", yend="yend", color="color"), data=segments_df, size=1.8)
114-
+ geom_text(aes(x="x", y="y", label="label"), data=label_df, angle=45, ha="right", va="top", size=9)
115-
+ scale_color_manual(values={"#306998": "#306998", "#FFD43B": "#FFD43B"}, guide=None)
116-
+ scale_x_continuous(breaks=[], expand=(0.12, 0.05))
117-
+ scale_y_continuous(expand=(0.25, 0.02))
118-
+ labs(x="Sample", y="Distance (Ward)", title="dendrogram-basic · plotnine · pyplots.ai")
134+
# Dendrogram branches - thicker for HD visibility
135+
+ geom_segment(aes(x="x", xend="xend", y="y", yend="yend", color="branch_type"), data=segments_df, size=2.2)
136+
# Threshold line using idiomatic geom_hline
137+
+ geom_hline(aes(yintercept="yintercept"), data=threshold_df, linetype="dashed", color="#AAAAAA", size=0.8)
138+
# Threshold annotation using plotnine annotate
139+
+ annotate(
140+
"text",
141+
x=x_max - x_pad,
142+
y=setosa_sep_height + 0.35,
143+
label="Setosa separates",
144+
size=13,
145+
color="#555555",
146+
fontstyle="italic",
147+
ha="right",
148+
)
149+
# Intermixing annotation - data storytelling for Versicolor/Virginica
150+
+ annotate(
151+
"text",
152+
x=x_max - x_pad,
153+
y=linkage_matrix[-1, 2] * 0.55,
154+
label="Versicolor & Virginica intermixed",
155+
size=12,
156+
color="#888888",
157+
fontstyle="italic",
158+
ha="right",
159+
)
160+
# Leaf labels - larger for readability
161+
+ geom_text(
162+
aes(x="x", y="y", label="label", color="branch_type"),
163+
data=label_df,
164+
angle=45,
165+
ha="right",
166+
va="top",
167+
size=13,
168+
nudge_y=-0.3,
169+
show_legend=False,
170+
)
171+
# Merge node markers - emphasize join points
172+
+ geom_point(aes(x="x", y="y", color="branch_type"), data=merge_df, size=3.5, show_legend=False)
173+
+ scale_color_manual(values=color_map, name="Branch Type")
174+
+ guides(color=guide_legend(override_aes={"size": 4, "alpha": 1}))
175+
+ scale_x_continuous(breaks=[], expand=(0.04, 0))
176+
+ scale_y_continuous(breaks=np.arange(0, y_max, 2).tolist(), expand=(0.10, 0))
177+
+ coord_cartesian(xlim=(x_min - x_pad, x_max + x_pad), ylim=(-2.5, y_max))
178+
+ labs(
179+
x="",
180+
y="Ward Linkage Distance",
181+
title="Iris Species Clustering · dendrogram-basic · plotnine · pyplots.ai",
182+
subtitle="Hierarchical clustering of 15 iris samples using Ward's minimum variance method",
183+
)
119184
+ theme_minimal()
120185
+ theme(
121186
figure_size=(16, 9),
122-
text=element_text(size=14),
123-
axis_title=element_text(size=20),
124-
axis_text=element_text(size=16),
187+
text=element_text(size=14, family="sans-serif"),
188+
axis_title_x=element_blank(),
189+
axis_title_y=element_text(size=20, margin={"r": 12}),
190+
axis_text=element_text(size=16, color="#444444"),
125191
axis_text_x=element_blank(),
126-
plot_title=element_text(size=24),
192+
axis_ticks_major_x=element_blank(),
193+
plot_title=element_text(size=24, weight="bold", margin={"b": 4}),
194+
plot_subtitle=element_text(size=15, color="#666666", margin={"b": 12}),
195+
plot_background=element_rect(fill="#FAFAFA", color="none"),
196+
panel_background=element_rect(fill="#FAFAFA", color="none"),
127197
panel_grid_major_x=element_blank(),
128198
panel_grid_minor_x=element_blank(),
129-
panel_grid_major_y=element_line(alpha=0.3, linetype="dashed"),
199+
panel_grid_minor_y=element_blank(),
200+
panel_grid_major_y=element_line(alpha=0.2, size=0.5, color="#CCCCCC"),
201+
legend_title=element_text(size=16, weight="bold"),
202+
legend_text=element_text(size=14),
203+
legend_position="right",
204+
legend_background=element_rect(fill="#FAFAFA", color="#DDDDDD", size=0.5),
205+
legend_key=element_rect(fill="none", color="none"),
206+
plot_margin=0.02,
130207
)
131208
)
132209

133-
plot.save("plot.png", dpi=300)
210+
# Save with tight layout
211+
fig = plot.draw()
212+
fig.savefig("plot.png", dpi=300, bbox_inches="tight")

0 commit comments

Comments
 (0)