2017-11-17 30 views
0

在下面的代码中,数据是一个句子列表,而“y”列(data.metagroup)是一个类列表 - 一个直接的分类问题。MultisomialNB分类器中partial_fit的错误

问题:

我想用partial_fit与MultinomialNB分类。

基础上的文档,我传递一个稀疏向量X(称为xtrain),为y简单的串联(称为ytrain),以及用于np.arrayclasses是所有可能的类的列表。

目标是最终使用xtrainytrain的子集,但我必须先让它工作。

相关的文档可以在这里找到: http://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.MultinomialNB.html#sklearn.naive_bayes.MultinomialNB.partial_fit

我得到的错误是:

ValueError: operands could not be broadcast together with shapes 
(42633,3809) (800,3809) (42633,3809) 

希望得到任何见解。

def make_xy(data): 
    vectorizer = CountVectorizer(ngram_range = (1,3), min_df = 3, stop_words='english') 
    X = vectorizer.fit_transform(data.sentences) 
    y = data.metagroup 
    return X, y, vectorizer 

x, y, vv = make_xy(data) 

xtrain, xtest, ytrain, ytest = train_test_split(x, y, test_size=0.30) 

clf = MultinomialNB(alpha=1) 
clf.partial_fit(xtrain, ytrain, classes=np.array(y), sample_weight=None) 

predictions = clf.predict(xtest) 


--------------------------------------------------------------------------- 
ValueError        Traceback (most recent call last) 
<ipython-input-25-cc08c1d170fd> in <module>() 
    48 clf = MultinomialNB(alpha=1) 
---> 50 clf.partial_fit(xtrain, ytrain, classes=np.array(y), sample_weight=None) 

/usr/local/lib/python2.7/site-packages/sklearn/naive_bayes.pyc in partial_fit(self, X, y, classes, sample_weight) 
     530   # Count raw events from data before updating the class log prior 
     531   # and feature log probas 
    --> 532   self._count(X, Y) 
     533 
     534   # XXX: OPTIM: we could introduce a public finalization method to 

    /usr/local/lib/python2.7/site-packages/sklearn/naive_bayes.pyc in _count(self, X, Y) 
     689   if np.any((X.data if issparse(X) else X) < 0): 
     690    raise ValueError("Input X must be non-negative") 
    --> 691   self.feature_count_ += safe_sparse_dot(Y.T, X) 
     692   self.class_count_ += Y.sum(axis=0) 
     693 

    ValueError: operands could not be broadcast together with shapes 
+0

添加完整的代码和数据 – sera

+0

老实说,它只是在“元组”列中的数值类句子列表 - 一个两列熊猫数据框。代码中没有其他内容,但是从csv中提取数据等。这是一切相关的。谢谢。 – paszoon

+1

重现错误的唯一方法是获取数据。如果你不能修改数据,请添加一些导致相同错误的人为数据。 – sera

回答

0

我解决了它。结果问题是我从我的数据中传递了字面Y列,当它真正需要的是指示可能分类的唯一值列表。感谢任何看着这个的人。