|
| 1 | +""" pyplots.ai |
| 2 | +calibration-curve: Calibration Curve |
| 3 | +Library: bokeh 3.8.1 | Python 3.13.11 |
| 4 | +Quality: 91/100 | Created: 2025-12-26 |
| 5 | +""" |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +from bokeh.io import export_png, save |
| 9 | +from bokeh.models import ColumnDataSource, Label |
| 10 | +from bokeh.plotting import figure |
| 11 | +from bokeh.resources import CDN |
| 12 | + |
| 13 | + |
| 14 | +# Data - Simulate binary classification predictions with realistic calibration |
| 15 | +np.random.seed(42) |
| 16 | +n_samples = 5000 |
| 17 | + |
| 18 | +# Generate predicted probabilities uniformly across [0, 1] |
| 19 | +y_prob = np.random.uniform(0, 1, n_samples) |
| 20 | + |
| 21 | +# Generate true labels based on probabilities with slight miscalibration |
| 22 | +# Simulating a model that's slightly overconfident (sigmoid distortion) |
| 23 | +calibration_factor = 1.3 # >1 means overconfident |
| 24 | +adjusted_prob = 1 / (1 + np.exp(-calibration_factor * (np.log(y_prob / (1 - y_prob + 1e-10))))) |
| 25 | +adjusted_prob = np.clip(adjusted_prob, 0.01, 0.99) |
| 26 | +y_true = (np.random.uniform(0, 1, n_samples) < adjusted_prob).astype(int) |
| 27 | + |
| 28 | +# Calculate calibration curve (binned) |
| 29 | +n_bins = 10 |
| 30 | +bin_edges = np.linspace(0, 1, n_bins + 1) |
| 31 | +bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 |
| 32 | + |
| 33 | +fraction_of_positives = [] |
| 34 | +mean_predicted_value = [] |
| 35 | +bin_counts = [] |
| 36 | + |
| 37 | +for i in range(n_bins): |
| 38 | + mask = (y_prob >= bin_edges[i]) & (y_prob < bin_edges[i + 1]) |
| 39 | + if i == n_bins - 1: # Include right edge for last bin |
| 40 | + mask = (y_prob >= bin_edges[i]) & (y_prob <= bin_edges[i + 1]) |
| 41 | + |
| 42 | + if mask.sum() > 0: |
| 43 | + fraction_of_positives.append(y_true[mask].mean()) |
| 44 | + mean_predicted_value.append(y_prob[mask].mean()) |
| 45 | + bin_counts.append(mask.sum()) |
| 46 | + else: |
| 47 | + fraction_of_positives.append(np.nan) |
| 48 | + mean_predicted_value.append(bin_centers[i]) |
| 49 | + bin_counts.append(0) |
| 50 | + |
| 51 | +# Filter out empty bins for plotting |
| 52 | +valid_mask = np.array(bin_counts) > 0 |
| 53 | +mean_pred_valid = np.array(mean_predicted_value)[valid_mask] |
| 54 | +frac_pos_valid = np.array(fraction_of_positives)[valid_mask] |
| 55 | +counts_valid = np.array(bin_counts)[valid_mask] |
| 56 | + |
| 57 | +# Calculate Brier score |
| 58 | +brier_score = np.mean((y_prob - y_true) ** 2) |
| 59 | + |
| 60 | +# Calculate Expected Calibration Error (ECE) |
| 61 | +ece = 0 |
| 62 | +total_samples = sum(bin_counts) |
| 63 | +for i in range(len(bin_counts)): |
| 64 | + if bin_counts[i] > 0: |
| 65 | + ece += (bin_counts[i] / total_samples) * abs(fraction_of_positives[i] - mean_predicted_value[i]) |
| 66 | + |
| 67 | +# Create main calibration plot |
| 68 | +p = figure( |
| 69 | + width=4800, |
| 70 | + height=2700, |
| 71 | + title="calibration-curve · bokeh · pyplots.ai", |
| 72 | + x_axis_label="Mean Predicted Probability", |
| 73 | + y_axis_label="Fraction of Positives", |
| 74 | + x_range=(-0.02, 1.02), |
| 75 | + y_range=(-0.02, 1.02), |
| 76 | +) |
| 77 | + |
| 78 | +# Add diagonal reference line (perfect calibration) |
| 79 | +p.line([0, 1], [0, 1], line_color="#888888", line_dash="dashed", line_width=4, legend_label="Perfect Calibration") |
| 80 | + |
| 81 | +# Create source for calibration curve |
| 82 | +source = ColumnDataSource(data={"x": mean_pred_valid, "y": frac_pos_valid, "count": counts_valid}) |
| 83 | + |
| 84 | +# Plot calibration curve with markers |
| 85 | +p.line("x", "y", source=source, line_color="#306998", line_width=5, legend_label="Classifier") |
| 86 | +p.scatter("x", "y", source=source, size=25, color="#306998", fill_alpha=0.9, line_color="#1a3d5c", line_width=3) |
| 87 | + |
| 88 | +# Add metrics annotation |
| 89 | +metrics_text = f"Brier Score: {brier_score:.3f}\nECE: {ece:.3f}" |
| 90 | +metrics_label = Label( |
| 91 | + x=0.05, |
| 92 | + y=0.88, |
| 93 | + x_units="data", |
| 94 | + y_units="data", |
| 95 | + text=metrics_text, |
| 96 | + text_font_size="28pt", |
| 97 | + text_color="#333333", |
| 98 | + background_fill_color="white", |
| 99 | + background_fill_alpha=0.9, |
| 100 | +) |
| 101 | +p.add_layout(metrics_label) |
| 102 | + |
| 103 | +# Style the plot |
| 104 | +p.title.text_font_size = "36pt" |
| 105 | +p.title.text_font_style = "bold" |
| 106 | +p.xaxis.axis_label_text_font_size = "28pt" |
| 107 | +p.yaxis.axis_label_text_font_size = "28pt" |
| 108 | +p.xaxis.major_label_text_font_size = "22pt" |
| 109 | +p.yaxis.major_label_text_font_size = "22pt" |
| 110 | + |
| 111 | +# Grid styling |
| 112 | +p.xgrid.grid_line_alpha = 0.3 |
| 113 | +p.ygrid.grid_line_alpha = 0.3 |
| 114 | +p.xgrid.grid_line_dash = "dashed" |
| 115 | +p.ygrid.grid_line_dash = "dashed" |
| 116 | + |
| 117 | +# Legend styling |
| 118 | +p.legend.label_text_font_size = "24pt" |
| 119 | +p.legend.location = "bottom_right" |
| 120 | +p.legend.background_fill_alpha = 0.9 |
| 121 | +p.legend.border_line_width = 2 |
| 122 | +p.legend.padding = 15 |
| 123 | +p.legend.spacing = 10 |
| 124 | + |
| 125 | +# Axis styling |
| 126 | +p.xaxis.axis_line_width = 2 |
| 127 | +p.yaxis.axis_line_width = 2 |
| 128 | +p.xaxis.major_tick_line_width = 2 |
| 129 | +p.yaxis.major_tick_line_width = 2 |
| 130 | + |
| 131 | +# Background |
| 132 | +p.background_fill_color = "#fafafa" |
| 133 | +p.border_fill_color = "#ffffff" |
| 134 | + |
| 135 | +# Save outputs |
| 136 | +export_png(p, filename="plot.png") |
| 137 | +save(p, filename="plot.html", resources=CDN, title="Calibration Curve") |
0 commit comments