|
| 1 | +""" pyplots.ai |
| 2 | +precision-recall: Precision-Recall Curve |
| 3 | +Library: altair 6.0.0 | Python 3.13.11 |
| 4 | +Quality: 91/100 | Created: 2025-12-26 |
| 5 | +""" |
| 6 | + |
| 7 | +import altair as alt |
| 8 | +import numpy as np |
| 9 | +import pandas as pd |
| 10 | + |
| 11 | + |
| 12 | +# Data - Simulate precision-recall curves for two classifiers |
| 13 | +np.random.seed(42) |
| 14 | + |
| 15 | +# Generate recall values (from 1 to 0, as thresholds increase) |
| 16 | +n_points = 100 |
| 17 | +recall_vals = np.linspace(1, 0, n_points) |
| 18 | + |
| 19 | +# Simulate Logistic Regression PR curve |
| 20 | +# Good classifier: precision increases as recall decreases |
| 21 | +lr_precision = 0.3 + 0.65 * (1 - recall_vals) + np.random.normal(0, 0.02, n_points) |
| 22 | +lr_precision = np.clip(lr_precision, 0, 1) |
| 23 | +# Ensure monotonic-ish behavior with step-like pattern |
| 24 | +lr_precision = np.maximum.accumulate(lr_precision) |
| 25 | +lr_ap = np.trapezoid(lr_precision, recall_vals[::-1]) # Average Precision |
| 26 | + |
| 27 | +# Simulate Random Forest PR curve (better classifier) |
| 28 | +rf_precision = 0.4 + 0.58 * (1 - recall_vals) ** 0.7 + np.random.normal(0, 0.015, n_points) |
| 29 | +rf_precision = np.clip(rf_precision, 0, 1) |
| 30 | +rf_precision = np.maximum.accumulate(rf_precision) |
| 31 | +rf_ap = np.trapezoid(rf_precision, recall_vals[::-1]) |
| 32 | + |
| 33 | +# Baseline (positive class ratio - simulating ~30% positive class) |
| 34 | +baseline = 0.30 |
| 35 | + |
| 36 | +# Create DataFrames for Altair |
| 37 | +lr_df = pd.DataFrame( |
| 38 | + {"Recall": recall_vals, "Precision": lr_precision, "Model": f"Logistic Regression (AP = {lr_ap:.3f})"} |
| 39 | +) |
| 40 | + |
| 41 | +rf_df = pd.DataFrame({"Recall": recall_vals, "Precision": rf_precision, "Model": f"Random Forest (AP = {rf_ap:.3f})"}) |
| 42 | + |
| 43 | +# Combine classifier data |
| 44 | +curve_df = pd.concat([lr_df, rf_df], ignore_index=True) |
| 45 | + |
| 46 | +# Baseline data for reference line |
| 47 | +baseline_df = pd.DataFrame( |
| 48 | + {"Recall": [0.0, 1.0], "Precision": [baseline, baseline], "Model": f"Random Classifier (baseline = {baseline:.2f})"} |
| 49 | +) |
| 50 | + |
| 51 | +# Create precision-recall curves with stepped interpolation |
| 52 | +pr_curves = ( |
| 53 | + alt.Chart(curve_df) |
| 54 | + .mark_line(strokeWidth=4, interpolate="step-after") |
| 55 | + .encode( |
| 56 | + x=alt.X("Recall:Q", title="Recall", scale=alt.Scale(domain=[0, 1])), |
| 57 | + y=alt.Y("Precision:Q", title="Precision", scale=alt.Scale(domain=[0, 1])), |
| 58 | + color=alt.Color( |
| 59 | + "Model:N", |
| 60 | + scale=alt.Scale( |
| 61 | + domain=[ |
| 62 | + f"Logistic Regression (AP = {lr_ap:.3f})", |
| 63 | + f"Random Forest (AP = {rf_ap:.3f})", |
| 64 | + f"Random Classifier (baseline = {baseline:.2f})", |
| 65 | + ], |
| 66 | + range=["#306998", "#FFD43B", "#888888"], |
| 67 | + ), |
| 68 | + legend=alt.Legend( |
| 69 | + title="Model", |
| 70 | + titleFontSize=20, |
| 71 | + labelFontSize=16, |
| 72 | + labelLimit=400, |
| 73 | + orient="bottom-right", |
| 74 | + direction="vertical", |
| 75 | + offset=10, |
| 76 | + symbolStrokeWidth=4, |
| 77 | + symbolSize=300, |
| 78 | + ), |
| 79 | + ), |
| 80 | + strokeDash=alt.StrokeDash( |
| 81 | + "Model:N", |
| 82 | + scale=alt.Scale( |
| 83 | + domain=[ |
| 84 | + f"Logistic Regression (AP = {lr_ap:.3f})", |
| 85 | + f"Random Forest (AP = {rf_ap:.3f})", |
| 86 | + f"Random Classifier (baseline = {baseline:.2f})", |
| 87 | + ], |
| 88 | + range=[[0], [0], [8, 4]], # Solid for models, dashed for baseline |
| 89 | + ), |
| 90 | + legend=None, |
| 91 | + ), |
| 92 | + ) |
| 93 | +) |
| 94 | + |
| 95 | +# Baseline reference line |
| 96 | +baseline_line = ( |
| 97 | + alt.Chart(baseline_df) |
| 98 | + .mark_line(strokeWidth=3, strokeDash=[8, 4]) |
| 99 | + .encode(x=alt.X("Recall:Q"), y=alt.Y("Precision:Q"), color=alt.Color("Model:N", legend=None)) |
| 100 | +) |
| 101 | + |
| 102 | +# Combine layers |
| 103 | +chart = ( |
| 104 | + alt.layer(pr_curves, baseline_line) |
| 105 | + .properties( |
| 106 | + width=1600, height=900, title=alt.Title("precision-recall · altair · pyplots.ai", fontSize=28, anchor="middle") |
| 107 | + ) |
| 108 | + .configure_axis(labelFontSize=18, titleFontSize=22, gridColor="#CCCCCC", gridOpacity=0.3) |
| 109 | + .configure_view(strokeWidth=0) |
| 110 | +) |
| 111 | + |
| 112 | +# Save as PNG and HTML |
| 113 | +chart.save("plot.png", scale_factor=3.0) |
| 114 | +chart.save("plot.html") |
0 commit comments