2017-05-04 52 views
6

我正在使用Python API for Tensorflow。我想实现以下给出不使用Python的循环Rosenbrock function了解Tensorflow中的while循环

Rosenbrock function

我当前实现如下:

def rosenbrock(data_tensor): 
    columns = tf.unstack(data_tensor) 

    summation = 0 
    for i in range(1, len(columns) - 1): 
     first_term = tf.square(tf.subtract(columns[i + 1], tf.square(columns[i]))) 
     second_term = tf.square(tf.subtract(columns[i], 1.0)) 
     summation += tf.add(tf.multiply(100.0, first_term), second_term) 

    return summation 

我试图实现在tf.while_loop()的总和;然而,我发现这个API在使用索引整数时有点不直观,因为索引整数意味着与数据保持独立。在documentation给出的示例使用数据作为索引(或反之亦然):

i = tf.constant(0) 
c = lambda i: tf.less(i, 10) 
b = lambda i: tf.add(i, 1) 
r = tf.while_loop(c, b, [i]) 
+0

仅仅使用for循环是否合适?使用while_loop有什么好处?还是有必要? – lerner

+0

在上面的代码中,for循环将执行python代码。如果我们调用他的for循环“f”的主体,那么你可以将python代码视为执行f,f,f,f,f,... f。所以它会将这个“主体”函数调用N次,并且函数的图形将因此具有该函数N次。如果你使用tf.while_loop,那么你只会在图中看到该函数一次。 –

+0

tf.while_loop的优点是:1)可以并行运行迭代,2)可以在条件语句中使用运行时常量。例如,如果你想运行优化器直到满足一定的容差,那么你必须使用tf.while_loop变体,因为python最初不能评估条件 –

回答

10

这可以通过使用可实现的tf.while_loop()和标准tuples按照在documentation第二示例。

def rosenbrock(data_tensor): 
    columns = tf.unstack(data_tensor) 

    # Track both the loop index and summation in a tuple in the form (index, summation) 
    index_summation = (tf.constant(1), tf.constant(0.0)) 

    # The loop condition, note the loop condition is 'i < n-1' 
    def condition(index, summation): 
     return tf.less(index, tf.subtract(tf.shape(columns)[0], 1)) 

    # The loop body, this will return a result tuple in the same form (index, summation) 
    def body(index, summation): 
     x_i = tf.gather(columns, index) 
     x_ip1 = tf.gather(columns, tf.add(index, 1)) 

     first_term = tf.square(tf.subtract(x_ip1, tf.square(x_i))) 
     second_term = tf.square(tf.subtract(x_i, 1.0)) 
     summand = tf.add(tf.multiply(100.0, first_term), second_term) 

     return tf.add(index, 1), tf.add(summation, summand) 

    # We do not care about the index value here, return only the summation 
    return tf.while_loop(condition, body, index_summation)[1] 

重要的是要注意,索引增量应该出现在循环体中,类似于标准while循环。在给出的解决方案中,它是由body()函数返回的元组中的第一项。

此外,循环条件函数必须为总和分配一个参数,虽然它在此特定示例中未使用。