2017-02-22 32 views
1

我一直在使用CNN进行文本分类并使用tensorflow的contrib learn。attr'TI'的DataType字符串不在允许值列表中:uint8,int32,int64

然而,当我尝试执行以下代码:

classifier = learn.Estimator(model_fn=cnn_model) 

classifier.fit(x_train, y_train, steps=10000) 
y_predicted = [ p['class'] for p in classifier.predict(x_test, as_iterable=True)] 

score = metrics.accuracy_score(y_test, y_predicted) 

print('Accuracy: {0:f}'.format(score)) 

我下面的错误运行:

ERROR:DataType string for attr 'TI' not in list of allowed values: uint8, int32, int64 on line 'classifier.fit'

+0

我已经格式化了您的代码,请检查内容是否仍然正确。而且,疯狂的猜测:可能是因为'y_train'应该将类表示为整数,但实际上包含了浮点数? – phg

+0

y_train包含0和1. – Raj

+0

而x_train包含数字阵列 – Raj

回答

0

您需要将输入转换y_train给定类型。 print(type(y_train))最有可能是一个浮点数,而不是一个整数。

+0

它们都是整数(1和0) – Raj

相关问题