2014-02-13 44 views
3

我在python中实现了kmeans算法,代码如下。我测试代码使用一些简单的数据。正如以下,其存储在名为data.txt中
-1
-3
4 -4
3 -7
3.5文件-2 -3 -9运行主函数时AssertionError

我的问题是,在迭代过程中,某些集群似乎变空了,也就是(集群数量)< k,经过我的分析,似乎会出现这种情况,但是在搜索完网页后,我发现没有身体在kmeans算法中处理这个问题。

所以我不知道故障在哪里?是因为我的测试数据如此简单

import sys 
import numpy as np 
from math import sqrt 

""" 
useage: python mykmeans.py mydata.txt k 

""" 

GAP = 2 
MIN_VAL = 1000000 

def get_distance(point1, point2): 
    dis = sqrt(pow(point1[0] - point2[0], 2) + pow(point1[1] - point2[1], 2)) 

    return dis 


def cluster_dis(centroid, cluster): 
    dis = 0.0 
    for point in cluster: 
     dis += get_distance(centroid, point) 

    return dis 

def update_centroids(centroids, cluster_id, cluster): 
    x, y = 0.0, 0.0 
    length = len(cluster) 
    if length == 0: # TODO: this is my question? do we need to examine this? 
     return 

    for item in cluster: 
     x += item[0] 
     y += item[1] 
    centroids[cluster_id] = (x/length, y/length) 


def kmeans(data, k): 
    assert k <= len(data) 

    seed_ids = np.random.randint(0, len(data), k) 
    centroids = [data[idx] for idx in seed_ids] 
    clusters = [[] for _ in xrange(k)] 
    cluster_idx = [-1] * len(data) 

    pre_dis = 0 
    while True: 
     for point_id, point in enumerate(data): 
      min_distance, tmp_id = MIN_VAL, -1 
      for seed_id, seed in enumerate(centroids): 
       distance = get_distance(seed, point) 
       if distance < min_distance: 
        min_distance = distance 
        tmp_id = seed_id 
      if cluster_idx[point_id] != -1: 
       dex = clusters[cluster_idx[point_id]].index(point) 
       del clusters[cluster_idx[point_id]][dex] 
      clusters[tmp_id].append(point) 
      cluster_idx[point_id] = tmp_id 

     now_dis = 0.0 
     for cluster_id, cluster in enumerate(clusters): 
      now_dis += cluster_dis(centroids[cluster_id], cluster) 
      update_centroids(centroids, cluster_id, cluster) 

     delta_dis = now_dis - pre_dis 
     pre_dis = now_dis 

     if delta_dis < GAP: 
      break 

    print(centroids) 
    print(clusters) 

    return centroids, clusters 

def get_data(file_name): 
    try: 
     fr = open(file_name) 
     lines = fr.read().splitlines() 
    except IOError, e: 
     pass 
    finally: 
     fr.close() 

    data = [] 
    for line in lines: 
     tmp = line.split() 
     x, y = float(tmp[0]), float(tmp[1]) 
     data.append([x, y]) 

    return data 

def main(): 
    args = sys.argv[1:] 
    assert len(args) > 1 
    file_name, k = args[0], int(args[1]) 

    data = get_data(file_name) 
    kmeans(data, k) 


if __name__ == '__main__': 
    main() 

回答

5

k-means可能会诱发空集群。这里是图中所示的one example。我也复制下面的数字,以防链接可能在某天终止。

下面的第一张图显示了7分的分布。最初选择3,5和6作为聚类中心。

enter image description here

的“+”表示以下第1次迭代后的聚类中心变化,并且相同颜色表示的对应点在相同的簇。

enter image description here

从下图中可以经过2次迭代看到,蓝色的集群变空,的确是有2群,而不是初始值3

enter image description here

所以空簇可能是由于初始化和'不正确的'簇号。您可以在您的代码中尝试不同的k,并多次运行该程序以观察聚类结果,使其更加健壮。

+0

非常感谢。这是一个美妙的答案! 我真正感到困惑的是,我从网上看到的代码并不考虑这种情况,所以我不确定我的理解是否正确。 – Djvu