2017-03-01 33 views
1

我想为每个输入数据分配一个标签;例如,数据[0]是'k',数据[2]是'b',数据[5]是'j',数据[13]是另一个'k',....等等。Scikit学习为输入数据分配标签的Kmeans

在这里显示聚类后:3D plot of 2 clusters,我想检索每个群集中每个“点标签”的类型。

import numpy as np 
from sklearn.cluster import KMeans 
import matplotlib.pyplot as plt 
from mpl_toolkits.mplot3d import Axes3D 

data = np.array([[-7.82,4.58,-3.97],[-6.68,3.16,2.71], 
[4.36,-2.19,2.09],[6.72,0.88,2.80], 
[-8.64,3.06,3.50],[-6.87,0.57,-5.45], 
[4.47,-2.62,5.76],[6.73,-2.01,4.18], 
[-7.71,2.34,-6.33],[-6.91,-0.49,-5.68], 
[6.18,2.81,5.82],[6.72,-0.93,-4.04], 
[-6.25,-0.26,0.56],[-6.94,-1.22,1.13], 
[8.09,0.20,2.25],[6.81,0.17,-4.15], 
[-5.19,4.24,4.04],[-6.38,-1.74,1.43], 
[4.08,1.30,5.33],[6.27,0.93,-2.78]]) 


centers = np.array([[1,1,1],[-1,1,-1]]) 
model_ = KMeans(n_clusters=2, init = centers, n_init=1).fit(data) 
print('The labels for Part a are %s' % model_.labels_) 
fig = plt.figure() 
ax = Axes3D(fig) 
ax.scatter(data[:,0], data[:,1], data[:,2],c=model_.labels_.astype(float),s=70) 
plt.title('Visualizing Clusters') 
ax.set_xlabel('X1', fontsize = 10) 
ax.set_ylabel('X2', fontsize = 10) 
ax.set_zlabel('X3', fontsize = 10) 
plt.show() 

回答

0

您可以labels = model_.labels_检索标签。例如,您可以用sum(labels[0:9] == 0)来计算群集零上'左'的数量。

+0

感谢您的回复,我编辑了这个问题,或许现在更清楚了。 –

+0

只需创建一个包含每个数据点标签的数组,并遵循'data'数组中的顺序。 'assigned_labels = np.array(['k','b',...',k',...])''。例如,sum(labels [assigned_labels =='k'] == 0)'将返回分配给簇'0'的具有标签'k'的元素的数量。 – czr

+0

我不确定那是什么回报!我试过:labels = np.array([2,0,1])和assigned_labels = np.array(['k','b',',k']),标签[assigned_labels =='k']返回数组([2]),标签[assigned_labels =='b']返回数组([0])。 –

0

如果您想将字母分配给行,可以使用熊猫作为例子。

>>> import pandas as pd 
>>> indexes = [chr(ord('a') + i) for i in range(data.shape[0])] 
>>> indexes 
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't'] 
>>> data = pd.DataFrame(data, index=indexes) # pass your old data variable 
>>> data 
     0  1  2 
a -7.82 4.58 -3.97 
b -6.68 3.16 2.71 
c 4.36 -2.19 2.09 
d 6.72 0.88 2.80 
e -8.64 3.06 3.50 
f -6.87 0.57 -5.45 
g 4.47 -2.62 5.76 
h 6.73 -2.01 4.18 
i -7.71 2.34 -6.33 
j -6.91 -0.49 -5.68 
k 6.18 2.81 5.82 
l 6.72 -0.93 -4.04 
m -6.25 -0.26 0.56 
n -6.94 -1.22 1.13 
o 8.09 0.20 2.25 
p 6.81 0.17 -4.15 
q -5.19 4.24 4.04 
r -6.38 -1.74 1.43 
s 4.08 1.30 5.33 
t 6.27 0.93 -2.78 
>>> data.loc['a'] 
0 -7.82 
1 4.58 
2 -3.97 
Name: a, dtype: float64