|
| 1 | +""" pyplots.ai |
| 2 | +silhouette-basic: Silhouette Plot |
| 3 | +Library: pygal 3.1.0 | Python 3.13.11 |
| 4 | +Quality: 88/100 | Created: 2025-12-26 |
| 5 | +""" |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import pygal |
| 9 | +from pygal.style import Style |
| 10 | +from sklearn.cluster import KMeans |
| 11 | +from sklearn.metrics import silhouette_samples, silhouette_score |
| 12 | + |
| 13 | + |
| 14 | +# Data - Synthetic clustering data designed to show both positive and negative silhouettes |
| 15 | +np.random.seed(42) |
| 16 | + |
| 17 | +# Create 3 clusters with deliberate overlap to generate negative silhouette values |
| 18 | +# Cluster 0: tight cluster at origin |
| 19 | +# Cluster 1: well-separated cluster |
| 20 | +# Cluster 2: overlaps significantly with cluster 0 to create misclassified samples |
| 21 | +n_samples_per_cluster = 50 |
| 22 | +cluster0 = np.random.randn(n_samples_per_cluster, 2) * 0.6 + np.array([0, 0]) |
| 23 | +cluster1 = np.random.randn(n_samples_per_cluster, 2) * 0.7 + np.array([4, 4]) |
| 24 | +cluster2 = np.random.randn(n_samples_per_cluster, 2) * 1.0 + np.array([0.8, 0.3]) # Heavy overlap with cluster0 |
| 25 | +X = np.vstack([cluster0, cluster1, cluster2]) |
| 26 | + |
| 27 | +# Cluster the data |
| 28 | +kmeans = KMeans(n_clusters=3, random_state=42, n_init=10) |
| 29 | +cluster_labels = kmeans.fit_predict(X) |
| 30 | + |
| 31 | +# Compute silhouette scores |
| 32 | +silhouette_vals = silhouette_samples(X, cluster_labels) |
| 33 | +avg_silhouette = silhouette_score(X, cluster_labels) |
| 34 | +n_clusters = 3 |
| 35 | + |
| 36 | +# Colors for each cluster (Python Blue, Python Yellow, Complementary Red) |
| 37 | +cluster_colors = ["#306998", "#FFD43B", "#E74C3C"] |
| 38 | + |
| 39 | +# Custom style for pyplots with prominent grid and reference lines |
| 40 | +custom_style = Style( |
| 41 | + background="white", |
| 42 | + plot_background="white", |
| 43 | + foreground="#333333", |
| 44 | + foreground_strong="#000000", # Bold black for major labels (avg reference) |
| 45 | + foreground_subtle="#666666", # More visible grid lines |
| 46 | + guide_stroke_color="#999999", # Darker grid color for visibility |
| 47 | + major_guide_stroke_dasharray="8,4", # More prominent dashed pattern for avg line |
| 48 | + colors=tuple(cluster_colors), |
| 49 | + title_font_size=48, |
| 50 | + label_font_size=32, |
| 51 | + major_label_font_size=36, # Larger major labels for avg reference line |
| 52 | + legend_font_size=28, |
| 53 | + value_font_size=24, |
| 54 | + stroke_width=0, |
| 55 | +) |
| 56 | + |
| 57 | +# Process and sort silhouette values within each cluster |
| 58 | +# Store original averages before any sample reduction |
| 59 | +original_cluster_avgs = {} |
| 60 | +for i in range(n_clusters): |
| 61 | + cluster_silhouette_vals = silhouette_vals[cluster_labels == i] |
| 62 | + original_cluster_avgs[i] = np.mean(cluster_silhouette_vals) |
| 63 | + |
| 64 | +# Build cluster data with sorted values (descending for visual appeal) |
| 65 | +cluster_data = {} |
| 66 | +sample_idx = 0 |
| 67 | +for i in range(n_clusters): |
| 68 | + cluster_silhouette_vals = silhouette_vals[cluster_labels == i] |
| 69 | + cluster_silhouette_vals = np.sort(cluster_silhouette_vals)[::-1] # Descending |
| 70 | + # Subsample for thicker bars while maintaining pattern |
| 71 | + reduced_vals = cluster_silhouette_vals[::2] if len(cluster_silhouette_vals) > 30 else cluster_silhouette_vals |
| 72 | + cluster_data[i] = { |
| 73 | + "values": reduced_vals, |
| 74 | + "avg": original_cluster_avgs[i], # Use original average |
| 75 | + "start_idx": sample_idx, |
| 76 | + "size": len(reduced_vals), |
| 77 | + } |
| 78 | + sample_idx += len(reduced_vals) |
| 79 | + |
| 80 | +total_samples = sample_idx |
| 81 | + |
| 82 | +# Build all bars list for chart data with separator gaps between clusters |
| 83 | +all_bars = [] |
| 84 | +separator_count = 3 # Number of empty bars between clusters for visual separation |
| 85 | +for i in range(n_clusters): |
| 86 | + for val in cluster_data[i]["values"]: |
| 87 | + all_bars.append((i, val)) |
| 88 | + # Add separator gaps after each cluster except the last |
| 89 | + if i < n_clusters - 1: |
| 90 | + for _ in range(separator_count): |
| 91 | + all_bars.append((-1, None)) # -1 indicates separator |
| 92 | + |
| 93 | +chart = pygal.HorizontalBar( |
| 94 | + width=4800, |
| 95 | + height=2700, |
| 96 | + style=custom_style, |
| 97 | + title="silhouette-basic · pygal · pyplots.ai", |
| 98 | + x_title=f"Silhouette Coefficient (avg: {avg_silhouette:.3f})", |
| 99 | + y_title="Samples (grouped by cluster)", |
| 100 | + show_legend=True, |
| 101 | + legend_at_bottom=True, |
| 102 | + legend_at_bottom_columns=3, |
| 103 | + show_y_guides=False, |
| 104 | + show_x_guides=True, |
| 105 | + print_values=False, |
| 106 | + range=(-0.2, 1.0), |
| 107 | + spacing=4, # Increased spacing between bars |
| 108 | + margin=50, |
| 109 | + margin_bottom=150, |
| 110 | + show_y_labels=False, |
| 111 | + x_labels=[-0.2, 0.0, 0.2, round(avg_silhouette, 2), 0.4, 0.6, 0.8, 1.0], |
| 112 | + x_labels_major=[round(avg_silhouette, 2)], # Highlight average silhouette value as major |
| 113 | +) |
| 114 | + |
| 115 | +# Build data series for each cluster |
| 116 | +# Track positions for cluster midpoints (excluding separators) |
| 117 | +cluster_positions = {} |
| 118 | +pos = 0 |
| 119 | +for i in range(n_clusters): |
| 120 | + cluster_positions[i] = {"start": pos, "size": cluster_data[i]["size"]} |
| 121 | + pos += cluster_data[i]["size"] |
| 122 | + if i < n_clusters - 1: |
| 123 | + pos += separator_count |
| 124 | + |
| 125 | +for cluster_idx in range(n_clusters): |
| 126 | + cluster_avg = cluster_data[cluster_idx]["avg"] |
| 127 | + cluster_size = cluster_positions[cluster_idx]["size"] |
| 128 | + start_pos = cluster_positions[cluster_idx]["start"] |
| 129 | + mid_point = start_pos + cluster_size // 2 |
| 130 | + |
| 131 | + series_data = [] |
| 132 | + for bar_idx, (c, val) in enumerate(all_bars): |
| 133 | + if c == cluster_idx: |
| 134 | + # Annotate at cluster midpoint with cluster average |
| 135 | + if bar_idx == mid_point: |
| 136 | + series_data.append({"value": val, "label": f"Cluster {cluster_idx} avg: {cluster_avg:.3f}"}) |
| 137 | + else: |
| 138 | + series_data.append(val) |
| 139 | + else: |
| 140 | + series_data.append(None) |
| 141 | + |
| 142 | + chart.add(f"Cluster {cluster_idx} (avg: {cluster_avg:.3f})", series_data) |
| 143 | + |
| 144 | +# Save outputs |
| 145 | +chart.render_to_file("plot.html") |
| 146 | +chart.render_to_png("plot.png") |
0 commit comments