Skip to content

Commit fa67e35

Browse files
fix(plotnine): address review feedback for heatmap-basic
Attempt 3/3 - fixes based on AI review
1 parent ced8294 commit fa67e35

1 file changed

Lines changed: 30 additions & 33 deletions

File tree

plots/heatmap-basic/implementations/plotnine.py

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
""" pyplots.ai
1+
"""pyplots.ai
22
heatmap-basic: Basic Heatmap
33
Library: plotnine 0.15.3 | Python 3.14.3
4-
Quality: 86/100 | Updated: 2026-02-15
54
"""
65

76
import numpy as np
@@ -29,43 +28,41 @@
2928
departments = ["Engineering", "Marketing", "Sales", "Finance", "Operations", "HR", "Research", "Support"]
3029
quarters = ["Q1 '23", "Q2 '23", "Q3 '23", "Q4 '23", "Q1 '24", "Q2 '24", "Q3 '24", "Q4 '24"]
3130

32-
# Growth rates with a recovery trend and departmental variation
33-
# Include a wider range with a few extreme outliers for feature coverage
31+
# Vectorized growth rates with recovery trend and departmental variation
3432
base_trend = np.linspace(-18, 22, 8)
3533
dept_offsets = np.array([-6, 10, 14, -3, 4, -10, 8, -5])
36-
values = np.zeros((8, 8))
37-
for i in range(8):
38-
for j in range(8):
39-
values[i, j] = round(base_trend[j] + dept_offsets[i] + np.random.normal(0, 4), 1)
34+
values = np.round(base_trend[np.newaxis, :] + dept_offsets[:, np.newaxis] + np.random.normal(0, 4, (8, 8)), 1)
4035

41-
# Inject a few distinctive extreme values for storytelling
36+
# Inject distinctive extreme values for storytelling focal points
4237
values[5, 0] = -32.5 # HR deep crisis in Q1 '23
4338
values[2, 7] = 38.2 # Sales strong recovery in Q4 '24
4439
values[6, 6] = 33.7 # Research surge in Q3 '24
4540

46-
# Long-form DataFrame
47-
records = []
48-
for i, dept in enumerate(departments):
49-
for j, qtr in enumerate(quarters):
50-
records.append({"Department": dept, "Quarter": qtr, "Growth (%)": values[i, j]})
51-
52-
df = pd.DataFrame(records)
53-
df["Quarter"] = pd.Categorical(df["Quarter"], categories=quarters, ordered=True)
54-
df["Department"] = pd.Categorical(df["Department"], categories=departments[::-1], ordered=True)
41+
# Build long-form DataFrame via meshgrid indexing
42+
dept_idx, qtr_idx = np.meshgrid(np.arange(8), np.arange(8), indexing="ij")
43+
df = pd.DataFrame(
44+
{
45+
"Department": pd.Categorical(
46+
[departments[i] for i in dept_idx.ravel()], categories=departments[::-1], ordered=True
47+
),
48+
"Quarter": pd.Categorical([quarters[j] for j in qtr_idx.ravel()], categories=quarters, ordered=True),
49+
"Growth (%)": values.ravel(),
50+
}
51+
)
5552

56-
# Conditional text color for optimal contrast on all cell backgrounds
57-
df["text_color"] = df["Growth (%)"].apply(lambda v: "white" if v < -12 else ("#444444" if v < 18 else "#5a3e00"))
53+
# Conditional text color: white on dark blue, dark gray on mid, dark brown on gold
54+
df["text_color"] = np.where(df["Growth (%)"] < -12, "white", np.where(df["Growth (%)"] < 18, "#3a3a3a", "#4a2e00"))
5855

59-
# Format labels with sign
60-
df["label"] = df["Growth (%)"].apply(lambda v: f"{v:+.1f}")
56+
# Signed annotation labels
57+
df["label"] = [f"{v:+.1f}" for v in df["Growth (%)"]]
6158

6259
# Plot
6360
plot = (
6461
ggplot(df, aes(x="Quarter", y="Department"))
65-
+ geom_tile(aes(fill="Growth (%)"), color="white", size=1.0)
62+
+ geom_tile(aes(fill="Growth (%)"), color="white", size=1.2)
6663
+ geom_text(aes(label="label", color="text_color"), size=10, fontweight="bold", show_legend=False)
6764
+ scale_fill_gradient2(
68-
low="#1a4971", mid="#f0eeeb", high="#cc8400", midpoint=0, name="Growth (%)", limits=(-35, 40)
65+
low="#14405e", mid="#ede8e3", high="#c47d00", midpoint=0, name="Growth (%)", limits=(-35, 40)
6966
)
7067
+ scale_color_identity()
7168
+ scale_x_discrete(expand=(0, 0.5))
@@ -80,20 +77,20 @@
8077
+ theme(
8178
figure_size=(16, 9),
8279
text=element_text(family="sans-serif"),
83-
plot_title=element_text(size=22, ha="center", weight="bold", margin={"b": 4}),
84-
plot_subtitle=element_text(size=15, ha="center", color="#666666", margin={"b": 12}),
85-
axis_title_x=element_text(size=18, margin={"t": 10}),
86-
axis_title_y=element_text(size=18, margin={"r": 8}),
87-
axis_text_x=element_text(size=15, rotation=45, ha="right", margin={"t": 5}),
88-
axis_text_y=element_text(size=15, ha="right", margin={"r": 5}),
89-
legend_title=element_text(size=15, weight="bold"),
90-
legend_text=element_text(size=13),
80+
plot_title=element_text(size=24, ha="center", weight="bold", margin={"b": 2}),
81+
plot_subtitle=element_text(size=16, ha="center", color="#555555", margin={"b": 8}),
82+
axis_title_x=element_text(size=20, margin={"t": 10}),
83+
axis_title_y=element_text(size=20, margin={"r": 8}),
84+
axis_text_x=element_text(size=16, rotation=45, ha="right", margin={"t": 4}),
85+
axis_text_y=element_text(size=16, ha="right", margin={"r": 4}),
86+
legend_title=element_text(size=16, weight="bold"),
87+
legend_text=element_text(size=14),
9188
legend_position="right",
9289
legend_key_height=40,
9390
panel_grid_major=element_blank(),
9491
panel_grid_minor=element_blank(),
9592
panel_background=element_rect(fill="white", color="none"),
96-
plot_background=element_rect(fill="#fafafa", color="none"),
93+
plot_background=element_rect(fill="#f7f7f7", color="none"),
9794
plot_margin=0.02,
9895
)
9996
)

0 commit comments

Comments
 (0)