-
Notifications
You must be signed in to change notification settings - Fork 29
Expand file tree
/
Copy path06-k-means-clustering.py
More file actions
32 lines (26 loc) · 905 Bytes
/
Copy path06-k-means-clustering.py
File metadata and controls
32 lines (26 loc) · 905 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#Create artificial data set
from sklearn.datasets import make_blobs
raw_data = make_blobs(n_samples = 200, n_features = 2, centers = 4, cluster_std = 1.8)
#Data imports
import pandas as pd
import numpy as np
#Visualization imports
import seaborn
import matplotlib.pyplot as plt
%matplotlib inline
#Visualize the data
plt.scatter(raw_data[0][:,0], raw_data[0][:,1])
plt.scatter(raw_data[0][:,0], raw_data[0][:,1], c=raw_data[1])
#Build and train the model
from sklearn.cluster import KMeans
model = KMeans(n_clusters=4)
model.fit(raw_data[0])
#See the predictions
model.labels_
model.cluster_centers_
#PLot the predictions against the original data set
f, (ax1, ax2) = plt.subplots(1, 2, sharey=True,figsize=(10,6))
ax1.set_title('Our Model')
ax1.scatter(raw_data[0][:,0], raw_data[0][:,1],c=model.labels_)
ax2.set_title('Original Data')
ax2.scatter(raw_data[0][:,0], raw_data[0][:,1],c=raw_data[1])