2016-10-27 253 views
0

我很困惑:while_loop在tensorflow为什么下面的代码返回该错误消息返回类型错误

Traceback (most recent call last): 
    File "/Users/Desktop/TestPython/tftest.py", line 46, in <module> 
    main(sys.argv[1:]) 
    File "/Users/Desktop/TestPython/tftest.py", line 35, in main 
    result = tf.while_loop(Cond_f2, Body_f1, loop_vars=loopvars) 
    File "/Users/Desktop/HPC_LIB/TENSORFLOW/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2518, in while_loop 
    result = context.BuildLoop(cond, body, loop_vars, shape_invariants) 
    File "/Users/Desktop/HPC_LIB/TENSORFLOW/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2356, in BuildLoop 
    pred, body, original_loop_vars, loop_vars, shape_invariants) 
    File "/Users/Desktop/HPC_LIB/TENSORFLOW/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2292, in _BuildLoop 
    c = ops.convert_to_tensor(pred(*packed_vars)) 
    File "/Users/Desktop/TestPython/tftest.py", line 18, in Cond_f2 
    boln = tf.less(tf.cast(tf.constant(ind), dtype=tf.int32), tf.cast(tf.constant(N), dtype=tf.int32)) 
    File "/Users/Desktop/HPC_LIB/TENSORFLOW/lib/python2.7/site-packages/tensorflow/python/framework/constant_op.py", line 163, in constant 
    tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape)) 
    File "/Users/Desktop/HPC_LIB/TENSORFLOW/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.py", line 353, in make_tensor_proto 
    _AssertCompatible(values, dtype) 
    File "/Users/Desktop/HPC_LIB/TENSORFLOW/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.py", line 287, in _AssertCompatible 
    raise TypeError("List of Tensors when single Tensor expected") 
TypeError: List of Tensors when single Tensor expected 

我将不胜感激,如果有人可以帮助我解决这个错误。谢谢!

from math import * 
import numpy as np 
import sys 
import tensorflow as tf 

def Body_f1(n, ind, N, T): 
    # Compute trace 
    a = tf.trace(tf.random_normal(0.0, 1.0, (n, n))) 
    # Update trace 
    a = tf.cast(a, dtype=T.dtype) 
    T = tf.scatter_update(T, ind, a) 
    # Update index 
    ind = ind + 1 

    return n, ind, N, T 

def Cond_f2(n, ind, N, T): 
    boln = tf.less(tf.cast(tf.constant(ind), dtype=tf.int32), tf.cast(tf.constant(N), dtype=tf.int32)) 
    return boln 



def main(argv): 
    # Open tensorflow session 
    sess = tf.Session() 

    # Parameters 
    N = 10 
    T = tf.zeros((N), dtype=tf.float64) 
    n = 4 
    ind = 0 

    # While loop 
    loopvars = [n, ind, N, T] 
    result = tf.while_loop(Cond_f2, Body_f1, loop_vars=loopvars, shape_invariants=None, \ 
    parallel_iterations=1, back_prop=False, swap_memory=False, name=None) 
    trace = result[3] 
    trace = sess.run(trace) 
    print trace 
    print 'Done!' 

    # Close tensorflow session 
    if session==None: 
     sess.close() 

if __name__ == "__main__": 
    main(sys.argv[1:]) 

更新:我已添加完整的错误消息。我不知道为什么我会收到此错误消息。 loop_vars期望单张量而不是张量列表吗?我希望不是。

+0

你能分享错误的完整堆栈跟踪吗? – mrry

+0

我已更新了我的帖子,并附有完整的错误消息。非常感谢您的快速反应。 – QED

回答

2

tf.constant预计非张量的值,像Python列表或numpy的阵列。您可以通过迭代tf.constant来获得相同的错误,如tf.constant(tf.constant(5。))。删除这些电话会修复第一个错误。这是一个非常糟糕的错误信息,所以我会鼓励你到file a bug on Github

它看起来像random_normal的参数有点混淆;关键字参数可以很好地避免类似的问题:

tf.random_normal(mean=0.0, stddev=1.0, shape=(n, n)) 

最后,scatter_update需要一个变量。它看起来像一个TensorArray可能是你在这里寻找的东西(或者隐式使用TensorArray的higher level looping constructs之一)。