2017-09-25 26 views
0

我有一本字典,我需要通过键从字典中获取元素。但是,关键需要设置为tf.placeholder并获取值。例如,我想获取Tensorflow中的字典元素

dict={'a':10,'b':20} 
key=tf.placeholder(tf.float32, shape=[1]). 
result=dict[key] 
sess=tf.InteractiveSession() 
sess.run(result, {key:'a'}) 

但它说:

KeyError:<tf.Tensor 'Placeholder_1:0' shape=(?, 1) dtype=float32> 

谁能帮我把从字典中的值,但提供feed_dict的关键?

回答

0

这是一个常见的与TF混淆。虽然一切看起来像Python代码,但实际发生的事情是,TF正在构建一个操作图,然后在python解释器上下文之外执行它。这意味着你不能从图中访问python对象。据我所知,TensorFlow在其张量中不支持字符串索引。所以你在做什么都行不通。你可以做的最接近的是使用数字索引。

data = tf.constant([10, 20]) 
key = tf.placeholder(tf.float32, shape=[1]). 
result = data[key] 
sess=tf.InteractiveSession() 
sess.run(result, {key: 1}) 
# result should now contain 20 
+0

感谢疯狂!如果我的字典是dict = {'a':loss1,'b':loss2}和loss1,2是张量,会怎么样?我想根据tf.placeholder评估loss1或loss2是'a'还是'b'。 –

+0

只需使用数字索引。损失= [loss1,loss2]并将0或1传递给Feed字典。另外,不要将你的变量'dict'与dict数据类型冲突,你可以通过这种方式获得一些非常不愉快的错误。 –

+0

如果你希望它是友好的,你可以将人类可读的键映射到单独的字典中的数字。 –

0

看来,这为我工作基础上疯狂的解决方案:

dict={'a':10,'b':20} 
data = tf.convert_to_tensor(dict.values()) 
key = tf.placeholder(tf.int32, shape=None) 
result = tf.gather(data, key) 
sess=tf.InteractiveSession() 
sess.run(result, {key: [0,1,0]})