Skip to content

Commit 43563ea

Browse files
feat(plotly): implement line-loss-training (#6650)
## Implementation: `line-loss-training` - python/plotly Implements the **python/plotly** version of `line-loss-training`. **File:** `plots/line-loss-training/implementations/python/plotly.py` **Parent Issue:** #2860 --- :robot: *[impl-generate workflow](https://github.com/MarkusNeusinger/anyplot/actions/runs/25843729734)* --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Markus Neusinger <2921697+MarkusNeusinger@users.noreply.github.com>
1 parent e467e9e commit 43563ea

2 files changed

Lines changed: 217 additions & 176 deletions

File tree

Lines changed: 75 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,46 @@
1-
""" pyplots.ai
1+
""" anyplot.ai
22
line-loss-training: Training Loss Curve
3-
Library: plotly 6.5.0 | Python 3.13.11
4-
Quality: 92/100 | Created: 2025-12-31
3+
Library: plotly 6.7.0 | Python 3.13.13
4+
Quality: 91/100 | Updated: 2026-05-14
55
"""
66

7+
import os
8+
79
import numpy as np
810
import plotly.graph_objects as go
911

1012

11-
# Data - Simulated neural network training history
13+
# Theme tokens
14+
THEME = os.getenv("ANYPLOT_THEME", "light")
15+
PAGE_BG = "#FAF8F1" if THEME == "light" else "#1A1A17"
16+
ELEVATED_BG = "#FFFDF6" if THEME == "light" else "#242420"
17+
INK = "#1A1A17" if THEME == "light" else "#F0EFE8"
18+
INK_SOFT = "#4A4A44" if THEME == "light" else "#B8B7B0"
19+
GRID = "rgba(26,26,23,0.10)" if THEME == "light" else "rgba(240,239,232,0.10)"
20+
21+
# Okabe-Ito palette
22+
TRAIN_COLOR = "#009E73" # Position 1 - bluish green (brand)
23+
VAL_COLOR = "#D55E00" # Position 2 - vermillion
24+
25+
# Data - Simulated neural network training with different trajectory
1226
np.random.seed(42)
13-
epochs = np.arange(1, 101)
27+
epochs = np.arange(1, 71) # 70 epochs (differentiate from 100-epoch letsplot)
1428

15-
# Training loss: starts high, decreases with noise, eventually plateaus
16-
train_loss = 2.5 * np.exp(-0.05 * epochs) + 0.15 + np.random.normal(0, 0.02, len(epochs))
17-
train_loss = np.maximum(train_loss, 0.1) # Ensure positive
29+
# Training loss: linear-like decay with small noise, flattens near end
30+
train_base = 2.0 - 0.025 * epochs + np.random.normal(0, 0.025, len(epochs))
31+
train_loss = np.maximum(train_base, 0.1)
1832

19-
# Validation loss: follows training initially, then diverges (overfitting after epoch ~60)
20-
val_loss = 2.5 * np.exp(-0.045 * epochs) + 0.25 + np.random.normal(0, 0.03, len(epochs))
21-
# Add overfitting effect: validation loss starts increasing after epoch 60
22-
overfitting_effect = np.where(epochs > 60, 0.008 * (epochs - 60), 0)
23-
val_loss = val_loss + overfitting_effect
33+
# Validation loss: similar pattern but with larger noise and divergence after epoch ~45
34+
val_base = 2.0 - 0.020 * epochs + np.random.normal(0, 0.04, len(epochs))
35+
# Add gentle divergence effect
36+
divergence_effect = np.where(epochs > 45, 0.015 * np.sqrt(np.maximum(epochs - 45, 0)), 0)
37+
val_loss = val_base + divergence_effect
2438
val_loss = np.maximum(val_loss, 0.15)
2539

26-
# Find minimum validation loss epoch for annotation
27-
min_val_epoch = epochs[np.argmin(val_loss)]
28-
min_val_loss = np.min(val_loss)
40+
# Find minimum validation loss epoch
41+
min_val_idx = np.argmin(val_loss)
42+
min_val_epoch = epochs[min_val_idx]
43+
min_val_loss = val_loss[min_val_idx]
2944

3045
# Create figure
3146
fig = go.Figure()
@@ -37,8 +52,8 @@
3752
y=train_loss,
3853
mode="lines",
3954
name="Training Loss",
40-
line=dict(color="#306998", width=3),
41-
hovertemplate="Epoch %{x}<br>Training Loss: %{y:.4f}<extra></extra>",
55+
line=dict(color=TRAIN_COLOR, width=4),
56+
hovertemplate="Epoch %{x}<br>Training Loss: %{y:.3f}<extra></extra>",
4257
)
4358
)
4459

@@ -49,57 +64,66 @@
4964
y=val_loss,
5065
mode="lines",
5166
name="Validation Loss",
52-
line=dict(color="#FFD43B", width=3),
53-
hovertemplate="Epoch %{x}<br>Validation Loss: %{y:.4f}<extra></extra>",
67+
line=dict(color=VAL_COLOR, width=4),
68+
hovertemplate="Epoch %{x}<br>Validation Loss: %{y:.3f}<extra></extra>",
5469
)
5570
)
5671

57-
# Mark minimum validation loss point
72+
# Optimal stopping point marker
5873
fig.add_trace(
5974
go.Scatter(
6075
x=[min_val_epoch],
6176
y=[min_val_loss],
62-
mode="markers+text",
63-
name="Best Epoch",
64-
marker=dict(color="#E74C3C", size=16, symbol="star"),
65-
text=[f"Best: Epoch {min_val_epoch}"],
66-
textposition="top center",
67-
textfont=dict(size=16, color="#E74C3C"),
68-
hovertemplate="Best Epoch: %{x}<br>Min Val Loss: %{y:.4f}<extra></extra>",
77+
mode="markers",
78+
name="Optimal Epoch",
79+
marker=dict(color=VAL_COLOR, size=20, symbol="diamond", line=dict(color=INK, width=2)),
80+
hovertemplate="Optimal Epoch: %{x}<br>Min Validation Loss: %{y:.3f}<extra></extra>",
6981
)
7082
)
7183

72-
# Update layout
84+
# Add vertical line at optimal epoch using shape
85+
fig.add_shape(
86+
type="line",
87+
x0=min_val_epoch,
88+
x1=min_val_epoch,
89+
y0=0,
90+
y1=max(train_loss.max(), val_loss.max()),
91+
line=dict(color=VAL_COLOR, width=1.5, dash="dash"),
92+
opacity=0.3,
93+
)
94+
95+
# Update layout with theme-adaptive styling
7396
fig.update_layout(
74-
title=dict(text="line-loss-training · plotly · pyplots.ai", font=dict(size=28), x=0.5, xanchor="center"),
97+
title=dict(text="line-loss-training · plotly · anyplot.ai", font=dict(size=28, color=INK), x=0.5, xanchor="center"),
7598
xaxis=dict(
76-
title=dict(text="Epoch", font=dict(size=22)),
77-
tickfont=dict(size=18),
78-
gridcolor="rgba(128, 128, 128, 0.3)",
99+
title=dict(text="Epoch", font=dict(size=22, color=INK)),
100+
tickfont=dict(size=18, color=INK_SOFT),
101+
gridcolor=GRID,
79102
gridwidth=1,
80-
showgrid=True,
81-
range=[0, 105],
103+
linecolor=INK_SOFT,
104+
linewidth=1.5,
105+
zerolinecolor=INK_SOFT,
106+
zerolinewidth=0,
82107
),
83108
yaxis=dict(
84-
title=dict(text="Cross-Entropy Loss", font=dict(size=22)),
85-
tickfont=dict(size=18),
86-
gridcolor="rgba(128, 128, 128, 0.3)",
109+
title=dict(text="Cross-Entropy Loss", font=dict(size=22, color=INK)),
110+
tickfont=dict(size=18, color=INK_SOFT),
111+
gridcolor=GRID,
87112
gridwidth=1,
88-
showgrid=True,
113+
linecolor=INK_SOFT,
114+
linewidth=1.5,
115+
zerolinecolor=INK_SOFT,
116+
zerolinewidth=0,
89117
),
90118
legend=dict(
91-
font=dict(size=18),
92-
x=0.75,
93-
y=0.95,
94-
bgcolor="rgba(255, 255, 255, 0.8)",
95-
bordercolor="rgba(128, 128, 128, 0.3)",
96-
borderwidth=1,
119+
font=dict(size=18, color=INK_SOFT), bgcolor=ELEVATED_BG, bordercolor=INK_SOFT, borderwidth=1.5, x=0.72, y=0.97
97120
),
98-
template="plotly_white",
99-
margin=dict(l=100, r=80, t=100, b=100),
100-
plot_bgcolor="white",
121+
paper_bgcolor=PAGE_BG,
122+
plot_bgcolor=PAGE_BG,
123+
margin=dict(l=120, r=100, t=110, b=110),
124+
hovermode="x unified",
101125
)
102126

103-
# Save as PNG and HTML
104-
fig.write_image("plot.png", width=1600, height=900, scale=3)
105-
fig.write_html("plot.html", include_plotlyjs="cdn")
127+
# Save outputs
128+
fig.write_image(f"plot-{THEME}.png", width=1600, height=900, scale=3)
129+
fig.write_html(f"plot-{THEME}.html", include_plotlyjs="cdn")

0 commit comments

Comments
 (0)