2017-06-13 307 views
6

在numpy的,我可以做一个简单的矩阵乘法这样的:怎么办矩阵的点积在PyTorch

a = numpy.arange(2*3).reshape(3,2) 
b = numpy.arange(2).reshape(2,1) 
print(a) 
print(b) 
print(a.dot(b)) 

然而,当我想这跟PyTorch张量,这不起作用:

a = torch.Tensor([[1, 2, 3], [1, 2, 3]]).view(-1, 2) 
b = torch.Tensor([[2, 1]]).view(2, -1) 
print(a) 
print(a.size()) 

print(b) 
print(b.size()) 

print(torch.dot(a, b)) 

此代码引发以下错误:

RuntimeError: inconsistent tensor size at /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:503

任何想法如何简单的点积可以P中进行yTorch?

回答

13

您正在寻找

torch.mm(a,b) 

注意torch.dot()表现不同来np.dot()。关于什么是可取的here有一些讨论。具体而言,torch.dot()ab作为1D向量(不考虑它们的原始形状)并计算它们的内积。错误被抛出,因为这种行为使得你的a长度为6的矢量,而你的b长度为2的矢量;因此不能计算其内积。对于PyTorch中的矩阵乘法,请使用torch.mm()。 Numpy的np.dot()相比之下更加灵活;它计算一维数组的内积并为二维数组执行矩阵乘法。

5

大厦mexmex回答,如果你想要做一个矩阵乘法,你能做到这一点的方法有三种:

AB = A.mm(B) # computes A.B (matrix multiplication) 
# or 
AB = torch.mm(A, B) 
# or even simpler 
AB = A @ B # Python 3.5+ 

对于逐元素相乘,你可以简单地做(如果A和B具有相同的形状)

A * B # element-wise matrix multiplication (Hadamard product)