|
| 1 | +""" pyplots.ai |
| 2 | +shap-summary: SHAP Summary Plot |
| 3 | +Library: altair 6.0.0 | Python 3.13.11 |
| 4 | +Quality: 92/100 | Created: 2025-12-31 |
| 5 | +""" |
| 6 | + |
| 7 | +import altair as alt |
| 8 | +import numpy as np |
| 9 | +import pandas as pd |
| 10 | + |
| 11 | + |
| 12 | +# Generate synthetic SHAP values for a model explanation visualization |
| 13 | +np.random.seed(42) |
| 14 | +n_samples = 300 |
| 15 | +n_features = 10 |
| 16 | + |
| 17 | +# Feature names representing typical ML model inputs |
| 18 | +feature_names = [ |
| 19 | + "Account Age (months)", |
| 20 | + "Transaction Count", |
| 21 | + "Avg Transaction ($)", |
| 22 | + "Credit Score", |
| 23 | + "Income ($K)", |
| 24 | + "Debt Ratio", |
| 25 | + "Payment History", |
| 26 | + "Account Balance ($)", |
| 27 | + "Login Frequency", |
| 28 | + "Support Tickets", |
| 29 | +] |
| 30 | + |
| 31 | +# Create synthetic feature values (normalized to 0-1 for color mapping) |
| 32 | +feature_values = np.random.rand(n_samples, n_features) |
| 33 | + |
| 34 | +# Create synthetic SHAP values with varying importances per feature |
| 35 | +# Higher importance = wider spread of SHAP values |
| 36 | +feature_importances = np.array([0.25, 0.20, 0.15, 0.12, 0.10, 0.07, 0.05, 0.03, 0.02, 0.01]) |
| 37 | +shap_values = np.zeros((n_samples, n_features)) |
| 38 | + |
| 39 | +for i in range(n_features): |
| 40 | + # Create SHAP-like values: correlation with feature value + noise |
| 41 | + # More important features have larger SHAP value spreads |
| 42 | + base_effect = (feature_values[:, i] - 0.5) * feature_importances[i] * 4 |
| 43 | + noise = np.random.randn(n_samples) * feature_importances[i] * 0.5 |
| 44 | + shap_values[:, i] = base_effect + noise |
| 45 | + |
| 46 | +# Sort features by importance (already ordered, but calculate mean abs for verification) |
| 47 | +mean_abs_shap = np.mean(np.abs(shap_values), axis=0) |
| 48 | +feature_order = np.argsort(mean_abs_shap)[::-1] |
| 49 | +feature_order_names = [feature_names[i] for i in feature_order] |
| 50 | + |
| 51 | +# Build dataframe for Altair |
| 52 | +rows = [] |
| 53 | +for feat_idx in feature_order: |
| 54 | + for sample_idx in range(n_samples): |
| 55 | + rows.append( |
| 56 | + { |
| 57 | + "Feature": feature_names[feat_idx], |
| 58 | + "SHAP Value": shap_values[sample_idx, feat_idx], |
| 59 | + "Feature Value": feature_values[sample_idx, feat_idx], |
| 60 | + "Importance": mean_abs_shap[feat_idx], |
| 61 | + } |
| 62 | + ) |
| 63 | + |
| 64 | +df = pd.DataFrame(rows) |
| 65 | + |
| 66 | +# Create the SHAP summary plot |
| 67 | +scatter = ( |
| 68 | + alt.Chart(df) |
| 69 | + .mark_circle(opacity=0.7, stroke="#333333", strokeWidth=0.3) |
| 70 | + .encode( |
| 71 | + x=alt.X( |
| 72 | + "SHAP Value:Q", |
| 73 | + title="SHAP Value (Impact on Model Output)", |
| 74 | + axis=alt.Axis(titleFontSize=22, labelFontSize=18, grid=True, gridOpacity=0.3), |
| 75 | + ), |
| 76 | + y=alt.Y("Feature:N", title=None, sort=feature_order_names, axis=alt.Axis(labelFontSize=18)), |
| 77 | + color=alt.Color( |
| 78 | + "Feature Value:Q", |
| 79 | + scale=alt.Scale(scheme="blueorange", domain=[0, 1]), |
| 80 | + legend=alt.Legend( |
| 81 | + title="Feature Value", titleFontSize=18, labelFontSize=16, orient="right", gradientLength=200 |
| 82 | + ), |
| 83 | + ), |
| 84 | + size=alt.value(80), |
| 85 | + yOffset=alt.YOffset("jitter:Q", scale=alt.Scale(domain=[-1, 1], range=[-15, 15])), |
| 86 | + ) |
| 87 | + .transform_calculate(jitter="random() * 2 - 1") |
| 88 | +) |
| 89 | + |
| 90 | +# Add vertical line at x=0 |
| 91 | +zero_line = ( |
| 92 | + alt.Chart(pd.DataFrame({"x": [0]})).mark_rule(color="#333333", strokeWidth=2, strokeDash=[5, 3]).encode(x="x:Q") |
| 93 | +) |
| 94 | + |
| 95 | +# Combine scatter and zero line |
| 96 | +chart = ( |
| 97 | + (zero_line + scatter) |
| 98 | + .properties( |
| 99 | + width=1400, height=850, title=alt.Title("shap-summary · altair · pyplots.ai", fontSize=28, anchor="middle") |
| 100 | + ) |
| 101 | + .configure_axis(labelFontSize=18, titleFontSize=22) |
| 102 | + .configure_view(strokeWidth=0) |
| 103 | + .configure_legend(titleFontSize=18, labelFontSize=16) |
| 104 | +) |
| 105 | + |
| 106 | +# Save outputs |
| 107 | +chart.save("plot.png", scale_factor=3.0) |
| 108 | +chart.save("plot.html") |
0 commit comments