2015-07-10 38 views
1

我想添加一种“球体”到我的数据集群。我想添加一个“球体”到我的数据集群

我的数据集群是这样的,它没有 “” 球体”。

enter image description here

这是我的代码

import numpy as np 
import matplotlib.pyplot as plt 
from matplotlib import style 
style.use('ggplot') 
import pandas as pd 
from sklearn.cluster import KMeans 

MY_FILE='total_watt.csv' 
date = [] 
consumption = [] 

df = pd.read_csv(MY_FILE, parse_dates=[0], index_col=[0]) 
df = df.resample('1D', how='sum') 
df = df.dropna() 

date = df.index.tolist() 
date = [x.strftime('%Y-%m-%d') for x in date] 
from sklearn.preprocessing import LabelEncoder 

encoder = LabelEncoder() 
date_numeric = encoder.fit_transform(date) 
consumption = df[df.columns[0]].values 

X = np.array([date_numeric, consumption]).T 

kmeans = KMeans(n_clusters=3) 
kmeans.fit(X) 

centroids = kmeans.cluster_centers_ 
labels = kmeans.labels_ 

print(centroids) 
print(labels) 

fig, ax = plt.subplots(figsize=(10,8)) 
rect = fig.patch 
rect.set_facecolor('#2D2B2B') 



colors = ["b.","r.","g."] 

for i in range(len(X)): 
    print("coordinate:",encoder.inverse_transform(X[i,0].astype(int)), X[i,1], "label:", labels[i]) 
    ax.plot(X[i][0], X[i][1], colors[labels[i]], markersize = 10) 
ax.scatter(centroids[:, 0],centroids[:, 1], marker = "x", s=150, linewidths = 5, zorder = 10) 
a = np.arange(0, len(X), 5) 
ax.set_xticks(a) 
ax.set_xticklabels(encoder.inverse_transform(a.astype(int))) 
ax.tick_params(axis='x', colors='lightseagreen') 
ax.tick_params(axis='y', colors='lightseagreen') 
plt.scatter(centroids[:, 0],centroids[:, 1], marker = "x", s=100, c="black", linewidths = 5, zorder = 10) 
ax.set_title('Energy consumptions Clusters (high/medium/low)', color='gold') 
ax.set_xlabel('time', color='gold') 
ax.set_ylabel('date(year 2011)', color='gold') 


plt.show() 

“球体” 是周围环境的情节(集聚区),如图所示。

enter image description here

我试图谷歌它。

但是当我键入“matplotlib球”,我不能得到任何结果..

回答

1

在您的文章中示例图表看起来像Generalized Gaussian Mixture导致每个球是高斯2-d密度。

我会立即写一个示例代码来演示如何在数据集上使用GMM并进行这种绘图。

import numpy as np 
import matplotlib.pyplot as plt 
from matplotlib import style 
style.use('ggplot') 
import pandas as pd 
# code changes here 
# =========================================== 
from sklearn.mixture import GMM 
# =========================================== 
from sklearn.preprocessing import LabelEncoder 

# replace it with you file path 
MY_FILE='/home/Jian/Downloads/total_watt.csv' 

df = pd.read_csv(MY_FILE, parse_dates=[0], index_col=[0]) 
df = df.resample('1D', how='sum') 
df = df.dropna() 

date = df.index.tolist() 
date = [x.strftime('%Y-%m-%d') for x in date] 

encoder = LabelEncoder() 
date_numeric = encoder.fit_transform(date) 
consumption = df[df.columns[0]].values 

X = np.array([date_numeric, consumption]).T 


# code changes here 
# =========================================== 
gmm = GMM(n_components=3, random_state=0) 
gmm.fit(X) 
y_pred = gmm.predict(X) 

# the center is given by mean 
gmm.means_ 

# =========================================== 

import matplotlib as mpl 
fig, ax = plt.subplots(figsize=(10,8)) 

for i, color in enumerate('rgb'): 
    # sphere background 
    width, height = 2 * 1.96 * np.sqrt(np.diagonal(gmm._get_covars()[i])) 
    ell = mpl.patches.Ellipse(gmm.means_[i], width, height, color=color) 
    ell.set_alpha(0.1) 
    ax.add_artist(ell) 
    # data points 
    X_data = X[y_pred == i] 
    ax.scatter(X_data[:,0], X_data[:,1], color=color) 
    # center 
    ax.scatter(gmm.means_[i][0], gmm.means_[i][1], marker='x', s=100, c=color) 


ax.set_title('Energy consumptions Clusters (high/medium/low)', color='gold') 
ax.set_xlabel('time', color='gold') 
ax.set_ylabel('date(year 2011)', color='gold') 
a = np.arange(0, len(X), 5) 
ax.set_xticks(a) 
ax.set_xticklabels(encoder.inverse_transform(a.astype(int))) 
ax.tick_params(axis='x', colors='lightseagreen') 
ax.tick_params(axis='y', colors='lightseagreen') 

enter image description here

+0

建勋!!!!谢谢你太多了!我会等你的!!! –

+0

@SuzukiSoma刚刚更新了我的文章。请看一看。 :-) –

+0

你怎么这么聪明..非常感谢你! 你是怎么研究它的? –

相关问题