2017-06-30 32 views
1

PyTorch的torch.transpose函数仅转换2D输入。文档是herePyTorch中没有N维转录

另一方面,Tensorflow的tf.transpose函数允许您转置张量N任意尺寸。

有人可以请解释为什么PyTorch不能/不能有N维转置功能?这是由于PyTorch中计算图构造的动态特性与Tensorflow的Define-then-Run范式相对应吗?

回答

3

它在pytorch中简单地被称为不同。 torch.Tensor.permute将允许您在pytorch中交换尺寸,例如TensorFlow中的tf.transpose。

作为如何你一个4D图像张量从NHWC转换为NCHW一个例子(未测试,因此可能包含bug):

>>> img_nhwc = torch.randn(10, 480, 640, 3) 
>>> img_nhwc.size() 
torch.Size([10, 480, 640, 3]) 
>>> img_nchw = img_nhwc.permute(0, 3, 1, 2) 
>>> img_nchw.size() 
torch.Size([10, 3, 480, 640])