2017-08-09 75 views
3

我看到在tensorflow contrib库中有一个Kmeans聚类的实现。但是,我无法做简单的估算2D点聚类中心的操作。Kmeans聚类如何在tensorflow中工作?

代码:

## Generate synthetic data 
N,D = 1000, 2 # number of points and dimenstinality 

means = np.array([[0.5, 0.0], 
        [0, 0], 
        [-0.5, -0.5], 
        [-0.8, 0.3]]) 
covs = np.array([np.diag([0.01, 0.01]), 
       np.diag([0.01, 0.01]), 
       np.diag([0.01, 0.01]), 
       np.diag([0.01, 0.01])]) 
n_clusters = means.shape[0] 

points = [] 
for i in range(n_clusters): 
    x = np.random.multivariate_normal(means[i], covs[i], N) 
    points.append(x) 
points = np.concatenate(points) 

## construct model 
kmeans = tf.contrib.learn.KMeansClustering(num_clusters = n_clusters) 
kmeans.fit(points.astype(np.float32)) 

我得到以下错误:

InvalidArgumentError (see above for traceback): Shape [-1,2] has negative dimensions 
    [[Node: input = Placeholder[dtype=DT_FLOAT, shape=[?,2], _device="/job:localhost/replica:0/task:0/cpu:0"]()]] 

我想我做错了什么,但不能从文档什么弄清楚。

编辑

我解决它使用input_fn但它实在是太慢了(我不得不在每个集群,以减少点的数量到10看到的结果)。为什么是这样,我怎样才能让它更快?

def input_fn(): 
    return tf.constant(points, dtype=tf.float32), None 

## construct model 
kmeans = tf.contrib.learn.KMeansClustering(num_clusters = n_clusters, relative_tolerance=0.0001) 
kmeans.fit(input_fn=input_fn) 
centers = kmeans.clusters() 
print(centers) 

解决:

似乎相对宽容应设置。所以我只更改了一行,它工作正常。 kmeans = tf.contrib.learn.KMeansClustering(num_clusters = n_clusters, relative_tolerance=0.0001)

+0

你正在运行什么版本的TF? –

回答

0

你原来的代码返回下面的错误与Tensorflow 1.2:

WARNING:tensorflow:From <stdin>:1: calling BaseEstimator.fit (from   
    tensorflow.contrib.learn.python.learn.estimators.estimator) with x 
    is deprecated and will be removed after 2016-12-01. 
    Instructions for updating: 
    Estimator is decoupled from Scikit Learn interface by moving into 
    separate class SKCompat. Arguments x, y and batch_size are only 
    available in the SKCompat class, Estimator will only accept input_fn. 

根据您的编辑,看来你想通了,input_fn是唯一可接受的输入。如果您真的想使用TF,我会升级到r1.2并将错误消息所示的Estimator包装到SKCompat类中。否则,我只会使用SKLearn包。您也可以手动在TF中实现您自己的聚类算法,如this blog中所示。

+0

谢谢。我想到了。一个问题,但如果我的观点是在一个变量?它的工作原理是否相同,还是我需要做一些不同的事情? (比如在输入kmeans聚类之前对其进行评估) –

+0

包装器不包含TF张量作为输入的估算器,因此排除占位符和变量。因此,在输入之前评估它应该工作! –