|
| 1 | +""" pyplots.ai |
| 2 | +precision-recall: Precision-Recall Curve |
| 3 | +Library: plotly 6.5.0 | Python 3.13.11 |
| 4 | +Quality: 92/100 | Created: 2025-12-26 |
| 5 | +""" |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import plotly.graph_objects as go |
| 9 | +from sklearn.metrics import average_precision_score, precision_recall_curve |
| 10 | + |
| 11 | + |
| 12 | +# Data - Simulate a binary classification scenario (fraud detection) |
| 13 | +np.random.seed(42) |
| 14 | +n_samples = 1000 |
| 15 | + |
| 16 | +# Imbalanced dataset: 10% positive class (fraud cases) |
| 17 | +y_true = np.zeros(n_samples, dtype=int) |
| 18 | +y_true[:100] = 1 |
| 19 | +np.random.shuffle(y_true) |
| 20 | + |
| 21 | +# Generate prediction scores - good classifier with some noise |
| 22 | +y_scores = np.where( |
| 23 | + y_true == 1, |
| 24 | + np.random.beta(5, 2, n_samples), # Higher scores for positive class |
| 25 | + np.random.beta(2, 5, n_samples), # Lower scores for negative class |
| 26 | +) |
| 27 | + |
| 28 | +# Calculate precision-recall curve |
| 29 | +precision, recall, thresholds = precision_recall_curve(y_true, y_scores) |
| 30 | +average_precision = average_precision_score(y_true, y_scores) |
| 31 | + |
| 32 | +# Calculate baseline (random classifier performance) |
| 33 | +positive_class_ratio = np.mean(y_true) |
| 34 | + |
| 35 | +# Create figure |
| 36 | +fig = go.Figure() |
| 37 | + |
| 38 | +# Add precision-recall curve (stepped style for accuracy) |
| 39 | +fig.add_trace( |
| 40 | + go.Scatter( |
| 41 | + x=recall, |
| 42 | + y=precision, |
| 43 | + mode="lines", |
| 44 | + name=f"Classifier (AP = {average_precision:.3f})", |
| 45 | + line={"color": "#306998", "width": 4, "shape": "hv"}, |
| 46 | + fill="tozeroy", |
| 47 | + fillcolor="rgba(48, 105, 152, 0.15)", |
| 48 | + ) |
| 49 | +) |
| 50 | + |
| 51 | +# Add baseline reference line (random classifier) |
| 52 | +fig.add_trace( |
| 53 | + go.Scatter( |
| 54 | + x=[0, 1], |
| 55 | + y=[positive_class_ratio, positive_class_ratio], |
| 56 | + mode="lines", |
| 57 | + name=f"Random Baseline ({positive_class_ratio:.2f})", |
| 58 | + line={"color": "#FFD43B", "width": 3, "dash": "dash"}, |
| 59 | + ) |
| 60 | +) |
| 61 | + |
| 62 | +# Add iso-F1 curves |
| 63 | +f1_values = [0.2, 0.4, 0.6, 0.8] |
| 64 | +for f1 in f1_values: |
| 65 | + # Iso-F1: precision = f1 * recall / (2 * recall - f1) for valid recall range |
| 66 | + x_iso = np.linspace(f1 / 2 + 0.01, 1, 100) # Start above f1/2 to avoid division issues |
| 67 | + y_iso = f1 * x_iso / (2 * x_iso - f1) |
| 68 | + # Only keep valid values within [0, 1] range |
| 69 | + mask = (y_iso > 0) & (y_iso <= 1) |
| 70 | + fig.add_trace( |
| 71 | + go.Scatter( |
| 72 | + x=x_iso[mask], |
| 73 | + y=y_iso[mask], |
| 74 | + mode="lines", |
| 75 | + name=f"F1 = {f1}", |
| 76 | + line={"color": "gray", "width": 1.5, "dash": "dot"}, |
| 77 | + opacity=0.5, |
| 78 | + showlegend=True if f1 == 0.2 else False, |
| 79 | + legendgroup="iso-f1", |
| 80 | + ) |
| 81 | + ) |
| 82 | + |
| 83 | +# Update layout for 4800x2700 px |
| 84 | +fig.update_layout( |
| 85 | + title={"text": "precision-recall · plotly · pyplots.ai", "font": {"size": 32}, "x": 0.5, "xanchor": "center"}, |
| 86 | + xaxis={ |
| 87 | + "title": {"text": "Recall (Sensitivity)", "font": {"size": 24}}, |
| 88 | + "tickfont": {"size": 18}, |
| 89 | + "range": [0, 1.02], |
| 90 | + "showgrid": True, |
| 91 | + "gridcolor": "rgba(0, 0, 0, 0.1)", |
| 92 | + "gridwidth": 1, |
| 93 | + "zeroline": False, |
| 94 | + }, |
| 95 | + yaxis={ |
| 96 | + "title": {"text": "Precision (Positive Predictive Value)", "font": {"size": 24}}, |
| 97 | + "tickfont": {"size": 18}, |
| 98 | + "range": [0, 1.05], |
| 99 | + "showgrid": True, |
| 100 | + "gridcolor": "rgba(0, 0, 0, 0.1)", |
| 101 | + "gridwidth": 1, |
| 102 | + "zeroline": False, |
| 103 | + }, |
| 104 | + legend={ |
| 105 | + "font": {"size": 18}, |
| 106 | + "x": 0.02, |
| 107 | + "y": 0.02, |
| 108 | + "xanchor": "left", |
| 109 | + "yanchor": "bottom", |
| 110 | + "bgcolor": "rgba(255, 255, 255, 0.9)", |
| 111 | + "bordercolor": "rgba(0, 0, 0, 0.3)", |
| 112 | + "borderwidth": 1, |
| 113 | + }, |
| 114 | + template="plotly_white", |
| 115 | + margin={"l": 100, "r": 60, "t": 100, "b": 100}, |
| 116 | +) |
| 117 | + |
| 118 | +# Add annotation for iso-F1 curves |
| 119 | +fig.add_annotation( |
| 120 | + x=0.92, y=0.92, text="Iso-F1 curves", font={"size": 16, "color": "gray"}, showarrow=False, xanchor="right" |
| 121 | +) |
| 122 | + |
| 123 | +# Save outputs |
| 124 | +fig.write_image("plot.png", width=1600, height=900, scale=3) |
| 125 | +fig.write_html("plot.html", include_plotlyjs="cdn") |
0 commit comments