|
| 1 | +""" pyplots.ai |
| 2 | +silhouette-basic: Silhouette Plot |
| 3 | +Library: bokeh 3.8.1 | Python 3.13.11 |
| 4 | +Quality: 91/100 | Created: 2025-12-26 |
| 5 | +""" |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +from bokeh.io import export_png, output_file, save |
| 9 | +from bokeh.models import ColumnDataSource, Label, Span |
| 10 | +from bokeh.plotting import figure |
| 11 | + |
| 12 | + |
| 13 | +# Data - simulating silhouette analysis of customer segmentation (3 clusters) |
| 14 | +# Realistic scenario: clustering customers by purchase behavior |
| 15 | +np.random.seed(42) |
| 16 | + |
| 17 | +n_clusters = 3 |
| 18 | +cluster_sizes = [50, 55, 45] # Different sized clusters |
| 19 | + |
| 20 | +# Generate realistic silhouette values for each cluster |
| 21 | +# Cluster 0: Well-separated cluster (high silhouette scores) |
| 22 | +cluster0_vals = np.clip(np.random.beta(8, 2, cluster_sizes[0]) * 0.6 + 0.35, 0.1, 0.95) |
| 23 | +# Cluster 1: Good cluster with some overlap (medium-high scores) |
| 24 | +cluster1_vals = np.clip(np.random.beta(5, 2, cluster_sizes[1]) * 0.5 + 0.25, 0.0, 0.85) |
| 25 | +# Cluster 2: Some ambiguous samples (includes negative values) |
| 26 | +cluster2_vals = np.clip(np.random.beta(4, 3, cluster_sizes[2]) * 0.8 - 0.1, -0.15, 0.75) |
| 27 | + |
| 28 | +# Combine all values |
| 29 | +silhouette_vals = np.concatenate([cluster0_vals, cluster1_vals, cluster2_vals]) |
| 30 | +cluster_labels = np.concatenate( |
| 31 | + [ |
| 32 | + np.zeros(cluster_sizes[0], dtype=int), |
| 33 | + np.ones(cluster_sizes[1], dtype=int), |
| 34 | + np.full(cluster_sizes[2], 2, dtype=int), |
| 35 | + ] |
| 36 | +) |
| 37 | + |
| 38 | +# Calculate average silhouette score |
| 39 | +avg_silhouette = float(np.mean(silhouette_vals)) |
| 40 | + |
| 41 | +# Prepare data for plotting - sorted silhouette values within each cluster |
| 42 | +y_lower = 15 |
| 43 | +bar_data = {"x": [], "y": [], "width": [], "height": [], "color": []} |
| 44 | +cluster_info = [] # For labels |
| 45 | + |
| 46 | +# Colors for clusters - Python blue, yellow, and a colorblind-safe third color |
| 47 | +colors = ["#306998", "#FFD43B", "#2ECC71"] |
| 48 | + |
| 49 | +for i in range(n_clusters): |
| 50 | + # Get silhouette values for this cluster |
| 51 | + cluster_mask = cluster_labels == i |
| 52 | + cluster_silhouette_vals = silhouette_vals[cluster_mask] |
| 53 | + cluster_silhouette_vals.sort() |
| 54 | + |
| 55 | + cluster_size = len(cluster_silhouette_vals) |
| 56 | + y_upper = y_lower + cluster_size |
| 57 | + |
| 58 | + # Store center position for cluster label |
| 59 | + cluster_center = (y_lower + y_upper) / 2 |
| 60 | + cluster_avg = float(np.mean(cluster_silhouette_vals)) |
| 61 | + cluster_info.append((cluster_center, cluster_avg, cluster_size, i)) |
| 62 | + |
| 63 | + # Add bars for each sample in cluster |
| 64 | + for j, val in enumerate(cluster_silhouette_vals): |
| 65 | + bar_data["x"].append(val / 2) # Center of bar |
| 66 | + bar_data["y"].append(y_lower + j + 0.5) # Y position |
| 67 | + bar_data["width"].append(abs(val)) # Width = silhouette value |
| 68 | + bar_data["height"].append(0.85) # Slightly less than 1 for gap |
| 69 | + bar_data["color"].append(colors[i]) |
| 70 | + |
| 71 | + y_lower = y_upper + 15 # Gap between clusters |
| 72 | + |
| 73 | +# Create figure |
| 74 | +p = figure( |
| 75 | + width=4800, |
| 76 | + height=2700, |
| 77 | + title="silhouette-basic · bokeh · pyplots.ai", |
| 78 | + x_axis_label="Silhouette Coefficient", |
| 79 | + y_axis_label="Cluster (samples sorted by silhouette score)", |
| 80 | + x_range=(-0.3, 1.25), |
| 81 | + y_range=(0, y_lower + 5), |
| 82 | + tools="", |
| 83 | +) |
| 84 | + |
| 85 | +# Style the figure - larger text for 4800x2700 canvas |
| 86 | +p.title.text_font_size = "42pt" |
| 87 | +p.title.text_font_style = "bold" |
| 88 | +p.xaxis.axis_label_text_font_size = "32pt" |
| 89 | +p.yaxis.axis_label_text_font_size = "32pt" |
| 90 | +p.xaxis.major_label_text_font_size = "24pt" |
| 91 | +p.yaxis.major_label_text_font_size = "24pt" |
| 92 | +p.xaxis.axis_label_standoff = 25 |
| 93 | +p.yaxis.axis_label_standoff = 25 |
| 94 | + |
| 95 | +# Create data source for bars |
| 96 | +source = ColumnDataSource(data=bar_data) |
| 97 | + |
| 98 | +# Draw horizontal bars using hbar for proper horizontal bar rendering |
| 99 | +p.rect(x="x", y="y", width="width", height="height", color="color", source=source, line_color=None, alpha=0.85) |
| 100 | + |
| 101 | +# Add vertical line for average silhouette score |
| 102 | +avg_line = Span(location=avg_silhouette, dimension="height", line_color="#E74C3C", line_width=5, line_dash="dashed") |
| 103 | +p.add_layout(avg_line) |
| 104 | + |
| 105 | +# Add average silhouette score label at top |
| 106 | +avg_label = Label( |
| 107 | + x=avg_silhouette + 0.03, |
| 108 | + y=y_lower - 5, |
| 109 | + text=f"Average: {avg_silhouette:.3f}", |
| 110 | + text_font_size="28pt", |
| 111 | + text_color="#E74C3C", |
| 112 | + text_font_style="bold", |
| 113 | +) |
| 114 | +p.add_layout(avg_label) |
| 115 | + |
| 116 | +# Add cluster labels with their average silhouette scores - positioned to avoid overlap |
| 117 | +for center_y, cluster_avg, size, cluster_idx in cluster_info: |
| 118 | + # Position label to the left side, outside the bars |
| 119 | + cluster_label = Label( |
| 120 | + x=-0.22, |
| 121 | + y=center_y, |
| 122 | + text=f"Cluster {cluster_idx}", |
| 123 | + text_font_size="26pt", |
| 124 | + text_color=colors[cluster_idx], |
| 125 | + text_font_style="bold", |
| 126 | + text_align="left", |
| 127 | + text_baseline="middle", |
| 128 | + ) |
| 129 | + p.add_layout(cluster_label) |
| 130 | + |
| 131 | + # Add cluster stats on the right side |
| 132 | + stats_label = Label( |
| 133 | + x=1.01, |
| 134 | + y=center_y, |
| 135 | + text=f"n={size}, avg={cluster_avg:.2f}", |
| 136 | + text_font_size="22pt", |
| 137 | + text_color=colors[cluster_idx], |
| 138 | + text_font_style="normal", |
| 139 | + text_align="left", |
| 140 | + text_baseline="middle", |
| 141 | + ) |
| 142 | + p.add_layout(stats_label) |
| 143 | + |
| 144 | +# Style grid |
| 145 | +p.xgrid.grid_line_alpha = 0.3 |
| 146 | +p.ygrid.grid_line_alpha = 0.0 # No horizontal grid |
| 147 | +p.xgrid.grid_line_dash = [6, 4] |
| 148 | + |
| 149 | +# Remove y-axis ticks (sample indices are not meaningful) |
| 150 | +p.yaxis.major_tick_line_color = None |
| 151 | +p.yaxis.minor_tick_line_color = None |
| 152 | +p.yaxis.major_label_text_font_size = "0pt" |
| 153 | + |
| 154 | +# Add vertical line at x=0 for reference |
| 155 | +zero_line = Span(location=0, dimension="height", line_color="#666666", line_width=2, line_alpha=0.5) |
| 156 | +p.add_layout(zero_line) |
| 157 | + |
| 158 | +# Background |
| 159 | +p.background_fill_color = "#fafafa" |
| 160 | +p.border_fill_color = "#ffffff" |
| 161 | +p.outline_line_color = None |
| 162 | + |
| 163 | +# Save as PNG and HTML |
| 164 | +export_png(p, filename="plot.png") |
| 165 | + |
| 166 | +output_file("plot.html") |
| 167 | +save(p) |
0 commit comments