2013-12-21 75 views
7

我有两个numpy数组,X_train和Y_train,其中第一个维(700,1000)由值0,1,2,3,4和10填充。因为我正在使用Rotten Tomatoes的API,所以维度(700,)的第二个由值'fresh'或'rotten'填充。出于某种原因,当我执行:MultinomialNB错误:“未知标签类型”

nb = MultinomialNB() 
nb.fit(X_train, Y_train) 

我得到:

ValueError: Unknown label type 

我试图建立一个更小的对阵列:

print xs, '\n', ys 

[[0 0 0 0 1] 
[1 0 0 2 5] 
[3 2 5 5 0] 
[3 2 0 0 1] 
[1 5 1 0 0]] 

['rotten' 'fresh' 'fresh' 'rotten' 'fresh'] 

和多项式NB拟合不会给出未知标签错误。任何想法为什么发生这种情况?

我还用numpy.unique检查了X_train,Y_train中的唯一值,它似乎没有任何奇怪或错误的标签 - 它们都是'新鲜'或'烂的'。

我的用于产生X_train和Y_train代码:

def make_xy(critics, vectorizer=None): 
    stext = critics['quote'].tolist() # need to have a list 
    if vectorizer == None: 
     vectorizer = CountVectorizer(min_df=0) 
    vectorizer.fit(stext) 
    X = vectorizer.transform(stext).toarray() # this is X 
    Y = np.asarray(critics['fresh']) 
    return X[0:1000,0:1000], Y[0:1000] # this is X_train, Y_train 

其中 '评论家' 是从CSV文件(https://www.dropbox.com/s/0lu5oujfm483wtr/critics.csv)导入的熊猫数据帧,任何丢失的数据的清洁:

critics = pd.read_csv('critics.csv') 
critics = critics[~critics.quote.isnull()] 
critics = critics[critics.fresh != 'none'] 
critics = critics[critics.quote.str.len() > 0] 

回答

14

问题似乎是y的dtype。看起来像numpy没有设法弄清楚这是一个字符串。所以它被设置为一个通用对象。如果你改变:
Y = np.asarray(critics['fresh'])Y = np.asarray(critics['fresh'], dtype="|S6")我认为它应该工作。

+0

啊,现在完美!感谢您的帮助。 – covariance

0

我也遇到了同样的问题。 Numpy有时会检测不到数组的数据类型。所以,我们明确地给它。 here is the documentation由numpy所有类型。根据您的要求选择数据类型,并将其提供为“dtype =”属性。