Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 86 additions & 82 deletions plots/dendrogram-basic/implementations/plotnine.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
""" pyplots.ai
dendrogram-basic: Basic Dendrogram
Library: plotnine 0.15.2 | Python 3.13.11
Quality: 91/100 | Created: 2025-12-23
Library: plotnine 0.15.3 | Python 3.14.3
Quality: 87/100 | Updated: 2026-04-05
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The header line Quality: /100 | Updated: 2026-04-05 is missing the numeric score (or the standard placeholder like pending). This looks like an accidental formatting regression and makes the header inconsistent with other implementations and the metadata quality_score.

Suggested change
Quality: 87/100 | Updated: 2026-04-05
Quality: pending/100 | Updated: 2026-04-05

Copilot uses AI. Check for mistakes.
"""

import numpy as np
import pandas as pd
from plotnine import (
aes,
coord_cartesian,
element_blank,
element_line,
element_text,
Expand All @@ -22,112 +23,115 @@
theme_minimal,
)
from scipy.cluster.hierarchy import dendrogram, linkage
from sklearn.datasets import load_iris


# Data - Iris flower measurements (4 features for 15 samples)
# Data - Real iris flower measurements (15 samples, 5 per species)
iris = load_iris()
np.random.seed(42)
species_names = ["Setosa", "Versicolor", "Virginica"]
species_counts = dict.fromkeys(species_names, 0)
sample_labels = []
indices = np.concatenate([np.random.choice(np.where(iris.target == i)[0], 5, replace=False) for i in range(3)])
for i in indices:
name = species_names[iris.target[i]]
species_counts[name] += 1
sample_labels.append(f"{name}-{species_counts[name]}")
features = iris.data[indices]

# Hierarchical clustering with Ward's method
linkage_matrix = linkage(features, method="ward")
palette = {"Setosa": "#306998", "Versicolor": "#E8833A", "Virginica": "#55A868"}

# Extract dendrogram coordinates
dend = dendrogram(linkage_matrix, labels=sample_labels, no_plot=True)

# Track species composition of each node for branch coloring
n = len(sample_labels)
leaf_species = {lbl: lbl.rsplit("-", 1)[0] for lbl in sample_labels}
node_species = {}
for i, label in enumerate(sample_labels):
node_species[i] = {leaf_species[label]}
for i, row in enumerate(linkage_matrix):
left, right = int(row[0]), int(row[1])
node_species[n + i] = node_species[left] | node_species[right]

# Color each U-shape: species color if pure, grey if mixed
merge_colors = []
for i in range(len(linkage_matrix)):
sp = node_species[n + i]
if len(sp) == 1:
merge_colors.append(palette[next(iter(sp))])
else:
merge_colors.append("#888888")

# Simulate iris-like measurements: sepal length, sepal width, petal length, petal width
# Three species with distinct characteristics
samples_per_species = 5

labels = []
data = []

# Setosa: shorter petals, wider sepals
for i in range(samples_per_species):
labels.append(f"Setosa-{i + 1}")
data.append(
[
5.0 + np.random.randn() * 0.3, # sepal length
3.4 + np.random.randn() * 0.3, # sepal width
1.5 + np.random.randn() * 0.2, # petal length
0.3 + np.random.randn() * 0.1, # petal width
]
)

# Versicolor: medium measurements
for i in range(samples_per_species):
labels.append(f"Versicolor-{i + 1}")
data.append(
[
5.9 + np.random.randn() * 0.4, # sepal length
2.8 + np.random.randn() * 0.3, # sepal width
4.3 + np.random.randn() * 0.4, # petal length
1.3 + np.random.randn() * 0.2, # petal width
]
)

# Virginica: longer petals and sepals
for i in range(samples_per_species):
labels.append(f"Virginica-{i + 1}")
data.append(
[
6.6 + np.random.randn() * 0.5, # sepal length
3.0 + np.random.randn() * 0.3, # sepal width
5.5 + np.random.randn() * 0.5, # petal length
2.0 + np.random.randn() * 0.3, # petal width
]
)

data = np.array(data)

# Compute hierarchical clustering using Ward's method
linkage_matrix = linkage(data, method="ward")

# Extract dendrogram coordinates using scipy (no_plot=True returns coordinates only)
dend = dendrogram(linkage_matrix, labels=labels, no_plot=True)
# Map dendrogram order to linkage order via merge heights
height_to_merge = {}
for i, h in enumerate(linkage_matrix[:, 2]):
height_to_merge.setdefault(round(h, 10), []).append(i)

# Convert dendrogram coordinates to segment data for plotnine
# icoord contains x coords (pairs of 4 for each merge)
# dcoord contains y coords (pairs of 4 for each merge)
# Build segment dataframe
segments = []
color_threshold = 0.7 * max(linkage_matrix[:, 2])

for xs, ys in zip(dend["icoord"], dend["dcoord"], strict=True):
# Each merge has 4 points forming a U-shape: [x1, x2, x3, x4], [y1, y2, y3, y4]
# We need 3 segments: left vertical, horizontal, right vertical

# Determine color based on height (merge distance)
merge_height = max(ys)
if merge_height > color_threshold:
color = "#306998" # Python Blue for high-level merges
h = round(max(ys), 10)
if h in height_to_merge and height_to_merge[h]:
merge_idx = height_to_merge[h].pop(0)
color = merge_colors[merge_idx]
else:
color = "#FFD43B" # Python Yellow for low-level merges

color = "#888888"
segments.append({"x": xs[0], "xend": xs[1], "y": ys[0], "yend": ys[1], "color": color})
segments.append({"x": xs[1], "xend": xs[2], "y": ys[1], "yend": ys[2], "color": color})
segments.append({"x": xs[2], "xend": xs[3], "y": ys[2], "yend": ys[3], "color": color})

segments_df = pd.DataFrame(segments)

# Create label data using the actual leaf positions from dendrogram
# dend['leaves'] gives the order, and x positions are at 5, 15, 25, ... (spacing of 10)
leaf_positions = [(i + 1) * 10 - 5 for i in range(len(dend["ivl"]))]
ivl = dend["ivl"] # Reordered labels from dendrogram
label_df = pd.DataFrame({"x": leaf_positions, "label": ivl, "y": [-0.8] * len(ivl)})
# Leaf labels with species-based coloring
n_leaves = len(dend["ivl"])
leaf_positions = [(i + 1) * 10 - 5 for i in range(n_leaves)]
leaf_labels = dend["ivl"]
leaf_colors = [palette[leaf_species[lbl]] for lbl in leaf_labels]
label_df = pd.DataFrame({"x": leaf_positions, "label": leaf_labels, "y": [0.0] * n_leaves, "color": leaf_colors})

# Plot using plotnine's native geom_segment
# Unique colors for scale
unique_colors = sorted(set(segments_df["color"].tolist() + leaf_colors))
color_identity = {c: c for c in unique_colors}

# Plot
y_max = max(linkage_matrix[:, 2]) * 1.05
plot = (
ggplot()
+ geom_segment(aes(x="x", xend="xend", y="y", yend="yend", color="color"), data=segments_df, size=1.8)
+ geom_text(aes(x="x", y="y", label="label"), data=label_df, angle=45, ha="right", va="top", size=9)
+ scale_color_manual(values={"#306998": "#306998", "#FFD43B": "#FFD43B"}, guide=None)
+ scale_x_continuous(breaks=[], expand=(0.12, 0.05))
+ scale_y_continuous(expand=(0.25, 0.02))
+ labs(x="Sample", y="Distance (Ward)", title="dendrogram-basic · plotnine · pyplots.ai")
+ geom_segment(aes(x="x", xend="xend", y="y", yend="yend", color="color"), data=segments_df, size=1.6)
+ geom_text(
aes(x="x", y="y", label="label", color="color"),
data=label_df,
angle=45,
ha="right",
va="top",
size=9,
nudge_y=-0.3,
)
+ scale_color_manual(values=color_identity, guide=None)
+ scale_x_continuous(breaks=[], expand=(0.08, 0))
+ scale_y_continuous(breaks=np.arange(0, y_max, 2).tolist(), expand=(0.12, 0))
+ coord_cartesian(ylim=(-2.5, y_max))
+ labs(x="", y="Ward Linkage Distance", title="Iris Species Clustering · dendrogram-basic · plotnine · pyplots.ai")
+ theme_minimal()
+ theme(
figure_size=(16, 9),
text=element_text(size=14),
axis_title=element_text(size=20),
axis_title_x=element_blank(),
axis_title_y=element_text(size=20),
axis_text=element_text(size=16),
axis_text_x=element_blank(),
axis_ticks_major_x=element_blank(),
plot_title=element_text(size=24),
panel_grid_major_x=element_blank(),
panel_grid_minor_x=element_blank(),
panel_grid_major_y=element_line(alpha=0.3, linetype="dashed"),
panel_grid_minor_y=element_blank(),
panel_grid_major_y=element_line(alpha=0.15, size=0.4),
)
)

plot.save("plot.png", dpi=300)
# Save with tight layout
fig = plot.draw()
fig.savefig("plot.png", dpi=300, bbox_inches="tight")
Loading
Loading