|
| 1 | +""" pyplots.ai |
| 2 | +pdp-basic: Partial Dependence Plot |
| 3 | +Library: bokeh 3.8.1 | Python 3.13.11 |
| 4 | +Quality: 91/100 | Created: 2025-12-31 |
| 5 | +""" |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +from bokeh.io import export_png |
| 9 | +from bokeh.models import Band, ColumnDataSource, Span |
| 10 | +from bokeh.plotting import figure |
| 11 | +from sklearn.datasets import make_friedman1 |
| 12 | +from sklearn.ensemble import GradientBoostingRegressor |
| 13 | +from sklearn.inspection import partial_dependence |
| 14 | + |
| 15 | + |
| 16 | +# Data - Train a model and compute partial dependence |
| 17 | +np.random.seed(42) |
| 18 | + |
| 19 | +# Use Friedman #1 dataset which has known non-linear relationships |
| 20 | +X, y = make_friedman1(n_samples=500, n_features=5, noise=0.5, random_state=42) |
| 21 | + |
| 22 | +# Train a gradient boosting model |
| 23 | +model = GradientBoostingRegressor(n_estimators=100, max_depth=4, random_state=42) |
| 24 | +model.fit(X, y) |
| 25 | + |
| 26 | +# Compute partial dependence for feature 0 (has sin relationship) |
| 27 | +feature_idx = 0 |
| 28 | +grid_resolution = 100 |
| 29 | + |
| 30 | +# Compute partial dependence using sklearn |
| 31 | +pdp_results = partial_dependence(model, X, features=[feature_idx], kind="both", grid_resolution=grid_resolution) |
| 32 | + |
| 33 | +# Extract values |
| 34 | +avg_predictions = pdp_results["average"][0] |
| 35 | +individual_predictions = pdp_results["individual"][0] # ICE lines |
| 36 | +grid_values = pdp_results["grid_values"][0] |
| 37 | + |
| 38 | +# Calculate confidence interval (percentiles of ICE lines) |
| 39 | +lower_bound = np.percentile(individual_predictions, 10, axis=0) |
| 40 | +upper_bound = np.percentile(individual_predictions, 90, axis=0) |
| 41 | + |
| 42 | +# Center partial dependence at zero for easier interpretation |
| 43 | +center_val = avg_predictions.mean() |
| 44 | +avg_centered = avg_predictions - center_val |
| 45 | +lower_centered = lower_bound - center_val |
| 46 | +upper_centered = upper_bound - center_val |
| 47 | + |
| 48 | +# Get training data distribution for rug plot |
| 49 | +rug_x = X[:, feature_idx] |
| 50 | + |
| 51 | +# Create data source for main line and band |
| 52 | +source = ColumnDataSource(data={"x": grid_values, "y": avg_centered, "lower": lower_centered, "upper": upper_centered}) |
| 53 | + |
| 54 | +# Create data source for rug plot - position at bottom of plot area |
| 55 | +y_min = lower_centered.min() - 1.5 |
| 56 | +rug_source = ColumnDataSource(data={"x": rug_x, "y": np.full_like(rug_x, y_min + 0.3)}) |
| 57 | + |
| 58 | +# Create figure with proper sizing |
| 59 | +p = figure( |
| 60 | + width=4800, |
| 61 | + height=2700, |
| 62 | + title="pdp-basic · bokeh · pyplots.ai", |
| 63 | + x_axis_label="Feature X₀ Value", |
| 64 | + y_axis_label="Partial Dependence (centered)", |
| 65 | +) |
| 66 | + |
| 67 | +# Add confidence band |
| 68 | +band = Band( |
| 69 | + base="x", |
| 70 | + lower="lower", |
| 71 | + upper="upper", |
| 72 | + source=source, |
| 73 | + fill_color="#306998", |
| 74 | + fill_alpha=0.25, |
| 75 | + line_color="#306998", |
| 76 | + line_alpha=0.4, |
| 77 | +) |
| 78 | +p.add_layout(band) |
| 79 | + |
| 80 | +# Add horizontal line at y=0 for reference |
| 81 | +zero_line = Span(location=0, dimension="width", line_color="#555555", line_width=3, line_dash="dashed", line_alpha=0.6) |
| 82 | +p.add_layout(zero_line) |
| 83 | + |
| 84 | +# Add invisible patch for confidence band legend entry |
| 85 | +p.patch([], [], fill_color="#306998", fill_alpha=0.25, line_color="#306998", line_alpha=0.4, legend_label="80% CI") |
| 86 | + |
| 87 | +# Add main PDP line |
| 88 | +p.line("x", "y", source=source, line_width=5, line_color="#306998", legend_label="Average PD") |
| 89 | + |
| 90 | +# Add rug plot for data distribution |
| 91 | +p.scatter( |
| 92 | + "x", |
| 93 | + "y", |
| 94 | + source=rug_source, |
| 95 | + size=25, |
| 96 | + color="#FFD43B", |
| 97 | + alpha=0.6, |
| 98 | + line_width=3, |
| 99 | + angle=1.5708, |
| 100 | + marker="dash", |
| 101 | + legend_label="Data Distribution", |
| 102 | +) |
| 103 | + |
| 104 | +# Text styling - scaled for 4800x2700 px canvas |
| 105 | +p.title.text_font_size = "56pt" |
| 106 | +p.title.text_font_style = "bold" |
| 107 | +p.xaxis.axis_label_text_font_size = "42pt" |
| 108 | +p.yaxis.axis_label_text_font_size = "42pt" |
| 109 | +p.xaxis.major_label_text_font_size = "32pt" |
| 110 | +p.yaxis.major_label_text_font_size = "32pt" |
| 111 | + |
| 112 | +# Axis styling |
| 113 | +p.xaxis.axis_line_width = 3 |
| 114 | +p.yaxis.axis_line_width = 3 |
| 115 | +p.xaxis.major_tick_line_width = 3 |
| 116 | +p.yaxis.major_tick_line_width = 3 |
| 117 | +p.xaxis.minor_tick_line_width = 2 |
| 118 | +p.yaxis.minor_tick_line_width = 2 |
| 119 | + |
| 120 | +# Grid styling |
| 121 | +p.xgrid.grid_line_alpha = 0.3 |
| 122 | +p.ygrid.grid_line_alpha = 0.3 |
| 123 | +p.xgrid.grid_line_dash = "dashed" |
| 124 | +p.ygrid.grid_line_dash = "dashed" |
| 125 | + |
| 126 | +# Legend styling |
| 127 | +p.legend.location = "bottom_right" |
| 128 | +p.legend.label_text_font_size = "32pt" |
| 129 | +p.legend.background_fill_alpha = 0.9 |
| 130 | +p.legend.border_line_alpha = 0.5 |
| 131 | +p.legend.border_line_width = 2 |
| 132 | +p.legend.glyph_height = 50 |
| 133 | +p.legend.glyph_width = 50 |
| 134 | +p.legend.spacing = 20 |
| 135 | +p.legend.padding = 25 |
| 136 | +p.legend.margin = 40 |
| 137 | + |
| 138 | +# Background |
| 139 | +p.background_fill_color = "#fafafa" |
| 140 | +p.border_fill_color = "white" |
| 141 | + |
| 142 | +# Save |
| 143 | +export_png(p, filename="plot.png") |
0 commit comments