|
| 1 | +""" pyplots.ai |
| 2 | +roc-curve: ROC Curve with AUC |
| 3 | +Library: altair 6.0.0 | Python 3.13.11 |
| 4 | +Quality: 91/100 | Created: 2025-12-26 |
| 5 | +""" |
| 6 | + |
| 7 | +import altair as alt |
| 8 | +import numpy as np |
| 9 | +import pandas as pd |
| 10 | + |
| 11 | + |
| 12 | +# Data - Generate synthetic classification scores and compute ROC curve |
| 13 | +np.random.seed(42) |
| 14 | +n_samples = 500 |
| 15 | +n_thresholds = 200 |
| 16 | + |
| 17 | +# Simulate two models with different performance levels |
| 18 | +# Model 1 (Good): higher separation between classes |
| 19 | +y_true = np.concatenate([np.zeros(n_samples // 2), np.ones(n_samples // 2)]) |
| 20 | +scores_model1 = np.where( |
| 21 | + y_true == 1, |
| 22 | + np.random.beta(5, 2, n_samples), # Positive class scores shifted higher |
| 23 | + np.random.beta(2, 5, n_samples), # Negative class scores shifted lower |
| 24 | +) |
| 25 | + |
| 26 | +# Model 2 (Moderate): less separation |
| 27 | +scores_model2 = np.where(y_true == 1, np.random.beta(3, 2, n_samples), np.random.beta(2, 3, n_samples)) |
| 28 | + |
| 29 | +# Compute ROC curve for Model 1 |
| 30 | +thresholds = np.linspace(0, 1, n_thresholds) |
| 31 | +tpr1_list, fpr1_list = [], [] |
| 32 | +for thresh in thresholds: |
| 33 | + y_pred = (scores_model1 >= thresh).astype(int) |
| 34 | + tp = np.sum((y_pred == 1) & (y_true == 1)) |
| 35 | + fp = np.sum((y_pred == 1) & (y_true == 0)) |
| 36 | + fn = np.sum((y_pred == 0) & (y_true == 1)) |
| 37 | + tn = np.sum((y_pred == 0) & (y_true == 0)) |
| 38 | + tpr1_list.append(tp / (tp + fn) if (tp + fn) > 0 else 0) |
| 39 | + fpr1_list.append(fp / (fp + tn) if (fp + tn) > 0 else 0) |
| 40 | +fpr1 = np.array(fpr1_list) |
| 41 | +tpr1 = np.array(tpr1_list) |
| 42 | + |
| 43 | +# Compute ROC curve for Model 2 |
| 44 | +tpr2_list, fpr2_list = [], [] |
| 45 | +for thresh in thresholds: |
| 46 | + y_pred = (scores_model2 >= thresh).astype(int) |
| 47 | + tp = np.sum((y_pred == 1) & (y_true == 1)) |
| 48 | + fp = np.sum((y_pred == 1) & (y_true == 0)) |
| 49 | + fn = np.sum((y_pred == 0) & (y_true == 1)) |
| 50 | + tn = np.sum((y_pred == 0) & (y_true == 0)) |
| 51 | + tpr2_list.append(tp / (tp + fn) if (tp + fn) > 0 else 0) |
| 52 | + fpr2_list.append(fp / (fp + tn) if (fp + tn) > 0 else 0) |
| 53 | +fpr2 = np.array(fpr2_list) |
| 54 | +tpr2 = np.array(tpr2_list) |
| 55 | + |
| 56 | +# Compute AUC using trapezoidal rule |
| 57 | +auc1 = -np.trapezoid(tpr1, fpr1) |
| 58 | +auc2 = -np.trapezoid(tpr2, fpr2) |
| 59 | + |
| 60 | +# Create labels for legend |
| 61 | +label1 = f"Good Model (AUC = {auc1:.2f})" |
| 62 | +label2 = f"Moderate Model (AUC = {auc2:.2f})" |
| 63 | +label_random = "Random (AUC = 0.50)" |
| 64 | + |
| 65 | +# Create DataFrames for Altair |
| 66 | +df_model1 = pd.DataFrame({"fpr": fpr1, "tpr": tpr1, "Model": label1}) |
| 67 | +df_model2 = pd.DataFrame({"fpr": fpr2, "tpr": tpr2, "Model": label2}) |
| 68 | +df_roc = pd.concat([df_model1, df_model2], ignore_index=True) |
| 69 | + |
| 70 | +# Diagonal reference line (random classifier) |
| 71 | +df_diagonal = pd.DataFrame({"fpr": [0, 1], "tpr": [0, 1], "Model": label_random}) |
| 72 | + |
| 73 | +# Create ROC curves |
| 74 | +roc_lines = ( |
| 75 | + alt.Chart(df_roc) |
| 76 | + .mark_line(strokeWidth=4) |
| 77 | + .encode( |
| 78 | + x=alt.X("fpr:Q", title="False Positive Rate", scale=alt.Scale(domain=[0, 1])), |
| 79 | + y=alt.Y("tpr:Q", title="True Positive Rate", scale=alt.Scale(domain=[0, 1])), |
| 80 | + color=alt.Color( |
| 81 | + "Model:N", |
| 82 | + scale=alt.Scale(domain=[label1, label2, label_random], range=["#306998", "#FFD43B", "#888888"]), |
| 83 | + legend=alt.Legend( |
| 84 | + orient="none", |
| 85 | + legendX=930, |
| 86 | + legendY=1150, |
| 87 | + direction="vertical", |
| 88 | + titleFontSize=20, |
| 89 | + labelFontSize=18, |
| 90 | + symbolStrokeWidth=4, |
| 91 | + symbolSize=400, |
| 92 | + labelLimit=400, |
| 93 | + ), |
| 94 | + ), |
| 95 | + ) |
| 96 | +) |
| 97 | + |
| 98 | +# Diagonal reference line |
| 99 | +diagonal_line = ( |
| 100 | + alt.Chart(df_diagonal) |
| 101 | + .mark_line(strokeWidth=3, strokeDash=[8, 6]) |
| 102 | + .encode( |
| 103 | + x="fpr:Q", |
| 104 | + y="tpr:Q", |
| 105 | + color=alt.Color( |
| 106 | + "Model:N", scale=alt.Scale(domain=[label1, label2, label_random], range=["#306998", "#FFD43B", "#888888"]) |
| 107 | + ), |
| 108 | + ) |
| 109 | +) |
| 110 | + |
| 111 | +# Combine charts |
| 112 | +chart = ( |
| 113 | + (roc_lines + diagonal_line) |
| 114 | + .properties(width=1400, height=1400, title="roc-curve · altair · pyplots.ai") |
| 115 | + .configure_title(fontSize=32, anchor="middle", fontWeight="bold") |
| 116 | + .configure_axis( |
| 117 | + labelFontSize=18, titleFontSize=22, titlePadding=15, labelPadding=10, gridOpacity=0.3, gridDash=[4, 4] |
| 118 | + ) |
| 119 | + .configure_view(strokeWidth=0) |
| 120 | +) |
| 121 | + |
| 122 | +# Save outputs |
| 123 | +chart.save("plot.png", scale_factor=2.5) |
| 124 | +chart.save("plot.html") |
0 commit comments