Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 42 additions & 59 deletions plots/sankey-basic/implementations/python/letsplot.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
""" pyplots.ai
""" anyplot.ai
sankey-basic: Basic Sankey Diagram
Library: letsplot 4.8.2 | Python 3.13.11
Quality: 92/100 | Created: 2025-12-23
Library: letsplot 4.9.0 | Python 3.13.13
Quality: 85/100 | Updated: 2026-04-30
"""

import os

import pandas as pd
from lets_plot import (
LetsPlot,
aes,
element_blank,
element_rect,
element_text,
geom_polygon,
geom_rect,
Expand All @@ -27,6 +30,16 @@

LetsPlot.setup_html()

# Theme tokens
THEME = os.getenv("ANYPLOT_THEME", "light")
PAGE_BG = "#FAF8F1" if THEME == "light" else "#1A1A17"
ELEVATED_BG = "#FFFDF6" if THEME == "light" else "#242420"
INK = "#1A1A17" if THEME == "light" else "#F0EFE8"
INK_SOFT = "#4A4A44" if THEME == "light" else "#B8B7B0"

# Okabe-Ito palette for source categories (canonical order, first = #009E73)
OKABE_ITO = ["#009E73", "#D55E00", "#0072B2", "#CC79A7"]

# Energy flow data: sources -> sectors (realistic energy distribution)
flows = [
("Coal", "Industrial", 28),
Expand All @@ -41,9 +54,9 @@
("Renewable", "Industrial", 6),
]

# Define node ordering
sources = ["Coal", "Natural Gas", "Nuclear", "Renewable"]
targets = ["Industrial", "Residential", "Commercial"]
source_color_map = dict(zip(sources, OKABE_ITO, strict=True))

# Calculate totals for each node
source_totals = {}
Expand All @@ -54,7 +67,7 @@
for _, tgt, val in flows:
target_totals[tgt] = target_totals.get(tgt, 0) + val

# Normalize positions
# Layout parameters
total_flow = sum(v for _, _, v in flows)
node_gap = 0.04
x_left = 0.18
Expand All @@ -80,43 +93,33 @@
source_offsets = dict.fromkeys(sources, 0)
target_offsets = dict.fromkeys(targets, 0)

# Build flow polygons with smooth bezier curves
# Build flow polygons with smooth cubic bezier curves
flow_data = []

for src, tgt, val in flows:
flow_height = val / total_flow * 0.85

# Source connection points
src_y0 = source_positions[src]["y0"] + source_offsets[src]
src_y1 = src_y0 + flow_height
source_offsets[src] += flow_height

# Target connection points
tgt_y0 = target_positions[tgt]["y0"] + target_offsets[tgt]
tgt_y1 = tgt_y0 + flow_height
target_offsets[tgt] += flow_height

# Create smooth bezier polygon for flow
n_points = 40
x_vals_top = []
y_vals_top = []
x_vals_bottom = []
y_vals_bottom = []
x_vals_top, y_vals_top = [], []
x_vals_bottom, y_vals_bottom = [], []

for i in range(n_points + 1):
t = i / n_points
x = x_left + t * (x_right - x_left)
# Smooth cubic bezier easing
ease = t * t * (3 - 2 * t)
y_top = src_y1 + ease * (tgt_y1 - src_y1)
y_bottom = src_y0 + ease * (tgt_y0 - src_y0)

x_vals_top.append(x)
y_vals_top.append(y_top)
y_vals_top.append(src_y1 + ease * (tgt_y1 - src_y1))
x_vals_bottom.append(x)
y_vals_bottom.append(y_bottom)
y_vals_bottom.append(src_y0 + ease * (tgt_y0 - src_y0))

# Combine into closed polygon
x_polygon = x_vals_top + x_vals_bottom[::-1]
y_polygon = y_vals_top + y_vals_bottom[::-1]

Expand All @@ -132,32 +135,18 @@
for src in sources:
pos = source_positions[src]
node_rects.append(
{
"xmin": pos["x"] - node_width / 2,
"xmax": pos["x"] + node_width / 2,
"ymin": pos["y0"],
"ymax": pos["y1"],
"label": src,
"side": "source",
}
{"xmin": pos["x"] - node_width / 2, "xmax": pos["x"] + node_width / 2, "ymin": pos["y0"], "ymax": pos["y1"]}
)

for tgt in targets:
pos = target_positions[tgt]
node_rects.append(
{
"xmin": pos["x"] - node_width / 2,
"xmax": pos["x"] + node_width / 2,
"ymin": pos["y0"],
"ymax": pos["y1"],
"label": tgt,
"side": "target",
}
{"xmin": pos["x"] - node_width / 2, "xmax": pos["x"] + node_width / 2, "ymin": pos["y0"], "ymax": pos["y1"]}
)

df_nodes = pd.DataFrame(node_rects)

# Build labels with flow values
# Build labels with flow totals
labels = []
for src in sources:
pos = source_positions[src]
Expand All @@ -183,56 +172,50 @@

df_labels = pd.DataFrame(labels)

# Colors for each energy source
source_colors = {"Coal": "#4A4A4A", "Natural Gas": "#306998", "Nuclear": "#9B59B6", "Renewable": "#27AE60"}

# Create the plot
# Plot
plot = (
ggplot()
+ geom_polygon(
aes(x="x", y="y", group="flow_id", fill="source"), data=df_flows, alpha=0.65, color="white", size=0.2
)
+ geom_rect(
aes(xmin="xmin", xmax="xmax", ymin="ymin", ymax="ymax"),
data=df_nodes,
fill="#2C3E50",
color="#1A252F",
size=1.5,
aes(x="x", y="y", group="flow_id", fill="source"), data=df_flows, alpha=0.65, color=PAGE_BG, size=0.2
)
+ geom_rect(aes(xmin="xmin", xmax="xmax", ymin="ymin", ymax="ymax"), data=df_nodes, fill=INK, color=INK, size=1.5)
+ geom_text(
aes(x="x", y="y", label="label"),
data=df_labels[df_labels["side"] == "left"],
size=14,
hjust=1,
color=INK_SOFT,
family="sans-serif",
)
+ geom_text(
aes(x="x", y="y", label="label"),
data=df_labels[df_labels["side"] == "right"],
size=14,
hjust=0,
color=INK_SOFT,
family="sans-serif",
)
+ scale_fill_manual(values=[source_colors[s] for s in sources], name="Energy Source")
+ labs(title="Energy Flow · sankey-basic · letsplot · pyplots.ai")
+ scale_fill_manual(values=[source_color_map[s] for s in sources], name="Energy Source")
+ labs(title="Energy Flow · sankey-basic · letsplot · anyplot.ai")
+ theme_minimal()
+ theme(
plot_title=element_text(size=30, face="bold"),
plot_background=element_rect(fill=PAGE_BG, color=PAGE_BG),
panel_background=element_rect(fill=PAGE_BG),
plot_title=element_text(size=30, face="bold", color=INK),
axis_title=element_blank(),
axis_text=element_blank(),
axis_ticks=element_blank(),
panel_grid=element_blank(),
legend_text=element_text(size=18),
legend_title=element_text(size=20, face="bold"),
legend_text=element_text(size=18, color=INK_SOFT),
legend_title=element_text(size=20, face="bold", color=INK),
legend_position="bottom",
legend_background=element_rect(fill=ELEVATED_BG, color=INK_SOFT),
)
+ scale_x_continuous(limits=[-0.02, 1.02])
+ scale_y_continuous(limits=[-0.02, 1.02])
+ ggsize(1600, 900)
)

# Save as PNG (scale 3x for 4800 × 2700 px)
ggsave(plot, "plot.png", path=".", scale=3)

# Save as HTML for interactivity
ggsave(plot, "plot.html", path=".")
# Save PNG (scale 3x for 4800 × 2700 px) and HTML
ggsave(plot, f"plot-{THEME}.png", path=".", scale=3)
ggsave(plot, f"plot-{THEME}.html", path=".")
Loading
Loading