第六篇 聚类:深入K-Means

聚类:深入K-Means

这篇我们探讨一下K-Means聚类,它属于一种非监督聚类技术。

先导入一些标准方法。

1
2
3
4
5
6
7
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
# use seaborn plotting defaults
import seaborn as sns; sns.set()

介绍K-Means

K-Means是一种非监督聚类算法,它不是通过标签来聚类而是数据本身所拥有的属性。

K-Means是一种相对易于理解的算法。它搜索每个簇的中心点,把离中心点最近的点归为那个簇。

让我们看下K-Means如何来操作一个简单的簇群。强调一下,这是非监督算法,我们不会给簇上色。

1
2
3
4
from sklearn.datasets.samples_generator import make_blobs
X, y = make_blobs(n_samples=300, centers=4,
random_state=0, cluster_std=0.60)
plt.scatter(X[:, 0], X[:, 1], s=50);

png

用肉眼能相对容易找出这四簇。如果你执行全面的搜索来划分数据集,搜索空间将达到指数级。幸运的是,sklearn实现了一个著名的Expectation Maximization (EM)方法来帮我们解决这个问题。

1
2
3
4
5
from sklearn.cluster import KMeans
est = KMeans(4) # 4 clusters
est.fit(X)
y_kmeans = est.predict(X)
plt.scatter(X[:, 0], X[:, 1], c=y_kmeans, s=50, cmap='rainbow');

png

算法区分了4种我们能用肉眼看到的簇!

K-Means算法:期望最大化(EM)

K-Means使用了一种叫做期望最大化(EM)的算法来解决问题。
期望最大化(EM)有下面的两个步骤:

  1. 猜测一些簇心位置
  2. 重复直至收敛
    • 为每个点分配一个离他最近的簇心位置
    • 用平均值法来重新计算簇心位置

让我们可视化这个过程:

1
2
from fig_code import plot_kmeans_interactive
plot_kmeans_interactive();

png

算法会在合理的簇心下面得到收敛。

K-Means需要注意的地方

这个算法不能保证收敛。所以sklearn默认会使用许多随机初始化的值从而发现最佳结果。

而且簇的数量事先得确定…

K-Means在数字上的应用

一个更实际的例子,让我们再看一下数字。这里我们将使用k-means来聚类64维的数字,然后看下算法发现的簇心是什么?

1
2
from sklearn.datasets import load_digits
digits = load_digits()
1
2
3
est = KMeans(n_clusters=10)
clusters = est.fit_predict(digits.data)
est.cluster_centers_.shape
(10, 64)

我们看到10个簇。让我们可视化每一个簇心代表的数字:

1
2
3
4
fig = plt.figure(figsize=(8, 3))
for i in range(10):
ax = fig.add_subplot(2, 5, 1 + i, xticks=[], yticks=[])
ax.imshow(est.cluster_centers_[i].reshape((8, 8)), cmap=plt.cm.binary)

png

K-Means能够发现这些簇,把它们的均值视作可以识别的数字!

1
2
3
4
5
6
from scipy.stats import mode
labels = np.zeros_like(clusters)
for i in range(10):
mask = (clusters == i)
labels[mask] = mode(digits.target[mask])[0]

作为对比,让我们使用PCA可视化,并看下真实的簇标签和K-means簇标签:

1
2
3
4
5
6
7
8
9
10
11
12
from sklearn.decomposition import PCA
X = PCA(2).fit_transform(digits.data)
kwargs = dict(cmap = plt.cm.get_cmap('rainbow', 10),
edgecolor='none', alpha=0.6)
fig, ax = plt.subplots(1, 2, figsize=(8, 4))
ax[0].scatter(X[:, 0], X[:, 1], c=labels, **kwargs)
ax[0].set_title('learned cluster labels')
ax[1].scatter(X[:, 0], X[:, 1], c=digits.target, **kwargs)
ax[1].set_title('true labels');

png

让我们看下K-Means分类器(无标签信息)的准确性

1
2
from sklearn.metrics import accuracy_score
accuracy_score(digits.target, labels)
0.79298831385642743

接近80%,结果还行。我们看下这个的混淆矩阵:

1
2
3
4
5
6
7
8
9
from sklearn.metrics import confusion_matrix
print(confusion_matrix(digits.target, labels))
plt.imshow(confusion_matrix(digits.target, labels),
cmap='Blues', interpolation='nearest')
plt.colorbar()
plt.grid(False)
plt.ylabel('true')
plt.xlabel('predicted');
[[177   0   0   0   1   0   0   0   0   0]
 [  0  55  24   1   0   1   2   0  99   0]
 [  1   2 148  13   0   0   0   3   8   2]
 [  0   0   0 154   0   2   0   7   7  13]
 [  0   5   0   0 164   0   0   9   3   0]
 [  0   0   0   1   2 136   1   0   0  42]
 [  1   1   0   0   0   0 177   0   2   0]
 [  0   2   0   0   0   0   0 175   2   0]
 [  0   6   3   2   0   4   2   5 100  52]
 [  0  20   0   6   0   6   0   7   2 139]]

png

这是一个没有标签的完全无监督estimator,具有80%的分类准确率。

例子:K-Means应用在图像的颜色空间压缩上

一个有趣的应用是在图像的颜色空间压缩上。例如,想象你有一张上百万颜色像素的图片。在大多数图片中,颜色中的大部分不会被使用,而大部分的像素点颜色是相似甚至一样的。

sklearn内置了一些图片,可以通过加载模块来使用。例如:

1
2
3
4
from sklearn.datasets import load_sample_image
china = load_sample_image("china.jpg")
plt.imshow(china)
plt.grid(False);

png

图片本身包含了3维数组(height, width, RGB)

1
china.shape
(427, 640, 3)

我们能想象这图片是3维颜色空间的点所组成。我们把像素点的取值范围限制在(0,1)的范围,再把位于一个平面上的像素点reshape成一条直线。

1
2
X = (china / 255.0).reshape(-1, 3)
print(X.shape)
(273280, 3)

我们现在有273280个3维的像素点。

我们的任务是使用K-Means压缩$256^3$种颜色空间到一个很小的颜色空间(64种颜色空间)。首先我们要在数据集中发现$N_{color}$ 个簇,每个簇代表一个颜色,把属于一个簇的像素点用这个簇的颜色来替代。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# reduce the size of the image for speed
image = china[::3, ::3]
n_colors = 64
X = (image / 255.0).reshape(-1, 3)
model = KMeans(n_colors)
labels = model.fit_predict(X)
colors = model.cluster_centers_
new_image = colors[labels].reshape(image.shape)
new_image = (255 * new_image).astype(np.uint8)
# create and plot the new image
with sns.axes_style('white'):
plt.figure()
plt.imshow(image)
plt.title('input')
plt.figure()
plt.imshow(new_image)
plt.title('{0} colors'.format(n_colors))

png

png

比较输入和输出结果:我们把$256^3$种的颜色空间降到了64种


Jupyter实现


知识共享许可协议
本作品采用知识共享署名-非商业性使用-禁止演绎 3.0 未本地化版本许可协议进行许可。