-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkmeans_app.py
More file actions
38 lines (31 loc) · 1.36 KB
/
kmeans_app.py
File metadata and controls
38 lines (31 loc) · 1.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import streamlit as st
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scripts.kmeans_utils import kmeans, initialize_centroids, assign_clusters, update_centroids
st.title("🧠 K-Means Clustering Explorer")
# Upload CSV
uploaded_file = st.file_uploader("Upload a CSV file", type=["csv"])
if uploaded_file:
df = pd.read_csv(uploaded_file)
st.write("Data Preview:", df.head())
numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
if len(numeric_cols) < 2:
st.warning("Need at least 2 numeric columns to cluster.")
else:
x_col = st.selectbox("X-axis feature", numeric_cols)
y_col = st.selectbox("Y-axis feature", numeric_cols, index=1)
k = st.slider("Number of clusters (k)", 1, 10, value=3)
run = st.button("Run K-Means")
if run:
data = df[[x_col, y_col]].dropna().values
centroids, labels = kmeans(data, k)
for i in range(k):
cluster_data = data[labels == i]
plt.scatter(cluster_data[:, 0], cluster_data[:, 1], label=f"Cluster {i+1}")
plt.scatter(centroids[:, 0], centroids[:, 1], c='black', s=200, marker='X', label='Centroids')
plt.xlabel(x_col)
plt.ylabel(y_col)
plt.title("KMeans Clustering Result")
plt.legend()
st.pyplot(plt)