|
| 1 | +""" pyplots.ai |
| 2 | +shap-summary: SHAP Summary Plot |
| 3 | +Library: plotly 6.5.0 | Python 3.13.11 |
| 4 | +Quality: 91/100 | Created: 2025-12-31 |
| 5 | +""" |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import plotly.graph_objects as go |
| 9 | + |
| 10 | + |
| 11 | +# Data - Generate synthetic SHAP values for ML model interpretability demo |
| 12 | +np.random.seed(42) |
| 13 | + |
| 14 | +# Simulated feature data (like from a gradient boosting model on tabular data) |
| 15 | +n_samples = 200 |
| 16 | +feature_names = [ |
| 17 | + "mean radius", |
| 18 | + "mean texture", |
| 19 | + "mean perimeter", |
| 20 | + "mean area", |
| 21 | + "mean smoothness", |
| 22 | + "mean compactness", |
| 23 | + "mean concavity", |
| 24 | + "mean concave points", |
| 25 | + "mean symmetry", |
| 26 | + "mean fractal dimension", |
| 27 | + "radius error", |
| 28 | + "texture error", |
| 29 | + "perimeter error", |
| 30 | + "area error", |
| 31 | + "smoothness error", |
| 32 | +] |
| 33 | +n_features = len(feature_names) |
| 34 | + |
| 35 | +# Generate realistic feature values (simulating measurement data) |
| 36 | +X = np.zeros((n_samples, n_features)) |
| 37 | +X[:, 0] = np.random.normal(14, 3.5, n_samples) # mean radius |
| 38 | +X[:, 1] = np.random.normal(19, 4, n_samples) # mean texture |
| 39 | +X[:, 2] = np.random.normal(92, 24, n_samples) # mean perimeter |
| 40 | +X[:, 3] = np.random.normal(655, 350, n_samples) # mean area |
| 41 | +X[:, 4] = np.random.normal(0.096, 0.014, n_samples) # mean smoothness |
| 42 | +X[:, 5] = np.random.normal(0.104, 0.053, n_samples) # mean compactness |
| 43 | +X[:, 6] = np.random.normal(0.089, 0.08, n_samples) # mean concavity |
| 44 | +X[:, 7] = np.random.normal(0.049, 0.039, n_samples) # mean concave points |
| 45 | +X[:, 8] = np.random.normal(0.181, 0.027, n_samples) # mean symmetry |
| 46 | +X[:, 9] = np.random.normal(0.063, 0.007, n_samples) # mean fractal dimension |
| 47 | +X[:, 10] = np.random.normal(0.41, 0.28, n_samples) # radius error |
| 48 | +X[:, 11] = np.random.normal(1.22, 0.55, n_samples) # texture error |
| 49 | +X[:, 12] = np.random.normal(2.87, 2.02, n_samples) # perimeter error |
| 50 | +X[:, 13] = np.random.normal(40, 45, n_samples) # area error |
| 51 | +X[:, 14] = np.random.normal(0.007, 0.003, n_samples) # smoothness error |
| 52 | + |
| 53 | +# Simulated feature importances (as from a tree-based model) |
| 54 | +importances = np.array([0.25, 0.08, 0.12, 0.18, 0.03, 0.06, 0.10, 0.09, 0.02, 0.01, 0.02, 0.01, 0.01, 0.01, 0.01]) |
| 55 | + |
| 56 | +# Generate SHAP values that correlate with feature values (simulating real SHAP behavior) |
| 57 | +shap_values = np.zeros((n_samples, n_features)) |
| 58 | +for i in range(n_features): |
| 59 | + feat_min, feat_max = X[:, i].min(), X[:, i].max() |
| 60 | + feat_normalized = (X[:, i] - feat_min) / (feat_max - feat_min + 1e-10) |
| 61 | + |
| 62 | + # SHAP values correlate with feature values, scaled by importance |
| 63 | + base_effect = (feat_normalized - 0.5) * importances[i] * 2 |
| 64 | + noise = np.random.randn(n_samples) * importances[i] * 0.3 |
| 65 | + shap_values[:, i] = base_effect + noise |
| 66 | + |
| 67 | +# Sort features by mean absolute SHAP value (most important first) |
| 68 | +mean_abs_shap = np.mean(np.abs(shap_values), axis=0) |
| 69 | +sorted_idx = np.argsort(mean_abs_shap)[::-1] |
| 70 | + |
| 71 | +# Show top 15 features for clarity |
| 72 | +top_n = 15 |
| 73 | +sorted_idx = sorted_idx[:top_n] |
| 74 | + |
| 75 | +# Create figure |
| 76 | +fig = go.Figure() |
| 77 | + |
| 78 | +# Add traces for each feature (from bottom to top for proper y-axis ordering) |
| 79 | +for rank, feat_idx in enumerate(reversed(sorted_idx)): |
| 80 | + feat_shap = shap_values[:, feat_idx] |
| 81 | + feat_vals = X[:, feat_idx] |
| 82 | + |
| 83 | + # Normalize feature values for coloring (0 to 1) |
| 84 | + feat_min, feat_max = feat_vals.min(), feat_vals.max() |
| 85 | + feat_normalized = (feat_vals - feat_min) / (feat_max - feat_min + 1e-10) |
| 86 | + |
| 87 | + # Add jitter to y-position |
| 88 | + y_base = rank |
| 89 | + jitter = np.random.uniform(-0.3, 0.3, n_samples) |
| 90 | + y_positions = y_base + jitter |
| 91 | + |
| 92 | + # Create color array based on feature values (blue=low, red=high) |
| 93 | + colors = feat_normalized |
| 94 | + |
| 95 | + fig.add_trace( |
| 96 | + go.Scatter( |
| 97 | + x=feat_shap, |
| 98 | + y=y_positions, |
| 99 | + mode="markers", |
| 100 | + marker={ |
| 101 | + "size": 8, |
| 102 | + "color": colors, |
| 103 | + "colorscale": "RdBu_r", |
| 104 | + "cmin": 0, |
| 105 | + "cmax": 1, |
| 106 | + "opacity": 0.7, |
| 107 | + "line": {"width": 0}, |
| 108 | + }, |
| 109 | + name=feature_names[feat_idx][:25], |
| 110 | + hovertemplate=( |
| 111 | + f"<b>{feature_names[feat_idx]}</b><br>" |
| 112 | + "SHAP: %{x:.3f}<br>" |
| 113 | + "Feature value: %{marker.color:.2f}<extra></extra>" |
| 114 | + ), |
| 115 | + showlegend=False, |
| 116 | + ) |
| 117 | + ) |
| 118 | + |
| 119 | +# Add vertical line at x=0 |
| 120 | +fig.add_vline(x=0, line_width=2, line_color="#333333", line_dash="solid") |
| 121 | + |
| 122 | +# Create y-axis labels (feature names in order from bottom to top) |
| 123 | +y_labels = [feature_names[idx][:25] for idx in reversed(sorted_idx)] |
| 124 | + |
| 125 | +# Add colorbar as a separate trace |
| 126 | +colorbar_trace = go.Scatter( |
| 127 | + x=[None], |
| 128 | + y=[None], |
| 129 | + mode="markers", |
| 130 | + marker={ |
| 131 | + "size": 0.1, |
| 132 | + "color": [0, 1], |
| 133 | + "colorscale": "RdBu_r", |
| 134 | + "cmin": 0, |
| 135 | + "cmax": 1, |
| 136 | + "colorbar": { |
| 137 | + "title": {"text": "Feature Value", "font": {"size": 20}, "side": "right"}, |
| 138 | + "tickfont": {"size": 16}, |
| 139 | + "tickvals": [0, 0.5, 1], |
| 140 | + "ticktext": ["Low", "Medium", "High"], |
| 141 | + "len": 0.5, |
| 142 | + "thickness": 25, |
| 143 | + "x": 1.02, |
| 144 | + "y": 0.5, |
| 145 | + }, |
| 146 | + "showscale": True, |
| 147 | + }, |
| 148 | + showlegend=False, |
| 149 | + hoverinfo="skip", |
| 150 | +) |
| 151 | +fig.add_trace(colorbar_trace) |
| 152 | + |
| 153 | +# Update layout |
| 154 | +fig.update_layout( |
| 155 | + title={ |
| 156 | + "text": "shap-summary · plotly · pyplots.ai", |
| 157 | + "font": {"size": 28, "color": "#333333"}, |
| 158 | + "x": 0.5, |
| 159 | + "xanchor": "center", |
| 160 | + }, |
| 161 | + xaxis={ |
| 162 | + "title": {"text": "SHAP Value (Impact on Model Output)", "font": {"size": 22}}, |
| 163 | + "tickfont": {"size": 18}, |
| 164 | + "zeroline": True, |
| 165 | + "zerolinewidth": 2, |
| 166 | + "zerolinecolor": "#333333", |
| 167 | + "gridcolor": "rgba(128, 128, 128, 0.2)", |
| 168 | + "showgrid": True, |
| 169 | + }, |
| 170 | + yaxis={ |
| 171 | + "title": {"text": "Feature", "font": {"size": 22}}, |
| 172 | + "tickfont": {"size": 16}, |
| 173 | + "tickmode": "array", |
| 174 | + "tickvals": list(range(top_n)), |
| 175 | + "ticktext": y_labels, |
| 176 | + "showgrid": False, |
| 177 | + }, |
| 178 | + template="plotly_white", |
| 179 | + plot_bgcolor="white", |
| 180 | + paper_bgcolor="white", |
| 181 | + margin={"l": 200, "r": 120, "t": 80, "b": 80}, |
| 182 | + showlegend=False, |
| 183 | +) |
| 184 | + |
| 185 | +# Save as PNG (4800 x 2700) |
| 186 | +fig.write_image("plot.png", width=1600, height=900, scale=3) |
| 187 | + |
| 188 | +# Save as HTML for interactivity |
| 189 | +fig.write_html("plot.html", include_plotlyjs="cdn") |
0 commit comments