Skip to content

Commit b603e60

Browse files
feat(plotnine): implement sankey-basic (#5609)
## Implementation: `sankey-basic` - python/plotnine Implements the **python/plotnine** version of `sankey-basic`. **File:** `plots/sankey-basic/implementations/python/plotnine.py` **Parent Issue:** #810 --- :robot: *[impl-generate workflow](https://github.com/MarkusNeusinger/anyplot/actions/runs/25156526692)* --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Markus Neusinger <2921697+MarkusNeusinger@users.noreply.github.com>
1 parent 92946f4 commit b603e60

2 files changed

Lines changed: 199 additions & 185 deletions

File tree

plots/sankey-basic/implementations/python/plotnine.py

Lines changed: 37 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
1-
""" pyplots.ai
1+
""" anyplot.ai
22
sankey-basic: Basic Sankey Diagram
3-
Library: plotnine 0.15.2 | Python 3.13.11
4-
Quality: 91/100 | Created: 2025-12-23
3+
Library: plotnine 0.15.3 | Python 3.13.13
4+
Quality: 84/100 | Updated: 2026-04-30
55
"""
66

7+
import os
78
import sys
89

910

10-
# Prevent current directory from shadowing the plotnine package
11-
sys.path = [p for p in sys.path if not p.endswith("implementations")]
11+
sys.path = [p for p in sys.path if os.path.abspath(p) != os.path.dirname(os.path.abspath(__file__))]
1212

13-
import numpy as np # noqa: E402
14-
import pandas as pd # noqa: E402
15-
from plotnine import ( # noqa: E402
13+
import numpy as np
14+
import pandas as pd
15+
from plotnine import (
1616
aes,
1717
annotate,
1818
coord_cartesian,
1919
element_blank,
20+
element_rect,
2021
element_text,
2122
geom_polygon,
2223
geom_rect,
@@ -29,6 +30,13 @@
2930
)
3031

3132

33+
THEME = os.getenv("ANYPLOT_THEME", "light")
34+
PAGE_BG = "#FAF8F1" if THEME == "light" else "#1A1A17"
35+
INK = "#1A1A17" if THEME == "light" else "#F0EFE8"
36+
INK_SOFT = "#4A4A44" if THEME == "light" else "#B8B7B0"
37+
38+
OKABE_ITO = ["#009E73", "#D55E00", "#0072B2", "#CC79A7"]
39+
3240
# Data - Energy flow from sources to sectors
3341
flows = pd.DataFrame(
3442
{
@@ -91,11 +99,11 @@
9199
}
92100
current_y = current_y - height - node_gap
93101

94-
# Color map for nodes - sources get their own colors, targets get distinct colors
95-
source_colors = {"Coal": "#306998", "Gas": "#FFD43B", "Nuclear": "#4ECDC4", "Renewables": "#2ECC71"}
96-
target_colors = {"Industrial": "#E74C3C", "Commercial": "#9B59B6", "Residential": "#F39C12"}
102+
# Okabe-Ito colors for sources; theme-adaptive neutral for targets
103+
source_colors_map = {"Coal": OKABE_ITO[0], "Gas": OKABE_ITO[1], "Nuclear": OKABE_ITO[2], "Renewables": OKABE_ITO[3]}
104+
target_colors_map = {"Industrial": INK_SOFT, "Commercial": INK_SOFT, "Residential": INK_SOFT}
97105

98-
# Build node rectangles dataframe - each node gets its own color
106+
# Build node rectangles dataframe
99107
node_data = []
100108
for src in sources:
101109
pos = source_positions[src]
@@ -131,57 +139,48 @@
131139

132140
nodes_df = pd.DataFrame(node_data)
133141

134-
135142
# Build flow polygons (curved paths between nodes)
136143
flow_polygons = []
137-
flow_labels = [] # Track center positions for flow value labels
144+
flow_labels = []
138145
for _, row in flows.iterrows():
139146
src = row["source"]
140147
tgt = row["target"]
141148
val = row["value"]
142149

143-
# Calculate flow thickness
144150
flow_height = val / total_flow * 0.8
145151

146-
# Source connection point
147152
src_pos = source_positions[src]
148153
src_y_top = src_pos["y_top"] - src_pos["flow_offset"]
149154
src_y_bottom = src_y_top - flow_height
150155
src_pos["flow_offset"] += flow_height
151156

152-
# Target connection point
153157
tgt_pos = target_positions[tgt]
154158
tgt_y_top = tgt_pos["y_top"] - tgt_pos["flow_offset"]
155159
tgt_y_bottom = tgt_y_top - flow_height
156160
tgt_pos["flow_offset"] += flow_height
157161

158-
# Create curved flow polygon using smooth interpolation
162+
# Smooth cubic Hermite interpolation for flow curves
159163
flow_x_left = x_left + node_width
160164
flow_x_right = x_right - node_width
161165
n_points = 50
162166

163-
# Top edge (left to right)
164167
t = np.linspace(0, 1, n_points)
165168
x_top = flow_x_left + (flow_x_right - flow_x_left) * t
166169
y_top = src_y_top + (tgt_y_top - src_y_top) * (3 * t**2 - 2 * t**3)
167170

168-
# Bottom edge (right to left)
169171
x_bottom = flow_x_right + (flow_x_left - flow_x_right) * t
170172
y_bottom = tgt_y_bottom + (src_y_bottom - tgt_y_bottom) * (3 * t**2 - 2 * t**3)
171173

172-
# Combine into polygon
173174
x_polygon = np.concatenate([x_top, x_bottom])
174175
y_polygon = np.concatenate([y_top, y_bottom])
175176

176177
for i in range(len(x_polygon)):
177178
flow_polygons.append({"x": x_polygon[i], "y": y_polygon[i], "flow_id": f"{src}_{tgt}", "source": src})
178179

179-
# Store center position for flow label (slightly offset from center based on source index)
180-
# This helps avoid label overlap when flows cross
181180
mid_idx = n_points // 2
182181
flow_center_y = (y_top[mid_idx] + y_bottom[n_points - 1 - mid_idx]) / 2
183182
src_idx = sources.index(src)
184-
label_x_offset = 0.35 + src_idx * 0.1 # Stagger labels across flow width
183+
label_x_offset = 0.35 + src_idx * 0.1
185184
flow_labels.append({"x": label_x_offset, "y": flow_center_y, "value": str(val), "flow_height": flow_height})
186185

187186
flows_df = pd.DataFrame(flow_polygons)
@@ -192,7 +191,7 @@
192191
ggplot()
193192
# Flow polygons with transparency
194193
+ geom_polygon(flows_df, aes(x="x", y="y", group="flow_id", fill="source"), alpha=0.5)
195-
# Node rectangles with individual colors
194+
# Node rectangles
196195
+ geom_rect(
197196
nodes_df, aes(xmin="xmin", xmax="xmax", ymin="ymin", ymax="ymax", fill="node_color"), color="white", size=0.5
198197
)
@@ -203,56 +202,43 @@
203202
ha="center",
204203
va="center",
205204
size=11,
206-
color="#333333",
205+
color=INK,
207206
fontweight="bold",
208207
)
209-
# Source labels (right-aligned) - increased font size
208+
# Source labels (right-aligned)
210209
+ geom_text(
211210
nodes_df[nodes_df["side"] == "source"],
212211
aes(x="label_x", y="label_y", label="name"),
213212
ha="right",
214213
size=16,
215-
color="#333333",
214+
color=INK,
216215
fontweight="bold",
217216
)
218-
# Target labels (left-aligned) - increased font size
217+
# Target labels (left-aligned)
219218
+ geom_text(
220219
nodes_df[nodes_df["side"] == "target"],
221220
aes(x="label_x", y="label_y", label="name"),
222221
ha="left",
223222
size=16,
224-
color="#333333",
223+
color=INK,
225224
fontweight="bold",
226225
)
227-
# Color scales - sources for flows, all nodes for rectangles
228-
+ scale_fill_manual(
229-
values={
230-
# Source colors (for flows and source nodes)
231-
"Coal": "#306998",
232-
"Gas": "#FFD43B",
233-
"Nuclear": "#4ECDC4",
234-
"Renewables": "#2ECC71",
235-
# Target colors (for target nodes)
236-
"Industrial": "#E74C3C",
237-
"Commercial": "#9B59B6",
238-
"Residential": "#F39C12",
239-
}
240-
)
241-
+ labs(title="Energy Flow · sankey-basic · plotnine · pyplots.ai", x="", y="")
226+
+ scale_fill_manual(values={**source_colors_map, **target_colors_map})
227+
+ labs(title="Energy Flow · sankey-basic · plotnine · anyplot.ai", x="", y="")
242228
+ coord_cartesian(xlim=(-0.05, 1.1))
243229
+ theme_minimal()
244230
+ theme(
245231
figure_size=(16, 9),
246-
plot_title=element_text(size=24, ha="center", weight="bold"),
232+
plot_background=element_rect(fill=PAGE_BG, color=PAGE_BG),
233+
panel_background=element_rect(fill=PAGE_BG),
234+
plot_title=element_text(size=24, ha="center", weight="bold", color=INK),
247235
axis_text=element_blank(),
248236
axis_ticks=element_blank(),
249237
panel_grid=element_blank(),
250238
legend_position="none",
251239
)
252-
+ annotate("text", x=x_left + node_width / 2, y=-0.05, label="Sources", size=16, color="#555555", fontweight="bold")
253-
+ annotate(
254-
"text", x=x_right - node_width / 2, y=-0.05, label="Sectors", size=16, color="#555555", fontweight="bold"
255-
)
240+
+ annotate("text", x=x_left + node_width / 2, y=-0.05, label="Sources", size=16, color=INK_SOFT, fontweight="bold")
241+
+ annotate("text", x=x_right - node_width / 2, y=-0.05, label="Sectors", size=16, color=INK_SOFT, fontweight="bold")
256242
)
257243

258-
plot.save("plot.png", dpi=300, verbose=False)
244+
plot.save(f"plot-{THEME}.png", dpi=300, verbose=False)

0 commit comments

Comments
 (0)