|
| 1 | +""" pyplots.ai |
| 2 | +silhouette-basic: Silhouette Plot |
| 3 | +Library: letsplot 4.8.2 | Python 3.13.11 |
| 4 | +Quality: 91/100 | Created: 2025-12-26 |
| 5 | +""" |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import pandas as pd |
| 9 | +from lets_plot import * |
| 10 | +from sklearn.cluster import KMeans |
| 11 | +from sklearn.datasets import load_iris |
| 12 | +from sklearn.metrics import silhouette_samples, silhouette_score |
| 13 | + |
| 14 | + |
| 15 | +LetsPlot.setup_html() |
| 16 | + |
| 17 | +# Data - Clustering iris dataset into 3 groups |
| 18 | +np.random.seed(42) |
| 19 | +iris = load_iris() |
| 20 | +X = iris.data |
| 21 | +n_clusters = 3 |
| 22 | + |
| 23 | +# Perform K-means clustering |
| 24 | +kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) |
| 25 | +cluster_labels = kmeans.fit_predict(X) |
| 26 | + |
| 27 | +# Calculate silhouette scores |
| 28 | +silhouette_vals = silhouette_samples(X, cluster_labels) |
| 29 | +avg_silhouette = silhouette_score(X, cluster_labels) |
| 30 | + |
| 31 | +# Build dataframe for plotting - sort samples within each cluster by silhouette score |
| 32 | +data_rows = [] |
| 33 | +y_position = 0 |
| 34 | +cluster_centers = [] |
| 35 | +cluster_avg_scores = [] |
| 36 | + |
| 37 | +colors = ["#306998", "#FFD43B", "#DC2626"] # Python Blue, Python Yellow, Red |
| 38 | + |
| 39 | +for cluster_idx in range(n_clusters): |
| 40 | + # Get samples in this cluster |
| 41 | + mask = cluster_labels == cluster_idx |
| 42 | + cluster_silhouettes = silhouette_vals[mask] |
| 43 | + cluster_silhouettes_sorted = np.sort(cluster_silhouettes) |
| 44 | + |
| 45 | + # Calculate cluster average |
| 46 | + cluster_avg = cluster_silhouettes.mean() |
| 47 | + cluster_avg_scores.append(cluster_avg) |
| 48 | + |
| 49 | + # Track the center position for annotation |
| 50 | + cluster_start = y_position |
| 51 | + |
| 52 | + # Add each sample as a row |
| 53 | + for sil_val in cluster_silhouettes_sorted: |
| 54 | + data_rows.append( |
| 55 | + {"y": y_position, "silhouette": sil_val, "cluster": f"Cluster {cluster_idx}", "cluster_idx": cluster_idx} |
| 56 | + ) |
| 57 | + y_position += 1 |
| 58 | + |
| 59 | + cluster_end = y_position - 1 |
| 60 | + cluster_centers.append((cluster_start + cluster_end) / 2) |
| 61 | + |
| 62 | + # Add small gap between clusters |
| 63 | + y_position += 5 |
| 64 | + |
| 65 | +df = pd.DataFrame(data_rows) |
| 66 | +df["x_start"] = 0 # Starting x position for horizontal bars |
| 67 | + |
| 68 | +# Create annotation dataframe for cluster labels |
| 69 | +annotation_df = pd.DataFrame( |
| 70 | + { |
| 71 | + "y": cluster_centers, |
| 72 | + "x": [-0.12] * n_clusters, |
| 73 | + "label": [f"Cluster {i}\n(avg: {cluster_avg_scores[i]:.2f})" for i in range(n_clusters)], |
| 74 | + } |
| 75 | +) |
| 76 | + |
| 77 | +# Create the silhouette plot using horizontal bars |
| 78 | +plot = ( |
| 79 | + ggplot() |
| 80 | + + geom_segment(aes(x="x_start", xend="silhouette", y="y", yend="y", color="cluster"), data=df, size=1.5) |
| 81 | + + geom_vline(xintercept=avg_silhouette, color="#333333", linetype="dashed", size=1) |
| 82 | + + geom_text(aes(x="x", y="y", label="label"), data=annotation_df, size=14, hjust=1) |
| 83 | + + scale_color_manual(values=colors) |
| 84 | + + labs( |
| 85 | + x="Silhouette Coefficient", |
| 86 | + y="Sample Index (sorted within cluster)", |
| 87 | + title="silhouette-basic · letsplot · pyplots.ai", |
| 88 | + ) |
| 89 | + + xlim(-0.3, 1.0) |
| 90 | + + theme_minimal() |
| 91 | + + theme( |
| 92 | + axis_title=element_text(size=20), |
| 93 | + axis_text=element_text(size=16), |
| 94 | + plot_title=element_text(size=24), |
| 95 | + legend_text=element_text(size=16), |
| 96 | + legend_title=element_text(size=18), |
| 97 | + legend_position="right", |
| 98 | + axis_text_y=element_blank(), |
| 99 | + axis_ticks_y=element_blank(), |
| 100 | + panel_grid_major_y=element_blank(), |
| 101 | + panel_grid_minor_y=element_blank(), |
| 102 | + ) |
| 103 | + + ggsize(1600, 900) |
| 104 | +) |
| 105 | + |
| 106 | +# Add annotation for average silhouette line |
| 107 | +avg_label_df = pd.DataFrame( |
| 108 | + {"x": [avg_silhouette + 0.03], "y": [max(df["y"]) * 0.95], "label": [f"Avg: {avg_silhouette:.2f}"]} |
| 109 | +) |
| 110 | +plot = plot + geom_text(aes(x="x", y="y", label="label"), data=avg_label_df, size=14, hjust=0) |
| 111 | + |
| 112 | +# Save as PNG (scale 3x to get 4800 x 2700 px) |
| 113 | +ggsave(plot, "plot.png", path=".", scale=3) |
| 114 | + |
| 115 | +# Save as HTML for interactive version |
| 116 | +ggsave(plot, "plot.html", path=".") |
0 commit comments