2017-01-31 41 views
3

我试图用scikit学习LDA分类器来分类一些数据。我不完全确定要从中“期望”什么,但是我所得到的很奇怪。这似乎是一个很好的机会来了解技术的缺点,或者我错误地应用它的方式。我知道没有一行可以完全分离这些数据,但似乎有比它发现的更好的行。我只是使用默认选项。有关如何做得更好的想法?我正在使用我的数据集大小的LDA because it is linear。尽管我认为线性SVM具有类似的复杂性。也许这样的数据会更好?当我测试其他可能性时我会更新。scikit学习LDA给出意想不到的结果

图片:(淡蓝色是我的LDA分类预测将是深蓝色)

LDA

代码:

import numpy as np 
from numpy import array 
import matplotlib.pyplot as plt 
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA 
import itertools 

X = array([[ 0.23125754, 0.79170351], 
     [ 0.78021491, -0.24999486], 
     [ 0.00856446, 0.41452734], 
     [ 0.66381753, -0.09872504], 
     [-0.03178685, 0.04876317], 
     [ 0.65574645, -0.68214948], 
     [ 0.14290684, 0.38256002], 
     [ 0.05156987, 0.11094875], 
     [ 0.06843403, 0.19110019], 
     [ 0.24070898, -0.07403764], 
     [ 0.03184353, 0.4411446 ], 
     [ 0.58708124, -0.38838008], 
     [-0.00700369, 0.07540799], 
     [-0.01907816, 0.07641038], 
     [ 0.30778608, 0.30317186], 
     [ 0.55774143, -0.38017325], 
     [-0.00957214, -0.03303287], 
     [ 0.8410637 , 0.158594 ], 
     [-0.00294113, -0.00380608], 
     [ 0.26577841, 0.07833684], 
     [-0.32249375, 0.49290502], 
     [ 0.11313078, 0.35697211], 
     [ 0.41153679, -0.4471876 ], 
     [-0.00313315, 0.30065913], 
     [ 0.14344143, -0.19127107], 
     [ 0.04857767, 0.01339191], 
     [ 0.5865007 , 0.71209886], 
     [ 0.08157439, 0.40909955], 
     [ 0.72495202, 0.29583866], 
     [-0.09391461, 0.17976605], 
     [ 0.06149141, 0.79323099], 
     [ 0.52208024, -0.2877661 ], 
     [ 0.01992141, -0.00435266], 
     [ 0.68492617, -0.46981335], 
     [-0.00641231, 0.29699622], 
     [ 0.2369677 , 0.140319 ], 
     [ 0.6602586 , 0.11200433], 
     [ 0.25311836, -0.03085372], 
     [-0.0895014 , 0.45147252], 
     [-0.18485667, 0.43744524], 
     [ 0.94636701, 0.16534406], 
     [ 0.01887734, -0.07702135], 
     [ 0.91586801, 0.17693792], 
     [-0.18834833, 0.31944796], 
     [ 0.20468328, 0.07099982], 
     [-0.15506378, 0.94527383], 
     [-0.14560083, 0.72027034], 
     [-0.31037647, 0.81962815], 
     [ 0.01719756, -0.01802322], 
     [-0.08495304, 0.28148978], 
     [ 0.01487427, 0.07632112], 
     [ 0.65414479, 0.17391618], 
     [ 0.00626276, 0.01200355], 
     [ 0.43328095, -0.34016614], 
     [ 0.05728525, -0.05233956], 
     [ 0.61218382, 0.20922571], 
     [-0.69803697, 2.16018536], 
     [ 1.38616732, -1.86041621], 
     [-1.21724616, 2.72682759], 
     [-1.26584365, 1.80585403], 
     [ 1.67900048, -2.36561699], 
     [ 1.35537903, -1.60023078], 
     [-0.77289615, 2.67040114], 
     [ 1.62928969, -1.20851808], 
     [-0.95174264, 2.51515935], 
     [-1.61953649, 2.34420531], 
     [ 1.38580104, -1.9908369 ], 
     [ 1.53224512, -1.96537012]]) 

y = array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
     0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 
     1., 1., 1.]) 

classifier = LDA() 
classifier.fit(X,y) 

xx = np.array(list(itertools.product(np.linspace(-4,4,300), np.linspace(-4,4,300)))) 
yy = classifier.predict(xx) 
b_colors = ['salmon' if yyy==0 else 'deepskyblue' for yyy in yy] 
p_colors = ['r' if yyy==0 else 'b' for yyy in y] 
plt.scatter(xx[:,0],xx[:,1],s=1,marker='o',edgecolor=b_colors,c=b_colors) 
plt.scatter(X[:,0], X[:,1], marker='o', s=5, c=p_colors, edgecolor=p_colors) 
plt.show() 

UPDATE:使用sklearn.discriminant_analysis.LinearDiscriminantAnalysissklearn.svm.LinearSVC更改还使用默认选项给出以下图片:

LinearSVM

我认为使用零一损失而不是铰链损失将有所帮助,但sklearn.svm.LinearSVC似乎不允许自定义损失函数。

UPDATE:损失函数sklearn.svm.LinearSVC接近零酮损失作为参数C趋于无穷。设置C = 1000给了我最初的希望。不要发布这个答案,因为原来的问题是关于LDA的。

图片:

enter image description here

回答

1

LDA模型每个类为高斯,因此对于每一个类由类所确定的模型估计的均值向量和协方差矩阵。 仅用眼睛来判断,你的蓝色和红色类别具有大约相同的平均值和相同的协方差,这意味着两个高斯将“坐在”彼此的顶部,并且歧视将会很差。实际上,这也意味着分隔符(蓝粉红色边框)会很嘈杂,也就是说,随机样本数据之间会发生很大变化。

Btw您的数据显然不是线性可分的,所以每个线性模型都会很难区分数据。

如果您必须使用线性模型,请尝试使用包含3个组件的LDA,例如左上角蓝色块被分类为'0',右下角蓝色块为'1',红色为' 2' 。这样你会得到一个更好的线性模型。你可以用K = 2类的聚类算法预处理蓝色类。

相关问题