2016-08-12 50 views
12

Tensorflow有这个API定义:张量流中的局部变量是什么?

tf.local_variables()

返回具有collection=[LOCAL_VARIABLES]创建的所有变量。

返回:

本地变量对象的列表。

TensorFlow中的局部变量究竟是什么?有人能给我一个例子吗?

+0

看到这个[问题](https://github.com/tensorflow/tensorflow/issues/1045),存在必须在使用前初始化一个局部变量。 – suiyuan2009

回答

14

它与常规变量相同,但与默认值不同(GraphKeys.VARIABLES)。该集合被保存器用于初始化要保存的变量的默认列表,因此具有local指定具有在默认情况下不保存该变量的效果。

我看到的唯一一个使用它的代码库,这是limit_epochs

with ops.name_scope(name, "limit_epochs", [tensor]) as name: 
    zero64 = constant_op.constant(0, dtype=dtypes.int64) 
    epochs = variables.Variable(
     zero64, name="epochs", trainable=False, 
     collections=[ops.GraphKeys.LOCAL_VARIABLES]) 
+0

distributed_replicated模式下的分布式tensorflow使用局部变量。 – suiyuan2009

13

简短的回答的地方:在TF局部变量是与collections=[tf.GraphKeys.LOCAL_VARIABLES]创建的任何变量。例如:

e = tf.Variable(6, name='var_e', collections=[tf.GraphKeys.LOCAL_VARIABLES]) 

LOCAL_VARIABLES:是本地的每个 机变量的对象的子集。通常用于临时变量,如计数器。注意: 使用tf.contrib.framework.local_variable添加到此集合。

它们通常不会保存/恢复到检查点并用于临时或中间值。


龙答:这是混乱的根源,我也是如此。一开始我还以为是局部变量意味着同样的事情local variable in almost any programming language,但它是不一样的东西:

import tensorflow as tf 

def some_func(): 
    z = tf.Variable(1, name='var_z') 

a = tf.Variable(1, name='var_a') 
b = tf.get_variable('var_b', 2) 
with tf.name_scope('aaa'): 
    c = tf.Variable(3, name='var_c') 

with tf.variable_scope('bbb'): 
    d = tf.Variable(3, name='var_d') 

some_func() 
some_func() 

print [str(i.name) for i in tf.global_variables()] 
print [str(i.name) for i in tf.local_variables()] 

不管我试过,我一直收到全球唯一的变量:

['var_a:0', 'var_b:0', 'aaa/var_c:0', 'bbb/var_d:0', 'var_z:0', 'var_z_1:0'] 
[] 

对于tf.local_variables文档没有提供很多细节:

局部变量 - 每个过程变量,通常不会被保存/恢复 通道eckpoint并用于临时或中间值。例如, 它们可以用作度量计算的计数器或本机读取数据的 时期的数量。 local_variable()自动 向GraphKeys.LOCAL_VARIABLES添加新变量。这个便利功能 函数返回该集合的内容。


但阅读文档在tf.Variable类init方法,我发现,虽然创建一个变量,你可以提供你希望它是什么样的一个变量做通过分配的collections列表。

可能的收集元素列表是here。所以要创建一个局部变量,你需要做这样的事情。你会看到它的local_variables名单:

e = tf.Variable(6, name='var_e', collections=[tf.GraphKeys.LOCAL_VARIABLES]) 
print [str(i.name) for i in tf.local_variables()]