2016-03-09 43 views
3

我有一个ByteTensor,并希望抓住有1的指数。在numpy的,我可以做类似等同于np.where()的Lua Torch?

a = np.array([1,0,1,0,1]) 
return np.where(a) 

这将返回(array([0, 2, 4]),)。火炬中定义了这个功能吗?

(在我的具体情况,我想用这些指标来索引到几个不同的张量的对象,但它会是不错的知道如何在一般的做到这一点。)

回答

5

您可以使用torch.nonzero,如:

> a = torch.ByteTensor{1,0,1,0,1} 
> print(torch.nonzero(a))                       
1                             
3                             
5                             
[torch.LongTensor of size 3x1] 

如果你真的需要找到1-S只有你能链中的逻辑运算符:

> a = torch.ByteTensor{1,2,1,6,1} 
> a:eq(1):nonzero()