|
| 1 | +""" anyplot.ai |
| 2 | +line-loss-training: Training Loss Curve |
| 3 | +Library: bokeh 3.9.0 | Python 3.13.13 |
| 4 | +Quality: 95/100 | Created: 2026-05-14 |
| 5 | +""" |
| 6 | + |
| 7 | +import os |
| 8 | +import sys |
| 9 | +import time |
| 10 | +from pathlib import Path |
| 11 | + |
| 12 | + |
| 13 | +# Prevent script name from shadowing bokeh module |
| 14 | +script_dir = str(Path(__file__).parent) |
| 15 | +if script_dir in sys.path: |
| 16 | + sys.path.remove(script_dir) |
| 17 | + |
| 18 | +import numpy as np # noqa: E402 |
| 19 | +import pandas as pd # noqa: E402 |
| 20 | +from bokeh.io import output_file, save # noqa: E402 |
| 21 | +from bokeh.models import ColumnDataSource, HoverTool, Label, Quad # noqa: E402 |
| 22 | +from bokeh.plotting import figure # noqa: E402 |
| 23 | +from selenium import webdriver # noqa: E402 |
| 24 | +from selenium.webdriver.chrome.options import Options # noqa: E402 |
| 25 | + |
| 26 | + |
| 27 | +# Theme tokens (see prompts/default-style-guide.md) |
| 28 | +THEME = os.getenv("ANYPLOT_THEME", "light") |
| 29 | +PAGE_BG = "#FAF8F1" if THEME == "light" else "#1A1A17" |
| 30 | +ELEVATED_BG = "#FFFDF6" if THEME == "light" else "#242420" |
| 31 | +INK = "#1A1A17" if THEME == "light" else "#F0EFE8" |
| 32 | +INK_SOFT = "#4A4A44" if THEME == "light" else "#B8B7B0" |
| 33 | + |
| 34 | +# Okabe-Ito palette |
| 35 | +TRAIN_COLOR = "#009E73" # bluish green - first series (brand) |
| 36 | +VAL_COLOR = "#D55E00" # vermillion - second series |
| 37 | + |
| 38 | +# Generate realistic neural network training data |
| 39 | +np.random.seed(42) |
| 40 | +n_epochs = 150 |
| 41 | + |
| 42 | +# Training loss: smooth exponential decay with noise |
| 43 | +epochs = np.arange(1, n_epochs + 1) |
| 44 | +train_loss_base = 2.5 * np.exp(-0.015 * (epochs - 1)) + 0.15 |
| 45 | +train_loss = train_loss_base + np.random.normal(0, 0.02, n_epochs) |
| 46 | +train_loss = np.maximum(train_loss, 0.15) # ensure positive |
| 47 | + |
| 48 | +# Validation loss: slightly noisier, higher baseline, potential overfitting |
| 49 | +val_loss_base = 2.5 * np.exp(-0.012 * (epochs - 1)) + 0.2 |
| 50 | +val_loss = val_loss_base + np.random.normal(0, 0.035, n_epochs) |
| 51 | +# Add slight overfitting effect in later epochs |
| 52 | +val_loss[80:] += np.linspace(0, 0.08, n_epochs - 80) |
| 53 | +val_loss = np.maximum(val_loss, 0.18) |
| 54 | + |
| 55 | +# Find minimum validation loss epoch (for annotation) |
| 56 | +min_val_idx = np.argmin(val_loss) |
| 57 | +min_val_epoch = epochs[min_val_idx] |
| 58 | +min_val_loss = val_loss[min_val_idx] |
| 59 | + |
| 60 | +# Create DataFrame |
| 61 | +df = pd.DataFrame({"epoch": epochs, "train_loss": train_loss, "val_loss": val_loss}) |
| 62 | + |
| 63 | +# Create Bokeh figure |
| 64 | +title_text = "line-loss-training · bokeh · anyplot.ai" |
| 65 | +p = figure( |
| 66 | + width=4800, |
| 67 | + height=2700, |
| 68 | + title=title_text, |
| 69 | + x_axis_label="Epoch", |
| 70 | + y_axis_label="Loss (Cross-Entropy)", |
| 71 | + toolbar_location="right", |
| 72 | +) |
| 73 | + |
| 74 | +# Set up data sources |
| 75 | +train_source = ColumnDataSource(df[["epoch", "train_loss"]]) |
| 76 | +val_source = ColumnDataSource(df[["epoch", "val_loss"]]) |
| 77 | + |
| 78 | +# Create a shaded region to highlight potential overfitting area (after epoch 80) |
| 79 | +overfitting_start = 80 |
| 80 | +max_loss = max(df["val_loss"].max(), df["train_loss"].max()) |
| 81 | +overfitting_quad = p.quad( |
| 82 | + left=[overfitting_start], |
| 83 | + right=[n_epochs], |
| 84 | + bottom=[0], |
| 85 | + top=[max_loss], |
| 86 | + fill_alpha=0.08, |
| 87 | + fill_color=VAL_COLOR, |
| 88 | + line_color=None, |
| 89 | + level="underlay", |
| 90 | +) |
| 91 | + |
| 92 | +# Plot lines |
| 93 | +train_line = p.line( |
| 94 | + x="epoch", |
| 95 | + y="train_loss", |
| 96 | + source=train_source, |
| 97 | + line_width=4, |
| 98 | + color=TRAIN_COLOR, |
| 99 | + legend_label="Training Loss", |
| 100 | + muted_color=TRAIN_COLOR, |
| 101 | + muted_alpha=0.15, |
| 102 | +) |
| 103 | + |
| 104 | +val_line = p.line( |
| 105 | + x="epoch", |
| 106 | + y="val_loss", |
| 107 | + source=val_source, |
| 108 | + line_width=4, |
| 109 | + color=VAL_COLOR, |
| 110 | + legend_label="Validation Loss", |
| 111 | + muted_color=VAL_COLOR, |
| 112 | + muted_alpha=0.15, |
| 113 | +) |
| 114 | + |
| 115 | +# Add circle markers at data points |
| 116 | +p.scatter( |
| 117 | + x="epoch", |
| 118 | + y="train_loss", |
| 119 | + source=train_source, |
| 120 | + size=5, |
| 121 | + color=TRAIN_COLOR, |
| 122 | + alpha=0.6, |
| 123 | + hover_color=TRAIN_COLOR, |
| 124 | + hover_alpha=1.0, |
| 125 | +) |
| 126 | + |
| 127 | +p.scatter( |
| 128 | + x="epoch", |
| 129 | + y="val_loss", |
| 130 | + source=val_source, |
| 131 | + size=5, |
| 132 | + color=VAL_COLOR, |
| 133 | + alpha=0.6, |
| 134 | + hover_color=VAL_COLOR, |
| 135 | + hover_alpha=1.0, |
| 136 | +) |
| 137 | + |
| 138 | +# Mark the epoch with minimum validation loss - larger marker for emphasis |
| 139 | +optimal_marker = p.scatter( |
| 140 | + x=[min_val_epoch], |
| 141 | + y=[min_val_loss], |
| 142 | + size=20, |
| 143 | + color=VAL_COLOR, |
| 144 | + line_color=INK, |
| 145 | + line_width=3, |
| 146 | + alpha=1.0, |
| 147 | + legend_label=f"Optimal epoch: {min_val_epoch}", |
| 148 | +) |
| 149 | + |
| 150 | +# Add annotation label at the optimal epoch |
| 151 | +label = Label( |
| 152 | + x=min_val_epoch, |
| 153 | + y=min_val_loss, |
| 154 | + text=f" Epoch {min_val_epoch}\n Loss {min_val_loss:.4f}", |
| 155 | + text_color=INK, |
| 156 | + text_font_size="14pt", |
| 157 | + text_baseline="middle", |
| 158 | + text_align="left", |
| 159 | +) |
| 160 | +p.add_layout(label) |
| 161 | + |
| 162 | +# Add detailed hover tool |
| 163 | +hover = HoverTool(tooltips=[("Epoch", "@epoch{0}"), ("Loss", "@y{0.0000}")], mode="vline") |
| 164 | +p.add_tools(hover) |
| 165 | + |
| 166 | +# Apply text sizing |
| 167 | +p.title.text_font_size = "28pt" |
| 168 | +p.xaxis.axis_label_text_font_size = "22pt" |
| 169 | +p.yaxis.axis_label_text_font_size = "22pt" |
| 170 | +p.xaxis.major_label_text_font_size = "18pt" |
| 171 | +p.yaxis.major_label_text_font_size = "18pt" |
| 172 | + |
| 173 | +# Apply theme-adaptive chrome colors |
| 174 | +p.background_fill_color = PAGE_BG |
| 175 | +p.border_fill_color = PAGE_BG |
| 176 | +p.outline_line_color = INK_SOFT |
| 177 | + |
| 178 | +p.title.text_color = INK |
| 179 | +p.xaxis.axis_label_text_color = INK |
| 180 | +p.yaxis.axis_label_text_color = INK |
| 181 | +p.xaxis.major_label_text_color = INK_SOFT |
| 182 | +p.yaxis.major_label_text_color = INK_SOFT |
| 183 | +p.xaxis.axis_line_color = INK_SOFT |
| 184 | +p.yaxis.axis_line_color = INK_SOFT |
| 185 | +p.xaxis.major_tick_line_color = INK_SOFT |
| 186 | +p.yaxis.major_tick_line_color = INK_SOFT |
| 187 | + |
| 188 | +# Y-axis grid (for line charts per style guide) |
| 189 | +p.ygrid.grid_line_color = INK |
| 190 | +p.ygrid.grid_line_alpha = 0.10 |
| 191 | +p.xgrid.grid_line_color = INK |
| 192 | +p.xgrid.grid_line_alpha = 0.05 |
| 193 | + |
| 194 | +# Configure legend |
| 195 | +p.legend.background_fill_color = ELEVATED_BG |
| 196 | +p.legend.border_line_color = INK_SOFT |
| 197 | +p.legend.label_text_color = INK_SOFT |
| 198 | +p.legend.location = "top_right" |
| 199 | +p.legend.click_policy = "mute" |
| 200 | +p.legend.label_text_font_size = "16pt" |
| 201 | + |
| 202 | +# Save HTML |
| 203 | +output_file(f"plot-{THEME}.html") |
| 204 | +save(p) |
| 205 | + |
| 206 | +# Screenshot with headless Chrome using Selenium |
| 207 | +W, H = 4800, 2700 |
| 208 | +opts = Options() |
| 209 | +for arg in ( |
| 210 | + "--headless=new", |
| 211 | + "--no-sandbox", |
| 212 | + "--disable-dev-shm-usage", |
| 213 | + "--disable-gpu", |
| 214 | + f"--window-size={W},{H}", |
| 215 | + "--hide-scrollbars", |
| 216 | +): |
| 217 | + opts.add_argument(arg) |
| 218 | + |
| 219 | +driver = webdriver.Chrome(options=opts) |
| 220 | +driver.set_window_size(W, H) |
| 221 | +driver.get(f"file://{Path(f'plot-{THEME}.html').resolve()}") |
| 222 | +time.sleep(3) # let bokeh's JS render the canvas |
| 223 | +driver.save_screenshot(f"plot-{THEME}.png") |
| 224 | +driver.quit() |
0 commit comments