我从Kaggle尝试tensorflow与泰坦尼克号数据:https://www.kaggle.com/c/titanicTensorflow错误:无效的参数:形状必须是一个矢量
这里是我试图从Sendex实现代码:https://www.youtube.com/watch?v=PwAGxqrXSCs&index=46&list=PLQVvvaa0QuDfKTOs3Keq_kaG2P55YRn5v#t=398.046664
import tensorflow as tf
import cleanData
import numpy as np
train, test = cleanData.read_and_clean()
train = train[['Pclass', 'Sex', 'Age', 'Fare', 'Child', 'Fam_size', 'Title', 'Mother', 'Survived']]
# one hot
train['Died'] = int('0')
train["Died"][train["Survived"] == 0] = 1
print(train.head())
n_nodes_hl1 = 500
n_classes = 2
batch_size = 100
# tf graph input
x = tf.placeholder("float", [None, 8])
y = tf.placeholder("float")
def neural_network_model(data):
hidden_layer_1 = {'weights':tf.Variable(tf.random_normal([8, n_nodes_hl1])),
'biases':tf.Variable(tf.random_normal(n_nodes_hl1))}
output_layer = {'weights':tf.Variable(tf.random_normal([n_nodes_hl1, n_classes])),
'biases':tf.Variable(tf.random_normal([n_classes]))}
l1 = tf.add(tf.matmul(data, hidden_layer_1['weights']), hidden_layer_1['biases'])
l1 = tf.nn.relu(l1)
output = tf.matmul(l1, output_layer['weights']) + output_layer['biases']
return output
def train_neural_network(x):
prediction = neural_network_model(x)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(prediction,y))
optimizer = tf.train.AdamOptimizer().minimize(cost)
desired_epochs = 10
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
for epoch in range(desired_epochs):
epoch_loss = 0
for _ in range(int(train.shape[0])/batch_size):
x_epoch, y_epoch = train.next_batch(batch_size)
_, c = sess.run([optimizer, cost], feed_dict= {x:x, y:y})
epoch_loss += c
print('Epoch', epoch, 'completed out of', desired_epochs, 'loss:', epoch_loss)
correct = tf.equal(tf.argmax(prediction,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct, 'float'))
print('Training accuracy:', accuracy.eval({x:x, y:y}))
train_neural_network(x)
当我运行的代码我得到一个错误,说:“W tensorflow /核心/框架/ op_kernel.cc:909]无效的参数:形状必须是{int32,int64}的向量,得到形状[]”
是有没有办法解决这个问题?我看到在Github上一个职位tensorflow的代码,显然库不采取大熊猫据帧作为输入..
请给我们一个最简单的示例程序,清楚地说明您在哪一行(哪一行)得到错误消息。 –