1
我一直在使用CNN进行文本分类并使用tensorflow的contrib learn。attr'TI'的DataType字符串不在允许值列表中:uint8,int32,int64
然而,当我尝试执行以下代码:
classifier = learn.Estimator(model_fn=cnn_model)
classifier.fit(x_train, y_train, steps=10000)
y_predicted = [ p['class'] for p in classifier.predict(x_test, as_iterable=True)]
score = metrics.accuracy_score(y_test, y_predicted)
print('Accuracy: {0:f}'.format(score))
我下面的错误运行:
ERROR:DataType string for attr 'TI' not in list of allowed values: uint8, int32, int64 on line 'classifier.fit'
我已经格式化了您的代码,请检查内容是否仍然正确。而且,疯狂的猜测:可能是因为'y_train'应该将类表示为整数,但实际上包含了浮点数? – phg
y_train包含0和1. – Raj
而x_train包含数字阵列 – Raj