Skip to content

Commit 7062d53

Browse files
feat(seaborn): implement sankey-basic (#5604)
## Implementation: `sankey-basic` - python/seaborn Implements the **python/seaborn** version of `sankey-basic`. **File:** `plots/sankey-basic/implementations/python/seaborn.py` **Parent Issue:** #810 --- :robot: *[impl-generate workflow](https://github.com/MarkusNeusinger/anyplot/actions/runs/25156180870)* --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: Markus Neusinger <2921697+MarkusNeusinger@users.noreply.github.com>
1 parent 8eba8c2 commit 7062d53

2 files changed

Lines changed: 319 additions & 328 deletions

File tree

Lines changed: 144 additions & 186 deletions
Original file line numberDiff line numberDiff line change
@@ -1,215 +1,173 @@
1-
""" pyplots.ai
1+
""" anyplot.ai
22
sankey-basic: Basic Sankey Diagram
3-
Library: seaborn 0.13.2 | Python 3.13.11
4-
Quality: 78/100 | Created: 2025-12-23
3+
Library: seaborn 0.13.2 | Python 3.13.13
4+
Quality: 88/100 | Updated: 2026-04-30
55
"""
66

7-
import matplotlib.patches as patches
7+
import os
8+
9+
import matplotlib.patches as mpatches
810
import matplotlib.pyplot as plt
911
import numpy as np
1012
import pandas as pd
1113
import seaborn as sns
1214

1315

14-
# Set seed for reproducibility
15-
np.random.seed(42)
16-
17-
# Apply seaborn styling
18-
sns.set_theme(style="white", context="talk", font_scale=1.2)
19-
20-
# Data - Energy flow from sources to sectors (in TWh)
21-
flows_data = {
22-
"source": ["Coal", "Coal", "Coal", "Gas", "Gas", "Gas", "Nuclear", "Nuclear", "Nuclear"],
23-
"target": [
24-
"Residential",
25-
"Commercial",
26-
"Industrial",
27-
"Residential",
28-
"Commercial",
29-
"Industrial",
30-
"Residential",
31-
"Commercial",
32-
"Industrial",
33-
],
34-
"value": [15, 12, 33, 20, 18, 22, 15, 15, 15],
35-
}
36-
df = pd.DataFrame(flows_data)
37-
38-
# Create figure with seaborn styling
39-
fig, ax = plt.subplots(figsize=(16, 9))
40-
41-
# Use seaborn color palettes - distinct colors for sources and targets
42-
source_names = df["source"].unique()
43-
target_names = df["target"].unique()
44-
source_palette = sns.color_palette("husl", n_colors=len(source_names))
45-
target_palette = sns.color_palette("Set2", n_colors=len(target_names))
46-
source_colors = dict(zip(source_names, source_palette, strict=True))
47-
target_colors = dict(zip(target_names, target_palette, strict=True))
48-
49-
# Calculate node totals
50-
sources = df.groupby("source")["value"].sum().sort_values(ascending=False)
51-
targets = df.groupby("target")["value"].sum().sort_values(ascending=False)
52-
53-
# Node dimensions and positions
54-
node_width = 0.06
55-
x_source = 0.12
56-
x_target = 0.88
57-
gap = 0.025
58-
total_height = 0.60 # Reduced to leave room for legends at bottom
59-
60-
# Calculate source node positions (left side)
61-
total_source = sources.sum()
62-
source_positions = {}
63-
y_pos = 0.92
64-
for source, value in sources.items():
65-
height = (value / total_source) * total_height
66-
source_positions[source] = {"y": y_pos - height, "height": height}
67-
y_pos -= height + gap
68-
69-
# Calculate target node positions (right side)
70-
total_target = targets.sum()
71-
target_positions = {}
72-
y_pos = 0.92
73-
for target, value in targets.items():
74-
height = (value / total_target) * total_height
75-
target_positions[target] = {"y": y_pos - height, "height": height}
76-
y_pos -= height + gap
77-
78-
# Track current position for stacking flows at each node
79-
source_current_y = {s: source_positions[s]["y"] + source_positions[s]["height"] for s in sources.index}
80-
target_current_y = {t: target_positions[t]["y"] + target_positions[t]["height"] for t in targets.index}
81-
82-
# Bezier curve parameters
83-
n_points = 100
84-
t = np.linspace(0, 1, n_points)
85-
86-
# Sort flows by source then by value for consistent stacking
87-
df_sorted = df.sort_values(["source", "value"], ascending=[True, False])
88-
89-
# Draw flows with widths proportional to values
16+
# Theme tokens
17+
THEME = os.getenv("ANYPLOT_THEME", "light")
18+
PAGE_BG = "#FAF8F1" if THEME == "light" else "#1A1A17"
19+
INK = "#1A1A17" if THEME == "light" else "#F0EFE8"
20+
21+
OKABE_ITO = ["#009E73", "#D55E00", "#0072B2", "#CC79A7", "#E69F00", "#56B4E9"]
22+
23+
sns.set_theme(style="white", rc={"figure.facecolor": PAGE_BG, "axes.facecolor": PAGE_BG, "text.color": INK})
24+
25+
# Data — energy flows in TWh (varied magnitudes for clear proportional scaling)
26+
source_names = ["Gas", "Coal", "Nuclear"]
27+
target_names = ["Residential", "Industrial", "Commercial"]
28+
flows = [
29+
("Gas", "Residential", 50),
30+
("Gas", "Industrial", 30),
31+
("Gas", "Commercial", 40),
32+
("Coal", "Industrial", 45),
33+
("Coal", "Residential", 20),
34+
("Coal", "Commercial", 15),
35+
("Nuclear", "Residential", 25),
36+
("Nuclear", "Industrial", 10),
37+
("Nuclear", "Commercial", 10),
38+
]
39+
df = pd.DataFrame(flows, columns=["source", "target", "value"])
40+
41+
source_colors = dict(zip(source_names, OKABE_ITO[:3], strict=True))
42+
target_colors = dict(zip(target_names, OKABE_ITO[3:6], strict=True))
43+
44+
sources = df.groupby("source")["value"].sum().loc[source_names]
45+
targets = df.groupby("target")["value"].sum().loc[target_names]
46+
47+
# Layout
48+
NODE_W = 0.055
49+
X_LEFT, X_RIGHT = 0.13, 0.87
50+
GAP = 0.022
51+
TOTAL_H = 0.72
52+
Y_START = 0.85
53+
54+
source_pos = {}
55+
y = Y_START
56+
for name in source_names:
57+
h = (sources[name] / sources.sum()) * TOTAL_H
58+
source_pos[name] = {"y": y - h, "h": h}
59+
y -= h + GAP
60+
61+
target_pos = {}
62+
y = Y_START
63+
for name in target_names:
64+
h = (targets[name] / targets.sum()) * TOTAL_H
65+
target_pos[name] = {"y": y - h, "h": h}
66+
y -= h + GAP
67+
68+
src_y = {n: source_pos[n]["y"] + source_pos[n]["h"] for n in source_names}
69+
tgt_y = {n: target_pos[n]["y"] + target_pos[n]["h"] for n in target_names}
70+
71+
# Figure
72+
fig, ax = plt.subplots(figsize=(16, 9), facecolor=PAGE_BG)
73+
ax.set_facecolor(PAGE_BG)
74+
75+
t = np.linspace(0, 1, 120)
76+
s = t**2 * (3 - 2 * t) # smoothstep: zero tangents at both endpoints
77+
78+
# Sort flows by source order then target order to minimise crossings
79+
src_ord = {n: i for i, n in enumerate(source_names)}
80+
tgt_ord = {n: i for i, n in enumerate(target_names)}
81+
df["_si"] = df["source"].map(src_ord)
82+
df["_ti"] = df["target"].map(tgt_ord)
83+
df_sorted = df.sort_values(["_si", "_ti"])
84+
85+
# Draw flows
9086
for _, row in df_sorted.iterrows():
91-
source = row["source"]
92-
target = row["target"]
93-
value = row["value"]
94-
color = source_colors[source]
95-
96-
# Calculate band height proportional to flow value
97-
source_band_height = (value / sources[source]) * source_positions[source]["height"]
98-
target_band_height = (value / targets[target]) * target_positions[target]["height"]
99-
100-
# Source side coordinates
101-
y0_top = source_current_y[source]
102-
y0_bot = y0_top - source_band_height
103-
source_current_y[source] = y0_bot
104-
105-
# Target side coordinates
106-
y1_top = target_current_y[target]
107-
y1_bot = y1_top - target_band_height
108-
target_current_y[target] = y1_bot
109-
110-
# Draw the flow band using cubic bezier curves
111-
x0 = x_source + node_width
112-
x1 = x_target
113-
cx0 = x0 + (x1 - x0) * 0.35
114-
cx1 = x0 + (x1 - x0) * 0.65
115-
116-
# Generate bezier curve points for top and bottom edges
117-
top_x = (1 - t) ** 3 * x0 + 3 * (1 - t) ** 2 * t * cx0 + 3 * (1 - t) * t**2 * cx1 + t**3 * x1
118-
top_y = (1 - t) ** 3 * y0_top + 3 * (1 - t) ** 2 * t * y0_top + 3 * (1 - t) * t**2 * y1_top + t**3 * y1_top
119-
bot_y = (1 - t) ** 3 * y0_bot + 3 * (1 - t) ** 2 * t * y0_bot + 3 * (1 - t) * t**2 * y1_bot + t**3 * y1_bot
120-
121-
# Draw flow band
122-
ax.fill_between(top_x, bot_y, top_y, color=color, alpha=0.65, linewidth=0, edgecolor="none")
123-
124-
# Draw source nodes (left) with seaborn colors
125-
for source in sources.index:
126-
pos = source_positions[source]
127-
rect = patches.FancyBboxPatch(
128-
(x_source, pos["y"]),
129-
node_width,
130-
pos["height"],
131-
boxstyle="round,pad=0.005,rounding_size=0.015",
132-
facecolor=source_colors[source],
133-
edgecolor="white",
134-
linewidth=2.5,
87+
src, tgt, val = row["source"], row["target"], row["value"]
88+
bh_src = (val / sources[src]) * source_pos[src]["h"]
89+
bh_tgt = (val / targets[tgt]) * target_pos[tgt]["h"]
90+
91+
y0t, y0b = src_y[src], src_y[src] - bh_src
92+
src_y[src] = y0b
93+
y1t, y1b = tgt_y[tgt], tgt_y[tgt] - bh_tgt
94+
tgt_y[tgt] = y1b
95+
96+
x0, x1 = X_LEFT + NODE_W, X_RIGHT
97+
cx0, cx1 = x0 + (x1 - x0) * 0.35, x0 + (x1 - x0) * 0.65
98+
xs = (1 - t) ** 3 * x0 + 3 * (1 - t) ** 2 * t * cx0 + 3 * (1 - t) * t**2 * cx1 + t**3 * x1
99+
100+
# Gas (dominant source) rendered with heavier alpha for visual emphasis
101+
flow_alpha = 0.68 if src == "Gas" else 0.44
102+
ax.fill_between(
103+
xs, y0b + (y1b - y0b) * s, y0t + (y1t - y0t) * s, color=source_colors[src], alpha=flow_alpha, linewidth=0
104+
)
105+
106+
# Draw source nodes and labels
107+
for name in source_names:
108+
pos = source_pos[name]
109+
ax.add_patch(
110+
mpatches.FancyBboxPatch(
111+
(X_LEFT, pos["y"]),
112+
NODE_W,
113+
pos["h"],
114+
boxstyle="round,pad=0.005,rounding_size=0.015",
115+
facecolor=source_colors[name],
116+
edgecolor=PAGE_BG,
117+
linewidth=2,
118+
)
135119
)
136-
ax.add_patch(rect)
137120
ax.text(
138-
x_source - 0.015,
139-
pos["y"] + pos["height"] / 2,
140-
f"{source}\n{sources[source]:.0f} TWh",
121+
X_LEFT - 0.015,
122+
pos["y"] + pos["h"] / 2,
123+
f"{name}\n{sources[name]:.0f} TWh",
141124
ha="right",
142125
va="center",
143-
fontsize=18,
126+
fontsize=20,
144127
fontweight="bold",
145-
color="#2d2d2d",
128+
color=INK,
146129
)
147130

148-
# Draw target nodes (right) with distinct colors from Set2 palette
149-
for target in targets.index:
150-
pos = target_positions[target]
151-
rect = patches.FancyBboxPatch(
152-
(x_target, pos["y"]),
153-
node_width,
154-
pos["height"],
155-
boxstyle="round,pad=0.005,rounding_size=0.015",
156-
facecolor=target_colors[target],
157-
edgecolor="white",
158-
linewidth=2.5,
131+
# Draw target nodes and labels
132+
for name in target_names:
133+
pos = target_pos[name]
134+
ax.add_patch(
135+
mpatches.FancyBboxPatch(
136+
(X_RIGHT, pos["y"]),
137+
NODE_W,
138+
pos["h"],
139+
boxstyle="round,pad=0.005,rounding_size=0.015",
140+
facecolor=target_colors[name],
141+
edgecolor=PAGE_BG,
142+
linewidth=2,
143+
)
159144
)
160-
ax.add_patch(rect)
161145
ax.text(
162-
x_target + node_width + 0.015,
163-
pos["y"] + pos["height"] / 2,
164-
f"{target}\n{targets[target]:.0f} TWh",
146+
X_RIGHT + NODE_W + 0.015,
147+
pos["y"] + pos["h"] / 2,
148+
f"{name}\n{targets[name]:.0f} TWh",
165149
ha="left",
166150
va="center",
167-
fontsize=18,
151+
fontsize=20,
168152
fontweight="bold",
169-
color="#2d2d2d",
153+
color=INK,
170154
)
171155

172-
# Create legend using simple patches for sources and targets
173-
source_handles = [
174-
patches.Patch(facecolor=source_colors[s], edgecolor="white", linewidth=1.5, label=s) for s in source_names
175-
]
176-
target_handles = [
177-
patches.Patch(facecolor=target_colors[t], edgecolor="white", linewidth=1.5, label=t) for t in target_names
178-
]
179-
180-
# Add source legend on the left
181-
source_legend = ax.legend(
182-
handles=source_handles,
183-
title="Energy Sources",
184-
loc="lower left",
185-
bbox_to_anchor=(0.02, 0.02),
186-
fontsize=14,
187-
title_fontsize=16,
188-
frameon=True,
189-
fancybox=True,
190-
edgecolor="#cccccc",
191-
)
192-
193-
# Add target legend on the right
194-
ax.add_artist(source_legend)
195-
ax.legend(
196-
handles=target_handles,
197-
title="Sectors",
198-
loc="lower right",
199-
bbox_to_anchor=(0.98, 0.02),
200-
fontsize=14,
201-
title_fontsize=16,
202-
frameon=True,
203-
fancybox=True,
204-
edgecolor="#cccccc",
156+
ax.set_title("sankey-basic · seaborn · anyplot.ai", fontsize=24, fontweight="medium", color=INK, pad=20)
157+
# Subtitle highlighting key insight: Gas is the dominant source (49% of total)
158+
ax.text(
159+
0.5,
160+
0.93,
161+
"Gas supplies 49 % of total energy — the dominant source",
162+
ha="center",
163+
va="center",
164+
fontsize=16,
165+
color=source_colors["Gas"],
166+
fontstyle="italic",
205167
)
206-
207-
# Set title using the required format
208-
ax.set_title("sankey-basic · seaborn · pyplots.ai", fontsize=26, fontweight="bold", pad=25)
209-
210-
# Set axis limits and remove decorations
211168
ax.set_xlim(0, 1)
212169
ax.set_ylim(0, 1)
213170
ax.axis("off")
214171

215-
plt.savefig("plot.png", dpi=300, bbox_inches="tight", facecolor="white")
172+
plt.tight_layout()
173+
plt.savefig(f"plot-{THEME}.png", dpi=300, bbox_inches="tight", facecolor=PAGE_BG)

0 commit comments

Comments
 (0)