2017-11-18 146 views
0

让我们一起走,目前,我想索引一个张量,并将非零项更改为-1,将零项更改为1.但我不知道如何在TensorFlow中执行此操作。TensorFlow布尔索引

这里是我的代码:

y_[y_ != 0].assign(-1) 
y_[y_ == 0].assign(1) 

原因是TensorFlow似乎不支持布尔索引。我该如何解决它?

顺便说一句,它似乎boolean_mask不适合我,因为我不想返回一片y_,我只是想要y_改变它的值。

谢谢!

回答

0

您可以使用tf.cond()作为条件赋值。我给出了一个示例代码如下,

import tensorflow as tf 

x_= tf.Variable(5) #non-zero variable 
y_= tf.Variable(0) #variable euqals to 0 

y_ =tf.cond(tf.equal(y_,0),lambda :y_.assign(1),lambda :y_.assign(-1)) #assign 1 if variable equals to zero else -1 
x_ =tf.cond(tf.equal(x_,0),lambda :x_.assign(1),lambda :x_.assign(-1)) #assign 1 if variable equals to zero else -1 

sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 

with sess.as_default(): 
    print(y_.eval()) #prints 1 
    print(x_.eval()) #prints -1 

希望这有助于。

+0

谢谢!这可以为0级的情况做出杰出的工作!然而,就我而言,由于y_处于CNN流程中,因此在我真正运行图之前,其形状实际上是rank1(在我的情况中,它是(?,))。现在输出为'ValueError:形状必须为0,但'cond/Switch'(op:'Switch')的等级为1,输入形状为[?],[?]。',我该怎么办才能修复它? – Andre

+0

嗨,尼朋!我想我通过使用'tf.where'解决了这个问题,但是,谢谢你的努力! – Andre

+0

是的,我很抱歉,我没有说清楚!仍然感谢您的帮助! – Andre