2015-12-06 24 views
1

我训练了一个scikit-learn的实例TfidfVectorizer,我想将其保存到磁盘。我将IDF矩阵(idf_属性)作为一个numpy数组保存到磁盘,并将词汇表(vocabulary_)作为JSON对象保存到磁盘(为了安全和其他reasons,我避免了pickle)。我试图做到这一点:向TfidfVectorizer提供预先计算的估计值

import json 
from idf import idf # numpy array with the pre-computed IDFs 
from sklearn.feature_extraction.text import TfidfVectorizer 

# dirty trick so I can plug my pre-computed IDFs 
# necessary because "vectorizer.idf_ = idf" doesn't work, 
# it returns "AttributeError: can't set attribute." 
class MyVectorizer(TfidfVectorizer): 
    TfidfVectorizer.idf_ = idf 

# instantiate vectorizer 
vectorizer = MyVectorizer(lowercase = False, 
          min_df = 2, 
          norm = 'l2', 
          smooth_idf = True) 

# plug vocabulary 
vocabulary = json.load(open('vocabulary.json', mode = 'rb')) 
vectorizer.vocabulary_ = vocabulary 

# test it 
vectorizer.transform(['foo bar']) 
Traceback (most recent call last): 
    File "<stdin>", line 2, in <module> 
    File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/sklearn/feature_extraction/text.py", line 1314, in transform 
    return self._tfidf.transform(X, copy=False) 
    File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/sklearn/feature_extraction/text.py", line 1014, in transform 
    check_is_fitted(self, '_idf_diag', 'idf vector is not fitted') 
    File "/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/sklearn/utils/validation.py", line 627, in check_is_fitted 
    raise NotFittedError(msg % {'name': type(estimator).__name__}) 
sklearn.utils.validation.NotFittedError: idf vector is not fitted 

那么,我在做什么错了?我无法欺骗矢量化对象:它知道我在作弊(即将预先计算的数据传递给它,而不是用实际的文本进行训练)。我检查了矢量化器对象的属性,但我找不到像'istrained','isfitted'等等。那么,我该如何欺骗矢量化器?

回答

1

好吧,我想我明白了:矢量化器实例有一个属性_tfidf,而这个属性又必须有一个属性_idf_diagtransform方法调用check_is_fitted函数,该函数检查_idf_diag是否存在。 (我错过了它,因为它是一个属性的属性。)所以,我检查了TfidfVectorizer source code以查看_idf_diag是如何创建的。然后,我只是把它添加到_tfidf属性:

import scipy.sparse as sp 

# ... code ... 

vectorizer._tfidf._idf_diag = sp.spdiags(idf, 
             diags = 0, 
             m = len(idf), 
             n = len(idf)) 

而现在的矢量作品。