2016-12-25 30 views
1

我已经使用keras对ANN分类器进行了编码,现在我正在学习自己编写用于文本和时间序列预测的keras中的RNN。在网上搜索了一段时间后,我发现了Jason Brownlee的tutorial,这对于RNN的初学者来说是一个很好的选择。原文将IMDb数据集用于LSTM文本分类,但由于其数据集大小较大,我将其更改为小型sms垃圾邮件检测数据集。如何在数据集中使用keras RNN进行文本分类?

# LSTM with dropout for sequence classification in the IMDB dataset 
import numpy 
from keras.datasets import imdb 
from keras.models import Sequential 
from keras.layers import Dense 
from keras.layers import LSTM 
from keras.layers.embeddings import Embedding 
from keras.preprocessing import sequence 
import pandaas as pd 
from sklearn.cross_validation import train_test_split 

# fix random seed for reproducibility 
numpy.random.seed(7) 

url = 'https://raw.githubusercontent.com/justmarkham/pydata-dc-2016-tutorial/master/sms.tsv' 
sms = pd.read_table(url, header=None, names=['label', 'message']) 

# convert label to a numerical variable 
sms['label_num'] = sms.label.map({'ham':0, 'spam':1}) 
X = sms.message 
y = sms.label_num 
print(X.shape) 
print(y.shape) 

# load the dataset 
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1) 
top_words = 5000 

# truncate and pad input sequences 
max_review_length = 500 
X_train = sequence.pad_sequences(X_train, maxlen=max_review_length) 
X_test = sequence.pad_sequences(X_test, maxlen=max_review_length) 

# create the model 
embedding_vecor_length = 32 
model = Sequential() 
model.add(Embedding(top_words, embedding_vecor_length, input_length=max_review_length, dropout=0.2)) 
model.add(LSTM(100, dropout_W=0.2, dropout_U=0.2)) 
model.add(Dense(1, activation='sigmoid')) 
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) 
print(model.summary()) 
model.fit(X_train, y_train, nb_epoch=3, batch_size=64) 

# Final evaluation of the model 
scores = model.evaluate(X_test, y_test, verbose=0) 
print("Accuracy: %.2f%%" % (scores[1]*100)) 

我已经成功地将数据集处理成了训练和测试集,但现在应如何为此数据集建立我的RNN模型?

回答

1

在训练神经网络模型之前,您需要将raw text数据表示为numeric vector。为此,您可以使用scikit-learn提供的CountVectorizerTfidfVectorizer。从原始文本格式转换为数字向量表示形式后,您可以训练RNN/LSTM/CNN进行文本分类问题。

相关问题