2016-12-10 29 views
0

我想绘制删除样本(行)的效果。有人称之为“学习曲线”。如何发送数据帧到scikit进行交叉验证?

所以我想使用熊猫来删除一些行。 How to remove, randomly, rows from a dataframe but from each label?

但是,当我想要做的交叉验证,我得到以下错误(即使使用df.values把数据框到一个数组后):

enter image description here

所以,我是什么做错了?

这里是我的代码:

import pandas as pd 
import numpy as np 
from sklearn.model_selection import StratifiedShuffleSplit 
from sklearn import neighbors 
from sklearn import cross_validation 

df = pd.DataFrame(np.random.rand(12, 5)) 
label = np.array([1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]) 
df['label'] = label 

df1 = pd.concat(g.sample(2) for idx, g in df.groupby('label')) 

X = df1[[0, 1, 2, 3, 4]].values 
y = df1.label.values 
print(X) 
print(y) 

clf = neighbors.KNeighborsClassifier() 
sss = StratifiedShuffleSplit(1, test_size=0.1) 
scoresSSS = cross_validation.cross_val_score(clf, X, y, cv=sss) 
print(scoresSSS) 

回答

1

马上蝙蝠,与sss = StratifiedShuffleSplit(n_splits=1, test_size=0.35)你生成一个对象,而不是一个可迭代:

>>> type(sss) 
    <class 'sklearn.model_selection._split.StratifiedShuffleSplit'> 

而不是给StratifiedShuffleSplit类你的整个对象(这显然是不可迭代的,因此错误),你需要给它的类的.split()方法(docs)的火车/测试输出。另外,StratifiedShuffleSplit类中的test_size参数太小。如果您使用0.1,则会抛出ValueError,因为您有3个独特的类,因此测试大小的0.1不会。最后,您在KNeighbors clf对象中使用默认的n_neighbors参数值。使用如此小的数据集时,此默认值太大。由于n_neighbors <= n_samples,使用你所拥有的将会抛出另一个ValueError。所以在我下面的例子我已经调升测试规模在StratifiedShuffleSplit对象,下降n_neighbors下降到2,并通过了iterables从sss.split(X, y)cross_validation.cross_val_scorecv PARAM。

因此,这里是你希望你的代码是什么样子:

import pandas as pd 
import numpy as np 
from sklearn.model_selection import StratifiedShuffleSplit 
from sklearn import neighbors 
from sklearn import cross_validation 

df = pd.DataFrame(np.random.rand(12, 5)) 
label=np.array([1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]) 
df['label'] = label 

df1 = pd.concat(g.sample(2) for idx, g in df.groupby('label')) 


X = df1[[0,1,2,3,4]].values 
y = df1.label.values 

clf = neighbors.KNeighborsClassifier(n_neighbors=2) 
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.35) 

scoresSSS = cross_validation.cross_val_score(clf, X, y, cv=sss.split(X, y)) 
print(scoresSSS) 

我只想说,我不知道比分你正在寻找得到的,并绝不是我在声称这将优化你的分数。但是,这将帮助您摆脱这些错误,以便您可以重新开始工作。

相关问题